Skip to content

Commit

Permalink
Change BanditDuality to use a prior and softmax to randomize arms
Browse files Browse the repository at this point in the history
  • Loading branch information
odow committed Sep 4, 2024
1 parent d2495d6 commit 99586ce
Showing 1 changed file with 30 additions and 23 deletions.
53 changes: 30 additions & 23 deletions src/plugins/duality_handlers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,17 @@ mutable struct BanditDuality <: AbstractDualityHandler
logs_seen::Int

function BanditDuality(args::AbstractDualityHandler...)
return new(_BanditArm[_BanditArm(arg, Float64[]) for arg in args], 1, 1)
# We initialize the arms with an informative prior to ensure that:
# 1) the `mean` is positive
# 2) the `std` is defined
# It does't matter that all arms have the same prior because we are
# sampling them based on a softmax in _choose_arm.
prior = [0.0, 1.0]
return new(
_BanditArm[_BanditArm(arg, prior) for arg in args],
1,
1,
)
end
end

Expand All @@ -352,12 +362,22 @@ function BanditDuality()
return BanditDuality(ContinuousConicDuality(), StrengthenedConicDuality())
end

function _choose_best_arm(handler::BanditDuality)
_, index = findmax(
map(handler.arms) do arm
return Statistics.mean(arm.rewards) + Statistics.std(arm.rewards)
end,
)
function _choose_arm(handler::BanditDuality)
scores = map(handler.arms) do arm
return Statistics.mean(arm.rewards) + Statistics.std(arm.rewards)
end
# Compute softmax
z = exp.(scores .- maximum(scores))
z ./= sum(z)
# Sample arm from softmax
r = rand()
index = length(z)
for i in 1:length(z)
r -= z[i]
if r <= 0
index = i
end
end
handler.last_arm_index = index
return handler.arms[index]
end
Expand All @@ -369,21 +389,8 @@ function _update_rewards(handler::BanditDuality, log::Vector{Log})
# reward = -----------------------
# time_t - time_{t-1}
t, t′ = log[end], log[end-1]
reward = abs(t.bound - t′.bound) / (t.time - t′.time)
# This check is needed because we should probably keep using the first
# handler until we start to improve the bound. This can take quite a few
# iterations in some models. (Until we start to improve, the reward will be
# zero, so we'd never revisit it.
const_bound = isapprox(log[1].bound, log[end].bound; atol = 1e-6)
# To start with, we should add the reward to all arms to construct a prior
# distribution for the arms. The 10 is somewhat arbitrary.
if length(log) < 10 || const_bound
for arm in handler.arms
push!(arm.rewards, reward)
end
else
push!(handler.arms[handler.last_arm_index].rewards, reward)
end
reward = abs(t.bound - t′.bound) / max(t.time - t′.time, 0.1)
push!(handler.arms[handler.last_arm_index].rewards, reward)
return
end

Expand All @@ -396,7 +403,7 @@ function prepare_backward_pass(
_update_rewards(handler, options.log)
handler.logs_seen = length(options.log)
end
arm = _choose_best_arm(handler)
arm = _choose_arm(handler)
return prepare_backward_pass(node, arm.handler, options)
end

Expand Down

0 comments on commit 99586ce

Please sign in to comment.