Skip to content

Commit

Permalink
Extend the BFGS example to test the larger logistic regression loss.
Browse files Browse the repository at this point in the history
Also
- Tag the loss function noinline, which speeds up compilation ~2x,
- Use %time rather than %bench in the example, and
- Fix whitespace.
  • Loading branch information
axch committed Jul 13, 2023
1 parent 808f6e9 commit 1650ea7
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 25 deletions.
6 changes: 3 additions & 3 deletions benchmarks/bfgs.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def multiclass_logreg_jaxopt(X, y):
fun = jaxopt.objective.multiclass_logreg
init = jnp.zeros((X.shape[1], FLAGS.n_classes))
bfgs = jaxopt.BFGS(
fun=fun,
fun=fun,
linesearch='zoom',
maxiter=FLAGS.maxiter,
maxls=FLAGS.maxls,
Expand Down Expand Up @@ -59,8 +59,8 @@ def main(argv):

start_time = time.time()
dex_value = dex_bfgs(
jnp.array(X),
jnp.array(y),
jnp.array(X),
jnp.array(y),
FLAGS.n_classes,
FLAGS.maxiter,
FLAGS.maxls,
Expand Down
61 changes: 39 additions & 22 deletions examples/bfgs.dx
Original file line number Diff line number Diff line change
@@ -1,22 +1,22 @@
'# BFGS optimizer
The BFGS method is a quasi-Newton algorithm for solving unconstrained nonlinear
optimization problems. A BFGS iteration entails computing a line search
The BFGS method is a quasi-Newton algorithm for solving unconstrained nonlinear
optimization problems. A BFGS iteration entails computing a line search
direction based on the gradient and Hessian approximation, finding a new point
in the line search direction that satisfies the Wolfe conditions, and updating
the Hessian approximation at the new point. This implementation is based on
BFGS as described in Nocedal, Jorge; Wright, Stephen J. (2006), Numerical
BFGS as described in Nocedal, Jorge; Wright, Stephen J. (2006), Numerical
Optimization (2nd ed).

'This example demonstrates Dex's ability to do fast, stateful loops with a
statically unknown number of iterations. See `benchmarks/bfgs.py` for a
'This example demonstrates Dex's ability to do fast, stateful loops with a
statically unknown number of iterations. See `benchmarks/bfgs.py` for a
comparison with Jaxopt BFGS on a multiclass logistic regression problem.

def outer_product(x:n=>Float, y:m=>Float) -> (n=>m=>Float) given (n|Ix, m|Ix) =
for i:n. for j:m. x[i]* y[j]

def zoom(
f_line: (Float)->Float,
a_lo_init:Float,
a_lo_init:Float,
a_hi_init:Float,
c1:Float,
c2:Float
Expand All @@ -40,7 +40,7 @@ def zoom(
else
f_ai = f_line a_i
if f_ai > (f0 + c1 * a_i * g0) || f_ai >= f_line a_lo
then
then
a_hi_ref := a_i
Continue
else
Expand Down Expand Up @@ -81,7 +81,7 @@ def zoom_line_search(
else
if g_i >= 0.
then Done (zoom f a_i a_last c1 c2)
else
else
a_ref_last := a_i
a_ref := 0.5 * (a_i + a_max)
Continue
Expand All @@ -91,7 +91,7 @@ def backtracking_line_search(
f: (Float)->Float
) -> Float =
-- Algorithm 3.1 in Nocedal and Wright (2006).
a_init = 1.
a_init = 1.
f_0 = f 0.
g_0 = grad f 0.
rho = 0.5
Expand All @@ -106,23 +106,23 @@ def backtracking_line_search(
a_ref := a_i * rho
Continue

struct BFGSresults(n|Ix) =
struct BFGSresults(n|Ix) =
fval : Float
x_opt: (n=>Float)
error: Float
error: Float
num_iter: Nat

def bfgs_minimize(
f: (n=>Float)->Float, --Objective function.
x0: n=>Float, --Initial point.
H0: n=>n=>Float, --Initial inverse Hessian approximation.
linesearch: ((Float)->Float)->Float, --Line search that returns a step size.
tol: Float, --Convergence tolerance (of the gradient L2 norm).
maxiter: Nat --Maximum number of BFGS iterations.
) -> BFGSresults n given (n|Ix) =
) -> BFGSresults n given (n|Ix) =
-- Algorithm 6.1 in Nocedal and Wright (2006).

xref <- with_state x0
xref <- with_state x0
Href <- with_state H0
gref <- with_state (grad f x0)

Expand All @@ -137,7 +137,7 @@ def bfgs_minimize(
H = get Href
search_direction = -H**.g
f_line = \s:Float. f (x + s .* search_direction)
step_size = linesearch f_line
step_size = linesearch f_line
x_diff = step_size .* search_direction
x_next = x + x_diff
g_next = grad f x_next
Expand All @@ -150,7 +150,7 @@ def bfgs_minimize(
rho = 1. / rho_inv
y = (eye - rho .* outer_product x_diff grad_diff)
Href := y ** H ** (transpose y) + rho .* outer_product x_diff x_diff

xref := x_next
gref := g_next
Continue
Expand All @@ -162,14 +162,14 @@ def rosenbrock(coord:(Fin 2)=>Float) -> Float =
y = coord[1@_]
pow (1 - x) 2 + 100 * pow (y - x * x) 2

%bench "rosenbrock"
%time
bfgs_minimize rosenbrock [10., 10.] eye (\f. backtracking_line_search 15 f) 0.001 100
> BFGSresults(8.668621e-13, [0.9999993, 0.9999985], 2.538457e-05, 41)
>
> rosenbrock
> Compile time: 618.962 ms
> Run time: 57.489 us (based on 1 run)
> Compile time: 675.707 ms
> Run time: 220.998 us

@noinline
def multiclass_logistic_loss(xs: n=>d=>Float, ys: n=>m, w: (d, m)=>Float) -> Float given (n|Ix, d|Ix, m|Ix) =
w_arr = for i:d. for j:m. w[(i, j)]
logits = xs ** w_arr
Expand All @@ -185,10 +185,10 @@ def multiclass_logreg(
tol:Float) -> Float given (n|Ix, d|Ix, m|Ix)=
ob_fun = \v. multiclass_logistic_loss xs ys v
w0 = zero
res = bfgs_minimize ob_fun w0 eye (\f. zoom_line_search maxls f) tol maxiter
res = bfgs_minimize ob_fun w0 eye (\f. zoom_line_search maxls f) tol maxiter
res.fval

-- Define a version of `multiclass_logreg` with Int instead of Nat labels, callable from Python
-- Define a version of `multiclass_logreg` with Int instead of Nat labels, callable from Python
-- (see benchmarks/bfgs.py).
def multiclass_logreg_int(
xs:(Fin n)=>(Fin d)=>Float,
Expand All @@ -199,3 +199,20 @@ def multiclass_logreg_int(
tol:Float) -> Float given (n, d) =
y_ind = Fin (i32_to_n num_classes)
multiclass_logreg xs (for i. i32_to_n ys[i] @ y_ind) (i32_to_n maxiter) (i32_to_n maxls) tol

n_samples = 100
n_features = 20
n_classes = 5
maxiter = 30
maxls = 15
tol = 0.001

xs = rand_mat n_samples n_features randn (new_key 0)
ys : (Fin n_samples) => (Fin n_classes) = rand_vec n_samples rand_idx (new_key 1)

%time
multiclass_logreg xs ys maxiter maxls tol
> 1.609437
>
> Compile time: 3.473 s
> Run time: 195.542 us

0 comments on commit 1650ea7

Please sign in to comment.