nlls_gram
Levenberg-Marquardt nonlinear least-squares solvers for JAX, with dense dual Cholesky/QR solves and optional matrix-free iterative solves.
UnderdeterminedLevenbergMarquardt minimizes ||r(params)||^2 for a user-supplied
residual_fn(params, batch), where params is any JAX pytree (a flat array, a
dict, nnx.state(model, nnx.Param), ...). It follows an init/update protocol:
update(params, state, batch) returns the new params pytree (same structure),
the next state, and an LMInfo. The default dense solver factors the small
residual-space Gram (dual) system. For a Jacobian J with shape m x n, this
means an m x m solve, which is useful when there are many more parameters than
residual rows. jac="vjp" is the only supported Jacobian materialization mode.
Optional solvers provide a dense QR stability path, matrix-free Gram-space CG,
and LSMR on the damped least-squares formulation.
The solver interface is general JAX — it knows nothing about
flax/nnx/optax — and the package depends on jax plus lineax for LSMR.
Dtypes flow from your params/residual, and the damping state follows the
residual dtype; JAX decides float32 vs float64 via jax_enable_x64.
init_damping must be positive; use a small positive value for near
Gauss-Newton behavior.
Install
For local development on an NVIDIA CUDA 13 machine, use the optional gpu
dependency group:
That group is for this repository's development and GPU tests; it is not a
published nlls-gram[gpu] extra. Users who want to run the optimizer on a GPU
should install the JAX accelerator build that matches their hardware alongside
nlls-gram, for example:
See the JAX installation guide for the current CUDA, ROCm, TPU, and CPU installation choices.
Minimal example
Fit y = a * exp(b * x) to noise-free data generated from (a, b) = (2, -1),
using a plain dict pytree of parameters. With JAX's default configuration, this
runs in float32:
import jax
import jax.numpy as jnp
from nlls_gram import UnderdeterminedLevenbergMarquardt
# residual_fn(params, batch) -> 1-D residual array; the solver minimizes its SSQ.
def residual_fn(params, batch):
x, y = batch
return params["a"] * jnp.exp(params["b"] * x) - y
x = jnp.linspace(0.0, 2.0, 20)
y = 2.0 * jnp.exp(-1.0 * x)
params = {"a": 1.0, "b": 0.0}
solver = UnderdeterminedLevenbergMarquardt(residual_fn, init_damping=1e-2)
lm_state = solver.init()
# The solver does not jit internally; wrap the train step yourself.
@jax.jit
def train_step(params, lm_state, batch):
return solver.update(params, lm_state, batch)
for _ in range(50):
params, lm_state, info = train_step(params, lm_state, (x, y))
print(params["a"], params["b"]) # ~2.0, ~-1.0
print(params["a"].dtype, info.loss.dtype) # float32 float32
Float64 example
Enable x64 before creating arrays, then initialize the data and parameters as float64:
import jax
jax.config.update("jax_enable_x64", True)
import jax.numpy as jnp
from nlls_gram import UnderdeterminedLevenbergMarquardt
dtype = jnp.float64
def residual_fn(params, batch):
x, y = batch
return params["a"] * jnp.exp(params["b"] * x) - y
x = jnp.linspace(0.0, 2.0, 20, dtype=dtype)
y = 2.0 * jnp.exp(-1.0 * x)
params = {
"a": jnp.asarray(1.0, dtype=dtype),
"b": jnp.asarray(0.0, dtype=dtype),
}
solver = UnderdeterminedLevenbergMarquardt(residual_fn, init_damping=1e-2)
lm_state = solver.init()
for _ in range(50):
params, lm_state, info = solver.update(params, lm_state, (x, y))
print(params["a"], params["b"]) # ~2.0, ~-1.0
print(params["a"].dtype, info.loss.dtype, info.damping.dtype)
Fletcher regularization
The default regularization="identity" uses the classic LM damping matrix
lambda * I. If parameters are badly scaled, regularization="fletcher" can
help by damping each parameter direction in proportion to diag(J.T @ J).
The diagonal is clipped before use, with defaults
fletcher_min_diagonal=1e-6 and fletcher_max_diagonal=1e6, so nearly unused
or extremely sensitive parameter directions do not dominate the Gram solve.
import jax.numpy as jnp
from nlls_gram import UnderdeterminedLevenbergMarquardt
x = jnp.linspace(0.0, 2.0, 50)
y = 2.0 * jnp.exp(-1.0 * x)
parameter_scale = 1e-3
def residual_fn(params, batch):
x, y = batch
b = parameter_scale * params["b_scaled"]
return params["a"] * jnp.exp(b * x) - y
def iterations_to_threshold(regularization):
params = {"a": 1.0, "b_scaled": 0.0}
solver = UnderdeterminedLevenbergMarquardt(
residual_fn,
init_damping=1e-2,
regularization=regularization,
fletcher_min_diagonal=1e-6,
fletcher_max_diagonal=1e6,
)
lm_state = solver.init()
for iteration in range(1, 51):
params, lm_state, info = solver.update(params, lm_state, (x, y))
if float(info.loss) < 1e-8:
return iteration
return None
print(iterations_to_threshold("identity")) # ~16
print(iterations_to_threshold("fletcher")) # ~4
Linear solvers
The default linear_solver="cholesky" materializes J.T with VJPs, forms the
dense residual-space Gram matrix J @ J.T, and uses a Cholesky factorization.
For J with shape m x n, where m is the number of residuals and n is the
number of parameters, this is an m x m factorization. This is the old Gram/dual
formulation and remains the default because it is usually the fastest path for
overparameterized problems with small residual dimension.
When the Gram/Cholesky path is too poorly conditioned, use the more numerically stable QR direct solve:
QR materializes J.T with VJPs and factors the transpose-side problem. The
original Jacobian is fat in the intended overparameterized setting, but J.T is
tall-skinny, so this path can be substantially slower than Gram/Cholesky on CPU
and GPU. Use it when the extra numerical stability is worth the cost. It solves
the damped subproblem through a small augmented QR system, without forming
J @ J.T or J.T @ J. QR currently supports only regularization="identity".
For larger identity-regularized problems, use an iterative solver with JAX
JVP/VJP linearization instead of materializing J:
Iterative solvers default to a small fixed iteration budget:
iterative_tol=0.0, iterative_atol=0.0, and iterative_maxiter=8. This avoids
extra tolerance-driven convergence work and is intended for low-rank local
linear solves. Set a positive iterative_tol or iterative_atol if you want
early convergence checks instead.
For fixed-budget Gram-space CG, use:
solver = UnderdeterminedLevenbergMarquardt(
residual_fn,
init_damping=1e-2,
linear_solver="cg",
iterative_tol=0.0,
iterative_atol=0.0,
iterative_maxiter=8,
)
CG currently supports only regularization="identity". It solves in residual
space, so the Krylov vectors have length equal to the number of residuals. It
uses matrix-free JVPs for J @ v and VJPs/linear transposes for J.T @ u.
linear_solver="lsmr" uses Lineax LSMR on the damped least-squares problem
directly:
For fixed-budget LSMR, use:
solver = UnderdeterminedLevenbergMarquardt(
residual_fn,
init_damping=1e-2,
linear_solver="lsmr",
iterative_tol=0.0,
iterative_atol=0.0,
iterative_maxiter=8,
lsmr_conlim=float("inf"),
)
It uses the augmented operator [J; sqrt(lambda) I], so matrix-vector products
call JAX JVPs for J @ s and transposed products call VJPs/linear transposes for
J.T @ u. LSMR does not use the dense Gram or QR factorizations. Its default
lsmr_conlim=float("inf") prevents condition-limit early termination; Lineax
still computes LSMR's internal norm estimates each iteration. Iterative solvers
can reduce memory and factorization cost on larger dense GPU problems, but each
iteration performs matrix-free Jacobian-vector and transpose-vector products, so
the dense direct solvers remain better for small residual dimensions.
Geodesic acceleration
Geodesic acceleration is off by default. When enabled, the solver uses analytic JAX forward-over-forward JVPs to build an accelerated candidate; it does not use finite differences.
solver = UnderdeterminedLevenbergMarquardt(
residual_fn,
init_damping=1e-2,
geodesic_acceleration=True,
)
The accelerated candidate is used only when its acceleration ratio,
2 * ||a|| / ||v||, is at or below a positive geodesic_acceptance_ratio and
its loss is no worse than the plain LM velocity candidate. Otherwise the update
automatically falls back to the velocity step. Use LMInfo.used_geodesic,
LMInfo.acceleration_ratio, LMInfo.loss_old, LMInfo.loss_candidate, and
LMInfo.damping_factor to tune damping and geodesic behavior.
params can be any pytree. With Flax NNX, pass nnx.state(model, nnx.Param) as
params and write residual_fn(state, batch) using nnx.merge; the solver itself
stays NNX-agnostic.
Filtering / freezing parameters
update optimizes exactly the params pytree you pass. For Flax NNX transfer
learning, construct or load the full module first, choose the trainable leaves
with an NNX filter, and pass only that trainable state to the solver. This mirrors
the wrt argument used by nnx.Optimizer: wrt means "differentiate and update
these leaves", while ... captures the already-initialized frozen remainder.
Install Flax in your project to run this example.
import jax
import jax.numpy as jnp
from flax import nnx
from nlls_gram import UnderdeterminedLevenbergMarquardt
class ExpModel(nnx.Module):
def __init__(self):
self.a = nnx.Param(jnp.asarray(1.0))
self.b = nnx.Param(jnp.asarray(-1.0))
def __call__(self, x):
return self.a[...] * jnp.exp(self.b[...] * x)
x = jnp.linspace(0.0, 2.0, 20)
y = 2.0 * jnp.exp(-1.0 * x)
model = ExpModel()
wrt = nnx.PathContains("a") # train "a"; keep all other initialized state fixed
graphdef, trainable, frozen = nnx.split(model, wrt, ...)
def residual_fn(trainable, batch):
x, y = batch
model = nnx.merge(graphdef, trainable, frozen)
return model(x) - y
solver = UnderdeterminedLevenbergMarquardt(residual_fn, init_damping=1e-2)
lm_state = solver.init()
for _ in range(50):
trainable, lm_state, info = solver.update(trainable, lm_state, (x, y))
model = nnx.merge(graphdef, trainable, frozen)
print(model.a[...], model.b[...]) # ~2.0, -1.0
For built-in NNX layers, set both computation and parameter initialization dtypes when you want an all-float64 model:
Benchmarks
Optional pytest-benchmark checks live outside the normal test suite and do not run in CI by default:
For a larger RBF-style interpolation profile with CPU/GPU, Cholesky/QR/CG/LSMR, and geodesic on/off variants:
uv run --group benchmark --group gpu pytest \
benchmarks/test_large_interpolation_benchmark.py --benchmark-only
For a small classic geodesic-acceleration convergence benchmark based on the GSL modified Rosenbrock example:
On machines with a CUDA-enabled JAX install, the optional GPU test checks that a jitted geodesic update runs on a GPU device:
API reference
nlls_gram.UnderdeterminedLevenbergMarquardt
Source code in src/nlls_gram/gram_lm.py
45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 | |
nlls_gram.LMState
nlls_gram.LMInfo
Bases: NamedTuple