Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLlib] SimpleQ PolicyV2 (sub-classing). #25871

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 17 additions & 16 deletions rllib/agents/dqn/__init__.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,28 @@
import ray.rllib.agents.dqn.apex as apex # noqa
import ray.rllib.agents.dqn.simple_q as simple_q # noqa
from ray.rllib.algorithms.apex_dqn.apex_dqn import (
ApexDQNConfig,
ApexDQN as ApexTrainer,
APEX_DEFAULT_CONFIG,
)
from ray.rllib.algorithms.dqn.dqn import DQNConfig, DQN as DQNTrainer, DEFAULT_CONFIG
from ray.rllib.algorithms.apex_dqn.apex_dqn import APEX_DEFAULT_CONFIG
from ray.rllib.algorithms.apex_dqn.apex_dqn import ApexDQN as ApexTrainer
from ray.rllib.algorithms.apex_dqn.apex_dqn import ApexDQNConfig
from ray.rllib.algorithms.dqn.dqn import DEFAULT_CONFIG
from ray.rllib.algorithms.dqn.dqn import DQN as DQNTrainer
from ray.rllib.algorithms.dqn.dqn import DQNConfig
from ray.rllib.algorithms.dqn.dqn_tf_policy import DQNTFPolicy
from ray.rllib.algorithms.dqn.dqn_torch_policy import DQNTorchPolicy
from ray.rllib.algorithms.r2d2.r2d2 import (
R2D2 as R2D2Trainer,
R2D2Config,
R2D2_DEFAULT_CONFIG,
)
from ray.rllib.algorithms.r2d2.r2d2 import R2D2 as R2D2Trainer
from ray.rllib.algorithms.r2d2.r2d2 import R2D2_DEFAULT_CONFIG, R2D2Config
from ray.rllib.algorithms.r2d2.r2d2_tf_policy import R2D2TFPolicy
from ray.rllib.algorithms.r2d2.r2d2_torch_policy import R2D2TorchPolicy
from ray.rllib.algorithms.simple_q.simple_q import (
SimpleQ as SimpleQTrainer,
SimpleQConfig,
DEFAULT_CONFIG as SIMPLE_Q_DEFAULT_CONFIG,
)
from ray.rllib.algorithms.simple_q.simple_q_tf_policy import SimpleQTFPolicy
from ray.rllib.algorithms.simple_q.simple_q import SimpleQ as SimpleQTrainer
from ray.rllib.algorithms.simple_q.simple_q import SimpleQConfig
from ray.rllib.algorithms.simple_q.simple_q_tf_policy import (
SimpleQTF1Policy,
SimpleQTF2Policy,
)
from ray.rllib.algorithms.simple_q.simple_q_torch_policy import SimpleQTorchPolicy
from ray.rllib.utils.deprecation import deprecation_warning

__all__ = [
"ApexDQNConfig",
Expand All @@ -35,7 +36,8 @@
"R2D2TorchPolicy",
"R2D2Trainer",
"SimpleQConfig",
"SimpleQTFPolicy",
"SimpleQTF1Policy",
"SimpleQTF2Policy",
"SimpleQTorchPolicy",
"SimpleQTrainer",
# Deprecated.
Expand All @@ -45,7 +47,6 @@
"SIMPLE_Q_DEFAULT_CONFIG",
]

from ray.rllib.utils.deprecation import deprecation_warning

deprecation_warning(
"ray.rllib.agents.dqn",
Expand Down
12 changes: 4 additions & 8 deletions rllib/algorithms/dqn/dqn_tf_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,17 @@

import gym
import numpy as np

import ray
from ray.rllib.algorithms.dqn.distributional_q_tf_model import DistributionalQTFModel
from ray.rllib.algorithms.simple_q.utils import Q_SCOPE, Q_TARGET_SCOPE
from ray.rllib.evaluation.postprocessing import adjust_nstep
from ray.rllib.models import ModelCatalog
from ray.rllib.models.modelv2 import ModelV2
from ray.rllib.models.tf.tf_action_dist import Categorical
from ray.rllib.policy.policy import Policy
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.policy.tf_mixins import (
LearningRateSchedule,
TargetNetworkMixin,
)
from ray.rllib.policy.tf_mixins import LearningRateSchedule, TargetNetworkMixin
from ray.rllib.policy.tf_policy_template import build_tf_policy
from ray.rllib.utils.error import UnsupportedSpaceException
from ray.rllib.utils.exploration import ParameterNoise
Expand All @@ -27,13 +26,10 @@
minimize_and_clip,
reduce_mean_ignore_inf,
)
from ray.rllib.utils.typing import ModelGradients, TensorType, AlgorithmConfigDict
from ray.rllib.utils.typing import AlgorithmConfigDict, ModelGradients, TensorType

tf1, tf, tfv = try_import_tf()

Q_SCOPE = "q_func"
Q_TARGET_SCOPE = "target_q_func"

# Importance sampling weights for prioritized replay
PRIO_WEIGHTS = "weights"

Expand Down
10 changes: 4 additions & 6 deletions rllib/algorithms/ppo/ppo_tf_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,14 @@
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.policy.tf_mixins import (
EntropyCoeffSchedule,
LearningRateSchedule,
KLCoeffMixin,
LearningRateSchedule,
ValueNetworkMixin,
)
from ray.rllib.utils.annotations import override
from ray.rllib.utils.framework import try_import_tf
from ray.rllib.utils.tf_utils import explained_variance
from ray.rllib.utils.typing import (
TensorType,
TFPolicyV2Type,
AlgorithmConfigDict,
)
from ray.rllib.utils.typing import AlgorithmConfigDict, TensorType, TFPolicyV2Type

