Skip to content

Commit fbef12e

Browse files
committed
also DCE
1 parent 4880e2b commit fbef12e

File tree

1 file changed

+5
-0
lines changed
  • jax/_src/interpreters

1 file changed

+5
-0
lines changed

jax/_src/interpreters/ad.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,11 @@ def new_arg(trace, primal_aval, nz):
194194
tangent_trace.invalidate()
195195
if attrs_tracked:
196196
raise NotImplementedError("TODO: attrs")
197+
tangent_jaxpr, used_consts, _ = pe.dce_jaxpr_consts(
198+
tangent_jaxpr, [True] * len(tangent_jaxpr.outvars),
199+
[False] * len(tangent_jaxpr.constvars) + [True] * len(tangent_jaxpr.invars))
200+
tangent_consts = [c for c, used in zip(tangent_consts, used_consts) if used]
201+
197202
residuals_and_primals = (*tangent_consts, *out_primals)
198203
residuals_and_primals = map(primal_trace.to_jaxpr_tracer, residuals_and_primals)
199204
primal_jaxpr, primal_consts, attrs_tracked = primal_trace.to_jaxpr(residuals_and_primals, debug_info)

0 commit comments

Comments
 (0)