Warn when closure-captured tensors cause recompilation storms#1715
Warn when closure-captured tensors cause recompilation storms#1715jeffreyksmithjr wants to merge 4 commits intoelixir-nx:mainfrom
Conversation
Closures passed to jit/2 or value_and_grad/2 that capture tensors bake those values as graph constants. When the captured values change every iteration, EXLA recompiles from scratch each time, eventually exhausting resources. This adds a lightweight runtime check that detects rapid recompilation of the same source-level function and emits a Logger.warning with actionable guidance. The check uses an ETS table with atomic operations (update_counter/4 for counting, select_replace/2 as CAS for the warned flag) to stay concurrency-safe without adding any GenServer or supervision tree changes. Also enriches telemetry metadata with function_identity and compilation_status fields. Closes elixir-nx#1714
|
I am not sure how this solution scales overall. It will catch simple examples like this but, in actual code, it is more likely that this will arise from |
exla/lib/exla/defn.ex
Outdated
| function_identity: fun_identity(key), | ||
| compilation_status: if(evaled, do: :compiled, else: :cache_hit) |
There was a problem hiding this comment.
I think this change is good and we could ship them separately of when/how we detect recompilation.
There was a problem hiding this comment.
I simplified this out, but could restore it or do it elsewhere, if helpful.
Sure. It's just in some private (pre-publication research) code. Roughly, I'm running a continuous RL training loop inside a GenServer. Every couple of seconds, it wakes up, pulls a randomized batch of experiences from an ETS replay buffer, and applies gradients. A somewhat scrubbed version of the code I'm running is: Because inputs and targets are captured, EXLA bakes them in as graph constants. Since the batch data is completely different every 2 seconds, I get a 100% cache miss rate. Results in cryptic deleted or donated buffer errors when Elixir GC'd the old batches but EXLA's cache still held the executables. And yep, I was on the fence about the ETS tracker. Figured I'd just put it up for comment, since it seems like a possible fix. Happy to refactor if you think that I'm partially on the right track and need to simplify. |
|
Does the cache miss still happen if value and grad is wrapped into a defn and you JIT that instead? |
Nope. Just trying to help the user solve this mystery themselves when they should know that but don't. |
…acker Instead of detecting recompilation storms after the fact in EXLA, warn at graph construction time in Nx.Defn.Expr when a non-scalar tensor (10+ elements) is embedded as a constant. This catches all sources of problematic constants (closure captures, deftransform, any code path through to_expr) rather than only closure captures. The warning fires once per trace via a process dictionary flag that is saved/restored in runtime_fun to handle nested jit correctly. Removes the ETS-based recompilation tracker from EXLA (table creation, counter logic, CAS warning dedup). Keeps the telemetry metadata enrichment (function_identity, compilation_status).
Per review feedback: key is always a function, so the compiled vs cache_hit distinction is redundant in the metadata.
Per review feedback: new_uniq is a hash, not a unique identifier, so it cannot reliably be used for identification. Reverts telemetry metadata to just the key field.
|
Maybe this approach is better? It at least seems closer to the point of the user's needed action. The choice of 10 elements is arbitrary but would seem to be likely to help people making the same sort of error I made in my code. Definitely less added infra to achieve this approach. |
|
I think this approach will have more false positives. Tensor constants may appear for other reasons, such as inlining code inside
Well, they are all heuristics. Someone could also have the same issue as you, but it is hit over a slightly longer window, and now they are not notified about it. That's why I would prefer to be straight-forward about it. For all intents and pusposes, you shouldn't be compiling the same anonymous function over and over again. You can with different shapes (so we need to be careful about it and check said case). |
Summary
When a non-scalar tensor with 10 or more elements is used as a constant inside a defn expression, this change emits a
Logger.warningduring graph construction. The warning explains that constants are embedded into the computation graph and will force recompilation if their values change between calls.This lives entirely in Nx core (
Nx.Defn.ExprandNx.Defn.Compiler). No changes to EXLA.The original version of this PR used an ETS-based recompilation counter in EXLA. Based on review feedback, I moved the check upstream to where constants actually enter the graph. This catches all sources of tensor constants (closure captures, deftransform, anything that flows through
to_expr), not just the closure capture case.How it works
In
Nx.Defn.Expr.to_expr/1, when aBinaryBackendtensor with non-scalar shape and 10+ elements is converted to a:tensorexpression node, the warning fires. A process dictionary flag limits it to once per trace. The flag is saved and restored inNx.Defn.Compiler.runtime_fun/3so that nested jit calls do not interfere with the outer trace's state.The 10-element threshold avoids warning on small intentional constants like masks or indices. The warning message notes that intentionally constant tensors can be safely ignored.
Test plan
mix testpasses in bothnxandexlaCloses #1714