tf1, tf, tfv = try_import_tf()

Expand Down Expand Up @@ -79,6 +75,8 @@ def __init__(
base.enable_eager_execution_if_necessary()

config = dict(ray.rllib.algorithms.ppo.ppo.PPOConfig().to_dict(), **config)
# TODO: Move into Policy API, if needed at all here. Why not move this into
# `PPOConfig`?.
validate_config(config)

# Initialize base class.
Expand Down
4 changes: 3 additions & 1 deletion rllib/algorithms/ppo/ppo_torch_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
Postprocessing,
compute_gae_for_sample_batch,
)
from ray.rllib.models.modelv2 import ModelV2
from ray.rllib.models.action_dist import ActionDistribution
from ray.rllib.models.modelv2 import ModelV2
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.policy.torch_mixins import (
EntropyCoeffSchedule,
Expand Down Expand Up @@ -43,6 +43,8 @@ class PPOTorchPolicy(

def __init__(self, observation_space, action_space, config):
config = dict(ray.rllib.algorithms.ppo.ppo.PPOConfig().to_dict(), **config)
# TODO: Move into Policy API, if needed at all here. Why not move this into
# `PPOConfig`?.
validate_config(config)

TorchPolicyV2.__init__(
Expand Down
11 changes: 7 additions & 4 deletions rllib/algorithms/simple_q/__init__.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,19 @@
from ray.rllib.algorithms.simple_q.simple_q import (
DEFAULT_CONFIG,
SimpleQ,
SimpleQConfig,
DEFAULT_CONFIG,
)
from ray.rllib.algorithms.simple_q.simple_q_tf_policy import SimpleQTFPolicy
from ray.rllib.algorithms.simple_q.simple_q_tf_policy import (
SimpleQTF1Policy,
SimpleQTF2Policy,
)
from ray.rllib.algorithms.simple_q.simple_q_torch_policy import SimpleQTorchPolicy


__all__ = [
"SimpleQ",
"SimpleQConfig",
"SimpleQTFPolicy",
"SimpleQTF1Policy",
"SimpleQTF2Policy",
"SimpleQTorchPolicy",
"DEFAULT_CONFIG",
]
38 changes: 17 additions & 21 deletions rllib/algorithms/simple_q/simple_q.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,37 +12,32 @@
import logging
from typing import List, Optional, Type, Union

from ray.rllib.algorithms.algorithm_config import AlgorithmConfig
from ray.rllib.algorithms.simple_q.simple_q_tf_policy import SimpleQTFPolicy
from ray.rllib.algorithms.simple_q.simple_q_torch_policy import SimpleQTorchPolicy
from ray.rllib.algorithms.algorithm import Algorithm
from ray.rllib.utils.metrics import SYNCH_WORKER_WEIGHTS_TIMER
from ray.rllib.utils.replay_buffers.utils import (
validate_buffer_config,
update_priorities_in_replay_buffer,
)
from ray.rllib.execution.rollout_ops import (
synchronous_parallel_sample,
)
from ray.rllib.execution.train_ops import (
train_one_step,
multi_gpu_train_one_step,
from ray.rllib.algorithms.algorithm_config import AlgorithmConfig
from ray.rllib.algorithms.simple_q.simple_q_tf_policy import (
SimpleQTF1Policy,
SimpleQTF2Policy,
)
from ray.rllib.algorithms.simple_q.simple_q_torch_policy import SimpleQTorchPolicy
from ray.rllib.execution.rollout_ops import synchronous_parallel_sample
from ray.rllib.execution.train_ops import multi_gpu_train_one_step, train_one_step
from ray.rllib.policy.policy import Policy
from ray.rllib.utils import deep_update
from ray.rllib.utils.annotations import ExperimentalAPI, override
from ray.rllib.utils.deprecation import Deprecated, DEPRECATED_VALUE
from ray.rllib.utils.annotations import override
from ray.rllib.utils.deprecation import DEPRECATED_VALUE, Deprecated
from ray.rllib.utils.metrics import (
LAST_TARGET_UPDATE_TS,
NUM_AGENT_STEPS_SAMPLED,
NUM_ENV_STEPS_SAMPLED,
NUM_TARGET_UPDATES,
SYNCH_WORKER_WEIGHTS_TIMER,
TARGET_NET_UPDATE_TIMER,
)
from ray.rllib.utils.typing import (
ResultDict,
AlgorithmConfigDict,
from ray.rllib.utils.replay_buffers.utils import (
update_priorities_in_replay_buffer,
validate_buffer_config,
)
from ray.rllib.utils.typing import AlgorithmConfigDict, ResultDict

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -307,10 +302,11 @@ def get_default_policy_class(
) -> Optional[Type[Policy]]:
if config["framework"] == "torch":
return SimpleQTorchPolicy
elif config["framework"] == "tf":
return SimpleQTF1Policy
else:
return SimpleQTFPolicy
return SimpleQTF2Policy

@ExperimentalAPI
@override(Algorithm)
def training_step(self) -> ResultDict:
"""Simple Q training iteration function.
Expand Down
Loading