Skip to content

Commit

Permalink
Merge pull request #3 from Naphthalin/revert-1-new_u
Browse files Browse the repository at this point in the history
Revert "New u"
  • Loading branch information
Naphthalin authored Sep 23, 2019
2 parents ea4d067 + 9f3b4a9 commit 59f10c3
Show file tree
Hide file tree
Showing 4 changed files with 7 additions and 42 deletions.
4 changes: 0 additions & 4 deletions src/mcts/node.h
Original file line number Diff line number Diff line change
Expand Up @@ -362,10 +362,6 @@ class EdgeAndNode {
return numerator * GetP() / (1 + GetNStarted());
}

float GetNewU(float numerator) const {
return numerator * GetP() / std::pow(1.0 + GetNStarted(), 1.5);
}

int GetVisitsToReachU(float target_score, float numerator,
float default_q) const {
const auto q = GetQ(default_q);
Expand Down
6 changes: 0 additions & 6 deletions src/mcts/params.cc
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,6 @@ const OptionId SearchParams::kCpuctBaseId{
"higher growth of Cpuct as number of node visits grows."};
const OptionId SearchParams::kCpuctFactorId{
"cpuct-factor", "CPuctFactor", "Multiplier for the cpuct growth formula."};
const OptionId SearchParams::kNewUEnabledId{
"new-u-enabled", "NewUEnabled",
"Whether the formula for U is the old AlphaZero or the new equilibrium"
"based one."};
const OptionId SearchParams::kTemperatureId{
"temperature", "Temperature",
"Tau value from softmax formula for the first move. If equal to 0, the "
Expand Down Expand Up @@ -199,7 +195,6 @@ void SearchParams::Populate(OptionsParser* options) {
options->Add<FloatOption>(kCpuctId, 0.0f, 100.0f) = 3.0f;
options->Add<FloatOption>(kCpuctBaseId, 1.0f, 1000000000.0f) = 19652.0f;
options->Add<FloatOption>(kCpuctFactorId, 0.0f, 1000.0f) = 2.0f;
options->Add<BoolOption>(kNewUEnabledId) = false;
options->Add<FloatOption>(kTemperatureId, 0.0f, 100.0f) = 0.0f;
options->Add<IntOption>(kTempDecayMovesId, 0, 100) = 0;
options->Add<IntOption>(kTemperatureCutoffMoveId, 0, 1000) = 0;
Expand Down Expand Up @@ -241,7 +236,6 @@ SearchParams::SearchParams(const OptionsDict& options)
kCpuct(options.Get<float>(kCpuctId.GetId())),
kCpuctBase(options.Get<float>(kCpuctBaseId.GetId())),
kCpuctFactor(options.Get<float>(kCpuctFactorId.GetId())),
kNewUEnabled(options.Get<bool>(kNewUEnabledId.GetId())),
kNoise(options.Get<bool>(kNoiseId.GetId())),
kSmartPruningFactor(options.Get<float>(kSmartPruningFactorId.GetId())),
kFpuAbsolute(options.Get<std::string>(kFpuStrategyId.GetId()) ==
Expand Down
3 changes: 0 additions & 3 deletions src/mcts/params.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@ class SearchParams {
float GetCpuct() const { return kCpuct; }
float GetCpuctBase() const { return kCpuctBase; }
float GetCpuctFactor() const { return kCpuctFactor; }
bool GetNewUEnabled() const { return kNewUEnabled; }
float GetTemperature() const {
return options_.Get<float>(kTemperatureId.GetId());
}
Expand Down Expand Up @@ -106,7 +105,6 @@ class SearchParams {
static const OptionId kCpuctId;
static const OptionId kCpuctBaseId;
static const OptionId kCpuctFactorId;
static const OptionId kNewUEnabledId;
static const OptionId kTemperatureId;
static const OptionId kTempDecayMovesId;
static const OptionId kTemperatureCutoffMoveId;
Expand Down Expand Up @@ -145,7 +143,6 @@ class SearchParams {
const float kCpuct;
const float kCpuctBase;
const float kCpuctFactor;
const bool kNewUEnabled;
const bool kNoise;
const float kSmartPruningFactor;
const bool kFpuAbsolute;
Expand Down
36 changes: 7 additions & 29 deletions src/mcts/search.cc
Original file line number Diff line number Diff line change
Expand Up @@ -212,28 +212,17 @@ std::vector<std::string> Search::GetVerboseStats(Node* node,
const float fpu = GetFpu(params_, node, node == root_node_);
const float cpuct = ComputeCpuct(params_, node->GetN());
const float U_coeff =
cpuct * (params_.GetNewUEnabled() ?
std::max(node->GetChildrenVisits(), 1u) :
std::sqrt(std::max(node->GetChildrenVisits(), 1u)));
cpuct * std::sqrt(std::max(node->GetChildrenVisits(), 1u));

std::vector<EdgeAndNode> edges;
for (const auto& edge : node->Edges()) edges.push_back(edge);

if (params_.GetNewUEnabled()) {
std::sort(
edges.begin(), edges.end(),
[&fpu, &U_coeff](EdgeAndNode a, EdgeAndNode b) {
return std::forward_as_tuple(a.GetN(), a.GetQ(fpu) + a.GetNewU(U_coeff)) <
std::forward_as_tuple(b.GetN(), b.GetQ(fpu) + b.GetNewU(U_coeff));
});
} else {
std::sort(
edges.begin(), edges.end(),
[&fpu, &U_coeff](EdgeAndNode a, EdgeAndNode b) {
return std::forward_as_tuple(a.GetN(), a.GetQ(fpu) + a.GetU(U_coeff)) <
std::forward_as_tuple(b.GetN(), b.GetQ(fpu) + b.GetU(U_coeff));
});
}

std::vector<std::string> infos;
for (const auto& edge : edges) {
Expand All @@ -257,12 +246,11 @@ std::vector<std::string> Search::GetVerboseStats(Node* node,
oss << "(D: " << std::setw(6) << std::setprecision(3) << edge.GetD()
<< ") ";

oss << "(U: " << std::setw(6) << std::setprecision(5)
<< (params_.GetNewUEnabled() ? edge.GetNewU(U_coeff) : edge.GetU(U_coeff))
oss << "(U: " << std::setw(6) << std::setprecision(5) << edge.GetU(U_coeff)
<< ") ";

oss << "(Q+U: " << std::setw(8) << std::setprecision(5)
<< edge.GetQ(fpu) + (params_.GetNewUEnabled() ? edge.GetNewU(U_coeff) :edge.GetU(U_coeff)) << ") ";
<< edge.GetQ(fpu) + edge.GetU(U_coeff) << ") ";

oss << "(V: ";
optional<float> v;
Expand Down Expand Up @@ -943,10 +931,7 @@ SearchWorker::NodeToProcess SearchWorker::PickNodeToExtend(
// playout remains incomplete; we must go deeper.
const float cpuct = ComputeCpuct(params_, node->GetN());
const float puct_mult =
cpuct * (params_.GetNewUEnabled() ?
std::max(node->GetChildrenVisits(), 1u) :
std::sqrt(std::max(node->GetChildrenVisits(), 1u)));

cpuct * std::sqrt(std::max(node->GetChildrenVisits(), 1u));
float best = std::numeric_limits<float>::lowest();
float second_best = std::numeric_limits<float>::lowest();
int possible_moves = 0;
Expand All @@ -971,10 +956,7 @@ SearchWorker::NodeToProcess SearchWorker::PickNodeToExtend(
++possible_moves;
}
const float Q = child.GetQ(fpu);
const float score = Q +
(params_.GetNewUEnabled() ?
child.GetNewU(puct_mult):
child.GetU(puct_mult));
const float score = child.GetU(puct_mult) + Q;
if (score > best) {
second_best = best;
second_best_edge = best_edge;
Expand Down Expand Up @@ -1169,16 +1151,12 @@ int SearchWorker::PrefetchIntoCache(Node* node, int budget) {
std::vector<ScoredEdge> scores;
const float cpuct = ComputeCpuct(params_, node->GetN());
const float puct_mult =
cpuct * (params_.GetNewUEnabled() ?
std::max(node->GetChildrenVisits(), 1u) :
std::sqrt(std::max(node->GetChildrenVisits(), 1u)));
cpuct * std::sqrt(std::max(node->GetChildrenVisits(), 1u));
const float fpu = GetFpu(params_, node, node == search_->root_node_);
for (auto edge : node->Edges()) {
if (edge.GetP() == 0.0f) continue;
// Flip the sign of a score to be able to easily sort.
scores.emplace_back(-(params_.GetNewUEnabled() ? edge.GetNewU(puct_mult) :
edge.GetU(puct_mult))
- edge.GetQ(fpu), edge);
scores.emplace_back(-edge.GetU(puct_mult) - edge.GetQ(fpu), edge);
}

size_t first_unsorted_index = 0;
Expand Down

0 comments on commit 59f10c3

Please sign in to comment.