Converting Numba Code to JAX#

This guide provides best practices for converting Numba-based economic models to JAX, embracing functional programming patterns that align with JAX’s design philosophy.

Overview#

QuantEcon is transitioning from Numba to JAX to provide cleaner, more performant implementations of computationally intensive economic models. JAX’s JIT compiler is significantly more powerful than Numba’s compiler and allows us to eliminate complex structures like jitclass.

Note: Pure NumPy lectures that are not computationally intensive can remain as NumPy, since NumPy syntax provides a simpler foundation and JAX uses the same NumPy API (jax.numpy).

Our focus is on replacing Numba with JAX for performance-critical code, settling around a uniform style based on JAX functional programming principles.

Style Guidelines#

All examples in this guide follow QuantEcon’s variable naming conventions (see Variable Naming Conventions section), particularly the use of Unicode symbols for Greek letters (α, β, γ, etc.) commonly used in economics.

Core Principles#

1. Functional Programming Style#

JAX encourages pure functions with no side effects. This means:

  • Functions should not modify their inputs

  • Functions should return new data rather than mutating existing data

  • Avoid global state and mutable objects

  • Each function should be deterministic

# ❌ Avoid: Mutating input arrays
def bad_update(state, shock):
    state[0] += shock  # Modifies input
    return state

# ✅ Prefer: Pure function returning new data
def good_update(state, shock):
    return state.at[0].add(shock)  # Returns new array

2. Model Structure Pattern#

Replace large classes with a structured approach:

Old Class-based Pattern (Avoid)#

class LargeEconomicModel:
    def __init__(self, α, β, γ):
        self.α = α
        self.β = β  
        self.γ = γ
        self.state = np.zeros(10)
    
    def update_state(self, shock):
        self.state += shock * self.α
        
    def compute_value(self):
        return np.sum(self.state * self.β)
    
    def simulate(self, n_periods):
        # Large method with multiple responsibilities
        pass

New JAX Pattern (Preferred)#

from typing import NamedTuple
import jax.numpy as jnp

# 1. Store primitives in a NamedTuple
class EconomicModel(NamedTuple):
    α: float
    β: float
    γ: float
    grid_size: int = 100

# 2. Factory function for creating instances
def create_model_instance(α=0.5, β=0.95, γ=2.0, grid_size=100):
    """Create a model instance with validation."""
    if not 0 < α < 1:
        raise ValueError("α must be between 0 and 1")
    if not 0 < β < 1:
        raise ValueError("β must be between 0 and 1")
    
    return EconomicModel(
        α=α,
        β=β,
        γ=γ,
        grid_size=grid_size
    )

# 3. Collection of pure functions for computations
def update_state(model: EconomicModel, state: jnp.ndarray, shock: float) -> jnp.ndarray:
    """Update state with shock, returning new state."""
    return state + shock * model.α

def compute_value(model: EconomicModel, state: jnp.ndarray) -> float:
    """Compute value function given current state."""
    return jnp.sum(state * model.β)

def simulate_path(model: EconomicModel, initial_state: jnp.ndarray, 
                  shocks: jnp.ndarray) -> jnp.ndarray:
    """Simulate the model for multiple periods."""
    def step(state, shock):
        new_state = update_state(model, state, shock)
        return new_state, new_state
    
    final_state, path = jax.lax.scan(step, initial_state, shocks)
    return path

Why Replace Numba with JAX?#

JAX JIT Compiler Advantages#

JAX’s JIT compiler offers several advantages over Numba:

  1. More Powerful Optimization: JAX’s XLA backend provides sophisticated optimizations including fusion, vectorization, and memory layout optimization.

  2. Eliminate Complex Structures: No need for complex constructs like jitclass - use simple NamedTuple instead.

  3. Functional Programming: Encourages cleaner, more maintainable code through pure functions.

  4. Additional Transformations: Built-in automatic differentiation (grad), vectorization (vmap), and parallelization (pmap).

Migration from Numba jitclass#

# ❌ Old Numba jitclass pattern
from numba import jitclass, float64, int64

spec = [
    ('α', float64),
    ('β', float64),
    ('state', float64[:])
]

@jitclass(spec)
class NumbaModel:
    def __init__(self, α, β):
        self.α = α
        self.β = β
        self.state = np.zeros(100)
    
    def update(self, shock):
        self.state[0] += shock * self.α

# ✅ New JAX pattern - much cleaner!
class JAXModel(NamedTuple):
    α: float
    β: float

def create_jax_model(α=0.5, β=0.95):
    return JAXModel(α=α, β=β)

@jax.jit
def update_state(model: JAXModel, state: jnp.ndarray, shock: float):
    return state.at[0].add(shock * model.α)

Migration Patterns#

From Numba to JAX#

# Numba function decoration
from numba import jit

@jit                                    # ❌ Numba JIT
def compute_value(α, β, data):
    result = 0.0
    for i in range(len(data)):
        result += α * data[i] + β
    return result

# ✅ JAX equivalent
@jax.jit
def compute_value(α: float, β: float, data: jnp.ndarray) -> float:
    return jnp.sum(α * data + β)  # Vectorized operation

NumPy Array Operations#

# NumPy → JAX conversions
import numpy as np           # ❌ Old
import jax.numpy as jnp      # ✅ New

