JAX & JIT Compatibility
When JIT compilation helps, when it breaks, and how to work around limitations.
Overview
Asgard compiles circuits to JAX functions. By default JIT is disabled:
from gimle.asgard.runtime.jax_compiler import JAXCircuitCompiler
compiler = JAXCircuitCompiler()
fn = compiler.compile(circuit, jit=False) # default
fn = compiler.compile(circuit, jit=True) # enable JIT
JIT adds overhead on the first call (tracing + compilation) but significantly speeds up subsequent calls with the same input shapes.
Compatibility Matrix
| Circuit Type | JIT Compatible | Notes |
|---|---|---|
Simple atomics (add, scalar, split, ...) |
Yes | Pure JAX operations. |
| Composition and monoidal (no trace) | Yes | Array-based, no control flow. |
| Trace with register | No | Coefficient-by-coefficient evaluation uses Python loops. |
| Stochastic simulation | Partial | Time-stepping loop runs outside JIT; individual steps are JIT-safe. |
| Parameter optimization | No | Circuit rebuilding requires Python control flow. |
Rule of thumb: If the circuit has no trace operator, jit=True is safe
and beneficial. If it has trace, leave JIT disabled.
Why Trace Breaks JIT
The trace operator solves feedback loops using either:
-
Stream calculus (coefficient-by-coefficient) — builds results in a Python
forloop where each iteration depends on the previous. JAX cannot trace through this because the loop body has data-dependent control flow. -
Fixed-point iteration — converges by checking
if diff < tolerance: break. This condition depends on a JAX value, which JIT cannot evaluate at trace time.
Both paths require Python control flow that depends on runtime values — exactly
what jax.jit prohibits.
When to Enable JIT
Enable JIT when:
- The circuit has no trace operator
- You will call the compiled function many times with the same input shapes
- Inputs are large arrays where JAX's XLA compilation pays off
Skip JIT when:
- The circuit contains trace or register (used together)
- You are running a one-shot simulation
- You are debugging and want clear Python stack traces
Patterns Used in the Codebase
Pre-computed Coefficients
Transcendental functions (exp, sin, cos) pre-compute Taylor coefficients
at compile time and store them as constants. This is JIT-safe because the values
are captured in closures:
# Inside the compiler
coeffs = exp_coefficients(max_degree) # computed once
def exp_fn(inputs, state):
return Stream(data=coeffs, ...), state # constant lookup
jnp.where for Conditional Logic
Instead of Python if/else (which breaks JIT), use jnp.where for
element-wise conditionals:
# JIT-safe
result = jnp.where(condition, value_if_true, value_if_false)
# NOT JIT-safe
if condition:
result = value_if_true
Stochastic Simulation Pattern
The stochastic executor runs the time-stepping loop in Python and compiles only the per-step circuit evaluation:
# Outer loop: Python (not JIT'd)
for t in time_steps:
# Inner step: JIT-compiled JAX function
output = compiled_circuit_fn(inputs, state)
x_new = euler_maruyama_step(x, drift, diffusion, dW)
This gives JIT benefits for the expensive numerical work while keeping the control flow in Python.
Troubleshooting
"TracerBoolConversionError"
jax.errors.TracerBoolConversionError: Attempted boolean conversion of Traced<...>
Cause: A Python if statement depends on a JAX-traced value inside a
JIT-compiled function.
Fix: Disable JIT (jit=False) or replace the conditional with jnp.where.
"ConcretizationTypeError"
jax.errors.ConcretizationTypeError: Abstract tracer value encountered
Cause: Code tries to use a JAX array as a concrete Python value (e.g.
indexing, len(), int()) inside JIT.
Fix: Move the concrete operation outside the JIT boundary or use JAX
equivalents (jnp.take, static shapes).
Slow First Call
JIT compilation happens on the first call and is cached for subsequent calls with matching input shapes. If input shapes change frequently, each new shape triggers recompilation.
Fix: Pad inputs to a fixed shape or batch calls with identical shapes.
NaN in Stochastic Results
NaN values in stochastic simulations typically indicate numerical instability rather than a JIT issue. The stochastic executor checks for NaN periodically and warns.
Fix: Reduce dt, add clamping for state variables, or use Stratonovich
interpretation (interpretation="stratonovich") for better stability with
multiplicative noise.
Extending the System
If you add new circuit operations:
- Keep atomic functions pure — use only JAX primitives (
jnp.*), no Python control flow depending on array values. - Use
jnp.whereinstead ofif/elsefor conditional logic. - Pre-compute constants at compile time, not at evaluation time.
- Test with
jit=Trueon simple circuits to catch tracer issues early.