JAX & JIT Compatibility

When JIT compilation helps, when it breaks, and how to work around limitations.

Overview

Asgard compiles circuits to JAX functions. By default JIT is disabled:

from gimle.asgard.runtime.jax_compiler import JAXCircuitCompiler

compiler = JAXCircuitCompiler()
fn = compiler.compile(circuit, jit=False)   # default
fn = compiler.compile(circuit, jit=True)    # enable JIT

JIT adds overhead on the first call (tracing + compilation) but significantly speeds up subsequent calls with the same input shapes.

Compatibility Matrix

Circuit Type JIT Compatible Notes
Simple atomics (add, scalar, split, ...) Yes Pure JAX operations.
Composition and monoidal (no trace) Yes Array-based, no control flow.
Trace with register No Coefficient-by-coefficient evaluation uses Python loops.
Stochastic simulation Partial Time-stepping loop runs outside JIT; individual steps are JIT-safe.
Parameter optimization No Circuit rebuilding requires Python control flow.

Rule of thumb: If the circuit has no trace operator, jit=True is safe and beneficial. If it has trace, leave JIT disabled.

Why Trace Breaks JIT

The trace operator solves feedback loops using either:

  1. Stream calculus (coefficient-by-coefficient) — builds results in a Python for loop where each iteration depends on the previous. JAX cannot trace through this because the loop body has data-dependent control flow.

  2. Fixed-point iteration — converges by checking if diff < tolerance: break. This condition depends on a JAX value, which JIT cannot evaluate at trace time.

Both paths require Python control flow that depends on runtime values — exactly what jax.jit prohibits.

When to Enable JIT

Enable JIT when:

Skip JIT when:

Patterns Used in the Codebase

Pre-computed Coefficients

Transcendental functions (exp, sin, cos) pre-compute Taylor coefficients at compile time and store them as constants. This is JIT-safe because the values are captured in closures:

# Inside the compiler
coeffs = exp_coefficients(max_degree)   # computed once
def exp_fn(inputs, state):
    return Stream(data=coeffs, ...), state  # constant lookup

jnp.where for Conditional Logic

Instead of Python if/else (which breaks JIT), use jnp.where for element-wise conditionals:

# JIT-safe
result = jnp.where(condition, value_if_true, value_if_false)

# NOT JIT-safe
if condition:
    result = value_if_true

Stochastic Simulation Pattern

The stochastic executor runs the time-stepping loop in Python and compiles only the per-step circuit evaluation:

# Outer loop: Python (not JIT'd)
for t in time_steps:
    # Inner step: JIT-compiled JAX function
    output = compiled_circuit_fn(inputs, state)
    x_new = euler_maruyama_step(x, drift, diffusion, dW)

This gives JIT benefits for the expensive numerical work while keeping the control flow in Python.

Troubleshooting

"TracerBoolConversionError"

jax.errors.TracerBoolConversionError: Attempted boolean conversion of Traced<...>

Cause: A Python if statement depends on a JAX-traced value inside a JIT-compiled function.

Fix: Disable JIT (jit=False) or replace the conditional with jnp.where.

"ConcretizationTypeError"

jax.errors.ConcretizationTypeError: Abstract tracer value encountered

Cause: Code tries to use a JAX array as a concrete Python value (e.g. indexing, len(), int()) inside JIT.

Fix: Move the concrete operation outside the JIT boundary or use JAX equivalents (jnp.take, static shapes).

Slow First Call

JIT compilation happens on the first call and is cached for subsequent calls with matching input shapes. If input shapes change frequently, each new shape triggers recompilation.

Fix: Pad inputs to a fixed shape or batch calls with identical shapes.

NaN in Stochastic Results

NaN values in stochastic simulations typically indicate numerical instability rather than a JIT issue. The stochastic executor checks for NaN periodically and warns.

Fix: Reduce dt, add clamping for state variables, or use Stratonovich interpretation (interpretation="stratonovich") for better stability with multiplicative noise.

Extending the System

If you add new circuit operations:

  1. Keep atomic functions pure — use only JAX primitives (jnp.*), no Python control flow depending on array values.
  2. Use jnp.where instead of if/else for conditional logic.
  3. Pre-compute constants at compile time, not at evaluation time.
  4. Test with jit=True on simple circuits to catch tracer issues early.