# Array creation
np.zeros(10)                 # ❌ 
jnp.zeros(10)               # ✅

# In-place operations → functional updates
arr[0] = 5                  # ❌ Mutation
arr = arr.at[0].set(5)      # ✅ Functional update

arr += 1                    # ❌ In-place
arr = arr + 1               # ✅ Pure function

Loop Patterns#

# Replace explicit loops with JAX constructs
# ❌ Explicit loop
result = []
state = initial_state
for i in range(n_periods):
    state = update_function(state, data[i])
    result.append(state)

# ✅ JAX scan
def step(state, data_i):
    new_state = update_function(state, data_i)
    return new_state, new_state

final_state, path = jax.lax.scan(step, initial_state, data)

Advanced Loop Patterns#

For more complex loop patterns, JAX provides jax.lax.while_loop and jax.lax.fori_loop:

# ✅ JAX while_loop for conditional iterations
def cond_fun(carry):
    i, val = carry
    return (i < 100) & (val < threshold)

def body_fun(carry):
    i, val = carry
    new_val = update_function(val, i)
    return (i + 1, new_val)

# Run while loop
initial_carry = (0, initial_value)
final_i, final_val = jax.lax.while_loop(cond_fun, body_fun, initial_carry)

# ✅ JAX fori_loop for fixed iterations
def body_fun(i, val):
    return update_function(val, i)

# Run for-loop from 0 to n_iterations
final_val = jax.lax.fori_loop(0, n_iterations, body_fun, initial_value)

Best Practices for JAX Loops:

  • Prefer jax.lax.scan when you need to collect intermediate results

  • Use jax.lax.fori_loop for simple fixed-iteration loops

  • Use jax.lax.while_loop for conditional loops

  • Always ensure loop bodies are pure functions

  • Keep carry/accumulator types consistent throughout iterations

Random Number Generation#

# NumPy random → JAX random
import numpy as np                    # ❌
np.random.seed(42)
shocks = np.random.normal(0, 1, 100)

# ✅ JAX random with explicit key management
import jax.random as jr
key = jr.PRNGKey(42)
shocks = jr.normal(key, (100,))

Complete Example: Asset Pricing Model#

Here’s a complete example showing the conversion of an asset pricing model from Numba to JAX:

Before (Numba/Class-based)#

from numba import jitclass, float64
import numpy as np

spec = [
    ('β', float64),
    ('α', float64),
    ('prices', float64[:])
]

@jitclass(spec)
class AssetPricingModel:
    def __init__(self, β=0.95, α=0.5):
        self.β = β
        self.α = α
        self.prices = np.array([0.0])  # placeholder
        
    def solve_prices(self, dividends, tolerance=1e-6):
        prices = np.zeros_like(dividends)
        # Iterative solution with mutation
        for iteration in range(1000):
            new_prices = self.β * (dividends + self.α * prices)
            if np.max(np.abs(new_prices - prices)) < tolerance:
                break
            prices[:] = new_prices  # In-place mutation
        self.prices = prices
        return prices

After (JAX/Functional)#

from typing import NamedTuple
import jax.numpy as jnp
import jax

class AssetPricingModel(NamedTuple):
    β: float = 0.95
    α: float = 0.5

def create_asset_model(β=0.95, α=0.5):
    """Create asset pricing model with validation."""
    return AssetPricingModel(β=β, α=α)

@jax.jit
def solve_prices(model: AssetPricingModel, dividends: jnp.ndarray, 
                tolerance: float = 1e-6) -> jnp.ndarray:
    """Solve for asset prices using fixed point iteration."""
    
    def update_prices(prices):
        return model.β * (dividends + model.α * prices)
    
    def convergence_check(prices_old, prices_new):
        return jnp.max(jnp.abs(prices_new - prices_old)) < tolerance
    
    # Fixed point iteration
    prices = jnp.zeros_like(dividends)
    
    def step(state):
        prices, converged = state
        new_prices = update_prices(prices)
        new_converged = convergence_check(prices, new_prices)
        return (new_prices, new_converged)
    
    def cond(state):
        prices, converged = state
        return ~converged
    
    final_prices, _ = jax.lax.while_loop(cond, step, (prices, False))
    return final_prices

# Usage
model = create_asset_model(β=0.95, α=0.5)
dividends = jnp.array([1.0, 1.1, 0.9, 1.2])
prices = solve_prices(model, dividends)

Best Practices Summary#

  1. Use NamedTuple for model parameters - Immutable, type-safe parameter storage

  2. Create factory functions - Centralized model creation with validation

  3. Write pure functions - No side effects, return new data

  4. Leverage JAX transformations - Use jit, vmap, grad for performance

  5. Explicit randomness - Manage PRNG keys explicitly

  6. Functional updates - Use .at[].set(), .at[].add() instead of mutations

  7. Use JAX control flow - lax.scan, lax.while_loop instead of Python loops

Performance Considerations#

  • Use @jax.jit decorator for compilation

  • Batch operations with jax.vmap

  • Avoid Python loops inside JIT-compiled functions

  • Keep array shapes static when possible

  • Use jax.lax operations for control flow

Additional Resources#

Following these patterns will result in cleaner, more maintainable, and higher-performance economic models that take full advantage of JAX’s capabilities.