Gradient-Based Optimization
Optimize circuit parameters using JAX automatic differentiation.
Overview
Asgard leverages JAX's automatic differentiation for three major capabilities:
- Gradient-Based Parameter Optimization - Fast optimization of circuit parameters
- Jacobian Computation - Sensitivity analysis and input-output dependencies
- Gradient-Based Fitness Evaluation - Evaluate loss and gradients simultaneously
1. Parameter Optimization
What It Does
Automatically optimizes numerical parameters in circuits (like scalar(c) or const(c)) using gradient descent instead of discrete search methods.
Why It's Useful
- 10-100x faster than genetic/random search for parameterized circuits
- Deterministic convergence (vs. stochastic search)
- Scalable to many parameters
Example
from gimle.asgard.circuit.circuit import Circuit
from gimle.asgard.circuit.circuit_fitness import DatasetSample
from gimle.asgard.circuit.circuit_optimizer import GradientBasedOptimizer
from gimle.asgard.runtime.stream import Stream
import jax.numpy as jnp
# Create training data: y = 3.0 * x
dataset = []
for x_val in [1.0, 2.0, 3.0, 4.0, 5.0]:
input_stream = Stream(data=jnp.array([[x_val]]), dim_labels=(), chunk_size=1)
output_stream = Stream(data=jnp.array([[3.0 * x_val]]), dim_labels=(), chunk_size=1)
dataset.append(DatasetSample([input_stream], [output_stream]))
# Start with wrong parameter: scalar(1.0) instead of scalar(3.0)
initial_circuit = Circuit.from_string("scalar(1.0)")
# Optimize using gradients
optimizer = GradientBasedOptimizer(
learning_rate=0.01,
num_iterations=50,
tolerance=1e-6,
)
optimized_circuit, loss_history = optimizer.optimize(
initial_circuit,
dataset,
verbose=True,
)
# Result: scalar(2.986) - converged to target in ~20 iterations!
API Reference
class GradientBasedOptimizer:
def __init__(
self,
learning_rate: float = 0.01,
num_iterations: int = 100,
tolerance: float = 1e-6,
)
def optimize(
self,
circuit: Circuit,
dataset: List[DatasetSample],
verbose: bool = True,
) -> Tuple[Circuit, List[float]]
2. Jacobian Computation
What It Does
Computes the Jacobian matrix J[i,j] = ∂output[i]/∂input[j], showing how each output component depends on each input component.
Why It's Useful
- Sensitivity analysis - Which inputs matter most?
- Stability analysis - How sensitive is the output to perturbations?
- Feature importance - Which dimensions drive the output?
- Gradient fields - Visualize how outputs change with inputs
Example
from gimle.asgard.circuit.circuit import Circuit
from gimle.asgard.circuit.circuit_gradients import compute_jacobian
from gimle.asgard.runtime.stream import Stream
import jax.numpy as jnp
# Create circuit: scalar(2.0) multiplies input by 2
circuit = Circuit.from_string("scalar(2.0)")
# Create input
input_stream = Stream(data=jnp.array([[1.0, 2.0, 3.0]]), dim_labels=(), chunk_size=1)
# Compute Jacobian
jacobian = compute_jacobian(circuit, [input_stream], output_idx=0, input_idx=0)
# Result:
# [[2. 0. 0.]
# [0. 2. 0.]
# [0. 0. 2.]]
# Diagonal values = 2.0 (the scalar multiplier)
API Reference
def compute_jacobian(
circuit: Circuit,
inputs: List[Stream],
output_idx: int = 0,
input_idx: int = 0,
) -> jnp.ndarray
Parameters:
circuit: Circuit to analyzeinputs: Input streamsoutput_idx: Which output stream to compute Jacobian forinput_idx: Which input stream to differentiate w.r.t.
Returns:
- Jacobian matrix of shape
(output_size, input_size)
3. Gradient-Based Fitness Evaluation
What It Does
Evaluates both the loss AND the gradient of the loss with respect to circuit parameters simultaneously.
Why It's Useful
- Gradient-based optimization - Use gradients to guide search
- Inverse problems - Find parameters that produce desired outputs
- Optimal control - Find inputs that minimize/maximize objectives
- Hybrid optimization - Combine with discrete search methods
Example
from gimle.asgard.circuit.circuit import Circuit
from gimle.asgard.circuit.circuit_fitness import GradientBasedFitnessEvaluator, DatasetSample
import jax.numpy as jnp
# Create dataset
dataset = [...] # Same as optimization example
# Create evaluator
evaluator = GradientBasedFitnessEvaluator(dataset)
# Evaluate circuit with gradients
circuit = Circuit.from_string("scalar(1.0)")
fitness, gradients = evaluator.evaluate_with_gradients(circuit)
print(f"Loss: {fitness.loss}")
print(f"Gradients: {gradients}")
# Output:
# Loss: 44.0
# Gradients: {'scalar_0': -43.831}
# Negative gradient → should INCREASE parameter
API Reference
class GradientBasedFitnessEvaluator:
def __init__(self, dataset: List[DatasetSample])
def evaluate(self, circuit: Circuit) -> FitnessResult
def evaluate_with_gradients(
self, circuit: Circuit
) -> Tuple[FitnessResult, Optional[Dict[str, jnp.ndarray]]]
Performance Comparison
| Method | Evaluations | Speed | Deterministic |
|---|---|---|---|
| Gradient-Based Optimization | 20-100 | Fast | Yes |
| Genetic Algorithm | 1000-10000 | Slow | No |
| Random Search | 10000+ | Very Slow | No |
Implementation Notes
Current Limitations
-
Numerical Gradients: Currently using finite differences instead of JAX autodiff through circuit compilation
- Reason: JAX can't differentiate through Python control flow (circuit rebuilding)
- Impact: Slightly slower than pure JAX autodiff, but still much faster than discrete search
- Future: Could use
jax.custom_vjpfor true autodiff through compilation
-
Parameter Types: Only
scalar()andconst()parameters are optimizable- Other atomic operations (like
var()) are not yet parameterizable
- Other atomic operations (like
-
Circuit Complexity: Works best with circuits that have few parameters (< 10)
- Numerical gradient computation scales linearly with number of parameters
Helper Functions
Extract Parameters
from gimle.asgard.circuit.circuit_gradients import extract_parameters
params = extract_parameters(circuit)
# Returns CircuitParameters with param_values and param_locations
Rebuild Circuit with Parameters
from gimle.asgard.circuit.circuit_gradients import rebuild_circuit_with_parameters
new_circuit = rebuild_circuit_with_parameters(
circuit_template,
param_locations,
new_params
)
Complete Example
from gimle.asgard.circuit.circuit import Circuit
from gimle.asgard.circuit.circuit_fitness import DatasetSample
from gimle.asgard.circuit.circuit_optimizer import GradientBasedOptimizer
from gimle.asgard.runtime.stream import Stream
import jax.numpy as jnp
# 1. Create training dataset
def create_dataset():
"""Create dataset for y = 2x + 1"""
dataset = []
for x in jnp.linspace(-5, 5, 20):
y = 2.0 * x + 1.0
input_stream = Stream(data=jnp.array([[float(x)]]), dim_labels=(), chunk_size=1)
output_stream = Stream(data=jnp.array([[float(y)]]), dim_labels=(), chunk_size=1)
dataset.append(DatasetSample([input_stream], [output_stream]))
return dataset
dataset = create_dataset()
# 2. Create initial circuit with wrong parameters
# We want: 2x + 1, start with: 1x + 0
initial_circuit = Circuit.from_string(
"composition(monoidal(scalar(1.0), const(0.0)), add)"
)
# 3. Optimize
optimizer = GradientBasedOptimizer(
learning_rate=0.1,
num_iterations=100,
tolerance=1e-8,
)
optimized, history = optimizer.optimize(initial_circuit, dataset, verbose=True)
# 4. Verify
print(f"Optimized circuit: {optimized}")
# Should be approximately: composition(monoidal(scalar(2.0), const(1.0)), add)
# 5. Plot convergence
import matplotlib.pyplot as plt
plt.plot(history)
plt.xlabel("Iteration")
plt.ylabel("Loss")
plt.title("Optimization Convergence")
plt.yscale("log")
plt.show()
Next Steps
- Jacobian Computation - Sensitivity analysis
- Core Concepts - Basic circuit usage
- API Reference - Complete API documentation