Converting Numba Code to JAX#
This guide provides best practices for converting Numba-based economic models to JAX, embracing functional programming patterns that align with JAX’s design philosophy.
Overview#
QuantEcon is transitioning from Numba to JAX to provide cleaner, more performant implementations of computationally intensive economic models. JAX’s JIT compiler is significantly more powerful than Numba’s compiler and allows us to eliminate complex structures like jitclass
.
Note: Pure NumPy lectures that are not computationally intensive can remain as NumPy, since NumPy syntax provides a simpler foundation and JAX uses the same NumPy API (jax.numpy
).
Our focus is on replacing Numba with JAX for performance-critical code, settling around a uniform style based on JAX functional programming principles.
Style Guidelines#
All examples in this guide follow QuantEcon’s variable naming conventions (see Variable Naming Conventions section), particularly the use of Unicode symbols for Greek letters (α, β, γ, etc.) commonly used in economics.
Core Principles#
1. Functional Programming Style#
JAX encourages pure functions with no side effects. This means:
Functions should not modify their inputs
Functions should return new data rather than mutating existing data
Avoid global state and mutable objects
Each function should be deterministic
# ❌ Avoid: Mutating input arrays
def bad_update(state, shock):
state[0] += shock # Modifies input
return state
# ✅ Prefer: Pure function returning new data
def good_update(state, shock):
return state.at[0].add(shock) # Returns new array
2. Model Structure Pattern#
Replace large classes with a structured approach:
Old Class-based Pattern (Avoid)#
class LargeEconomicModel:
def __init__(self, α, β, γ):
self.α = α
self.β = β
self.γ = γ
self.state = np.zeros(10)
def update_state(self, shock):
self.state += shock * self.α
def compute_value(self):
return np.sum(self.state * self.β)
def simulate(self, n_periods):
# Large method with multiple responsibilities
pass
New JAX Pattern (Preferred)#
from typing import NamedTuple
import jax.numpy as jnp
# 1. Store primitives in a NamedTuple
class EconomicModel(NamedTuple):
α: float
β: float
γ: float
grid_size: int = 100
# 2. Factory function for creating instances
def create_model_instance(α=0.5, β=0.95, γ=2.0, grid_size=100):
"""Create a model instance with validation."""
if not 0 < α < 1:
raise ValueError("α must be between 0 and 1")
if not 0 < β < 1:
raise ValueError("β must be between 0 and 1")
return EconomicModel(
α=α,
β=β,
γ=γ,
grid_size=grid_size
)
# 3. Collection of pure functions for computations
def update_state(model: EconomicModel, state: jnp.ndarray, shock: float) -> jnp.ndarray:
"""Update state with shock, returning new state."""
return state + shock * model.α
def compute_value(model: EconomicModel, state: jnp.ndarray) -> float:
"""Compute value function given current state."""
return jnp.sum(state * model.β)
def simulate_path(model: EconomicModel, initial_state: jnp.ndarray,
shocks: jnp.ndarray) -> jnp.ndarray:
"""Simulate the model for multiple periods."""
def step(state, shock):
new_state = update_state(model, state, shock)
return new_state, new_state
final_state, path = jax.lax.scan(step, initial_state, shocks)
return path
Why Replace Numba with JAX?#
JAX JIT Compiler Advantages#
JAX’s JIT compiler offers several advantages over Numba:
More Powerful Optimization: JAX’s XLA backend provides sophisticated optimizations including fusion, vectorization, and memory layout optimization.
Eliminate Complex Structures: No need for complex constructs like
jitclass
- use simpleNamedTuple
instead.Functional Programming: Encourages cleaner, more maintainable code through pure functions.
Additional Transformations: Built-in automatic differentiation (
grad
), vectorization (vmap
), and parallelization (pmap
).
Migration from Numba jitclass#
# ❌ Old Numba jitclass pattern
from numba import jitclass, float64, int64
spec = [
('α', float64),
('β', float64),
('state', float64[:])
]
@jitclass(spec)
class NumbaModel:
def __init__(self, α, β):
self.α = α
self.β = β
self.state = np.zeros(100)
def update(self, shock):
self.state[0] += shock * self.α
# ✅ New JAX pattern - much cleaner!
class JAXModel(NamedTuple):
α: float
β: float
def create_jax_model(α=0.5, β=0.95):
return JAXModel(α=α, β=β)
@jax.jit
def update_state(model: JAXModel, state: jnp.ndarray, shock: float):
return state.at[0].add(shock * model.α)
Migration Patterns#
From Numba to JAX#
# Numba function decoration
from numba import jit
@jit # ❌ Numba JIT
def compute_value(α, β, data):
result = 0.0
for i in range(len(data)):
result += α * data[i] + β
return result
# ✅ JAX equivalent
@jax.jit
def compute_value(α: float, β: float, data: jnp.ndarray) -> float:
return jnp.sum(α * data + β) # Vectorized operation
NumPy Array Operations#
# NumPy → JAX conversions
import numpy as np # ❌ Old
import jax.numpy as jnp # ✅ New
# Array creation
np.zeros(10) # ❌
jnp.zeros(10) # ✅
# In-place operations → functional updates
arr[0] = 5 # ❌ Mutation
arr = arr.at[0].set(5) # ✅ Functional update
arr += 1 # ❌ In-place
arr = arr + 1 # ✅ Pure function
Loop Patterns#
# Replace explicit loops with JAX constructs
# ❌ Explicit loop
result = []
state = initial_state
for i in range(n_periods):
state = update_function(state, data[i])
result.append(state)
# ✅ JAX scan
def step(state, data_i):
new_state = update_function(state, data_i)
return new_state, new_state
final_state, path = jax.lax.scan(step, initial_state, data)
Advanced Loop Patterns#
For more complex loop patterns, JAX provides jax.lax.while_loop
and jax.lax.fori_loop
:
# ✅ JAX while_loop for conditional iterations
def cond_fun(carry):
i, val = carry
return (i < 100) & (val < threshold)
def body_fun(carry):
i, val = carry
new_val = update_function(val, i)
return (i + 1, new_val)
# Run while loop
initial_carry = (0, initial_value)
final_i, final_val = jax.lax.while_loop(cond_fun, body_fun, initial_carry)
# ✅ JAX fori_loop for fixed iterations
def body_fun(i, val):
return update_function(val, i)
# Run for-loop from 0 to n_iterations
final_val = jax.lax.fori_loop(0, n_iterations, body_fun, initial_value)
Best Practices for JAX Loops:
Prefer
jax.lax.scan
when you need to collect intermediate resultsUse
jax.lax.fori_loop
for simple fixed-iteration loopsUse
jax.lax.while_loop
for conditional loopsAlways ensure loop bodies are pure functions
Keep carry/accumulator types consistent throughout iterations
Random Number Generation#
# NumPy random → JAX random
import numpy as np # ❌
np.random.seed(42)
shocks = np.random.normal(0, 1, 100)
# ✅ JAX random with explicit key management
import jax.random as jr
key = jr.PRNGKey(42)
shocks = jr.normal(key, (100,))
Complete Example: Asset Pricing Model#
Here’s a complete example showing the conversion of an asset pricing model from Numba to JAX:
Before (Numba/Class-based)#
from numba import jitclass, float64
import numpy as np
spec = [
('β', float64),
('α', float64),
('prices', float64[:])
]
@jitclass(spec)
class AssetPricingModel:
def __init__(self, β=0.95, α=0.5):
self.β = β
self.α = α
self.prices = np.array([0.0]) # placeholder
def solve_prices(self, dividends, tolerance=1e-6):
prices = np.zeros_like(dividends)
# Iterative solution with mutation
for iteration in range(1000):
new_prices = self.β * (dividends + self.α * prices)
if np.max(np.abs(new_prices - prices)) < tolerance:
break
prices[:] = new_prices # In-place mutation
self.prices = prices
return prices
After (JAX/Functional)#
from typing import NamedTuple
import jax.numpy as jnp
import jax
class AssetPricingModel(NamedTuple):
β: float = 0.95
α: float = 0.5
def create_asset_model(β=0.95, α=0.5):
"""Create asset pricing model with validation."""
return AssetPricingModel(β=β, α=α)
@jax.jit
def solve_prices(model: AssetPricingModel, dividends: jnp.ndarray,
tolerance: float = 1e-6) -> jnp.ndarray:
"""Solve for asset prices using fixed point iteration."""
def update_prices(prices):
return model.β * (dividends + model.α * prices)
def convergence_check(prices_old, prices_new):
return jnp.max(jnp.abs(prices_new - prices_old)) < tolerance
# Fixed point iteration
prices = jnp.zeros_like(dividends)
def step(state):
prices, converged = state
new_prices = update_prices(prices)
new_converged = convergence_check(prices, new_prices)
return (new_prices, new_converged)
def cond(state):
prices, converged = state
return ~converged
final_prices, _ = jax.lax.while_loop(cond, step, (prices, False))
return final_prices
# Usage
model = create_asset_model(β=0.95, α=0.5)
dividends = jnp.array([1.0, 1.1, 0.9, 1.2])
prices = solve_prices(model, dividends)
Best Practices Summary#
Use NamedTuple for model parameters - Immutable, type-safe parameter storage
Create factory functions - Centralized model creation with validation
Write pure functions - No side effects, return new data
Leverage JAX transformations - Use
jit
,vmap
,grad
for performanceExplicit randomness - Manage PRNG keys explicitly
Functional updates - Use
.at[].set()
,.at[].add()
instead of mutationsUse JAX control flow -
lax.scan
,lax.while_loop
instead of Python loops
Performance Considerations#
Use
@jax.jit
decorator for compilationBatch operations with
jax.vmap
Avoid Python loops inside JIT-compiled functions
Keep array shapes static when possible
Use
jax.lax
operations for control flow
Additional Resources#
Following these patterns will result in cleaner, more maintainable, and higher-performance economic models that take full advantage of JAX’s capabilities.