forked from ray-project/ray
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[RLlib] Deprecate (delete)
contrib
folder. (ray-project#30992)
Signed-off-by: tmynn <[email protected]>
- Loading branch information
1 parent
8f1568f
commit bd3877d
Showing
18 changed files
with
188 additions
and
199 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
from ray.rllib.algorithms.random_agent.random_agent import ( | ||
RandomAgent, | ||
RandomAgentConfig, | ||
) | ||
|
||
__all__ = [ | ||
"RandomAgent", | ||
"RandomAgentConfig", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,100 @@ | ||
import numpy as np | ||
from typing import Optional | ||
|
||
from ray.rllib.algorithms.algorithm import Algorithm | ||
from ray.rllib.algorithms.algorithm_config import AlgorithmConfig, NotProvided | ||
from ray.rllib.utils.annotations import override | ||
|
||
|
||
class RandomAgentConfig(AlgorithmConfig): | ||
"""Defines a configuration class from which a RandomAgent Algorithm can be built. | ||
Example: | ||
>>> from ray.rllib.algorithms.random_agent import RandomAgentConfig | ||
>>> config = RandomAgentConfig().rollouts(rollouts_per_iteration=20) | ||
>>> print(config.to_dict()) # doctest: +SKIP | ||
>>> # Build an Algorithm object from the config and run 1 training iteration. | ||
>>> algo = config.build(env="CartPole-v1") | ||
>>> algo.train() # doctest: +SKIP | ||
""" | ||
|
||
def __init__(self, algo_class=None): | ||
"""Initializes a RandomAgentConfig instance.""" | ||
super().__init__(algo_class=algo_class or RandomAgent) | ||
|
||
self.rollouts_per_iteration = 10 | ||
|
||
def rollouts( | ||
self, | ||
*, | ||
rollouts_per_iteration: Optional[int] = NotProvided, | ||
**kwargs, | ||
) -> "RandomAgentConfig": | ||
"""Sets the rollout configuration. | ||
Args: | ||
rollouts_per_iteration: How many episodes to run per training iteration. | ||
Returns: | ||
This updated AlgorithmConfig object. | ||
""" | ||
super().rollouts(**kwargs) | ||
|
||
if rollouts_per_iteration is not NotProvided: | ||
self.rollouts_per_iteration = rollouts_per_iteration | ||
|
||
return self | ||
|
||
|
||
# fmt: off | ||
# __sphinx_doc_begin__ | ||
class RandomAgent(Algorithm): | ||
"""Algo that produces random actions and never learns.""" | ||
|
||
@classmethod | ||
@override(Algorithm) | ||
def get_default_config(cls) -> AlgorithmConfig: | ||
config = AlgorithmConfig() | ||
config.rollouts_per_iteration = 10 | ||
return config | ||
|
||
@override(Algorithm) | ||
def _init(self, config, env_creator): | ||
self.env = env_creator(config["env_config"]) | ||
|
||
@override(Algorithm) | ||
def step(self): | ||
rewards = [] | ||
steps = 0 | ||
for _ in range(self.config.rollouts_per_iteration): | ||
self.env.reset() | ||
done = False | ||
reward = 0.0 | ||
while not done: | ||
action = self.env.action_space.sample() | ||
_, r, done, _ = self.env.step(action) | ||
reward += r | ||
steps += 1 | ||
rewards.append(reward) | ||
return { | ||
"episode_reward_mean": np.mean(rewards), | ||
"timesteps_this_iter": steps, | ||
} | ||
# __sphinx_doc_end__ | ||
|
||
|
||
if __name__ == "__main__": | ||
# Define a config object. | ||
config = ( | ||
RandomAgentConfig() | ||
.environment("CartPole-v1") | ||
.rollouts(rollouts_per_iteration=10) | ||
) | ||
# Build the agent. | ||
algo = config.build() | ||
# "Train" one iteration. | ||
result = algo.train() | ||
assert result["episode_reward_mean"] > 10, result | ||
algo.stop() | ||
|
||
print("Test: OK") |
Oops, something went wrong.