Skip to content

Commit

Permalink
[RLlib] AlphaStar polishing (fix logger.info bug). (ray-project#22281)
Browse files Browse the repository at this point in the history
  • Loading branch information
sven1977 authored Apr 1, 2022
1 parent 4510c2d commit 0bb82f2
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 16 deletions.
46 changes: 38 additions & 8 deletions rllib/agents/alpha_star/alpha_star.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,8 @@
A multi-agent, distributed multi-GPU, league-capable asynch. PPO
================================================================
"""
from collections import defaultdict
import gym
from typing import DefaultDict, Optional, Type
from typing import Optional, Type

import ray
from ray.actor import ActorHandle
Expand Down Expand Up @@ -57,10 +56,28 @@
# old (replayed) ones.
"replay_buffer_replay_ratio": 0.5,

# Timeout to use for `ray.wait()` when waiting for samplers to have placed
# new data into the buffers. If no samples are ready within the timeout,
# the buffers used for mixin-sampling will return only older samples.
"sample_wait_timeout": 0.0,
# Timeout to use for `ray.wait()` when waiting for the policy learner actors
# to have performed an update and returned learning stats. If no learner
# actors have produced any learning results in the meantime, their
# learner-stats in the results will be empty for that iteration.
"learn_wait_timeout": 0.0,

# League-building parameters.
# The LeagueBuilder class to be used for league building logic.
"league_builder_config": {
"type": AlphaStarLeagueBuilder,
# The number of random policies to add to the league. This must be an
# even number (including 0) as these will be evenly distributed
# amongst league- and main- exploiters.
"num_random_policies": 2,
# The number of initially learning league-exploiters to create.
"num_learning_league_exploiters": 4,
# The number of initially learning main-exploiters to create.
"num_learning_main_exploiters": 4,
# Minimum win-rate (between 0.0 = 0% and 1.0 = 100%) of any policy to
# be considered for snapshotting (cloning). The cloned copy may then
# be frozen (no further learning) or keep learning (independent of
Expand Down Expand Up @@ -264,9 +281,6 @@ def _set_policy_learners(worker):

self.distributed_learners = distributed_learners

# Store the win rates for league overview printouts.
self.win_rates: DefaultDict[PolicyID, float] = defaultdict(float)

@override(Trainer)
def step(self) -> ResultDict:
# Perform a full step (including evaluation).
Expand All @@ -287,7 +301,7 @@ def training_iteration(self) -> ResultDict:
sample_results = asynchronous_parallel_requests(
remote_requests_in_flight=self.remote_requests_in_flight,
actors=self.workers.remote_workers() or [self.workers.local_worker()],
ray_wait_timeout_s=0.01,
ray_wait_timeout_s=self.config["sample_wait_timeout"],
max_remote_requests_in_flight_per_actor=2,
remote_fn=self._sample_and_send_to_buffer,
)
Expand All @@ -307,7 +321,7 @@ def training_iteration(self) -> ResultDict:
train_results = asynchronous_parallel_requests(
remote_requests_in_flight=self.remote_requests_in_flight,
actors=pol_actors,
ray_wait_timeout_s=0.1,
ray_wait_timeout_s=self.config["learn_wait_timeout"],
max_remote_requests_in_flight_per_actor=2,
remote_fn=self._update_policy,
remote_args=args,
Expand Down Expand Up @@ -339,7 +353,7 @@ def training_iteration(self) -> ResultDict:

global_vars = {
"timestep": self._counters[NUM_ENV_STEPS_SAMPLED],
"win_rates": self.win_rates,
"league_builder": self.league_builder.__getstate__(),
}

for worker in self.workers.remote_workers():
Expand Down Expand Up @@ -455,3 +469,19 @@ def _update_policy(policy: Policy, replay_actor: ActorHandle, pid: PolicyID):
policy.update_kl(kl)

return train_results

@override(appo.APPOTrainer)
def __getstate__(self) -> dict:
state = super().__getstate__()
state.update(
{
"league_builder": self.league_builder.__getstate__(),
}
)
return state

@override(appo.APPOTrainer)
def __setstate__(self, state: dict) -> None:
state_copy = state.copy()
self.league_builder.__setstate__(state.pop("league_builder", {}))
super().__setstate__(state_copy)
40 changes: 34 additions & 6 deletions rllib/agents/alpha_star/league_builder.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
from abc import ABCMeta
from collections import defaultdict
import logging
import numpy as np
import re
from typing import Any, DefaultDict, Dict

from ray.rllib.agents.trainer import Trainer
from ray.rllib.examples.policy.random_policy import RandomPolicy
from ray.rllib.policy.policy import PolicySpec
from ray.rllib.utils.annotations import ExperimentalAPI, override
from ray.rllib.utils.numpy import softmax
from ray.rllib.utils.typing import TrainerConfigDict, ResultDict
from ray.rllib.utils.typing import PolicyID, TrainerConfigDict, ResultDict

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -39,6 +41,14 @@ def build_league(self, result: ResultDict) -> None:
"""
raise NotImplementedError

def __getstate__(self) -> Dict[str, Any]:
"""Returns a state dict, mapping str keys to state variables.
Returns:
The current state dict of this LeagueBuilder.
"""
return {}


@ExperimentalAPI
class NoLeagueBuilder(LeagueBuilder):
Expand Down Expand Up @@ -80,9 +90,9 @@ def __init__(
num_random_policies: The number of random policies to add to the
league. This must be an even number (including 0) as these
will be evenly distributed amongst league- and main- exploiters.
num_learning_league_exploiters: The number of learning
num_learning_league_exploiters: The number of initially learning
league-exploiters to create.
num_learning_main_exploiters: The number of learning
num_learning_main_exploiters: The number of initially learning
main-exploiters to create.
win_rate_threshold_for_new_snapshot: The win-rate to be achieved
for a learning policy to get snapshot'd (forked into `self` +
Expand All @@ -106,6 +116,8 @@ def __init__(
self.prob_main_exploiter_playing_against_learning_main = (
prob_main_exploiter_playing_against_learning_main
)
# Store the win rates for league overview printouts.
self.win_rates: DefaultDict[PolicyID, float] = defaultdict(float)

assert num_random_policies % 2 == 0, (
"ERROR: `num_random_policies` must be even number (we'll distribute "
Expand Down Expand Up @@ -149,7 +161,7 @@ def __init__(
policies[pid] = PolicySpec()
ma_config["policies_to_train"].append(pid)

# Initial policy mapping function: main_0 vs main_exploiter_0.
# Build initial policy mapping function: main_0 vs main_exploiter_0.
ma_config["policy_mapping_fn"] = (
lambda aid, ep, worker, **kw: "main_0"
if ep.episode_id % 2 == aid
Expand All @@ -166,6 +178,8 @@ def build_league(self, result: ResultDict) -> None:
else:
hist_stats = result["hist_stats"]

# TODO: Add example on how to use callable here, instead of updating
# policies_to_train via this simple set.
trainable_policies = local_worker.get_policies_to_train()
non_trainable_policies = (
set(local_worker.policy_map.keys()) - trainable_policies
Expand All @@ -188,13 +202,13 @@ def build_league(self, result: ResultDict) -> None:
win_rate = won / len(rew)
# TODO: This should probably be a running average
# (instead of hard-overriding it with the most recent data).
self.trainer.win_rates[policy_id] = win_rate
self.win_rates[policy_id] = win_rate

# Policy is a snapshot (frozen) -> Ignore.
if policy_id not in trainable_policies:
continue

logger.info(f"\t{policy_id} win-rate={win_rate} -> ", end="")
logger.info(f"\t{policy_id} win-rate={win_rate} -> ")

# If win rate is good enough -> Snapshot current policy and decide,
# whether to freeze the new snapshot or not.
Expand Down Expand Up @@ -354,3 +368,17 @@ def policy_mapping_fn(agent_id, episode, worker, **kwargs):

else:
logger.info("not good enough; will keep learning ...")

def __getstate__(self) -> Dict[str, Any]:
return {
"win_rates": self.win_rates,
"main_policies": self.main_policies,
"league_exploiters": self.league_exploiters,
"main_exploiters": self.main_exploiters,
}

def __setstate__(self, state) -> None:
self.win_rates = state["win_rates"]
self.main_policies = state["main_policies"]
self.league_exploiters = state["league_exploiters"]
self.main_exploiters = state["main_exploiters"]
3 changes: 1 addition & 2 deletions rllib/agents/alpha_star/tests/test_alpha_star.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import pprint
import pyspiel
import unittest

Expand Down Expand Up @@ -61,8 +60,8 @@ def test_alpha_star_compilation(self):
trainer = alpha_star.AlphaStarTrainer(config=_config)
for i in range(num_iterations):
results = trainer.train()
print(results)
check_train_results(results)
pprint.pprint(results)
check_compute_single_action(trainer)
trainer.stop()

Expand Down

0 comments on commit 0bb82f2

Please sign in to comment.