Gradient-Based Optimization

Optimize circuit parameters using JAX automatic differentiation.

Overview

Asgard leverages JAX's automatic differentiation for three major capabilities:

  1. Gradient-Based Parameter Optimization - Fast optimization of circuit parameters
  2. Jacobian Computation - Sensitivity analysis and input-output dependencies
  3. Gradient-Based Fitness Evaluation - Evaluate loss and gradients simultaneously

1. Parameter Optimization

What It Does

Automatically optimizes numerical parameters in circuits (like scalar(c) or const(c)) using gradient descent instead of discrete search methods.

Why It's Useful

Example

from gimle.asgard.circuit.circuit import Circuit
from gimle.asgard.circuit.circuit_fitness import DatasetSample
from gimle.asgard.circuit.circuit_optimizer import GradientBasedOptimizer
from gimle.asgard.runtime.stream import Stream
import jax.numpy as jnp

# Create training data: y = 3.0 * x
dataset = []
for x_val in [1.0, 2.0, 3.0, 4.0, 5.0]:
    input_stream = Stream(data=jnp.array([[x_val]]), dim_labels=(), chunk_size=1)
    output_stream = Stream(data=jnp.array([[3.0 * x_val]]), dim_labels=(), chunk_size=1)
    dataset.append(DatasetSample([input_stream], [output_stream]))

# Start with wrong parameter: scalar(1.0) instead of scalar(3.0)
initial_circuit = Circuit.from_string("scalar(1.0)")

# Optimize using gradients
optimizer = GradientBasedOptimizer(
    learning_rate=0.01,
    num_iterations=50,
    tolerance=1e-6,
)

optimized_circuit, loss_history = optimizer.optimize(
    initial_circuit,
    dataset,
    verbose=True,
)

# Result: scalar(2.986) - converged to target in ~20 iterations!

API Reference

class GradientBasedOptimizer:
    def __init__(
        self,
        learning_rate: float = 0.01,
        num_iterations: int = 100,
        tolerance: float = 1e-6,
    )

    def optimize(
        self,
        circuit: Circuit,
        dataset: List[DatasetSample],
        verbose: bool = True,
    ) -> Tuple[Circuit, List[float]]

2. Jacobian Computation

What It Does

Computes the Jacobian matrix J[i,j] = ∂output[i]/∂input[j], showing how each output component depends on each input component.

Why It's Useful

Example

from gimle.asgard.circuit.circuit import Circuit
from gimle.asgard.circuit.circuit_gradients import compute_jacobian
from gimle.asgard.runtime.stream import Stream
import jax.numpy as jnp

# Create circuit: scalar(2.0) multiplies input by 2
circuit = Circuit.from_string("scalar(2.0)")

# Create input
input_stream = Stream(data=jnp.array([[1.0, 2.0, 3.0]]), dim_labels=(), chunk_size=1)

# Compute Jacobian
jacobian = compute_jacobian(circuit, [input_stream], output_idx=0, input_idx=0)

# Result:
# [[2. 0. 0.]
#  [0. 2. 0.]
#  [0. 0. 2.]]
# Diagonal values = 2.0 (the scalar multiplier)

API Reference

def compute_jacobian(
    circuit: Circuit,
    inputs: List[Stream],
    output_idx: int = 0,
    input_idx: int = 0,
) -> jnp.ndarray

Parameters:

Returns:

3. Gradient-Based Fitness Evaluation

What It Does

Evaluates both the loss AND the gradient of the loss with respect to circuit parameters simultaneously.

Why It's Useful

Example

from gimle.asgard.circuit.circuit import Circuit
from gimle.asgard.circuit.circuit_fitness import GradientBasedFitnessEvaluator, DatasetSample
import jax.numpy as jnp

# Create dataset
dataset = [...]  # Same as optimization example

# Create evaluator
evaluator = GradientBasedFitnessEvaluator(dataset)

# Evaluate circuit with gradients
circuit = Circuit.from_string("scalar(1.0)")
fitness, gradients = evaluator.evaluate_with_gradients(circuit)

print(f"Loss: {fitness.loss}")
print(f"Gradients: {gradients}")
# Output:
# Loss: 44.0
# Gradients: {'scalar_0': -43.831}
# Negative gradient → should INCREASE parameter

API Reference

class GradientBasedFitnessEvaluator:
    def __init__(self, dataset: List[DatasetSample])

    def evaluate(self, circuit: Circuit) -> FitnessResult

    def evaluate_with_gradients(
        self, circuit: Circuit
    ) -> Tuple[FitnessResult, Optional[Dict[str, jnp.ndarray]]]

Performance Comparison

Method Evaluations Speed Deterministic
Gradient-Based Optimization 20-100 Fast Yes
Genetic Algorithm 1000-10000 Slow No
Random Search 10000+ Very Slow No

Implementation Notes

Current Limitations

  1. Numerical Gradients: Currently using finite differences instead of JAX autodiff through circuit compilation

    • Reason: JAX can't differentiate through Python control flow (circuit rebuilding)
    • Impact: Slightly slower than pure JAX autodiff, but still much faster than discrete search
    • Future: Could use jax.custom_vjp for true autodiff through compilation
  2. Parameter Types: Only scalar() and const() parameters are optimizable

    • Other atomic operations (like var()) are not yet parameterizable
  3. Circuit Complexity: Works best with circuits that have few parameters (< 10)

    • Numerical gradient computation scales linearly with number of parameters

Helper Functions

Extract Parameters

from gimle.asgard.circuit.circuit_gradients import extract_parameters

params = extract_parameters(circuit)
# Returns CircuitParameters with param_values and param_locations

Rebuild Circuit with Parameters

from gimle.asgard.circuit.circuit_gradients import rebuild_circuit_with_parameters

new_circuit = rebuild_circuit_with_parameters(
    circuit_template,
    param_locations,
    new_params
)

Complete Example

from gimle.asgard.circuit.circuit import Circuit
from gimle.asgard.circuit.circuit_fitness import DatasetSample
from gimle.asgard.circuit.circuit_optimizer import GradientBasedOptimizer
from gimle.asgard.runtime.stream import Stream
import jax.numpy as jnp

# 1. Create training dataset
def create_dataset():
    """Create dataset for y = 2x + 1"""
    dataset = []
    for x in jnp.linspace(-5, 5, 20):
        y = 2.0 * x + 1.0
        input_stream = Stream(data=jnp.array([[float(x)]]), dim_labels=(), chunk_size=1)
        output_stream = Stream(data=jnp.array([[float(y)]]), dim_labels=(), chunk_size=1)
        dataset.append(DatasetSample([input_stream], [output_stream]))
    return dataset

dataset = create_dataset()

# 2. Create initial circuit with wrong parameters
# We want: 2x + 1, start with: 1x + 0
initial_circuit = Circuit.from_string(
    "composition(monoidal(scalar(1.0), const(0.0)), add)"
)

# 3. Optimize
optimizer = GradientBasedOptimizer(
    learning_rate=0.1,
    num_iterations=100,
    tolerance=1e-8,
)

optimized, history = optimizer.optimize(initial_circuit, dataset, verbose=True)

# 4. Verify
print(f"Optimized circuit: {optimized}")
# Should be approximately: composition(monoidal(scalar(2.0), const(1.0)), add)

# 5. Plot convergence
import matplotlib.pyplot as plt
plt.plot(history)
plt.xlabel("Iteration")
plt.ylabel("Loss")
plt.title("Optimization Convergence")
plt.yscale("log")
plt.show()

Next Steps