Runtime
The runtime executes circuits under different mathematical interpretations called calculi. The same circuit can be evaluated as Taylor series, stochastic processes, or discrete sequences.
The Three Calculi
RealCalculus
Taylor series expansion for deterministic computation. Use for ODEs, symbolic differentiation, and classical analysis.
StochasticCalculus
Monte Carlo simulation of stochastic differential equations. Use for finance, physics simulations, and SDEs with Itô or Stratonovich semantics.
DiscreteCalculus
Discrete sequences and finite differences. Use for time series, difference equations, and signal processing.
Execution Flow
Equation
│ compile_equation_to_circuit()
▼
Circuit (Parse Tree)
│ JAXCircuitCompiler.compile()
▼
├─ has_trace + has_register?
│ → StreamCalculusEvaluator (coefficient-by-coefficient)
│
├─ has_stochastic_register?
│ → StochasticCircuitExecutor (Monte Carlo paths)
│
└─ otherwise
→ Standard JAX compilation (fixed-point for trace)
│
▼
Stream (coefficient representation)
│ StreamEvaluator.evaluate() with Calculus
▼
┌───┴────┬──────────────┐
▼ ▼ ▼
RealCalculus Stochastic Discrete
Taylor series SDE paths Sequences
Streams
Streams represent functions as coefficient expansions in multi-dimensional arrays:
from gimle.asgard.runtime.stream import Stream
import jax.numpy as jnp
# Scalar stream: 3 time steps
stream = Stream(data=jnp.array([1.0, 2.0, 3.0]), dim_labels=(), chunk_size=3)
# 1D spatial: f(x) = 2 + 3x + x^2/2
stream = Stream(
data=jnp.array([[2.0, 3.0, 1.0]]),
dim_labels=("x",),
chunk_size=1
)
# 2D field: spatial grid
stream = Stream(
data=jnp.zeros((100, 64, 64)),
dim_labels=("x", "y"),
chunk_size=100
)
Key Properties:
data: JAX array of coefficients — shape is(chunk_size, dim1, dim2, ...)dim_labels: Tuple of dimension names (e.g.,("x",)or("x", "t"))chunk_size: Number of samples per chunkshape,spatial_shape,ndim,spatial_ndim: Shape introspection
Stream State
StreamState carries persistent state across chunks for stateful operations:
from gimle.asgard.runtime.stream import StreamState
state = StreamState.empty()
It tracks three kinds of state:
- Boundaries: Carry-forward values for register/deregister across chunks
- Brownian increments: dW values for stochastic operations
- Metadata: Time step (
dt) and other runtime configuration
Both Stream and StreamState are registered as JAX pytrees, so they flow through jax.jit, jax.grad, and jax.vmap.
RealCalculus (Deterministic)
Represents functions as Taylor series with factorial scaling:
$$f(x) = c_0 + c_1 \cdot (x - x_0) + c_2 \cdot \frac{(x - x_0)^2}{2!} + c_3 \cdot \frac{(x - x_0)^3}{3!} + \cdots$$
from gimle.asgard.runtime.stream_evaluator import StreamEvaluator, RealCalculus
# Coefficients: [2, 3, 1] means f(x) = 2 + 3x + x^2/2
stream = Stream(
data=jnp.array([[2.0, 3.0, 1.0]]),
dim_labels=("x",),
chunk_size=1
)
# Create evaluator
evaluator = StreamEvaluator(stream, {"x": RealCalculus(center=0.0)})
# Evaluate at specific points
x_points = jnp.array([0.0, 1.0, 2.0])
values = evaluator.evaluate(x=x_points)
for x, val in zip(x_points, values):
print(f"f({x}) = {val:.2f}")
# f(0.0) = 2.00
# f(1.0) = 5.50 (2 + 3 + 0.5)
# f(2.0) = 10.00 (2 + 6 + 2)
StochasticCalculus
Simulates stochastic differential equations of the form $dX = \mu(X,t),dt + \sigma(X,t),dW$:
from gimle.asgard.runtime.stream_evaluator import StochasticCalculus
calculus = StochasticCalculus(
drift=0.0,
diffusion=1.0,
n_paths=1000,
dt=0.01,
seed=42,
interpretation="ito" # or "stratonovich"
)
Simulating SDEs
# Ornstein-Uhlenbeck process: dX = -theta(X - mu)dt + sigma dW
theta, mu, sigma, x0 = 0.5, 1.0, 0.3, 2.0
calculus = StochasticCalculus(n_paths=10000, dt=0.01)
paths = calculus.simulate_sde(
x0=x0,
drift_fn=lambda x, t: -theta * (x - mu),
diffusion_fn=lambda x, t: sigma,
t_start=0.0,
t_end=5.0,
n_steps=500
)
Coupled SDEs
Simulate systems of correlated stochastic differential equations:
# Heston model: coupled stock price and volatility
# dS = mu*S*dt + sqrt(V)*S*dW1
# dV = kappa*(theta - V)*dt + xi*sqrt(V)*dW2
# with correlation rho between W1 and W2
paths = calculus.simulate_coupled_sde(
x0_vector=[100.0, 0.04],
drift_fns=[
lambda x, t: mu * x[0],
lambda x, t: kappa * (theta - x[1]),
],
diffusion_fns=[
lambda x, t: jnp.sqrt(x[1]) * x[0],
lambda x, t: xi * jnp.sqrt(x[1]),
],
t_start=0.0,
t_end=1.0,
correlation=[[1.0, rho], [rho, 1.0]],
)
# Returns shape: (2, n_paths, n_time_steps)
Itô vs Stratonovich
The interpretation parameter controls SDE discretization:
Itô (default) — Euler-Maruyama method:
X(t+dt) = X(t) + drift(X(t), t) * dt + diffusion(X(t), t) * dW
- Single-pass per step (faster)
- Standard in mathematical finance
- Use
interpretation="ito"
Stratonovich — Heun predictor-corrector method:
K₁ = drift(X, t)*dt + diffusion(X, t)*dW
X̃ = X + K₁
K₂ = drift(X̃, t+dt)*dt + diffusion(X̃, t+dt)*dW (same dW)
X(t+dt) = X + 0.5 * (K₁ + K₂)
- Two-pass per step (more accurate, ~2x cost)
- Standard in physics (preserves chain rule)
- Use
interpretation="stratonovich"
| Scenario | Interpretation |
|---|---|
| Finance (Black-Scholes, interest rates) | Itô |
| Physics (Langevin, Brownian motion) | Stratonovich |
| Fast prototyping / lower accuracy needed | Itô |
| Higher accuracy needed | Stratonovich |
Circuit-Driven SDEs
SDEs can also be defined as equations and compiled to circuits:
from gimle.asgard.equation.equation import Equation
from gimle.asgard.compile.compiler import compile_equation_to_circuit
eq = Equation.from_string("sde($drift, $sigma, t) = Y")
circuit, metadata = compile_equation_to_circuit(eq)
The compiled circuit contains stochastic_register for the diffusion term. At runtime, the StochasticCircuitExecutor detects this and routes to Monte Carlo simulation automatically.
In YAML examples:
equation: "sde($drift, $sigma, t) = Y"
stochastic:
calculus: stratonovich
n_paths: 1000
dt: 0.01
seed: 42
params:
drift: -0.5
sigma: 0.3
DiscreteCalculus (Sequences)
Represents discrete-time sequences:
from gimle.asgard.runtime.stream_evaluator import DiscreteCalculus
# Sequence: f(n) = n^2 for n = 0, 1, 2, 3
stream = Stream(
data=jnp.array([[0.0, 1.0, 4.0, 9.0]]),
dim_labels=("n",),
chunk_size=1
)
evaluator = StreamEvaluator(stream, {"n": DiscreteCalculus()})
result = evaluator.evaluate(n=2.0)
print(f"f(2) = {result}") # 4.0
Operations:
- Register (Integration): Cumulative sum
[a, b, c] -> [0, a, a+b] - Deregister (Differentiation): Finite differences
[a, b, c] -> [b-a, c-b]
Mixed Calculi
Use different calculi for different dimensions:
# Discrete time, continuous space
evaluator = StreamEvaluator(
stream,
calculi={
"n": DiscreteCalculus(),
"x": RealCalculus(center=0.0)
}
)
result = evaluator.evaluate(n=5.0, x=1.5)
Stream Calculus Evaluator
For circuits with trace and register (i.e., differential equations), the runtime uses a specialized coefficient-by-coefficient evaluator rather than standard JAX array operations:
Circuit with trace + register
│
│ StreamCalculusEvaluator
▼
Coefficient 0 computed
Coefficient 1 computed (using coefficient 0)
Coefficient 2 computed (using coefficients 0, 1)
...
This is necessary because each Taylor coefficient depends on previously computed ones via the feedback loop. The evaluator:
- Creates a feedback buffer for the trace operator
- Computes one coefficient per iteration
- Register reads from the buffer at index
i-1 - Deregister reads from the buffer at index
i+1 - Multiplication uses Cauchy products of previously computed coefficients
This happens automatically — the JAXCircuitCompiler routes to StreamCalculusEvaluator when it detects has_trace and has_register on the circuit.
Differentiable Circuits
Every circuit compiles to JAX, so gradients flow through the entire simulation:
from gimle.asgard.runtime.jax_compiler import JAXCircuitCompiler
compiler = JAXCircuitCompiler()
# Make scalar parameters differentiable
diff_fn, params = compiler.compile_differentiable(circuit, param_locations)
# Now jax.grad works through the full pipeline
grad_fn = jax.grad(lambda p: loss(diff_fn(p, inputs, state)))
gradients = grad_fn(params)
For circuits with trace operators, the backward pass uses the Implicit Function Theorem rather than unrolling through iterations. This gives correct, efficient gradients even for deeply recursive computations.
Two-Phase Execution
Asgard separates compilation and execution for efficiency:
from gimle.asgard.circuit.circuit import Circuit
from gimle.asgard.runtime.stream import Stream, StreamState
# Phase 1: Compile once (slow)
circuit = Circuit.from_string("composition(register(x), deregister(x))")
# Phase 2: Execute many times (fast, JIT-compiled)
for input_data in dataset:
input_stream = Stream(data=input_data, dim_labels=("x",), chunk_size=1)
outputs, state = circuit.execute([input_stream], StreamState())
Benefits:
- Compile once, run many times
- JAX JIT optimization
- GPU/TPU acceleration
Performance Comparison
| Calculus | Speed | Memory | Accuracy |
|---|---|---|---|
| RealCalculus | Fast | Low | Exact (up to truncation) |
| StochasticCalculus | Slow | High | Statistical ($1/\sqrt{n}$) |
| DiscreteCalculus | Fast | Low | Exact |
Optimization Tips:
- Use fewer paths for testing (
n_paths=100) - Use more paths for production (
n_paths=10000+) - Leverage JAX's
vmapfor parallelization - Use larger
dtfor faster stochastic simulation (trade-off: accuracy) - Enable JIT compilation:
compiler.compile(circuit, jit=True)