diff --git a/rllib/agents/callbacks.py b/rllib/agents/callbacks.py index a24b5cbbbcfe..b6252f34037e 100644 --- a/rllib/agents/callbacks.py +++ b/rllib/agents/callbacks.py @@ -14,6 +14,7 @@ import psutil if TYPE_CHECKING: + from ray.rllib.agents.trainer import Trainer from ray.rllib.evaluation import RolloutWorker @@ -35,28 +36,21 @@ def __init__(self, legacy_callbacks_dict: Dict[str, callable] = None): "a class extending rllib.agents.callbacks.DefaultCallbacks") self.legacy_callbacks = legacy_callbacks_dict or {} - def on_episode_start(self, - *, - worker: "RolloutWorker", - base_env: BaseEnv, + def on_episode_start(self, *, worker: "RolloutWorker", base_env: BaseEnv, policies: Dict[PolicyID, Policy], - episode: MultiAgentEpisode, - env_index: Optional[int] = None, - **kwargs) -> None: + episode: MultiAgentEpisode, **kwargs) -> None: """Callback run on the rollout worker before each episode starts. Args: - worker (RolloutWorker): Reference to the current rollout worker. - base_env (BaseEnv): BaseEnv running the episode. The underlying + worker: Reference to the current rollout worker. + base_env: BaseEnv running the episode. The underlying env object can be gotten by calling base_env.get_unwrapped(). - policies (dict): Mapping of policy id to policy objects. In single + policies: Mapping of policy id to policy objects. In single agent mode there will only be a single "default" policy. - episode (MultiAgentEpisode): Episode object which contains episode + episode: Episode object which contains episode state. You can use the `episode.user_data` dict to store temporary data, and `episode.custom_metrics` to store custom metrics for the episode. - env_index (EnvID): Obsoleted: The ID of the environment, which the - episode belongs to. kwargs: Forward compatibility placeholder. """ @@ -73,7 +67,6 @@ def on_episode_step(self, base_env: BaseEnv, policies: Optional[Dict[PolicyID, Policy]] = None, episode: MultiAgentEpisode, - env_index: Optional[int] = None, **kwargs) -> None: """Runs on each episode step. @@ -88,8 +81,6 @@ def on_episode_step(self, state. You can use the `episode.user_data` dict to store temporary data, and `episode.custom_metrics` to store custom metrics for the episode. - env_index (EnvID): Obsoleted: The ID of the environment, which the - episode belongs to. kwargs: Forward compatibility placeholder. """ @@ -99,14 +90,9 @@ def on_episode_step(self, "episode": episode }) - def on_episode_end(self, - *, - worker: "RolloutWorker", - base_env: BaseEnv, + def on_episode_end(self, *, worker: "RolloutWorker", base_env: BaseEnv, policies: Dict[PolicyID, Policy], - episode: MultiAgentEpisode, - env_index: Optional[int] = None, - **kwargs) -> None: + episode: MultiAgentEpisode, **kwargs) -> None: """Runs when an episode is done. Args: @@ -120,8 +106,6 @@ def on_episode_end(self, state. You can use the `episode.user_data` dict to store temporary data, and `episode.custom_metrics` to store custom metrics for the episode. - env_index (EnvID): Obsoleted: The ID of the environment, which the - episode belongs to. kwargs: Forward compatibility placeholder. """ @@ -201,12 +185,13 @@ def on_learn_on_batch(self, *, policy: Policy, train_batch: SampleBatch, pass - def on_train_result(self, *, trainer, result: dict, **kwargs) -> None: + def on_train_result(self, *, trainer: "Trainer", result: dict, + **kwargs) -> None: """Called at the end of Trainable.train(). Args: - trainer (Trainer): Current trainer instance. - result (dict): Dict of results returned from trainer.train() call. + trainer: Current trainer instance. + result: Dict of results returned from trainer.train() call. You can mutate this object to add additional metrics. kwargs: Forward compatibility placeholder. """ diff --git a/rllib/agents/trainer.py b/rllib/agents/trainer.py index 44b23dbc7a68..3100f635e4c6 100644 --- a/rllib/agents/trainer.py +++ b/rllib/agents/trainer.py @@ -529,19 +529,34 @@ def with_common_config( @PublicAPI class Trainer(Trainable): - """A trainer coordinates the optimization of one or more RL policies. - - All RLlib trainers extend this base class, e.g., the A3CTrainer implements - the A3C algorithm for single and multi-agent training. - - Trainer objects retain internal model state between calls to train(), so - you should create a new trainer instance for each training session. - - Attributes: - env_creator (func): Function that creates a new training env. - config (obj): Algorithm-specific configuration data. - logdir (str): Directory in which training outputs should be placed. + """An RLlib algorithm responsible for optimizing one or more Policies. + + Trainers contain a WorkerSet under `self.workers`. A WorkerSet is + normally composed of a single local worker + (self.workers.local_worker()), used to compute and apply learning updates, + and optionally one or more remote workers (self.workers.remote_workers()), + used to generate environment samples in parallel. + + Each worker (remotes or local) contains a PolicyMap, which itself + may contain either one policy for single-agent training or one or more + policies for multi-agent training. Policies are synchronized + automatically from time to time using ray.remote calls. The exact + synchronization logic depends on the specific algorithm (Trainer) used, + but this usually happens from local worker to all remote workers and + after each training update. + + You can write your own Trainer sub-classes by using the + rllib.agents.trainer_template.py::build_traing() utility function. + This allows you to provide a custom `execution_plan`. You can find the + different built-in algorithms' execution plans in their respective main + py files, e.g. rllib.agents.dqn.dqn.py or rllib.agents.impala.impala.py. + + The most important API methods a Trainer exposes are `train()`, + `evaluate()`, `save()` and `restore()`. Trainer objects retain internal + model state between calls to train(), so you should create a new + Trainer instance for each training session. """ + # Whether to allow unknown top-level config keys. _allow_unknown_configs = False @@ -562,15 +577,18 @@ class Trainer(Trainable): @PublicAPI def __init__(self, config: TrainerConfigDict = None, - env: str = None, + env: Union[str, EnvType, None] = None, logger_creator: Callable[[], Logger] = None): - """Initialize an RLLib trainer. + """Initializes a Trainer instance. Args: - config (dict): Algorithm-specific configuration data. - env (str): Name of the environment to use. Note that this can also - be specified as the `env` key in config. - logger_creator (func): Function that creates a ray.tune.Logger + config: Algorithm-specific configuration dict. + env: Name of the environment to use (e.g. a gym-registered str), + a full class path (e.g. + "ray.rllib.examples.env.random_env.RandomEnv"), or an Env + class directly. Note that this arg can also be specified via + the "env" key in `config`. + logger_creator: Callable that creates a ray.tune.Logger object. If unspecified, a default logger is created. """ @@ -623,151 +641,6 @@ def default_logger_creator(config): super().__init__(config, logger_creator) - @classmethod - @override(Trainable) - def default_resource_request( - cls, config: PartialTrainerConfigDict) -> \ - Union[Resources, PlacementGroupFactory]: - cf = dict(cls._default_config, **config) - - eval_config = cf["evaluation_config"] - - # TODO(ekl): add custom resources here once tune supports them - # Return PlacementGroupFactory containing all needed resources - # (already properly defined as device bundles). - return PlacementGroupFactory( - bundles=[{ - # Driver. - "CPU": cf["num_cpus_for_driver"], - "GPU": 0 if cf["_fake_gpus"] else cf["num_gpus"], - }] + [ - { - # RolloutWorkers. - "CPU": cf["num_cpus_per_worker"], - "GPU": cf["num_gpus_per_worker"], - } for _ in range(cf["num_workers"]) - ] + ([ - { - # Evaluation workers. - # Note: The local eval worker is located on the driver CPU. - "CPU": eval_config.get("num_cpus_per_worker", - cf["num_cpus_per_worker"]), - "GPU": eval_config.get("num_gpus_per_worker", - cf["num_gpus_per_worker"]), - } for _ in range(cf["evaluation_num_workers"]) - ] if cf["evaluation_interval"] else []), - strategy=config.get("placement_strategy", "PACK")) - - @override(Trainable) - @PublicAPI - def train(self) -> ResultDict: - """Overrides super.train to synchronize global vars.""" - - result = None - for _ in range(1 + MAX_WORKER_FAILURE_RETRIES): - try: - result = Trainable.train(self) - except RayError as e: - if self.config["ignore_worker_failures"]: - logger.exception( - "Error in train call, attempting to recover") - self._try_recover() - else: - logger.info( - "Worker crashed during call to train(). To attempt to " - "continue training without the failed worker, set " - "`'ignore_worker_failures': True`.") - raise e - except Exception as e: - time.sleep(0.5) # allow logs messages to propagate - raise e - else: - break - if result is None: - raise RuntimeError("Failed to recover from worker crash") - - if hasattr(self, "workers") and isinstance(self.workers, WorkerSet): - self._sync_filters_if_needed(self.workers) - - return result - - def _sync_filters_if_needed(self, workers: WorkerSet): - if self.config.get("observation_filter", "NoFilter") != "NoFilter": - FilterManager.synchronize( - workers.local_worker().filters, - workers.remote_workers(), - update_remote=self.config["synchronize_filters"]) - logger.debug("synchronized filters: {}".format( - workers.local_worker().filters)) - - @override(Trainable) - def log_result(self, result: ResultDict): - self.callbacks.on_train_result(trainer=self, result=result) - # log after the callback is invoked, so that the user has a chance - # to mutate the result - Trainable.log_result(self, result) - - @DeveloperAPI - def _create_local_replay_buffer_if_necessary(self, config): - """Create a LocalReplayBuffer instance if necessary. - - Args: - config (dict): Algorithm-specific configuration data. - - Returns: - LocalReplayBuffer instance based on trainer config. - None, if local replay buffer is not needed. - """ - # These are the agents that utilizes a local replay buffer. - if ("replay_buffer_config" not in config - or not config["replay_buffer_config"]): - # Does not need a replay buffer. - return None - - replay_buffer_config = config["replay_buffer_config"] - if ("type" not in replay_buffer_config - or replay_buffer_config["type"] != "LocalReplayBuffer"): - # DistributedReplayBuffer coming soon. - return None - - capacity = config.get("buffer_size", DEPRECATED_VALUE) - if capacity != DEPRECATED_VALUE: - # Print a deprecation warning. - deprecation_warning( - old="config['buffer_size']", - new="config['replay_buffer_config']['capacity']", - error=False) - else: - # Get capacity out of replay_buffer_config. - capacity = replay_buffer_config["capacity"] - - if config.get("prioritized_replay"): - prio_args = { - "prioritized_replay_alpha": config["prioritized_replay_alpha"], - "prioritized_replay_beta": config["prioritized_replay_beta"], - "prioritized_replay_eps": config["prioritized_replay_eps"], - } - else: - prio_args = {} - - return LocalReplayBuffer( - num_shards=1, - learning_starts=config["learning_starts"], - capacity=capacity, - replay_batch_size=config["train_batch_size"], - replay_mode=config["multiagent"]["replay_mode"], - replay_sequence_length=config.get("replay_sequence_length", 1), - replay_burn_in=config.get("burn_in", 0), - replay_zero_init_states=config.get("zero_init_states", True), - **prio_args) - - @DeveloperAPI - def _kwargs_for_execution_plan(self): - kwargs = {} - if self.local_replay_buffer: - kwargs["local_replay_buffer"] = self.local_replay_buffer - return kwargs - @override(Trainable) def setup(self, config: PartialTrainerConfigDict): env = self._env_id @@ -839,6 +712,8 @@ def env_creator_from_classpath(env_context): self.local_replay_buffer = ( self._create_local_replay_buffer_if_necessary(self.config)) + # Make the call to self._init. Sub-classes should override this + # method to implement custom initialization logic. self._init(self.config, self.env_creator) # Evaluation setup. @@ -875,69 +750,53 @@ def env_creator_from_classpath(env_context): config=evaluation_config, num_workers=self.config["evaluation_num_workers"]) - @override(Trainable) - def cleanup(self): - if hasattr(self, "workers"): - self.workers.stop() - if hasattr(self, "optimizer") and self.optimizer: - self.optimizer.stop() - - @override(Trainable) - def save_checkpoint(self, checkpoint_dir: str) -> str: - checkpoint_path = os.path.join(checkpoint_dir, - "checkpoint-{}".format(self.iteration)) - pickle.dump(self.__getstate__(), open(checkpoint_path, "wb")) - - return checkpoint_path - - @override(Trainable) - def load_checkpoint(self, checkpoint_path: str): - extra_data = pickle.load(open(checkpoint_path, "rb")) - self.__setstate__(extra_data) - @DeveloperAPI - def _make_workers( - self, *, env_creator: Callable[[EnvContext], EnvType], - validate_env: Optional[Callable[[EnvType, EnvContext], None]], - policy_class: Type[Policy], config: TrainerConfigDict, - num_workers: int) -> WorkerSet: - """Default factory method for a WorkerSet running under this Trainer. + def _init(self, config: TrainerConfigDict, + env_creator: Callable[[EnvContext], EnvType]) -> None: + """Subclasses should override this for custom initialization. - Override this method by passing a custom `make_workers` into - `build_trainer`. + In the case of Trainer, this is called from inside `self.setup()`. Args: - env_creator (callable): A function that return and Env given an env - config. - validate_env (Optional[Callable[[EnvType, EnvContext], None]]): - Optional callable to validate the generated environment (only - on worker=0). - policy (Type[Policy]): The Policy class to use for creating the - policies of the workers. - config (TrainerConfigDict): The Trainer's config. - num_workers (int): Number of remote rollout workers to create. - 0 for local only. - - Returns: - WorkerSet: The created WorkerSet. + config: Algorithm-specific configuration dict. + env_creator: A callable taking an EnvContext as only arg and + returning an environment (of any type: e.g. gym.Env, RLlib + BaseEnv, MultiAgentEnv, etc..). """ - return WorkerSet( - env_creator=env_creator, - validate_env=validate_env, - policy_class=policy_class, - trainer_config=config, - num_workers=num_workers, - logdir=self.logdir) - - @DeveloperAPI - def _init(self, config: TrainerConfigDict, - env_creator: Callable[[EnvContext], EnvType]): - """Subclasses should override this for custom initialization.""" raise NotImplementedError - @Deprecated(new="Trainer.evaluate", error=False) - def _evaluate(self) -> dict: - return self.evaluate() + @override(Trainable) + @PublicAPI + def train(self) -> ResultDict: + """Overrides super.train to synchronize global vars.""" + + result = None + for _ in range(1 + MAX_WORKER_FAILURE_RETRIES): + try: + result = Trainable.train(self) + except RayError as e: + if self.config["ignore_worker_failures"]: + logger.exception( + "Error in train call, attempting to recover") + self._try_recover() + else: + logger.info( + "Worker crashed during call to train(). To attempt to " + "continue training without the failed worker, set " + "`'ignore_worker_failures': True`.") + raise e + except Exception as e: + time.sleep(0.5) # allow logs messages to propagate + raise e + else: + break + if result is None: + raise RuntimeError("Failed to recover from worker crash") + + if hasattr(self, "workers") and isinstance(self.workers, WorkerSet): + self._sync_filters_if_needed(self.workers) + + return result @PublicAPI def evaluate(self, episodes_left_fn: Optional[Callable[[int], int]] = None @@ -948,10 +807,10 @@ def evaluate(self, episodes_left_fn: Optional[Callable[[int], int]] = None merging evaluation_config with the normal trainer config. Args: - episodes_left_fn (Optional[Callable[[int], int]]): An optional - callable taking the already run num episodes as only arg - and returning the number of episodes left to run. It's used - to find out whether evaluation should continue. + episodes_left_fn: An optional callable taking the already run + num episodes as only arg and returning the number of + episodes left to run. It's used to find out whether + evaluation should continue. """ # In case we are evaluating (in a thread) parallel to training, # we may have to re-enable eager mode here (gets disabled in the @@ -963,8 +822,8 @@ def evaluate(self, episodes_left_fn: Optional[Callable[[int], int]] = None # Call the `_before_evaluate` hook. self._before_evaluate() + # Sync weights to the evaluation WorkerSet. if self.evaluation_workers is not None: - # Sync weights to the evaluation WorkerSet. self._sync_weights_to_workers(worker_set=self.evaluation_workers) self._sync_filters_if_needed(self.evaluation_workers) @@ -1053,25 +912,6 @@ def episodes_left_fn(num_episodes_done): self.evaluation_workers.remote_workers()) return {"evaluation": metrics} - @DeveloperAPI - def _before_evaluate(self): - """Pre-evaluation callback.""" - pass - - @DeveloperAPI - def _sync_weights_to_workers( - self, - *, - worker_set: Optional[WorkerSet] = None, - workers: Optional[List[RolloutWorker]] = None, - ) -> None: - """Sync "main" weights to given WorkerSet or list of workers.""" - assert worker_set is not None - # Broadcast the new policy weights to all evaluation workers. - logger.info("Synchronizing weights to workers.") - weights = ray.put(self.workers.local_worker().save()) - worker_set.foreach_worker(lambda w: w.restore(ray.get(weights))) - @PublicAPI def compute_single_action( self, @@ -1223,10 +1063,6 @@ def compute_single_action( else: return action - @Deprecated(new="compute_single_action", error=False) - def compute_action(self, *args, **kwargs): - return self.compute_single_action(*args, **kwargs) - @PublicAPI def compute_actions( self, @@ -1253,7 +1089,7 @@ def compute_actions( self.get_policy(policy_id) and call compute_actions() on it directly. Args: - observation: observation from the environment. + observation: Observation from the environment. state: RNN hidden state, if any. If state is not None, then all of compute_single_action(...) is returned (computed action, rnn state(s), logits dictionary). @@ -1284,7 +1120,7 @@ def compute_actions( Returns: any: The computed action if full_fetch=False, or tuple: The full output of policy.compute_actions() if - full_fetch=True or we have an RNN-based Policy. + full_fetch=True or we have an RNN-based Policy. """ if normalize_actions is not None: deprecation_warning( @@ -1359,31 +1195,21 @@ def compute_actions( else: return actions - @property - def _name(self) -> str: - """Subclasses should override this to declare their name.""" - raise NotImplementedError - - @property - def _default_config(self) -> TrainerConfigDict: - """Subclasses should override this to declare their default config.""" - raise NotImplementedError - @PublicAPI def get_policy(self, policy_id: PolicyID = DEFAULT_POLICY_ID) -> Policy: """Return policy for the specified id, or None. Args: - policy_id (PolicyID): ID of the policy to return. + policy_id: ID of the policy to return. """ return self.workers.local_worker().get_policy(policy_id) @PublicAPI - def get_weights(self, policies: List[PolicyID] = None) -> dict: + def get_weights(self, policies: Optional[List[PolicyID]] = None) -> dict: """Return a dictionary of policy ids to weights. Args: - policies (list): Optional list of policies to return weights for, + policies: Optional list of policies to return weights for, or None for all policies. """ return self.workers.local_worker().get_weights(policies) @@ -1393,7 +1219,7 @@ def set_weights(self, weights: Dict[PolicyID, dict]): """Set policy weights by policy id. Args: - weights (dict): Map of policy ids to weights to set. + weights: Map of policy ids to weights to set. """ self.workers.local_worker().set_weights(weights) @@ -1502,35 +1328,38 @@ def fn(worker): def export_policy_model(self, export_dir: str, policy_id: PolicyID = DEFAULT_POLICY_ID, - onnx: Optional[int] = None): - """Export policy model with given policy_id to local directory. + onnx: Optional[int] = None) -> None: + """Exports policy model with given policy_id to a local directory. Args: - export_dir (string): Writable local directory. - policy_id (string): Optional policy id to export. - onnx (int): If given, will export model in ONNX format. The + export_dir: Writable local directory. + policy_id: Optional policy id to export. + onnx: If given, will export model in ONNX format. The value of this parameter set the ONNX OpSet version to use. + If None, the output format will be DL framework specific. Example: >>> trainer = MyTrainer() >>> for _ in range(10): >>> trainer.train() - >>> trainer.export_policy_model("/tmp/export_dir") + >>> trainer.export_policy_model("/tmp/dir") + >>> trainer.export_policy_model("/tmp/dir/onnx", onnx=1) """ - self.workers.local_worker().export_policy_model( - export_dir, policy_id, onnx) + self.get_policy(policy_id).export_model(export_dir, onnx) @DeveloperAPI - def export_policy_checkpoint(self, - export_dir: str, - filename_prefix: str = "model", - policy_id: PolicyID = DEFAULT_POLICY_ID): - """Export tensorflow policy model checkpoint to local directory. + def export_policy_checkpoint( + self, + export_dir: str, + filename_prefix: str = "model", + policy_id: PolicyID = DEFAULT_POLICY_ID, + ) -> None: + """Exports policy model checkpoint to a local directory. Args: - export_dir (string): Writable local directory. - filename_prefix (string): file name prefix of checkpoint files. - policy_id (string): Optional policy id to export. + export_dir: Writable local directory. + filename_prefix: file name prefix of checkpoint files. + policy_id: Optional policy id to export. Example: >>> trainer = MyTrainer() @@ -1538,18 +1367,20 @@ def export_policy_checkpoint(self, >>> trainer.train() >>> trainer.export_policy_checkpoint("/tmp/export_dir") """ - self.workers.local_worker().export_policy_checkpoint( - export_dir, filename_prefix, policy_id) + self.get_policy(policy_id).export_checkpoint(export_dir, + filename_prefix) @DeveloperAPI - def import_policy_model_from_h5(self, - import_file: str, - policy_id: PolicyID = DEFAULT_POLICY_ID): + def import_policy_model_from_h5( + self, + import_file: str, + policy_id: PolicyID = DEFAULT_POLICY_ID, + ) -> None: """Imports a policy's model with given policy_id from a local h5 file. Args: - import_file (str): The h5 file to import from. - policy_id (string): Optional policy id to import into. + import_file: The h5 file to import from. + policy_id: Optional policy id to import into. Example: >>> trainer = MyTrainer() @@ -1557,8 +1388,9 @@ def import_policy_model_from_h5(self, >>> for _ in range(10): >>> trainer.train() """ - self.workers.local_worker().import_policy_model_from_h5( - import_file, policy_id) + self.get_policy(policy_id).import_model_from_h5(import_file) + # Sync new weights to remote workers. + self._sync_weights_to_workers() @DeveloperAPI def collect_metrics(self, @@ -1572,6 +1404,158 @@ def collect_metrics(self, min_history=self.config["metrics_smoothing_episodes"], selected_workers=selected_workers) + @override(Trainable) + def save_checkpoint(self, checkpoint_dir: str) -> str: + checkpoint_path = os.path.join(checkpoint_dir, + "checkpoint-{}".format(self.iteration)) + pickle.dump(self.__getstate__(), open(checkpoint_path, "wb")) + + return checkpoint_path + + @override(Trainable) + def load_checkpoint(self, checkpoint_path: str) -> None: + extra_data = pickle.load(open(checkpoint_path, "rb")) + self.__setstate__(extra_data) + + @override(Trainable) + def log_result(self, result: ResultDict) -> None: + # Log after the callback is invoked, so that the user has a chance + # to mutate the result. + self.callbacks.on_train_result(trainer=self, result=result) + # Then log according to Trainable's logging logic. + Trainable.log_result(self, result) + + @override(Trainable) + def cleanup(self) -> None: + # Stop all workers. + if hasattr(self, "workers"): + self.workers.stop() + # Stop all optimizers. + if hasattr(self, "optimizer") and self.optimizer: + self.optimizer.stop() + # Then stop according to Trainable's logic. + Trainable.stop(self) + + @classmethod + @override(Trainable) + def default_resource_request( + cls, config: PartialTrainerConfigDict) -> \ + Union[Resources, PlacementGroupFactory]: + + # Default logic for RLlib algorithms (Trainers): + # Create one bundle per individual worker (local or remote). + # Use `num_cpus_for_driver` and `num_gpus` for the local worker and + # `num_cpus_per_worker` and `num_gpus_per_worker` for the remote + # workers to determine their CPU/GPU resource needs. + + # Convenience config handles. + cf = dict(cls._default_config, **config) + eval_cf = cf["evaluation_config"] + + # TODO(ekl): add custom resources here once tune supports them + # Return PlacementGroupFactory containing all needed resources + # (already properly defined as device bundles). + return PlacementGroupFactory( + bundles=[{ + # Local worker. + "CPU": cf["num_cpus_for_driver"], + "GPU": 0 if cf["_fake_gpus"] else cf["num_gpus"], + }] + [ + { + # RolloutWorkers. + "CPU": cf["num_cpus_per_worker"], + "GPU": cf["num_gpus_per_worker"], + } for _ in range(cf["num_workers"]) + ] + ([ + { + # Evaluation workers. + # Note: The local eval worker is located on the driver CPU. + "CPU": eval_cf.get("num_cpus_per_worker", + cf["num_cpus_per_worker"]), + "GPU": eval_cf.get("num_gpus_per_worker", + cf["num_gpus_per_worker"]), + } for _ in range(cf["evaluation_num_workers"]) + ] if cf["evaluation_interval"] else []), + strategy=config.get("placement_strategy", "PACK")) + + @DeveloperAPI + def _before_evaluate(self): + """Pre-evaluation callback.""" + pass + + @DeveloperAPI + def _make_workers( + self, + *, + env_creator: Callable[[EnvContext], EnvType], + validate_env: Optional[Callable[[EnvType, EnvContext], None]], + policy_class: Type[Policy], + config: TrainerConfigDict, + num_workers: int, + ) -> WorkerSet: + """Default factory method for a WorkerSet running under this Trainer. + + Override this method by passing a custom `make_workers` into + `build_trainer`. + + Args: + env_creator: A function that return and Env given an env + config. + validate_env: Optional callable to validate the generated + environment. The env to be checked is the one returned from + the env creator, which may be a (single, not-yet-vectorized) + gym.Env or your custom RLlib env type (e.g. MultiAgentEnv, + VectorEnv, BaseEnv, etc..). + policy_class: The Policy class to use for creating the policies + of the workers. + config: The Trainer's config. + num_workers: Number of remote rollout workers to create. + 0 for local only. + + Returns: + WorkerSet: The created WorkerSet. + """ + return WorkerSet( + env_creator=env_creator, + validate_env=validate_env, + policy_class=policy_class, + trainer_config=config, + num_workers=num_workers, + logdir=self.logdir) + + def _sync_filters_if_needed(self, workers: WorkerSet): + if self.config.get("observation_filter", "NoFilter") != "NoFilter": + FilterManager.synchronize( + workers.local_worker().filters, + workers.remote_workers(), + update_remote=self.config["synchronize_filters"]) + logger.debug("synchronized filters: {}".format( + workers.local_worker().filters)) + + @DeveloperAPI + def _sync_weights_to_workers( + self, + *, + worker_set: Optional[WorkerSet] = None, + workers: Optional[List[RolloutWorker]] = None, + ) -> None: + """Sync "main" weights to given WorkerSet or list of workers.""" + assert worker_set is not None + # Broadcast the new policy weights to all evaluation workers. + logger.info("Synchronizing weights to workers.") + weights = ray.put(self.workers.local_worker().save()) + worker_set.foreach_worker(lambda w: w.restore(ray.get(weights))) + + @property + def _name(self) -> str: + """Subclasses should override this to declare their name.""" + raise NotImplementedError + + @property + def _default_config(self) -> TrainerConfigDict: + """Subclasses should override this to declare their default config.""" + raise NotImplementedError + @classmethod @override(Trainable) def resource_help(cls, config: TrainerConfigDict) -> str: @@ -1909,6 +1893,67 @@ def with_updates(**overrides) -> Type["Trainer"]: "that were generated via the `ray.rllib.agents.trainer_template." "build_trainer()` function!") + @DeveloperAPI + def _create_local_replay_buffer_if_necessary(self, config): + """Create a LocalReplayBuffer instance if necessary. + + Args: + config (dict): Algorithm-specific configuration data. + + Returns: + LocalReplayBuffer instance based on trainer config. + None, if local replay buffer is not needed. + """ + # These are the agents that utilizes a local replay buffer. + if ("replay_buffer_config" not in config + or not config["replay_buffer_config"]): + # Does not need a replay buffer. + return None + + replay_buffer_config = config["replay_buffer_config"] + if ("type" not in replay_buffer_config + or replay_buffer_config["type"] != "LocalReplayBuffer"): + # DistributedReplayBuffer coming soon. + return None + + capacity = config.get("buffer_size", DEPRECATED_VALUE) + if capacity != DEPRECATED_VALUE: + # Print a deprecation warning. + deprecation_warning( + old="config['buffer_size']", + new="config['replay_buffer_config']['capacity']", + error=False) + else: + # Get capacity out of replay_buffer_config. + capacity = replay_buffer_config["capacity"] + + if config.get("prioritized_replay"): + prio_args = { + "prioritized_replay_alpha": config["prioritized_replay_alpha"], + "prioritized_replay_beta": config["prioritized_replay_beta"], + "prioritized_replay_eps": config["prioritized_replay_eps"], + } + else: + prio_args = {} + + return LocalReplayBuffer( + num_shards=1, + learning_starts=config["learning_starts"], + capacity=capacity, + replay_batch_size=config["train_batch_size"], + replay_mode=config["multiagent"]["replay_mode"], + replay_sequence_length=config.get("replay_sequence_length", 1), + replay_burn_in=config.get("burn_in", 0), + replay_zero_init_states=config.get("zero_init_states", True), + **prio_args) + + @DeveloperAPI + def _kwargs_for_execution_plan(self): + kwargs = {} + if self.local_replay_buffer: + kwargs["local_replay_buffer"] = self.local_replay_buffer + return kwargs + def _register_if_needed(self, env_object: Union[str, EnvType, None], config) -> Optional[str]: if isinstance(env_object, str): @@ -1939,5 +1984,13 @@ def _is_multi_agent(self): "You can specify a custom env as either a class " "(e.g., YourEnvCls) or a registered env id (e.g., \"your_env\").") + @Deprecated(new="Trainer.evaluate", error=False) + def _evaluate(self) -> dict: + return self.evaluate() + + @Deprecated(new="compute_single_action", error=False) + def compute_action(self, *args, **kwargs): + return self.compute_single_action(*args, **kwargs) + def __repr__(self): return self._name diff --git a/rllib/agents/trainer_template.py b/rllib/agents/trainer_template.py index b3a8ff71c29c..450d91cefbbf 100644 --- a/rllib/agents/trainer_template.py +++ b/rllib/agents/trainer_template.py @@ -75,58 +75,50 @@ def build_trainer( allow_unknown_subkeys: Optional[List[str]] = None, override_all_subkeys_if_type_changes: Optional[List[str]] = None, ) -> Type[Trainer]: - """Helper function for defining a custom trainer. + """Helper function for defining a custom Trainer class. Functions will be run in this order to initialize the trainer: - 1. Config setup: validate_config, get_policy - 2. Worker setup: before_init, execution_plan - 3. Post setup: after_init + 1. Config setup: validate_config, get_policy. + 2. Worker setup: before_init, execution_plan. + 3. Post setup: after_init. Args: - name (str): name of the trainer (e.g., "PPO") - default_config (Optional[TrainerConfigDict]): The default config dict - of the algorithm, otherwise uses the Trainer default config. - validate_config (Optional[Callable[[TrainerConfigDict], None]]): - Optional callable that takes the config to check for correctness. - It may mutate the config as needed. - default_policy (Optional[Type[Policy]]): The default Policy class to - use if `get_policy_class` returns None. - get_policy_class (Optional[Callable[ - TrainerConfigDict, Optional[Type[Policy]]]]): Optional callable - that takes a config and returns the policy class or None. If None - is returned, will use `default_policy` (which must be provided - then). - validate_env (Optional[Callable[[EnvType, EnvContext], None]]): - Optional callable to validate the generated environment (only - on worker=0). - before_init (Optional[Callable[[Trainer], None]]): Optional callable to - run before anything is constructed inside Trainer (Workers with - Policies, execution plan, etc..). Takes the Trainer instance as - argument. - after_init (Optional[Callable[[Trainer], None]]): Optional callable to - run at the end of trainer init (after all Workers and the exec. - plan have been constructed). Takes the Trainer instance as - argument. - before_evaluate_fn (Optional[Callable[[Trainer], None]]): Callback to - run before evaluation. This takes the trainer instance as argument. - mixins (list): list of any class mixins for the returned trainer class. + name: name of the trainer (e.g., "PPO") + default_config: The default config dict of the algorithm, + otherwise uses the Trainer default config. + validate_config: Optional callable that takes the config to check + for correctness. It may mutate the config as needed. + default_policy: The default Policy class to use if `get_policy_class` + returns None. + get_policy_class: Optional callable that takes a config and returns + the policy class or None. If None is returned, will use + `default_policy` (which must be provided then). + validate_env: Optional callable to validate the generated environment + (only on worker=0). + before_init: Optional callable to run before anything is constructed + inside Trainer (Workers with Policies, execution plan, etc..). + Takes the Trainer instance as argument. + after_init: Optional callable to run at the end of trainer init + (after all Workers and the exec. plan have been constructed). + Takes the Trainer instance as argument. + before_evaluate_fn: Callback to run before evaluation. This takes + the trainer instance as argument. + mixins: List of any class mixins for the returned trainer class. These mixins will be applied in order and will have higher precedence than the Trainer class. - execution_plan (Optional[Callable[[WorkerSet, TrainerConfigDict], - Iterable[ResultDict]]]): Optional callable that sets up the + execution_plan: Optional callable that sets up the distributed execution workflow. - allow_unknown_configs (bool): Whether to allow unknown top-level config - keys. - allow_unknown_subkeys (Optional[List[str]]): List of top-level keys + allow_unknown_configs: Whether to allow unknown top-level config keys. + allow_unknown_subkeys: List of top-level keys with value=dict, for which new sub-keys are allowed to be added to the value dict. Appends to Trainer class defaults. - override_all_subkeys_if_type_changes (Optional[List[str]]): List of top - level keys with value=dict, for which we always override the entire - value (dict), iff the "type" key in that value dict changes. - Appends to Trainer class defaults. + override_all_subkeys_if_type_changes: List of top level keys with + value=dict, for which we always override the entire value (dict), + iff the "type" key in that value dict changes. Appends to Trainer + class defaults. Returns: - Type[Trainer]: A Trainer sub-class configured by the specified args. + A Trainer sub-class configured by the specified args. """ original_kwargs = locals().copy()