Skip to content

Warn when closure-captured tensors cause recompilation storms#1715

Open
jeffreyksmithjr wants to merge 4 commits intoelixir-nx:mainfrom
jeffreyksmithjr:closure-capture-recompilation-warning
Open

Warn when closure-captured tensors cause recompilation storms#1715
jeffreyksmithjr wants to merge 4 commits intoelixir-nx:mainfrom
jeffreyksmithjr:closure-capture-recompilation-warning

Conversation

@jeffreyksmithjr
Copy link

@jeffreyksmithjr jeffreyksmithjr commented Mar 20, 2026

Summary

When a non-scalar tensor with 10 or more elements is used as a constant inside a defn expression, this change emits a Logger.warning during graph construction. The warning explains that constants are embedded into the computation graph and will force recompilation if their values change between calls.

This lives entirely in Nx core (Nx.Defn.Expr and Nx.Defn.Compiler). No changes to EXLA.

The original version of this PR used an ETS-based recompilation counter in EXLA. Based on review feedback, I moved the check upstream to where constants actually enter the graph. This catches all sources of tensor constants (closure captures, deftransform, anything that flows through to_expr), not just the closure capture case.

How it works

In Nx.Defn.Expr.to_expr/1, when a BinaryBackend tensor with non-scalar shape and 10+ elements is converted to a :tensor expression node, the warning fires. A process dictionary flag limits it to once per trace. The flag is saved and restored in Nx.Defn.Compiler.runtime_fun/3 so that nested jit calls do not interfere with the outer trace's state.

The 10-element threshold avoids warning on small intentional constants like masks or indices. The warning message notes that intentionally constant tensors can be safely ignored.

Test plan

  • Warning fires when a large tensor is used as a constant in a closure
  • No warning for scalar constants
  • No warning for small tensor constants (below threshold)
  • No warning when tensors are passed as function arguments
  • Only one warning per trace, even with multiple large constants
  • Full mix test passes in both nx and exla

Closes #1714

Closures passed to jit/2 or value_and_grad/2 that capture tensors
bake those values as graph constants. When the captured values change
every iteration, EXLA recompiles from scratch each time, eventually
exhausting resources. This adds a lightweight runtime check that
detects rapid recompilation of the same source-level function and
emits a Logger.warning with actionable guidance.

The check uses an ETS table with atomic operations (update_counter/4
for counting, select_replace/2 as CAS for the warned flag) to stay
concurrency-safe without adding any GenServer or supervision tree
changes.

Also enriches telemetry metadata with function_identity and
compilation_status fields.

Closes elixir-nx#1714
@josevalim
Copy link
Contributor

I am not sure how this solution scales overall. It will catch simple examples like this but, in actual code, it is more likely that this will arise from deftransform and other hardcoded constants in the graph. Can you please expand a bit more on how you ran into this issue in a production-like example?

Comment on lines +488 to +489
function_identity: fun_identity(key),
compilation_status: if(evaled, do: :compiled, else: :cache_hit)
Copy link
Contributor

@josevalim josevalim Mar 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this change is good and we could ship them separately of when/how we detect recompilation.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I simplified this out, but could restore it or do it elsewhere, if helpful.

@jeffreyksmithjr
Copy link
Author

Can you please expand a bit more on how you ran into this issue in a production-like example?

Sure. It's just in some private (pre-publication research) code.

Roughly, I'm running a continuous RL training loop inside a GenServer. Every couple of seconds, it wakes up, pulls a randomized batch of experiences from an ETS replay buffer, and applies gradients.

A somewhat scrubbed version of the code I'm running is:

def handle_info(:train_tick, state) do
  {inputs, targets} = ExperienceBuffer.sample_batch(32)

  {loss, grads} = Nx.Defn.value_and_grad(state.model_state, fn params ->
    preds = predict_fn.(params, inputs)
    Axon.Losses.mse(targets, preds) # <--- targets and inputs captured from outer scope
  end)

  # apply updates...
end

Because inputs and targets are captured, EXLA bakes them in as graph constants. Since the batch data is completely different every 2 seconds, I get a 100% cache miss rate.

Results in cryptic deleted or donated buffer errors when Elixir GC'd the old batches but EXLA's cache still held the executables.

And yep, I was on the fence about the ETS tracker. Figured I'd just put it up for comment, since it seems like a possible fix.

Happy to refactor if you think that I'm partially on the right track and need to simplify.

@polvalente
Copy link
Contributor

Does the cache miss still happen if value and grad is wrapped into a defn and you JIT that instead?

@jeffreyksmithjr
Copy link
Author

Does the cache miss still happen if value and grad is wrapped into a defn and you JIT that instead?

Nope. Just trying to help the user solve this mystery themselves when they should know that but don't.

…acker

Instead of detecting recompilation storms after the fact in EXLA,
warn at graph construction time in Nx.Defn.Expr when a non-scalar
tensor (10+ elements) is embedded as a constant. This catches all
sources of problematic constants (closure captures, deftransform,
any code path through to_expr) rather than only closure captures.

The warning fires once per trace via a process dictionary flag that
is saved/restored in runtime_fun to handle nested jit correctly.

Removes the ETS-based recompilation tracker from EXLA (table
creation, counter logic, CAS warning dedup). Keeps the telemetry
metadata enrichment (function_identity, compilation_status).
Per review feedback: key is always a function, so the compiled vs
cache_hit distinction is redundant in the metadata.
Per review feedback: new_uniq is a hash, not a unique identifier,
so it cannot reliably be used for identification. Reverts telemetry
metadata to just the key field.
@jeffreyksmithjr
Copy link
Author

Maybe this approach is better? It at least seems closer to the point of the user's needed action. The choice of 10 elements is arbitrary but would seem to be likely to help people making the same sort of error I made in my code.

Definitely less added infra to achieve this approach.

@josevalim
Copy link
Contributor

I think this approach will have more false positives. Tensor constants may appear for other reasons, such as inlining code inside deftransform based on options and so on. I think the previous approach was fine, I would just simplify the logic to not be time based.

This doesn't really address my mysterious failure concern. And I could imagine some scenarios when the 10 constant was reached reasonably over a larger time frame and not my specific concern of rapid reoccurence.

Well, they are all heuristics. Someone could also have the same issue as you, but it is hit over a slightly longer window, and now they are not notified about it. That's why I would prefer to be straight-forward about it. For all intents and pusposes, you shouldn't be compiling the same anonymous function over and over again. You can with different shapes (so we need to be careful about it and check said case).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Closure-captured tensors cause silent recompilation storms and crashes

3 participants