diff --git a/benchmarks/bfgs.py b/benchmarks/bfgs.py index 6276a580a..83bf415d3 100644 --- a/benchmarks/bfgs.py +++ b/benchmarks/bfgs.py @@ -26,7 +26,7 @@ def multiclass_logreg_jaxopt(X, y): fun = jaxopt.objective.multiclass_logreg init = jnp.zeros((X.shape[1], FLAGS.n_classes)) bfgs = jaxopt.BFGS( - fun=fun, + fun=fun, linesearch='zoom', maxiter=FLAGS.maxiter, maxls=FLAGS.maxls, @@ -59,8 +59,8 @@ def main(argv): start_time = time.time() dex_value = dex_bfgs( - jnp.array(X), - jnp.array(y), + jnp.array(X), + jnp.array(y), FLAGS.n_classes, FLAGS.maxiter, FLAGS.maxls, diff --git a/examples/bfgs.dx b/examples/bfgs.dx index 677de2c10..237664a39 100644 --- a/examples/bfgs.dx +++ b/examples/bfgs.dx @@ -1,14 +1,14 @@ '# BFGS optimizer -The BFGS method is a quasi-Newton algorithm for solving unconstrained nonlinear -optimization problems. A BFGS iteration entails computing a line search +The BFGS method is a quasi-Newton algorithm for solving unconstrained nonlinear +optimization problems. A BFGS iteration entails computing a line search direction based on the gradient and Hessian approximation, finding a new point in the line search direction that satisfies the Wolfe conditions, and updating the Hessian approximation at the new point. This implementation is based on -BFGS as described in Nocedal, Jorge; Wright, Stephen J. (2006), Numerical +BFGS as described in Nocedal, Jorge; Wright, Stephen J. (2006), Numerical Optimization (2nd ed). -'This example demonstrates Dex's ability to do fast, stateful loops with a -statically unknown number of iterations. See `benchmarks/bfgs.py` for a +'This example demonstrates Dex's ability to do fast, stateful loops with a +statically unknown number of iterations. See `benchmarks/bfgs.py` for a comparison with Jaxopt BFGS on a multiclass logistic regression problem. def outer_product(x:n=>Float, y:m=>Float) -> (n=>m=>Float) given (n|Ix, m|Ix) = @@ -16,7 +16,7 @@ def outer_product(x:n=>Float, y:m=>Float) -> (n=>m=>Float) given (n|Ix, m|Ix) = def zoom( f_line: (Float)->Float, - a_lo_init:Float, + a_lo_init:Float, a_hi_init:Float, c1:Float, c2:Float @@ -40,7 +40,7 @@ def zoom( else f_ai = f_line a_i if f_ai > (f0 + c1 * a_i * g0) || f_ai >= f_line a_lo - then + then a_hi_ref := a_i Continue else @@ -81,7 +81,7 @@ def zoom_line_search( else if g_i >= 0. then Done (zoom f a_i a_last c1 c2) - else + else a_ref_last := a_i a_ref := 0.5 * (a_i + a_max) Continue @@ -91,7 +91,7 @@ def backtracking_line_search( f: (Float)->Float ) -> Float = -- Algorithm 3.1 in Nocedal and Wright (2006). - a_init = 1. + a_init = 1. f_0 = f 0. g_0 = grad f 0. rho = 0.5 @@ -106,12 +106,12 @@ def backtracking_line_search( a_ref := a_i * rho Continue -struct BFGSresults(n|Ix) = +struct BFGSresults(n|Ix) = fval : Float x_opt: (n=>Float) - error: Float + error: Float num_iter: Nat - + def bfgs_minimize( f: (n=>Float)->Float, --Objective function. x0: n=>Float, --Initial point. @@ -119,10 +119,10 @@ def bfgs_minimize( linesearch: ((Float)->Float)->Float, --Line search that returns a step size. tol: Float, --Convergence tolerance (of the gradient L2 norm). maxiter: Nat --Maximum number of BFGS iterations. - ) -> BFGSresults n given (n|Ix) = + ) -> BFGSresults n given (n|Ix) = -- Algorithm 6.1 in Nocedal and Wright (2006). - xref <- with_state x0 + xref <- with_state x0 Href <- with_state H0 gref <- with_state (grad f x0) @@ -137,7 +137,7 @@ def bfgs_minimize( H = get Href search_direction = -H**.g f_line = \s:Float. f (x + s .* search_direction) - step_size = linesearch f_line + step_size = linesearch f_line x_diff = step_size .* search_direction x_next = x + x_diff g_next = grad f x_next @@ -150,7 +150,7 @@ def bfgs_minimize( rho = 1. / rho_inv y = (eye - rho .* outer_product x_diff grad_diff) Href := y ** H ** (transpose y) + rho .* outer_product x_diff x_diff - + xref := x_next gref := g_next Continue @@ -162,14 +162,14 @@ def rosenbrock(coord:(Fin 2)=>Float) -> Float = y = coord[1@_] pow (1 - x) 2 + 100 * pow (y - x * x) 2 -%bench "rosenbrock" +%time bfgs_minimize rosenbrock [10., 10.] eye (\f. backtracking_line_search 15 f) 0.001 100 > BFGSresults(8.668621e-13, [0.9999993, 0.9999985], 2.538457e-05, 41) > -> rosenbrock -> Compile time: 618.962 ms -> Run time: 57.489 us (based on 1 run) +> Compile time: 675.707 ms +> Run time: 220.998 us +@noinline def multiclass_logistic_loss(xs: n=>d=>Float, ys: n=>m, w: (d, m)=>Float) -> Float given (n|Ix, d|Ix, m|Ix) = w_arr = for i:d. for j:m. w[(i, j)] logits = xs ** w_arr @@ -185,10 +185,10 @@ def multiclass_logreg( tol:Float) -> Float given (n|Ix, d|Ix, m|Ix)= ob_fun = \v. multiclass_logistic_loss xs ys v w0 = zero - res = bfgs_minimize ob_fun w0 eye (\f. zoom_line_search maxls f) tol maxiter + res = bfgs_minimize ob_fun w0 eye (\f. zoom_line_search maxls f) tol maxiter res.fval --- Define a version of `multiclass_logreg` with Int instead of Nat labels, callable from Python +-- Define a version of `multiclass_logreg` with Int instead of Nat labels, callable from Python -- (see benchmarks/bfgs.py). def multiclass_logreg_int( xs:(Fin n)=>(Fin d)=>Float, @@ -199,3 +199,20 @@ def multiclass_logreg_int( tol:Float) -> Float given (n, d) = y_ind = Fin (i32_to_n num_classes) multiclass_logreg xs (for i. i32_to_n ys[i] @ y_ind) (i32_to_n maxiter) (i32_to_n maxls) tol + +n_samples = 100 +n_features = 20 +n_classes = 5 +maxiter = 30 +maxls = 15 +tol = 0.001 + +xs = rand_mat n_samples n_features randn (new_key 0) +ys : (Fin n_samples) => (Fin n_classes) = rand_vec n_samples rand_idx (new_key 1) + +%time +multiclass_logreg xs ys maxiter maxls tol +> 1.609437 +> +> Compile time: 3.473 s +> Run time: 195.542 us