From b00b1f4b01db58eb33b2eb58a6da2d036d5b2ec8 Mon Sep 17 00:00:00 2001 From: Risto Vuorio Date: Thu, 28 Mar 2019 19:44:53 -0400 Subject: [PATCH 1/2] Fixes Inconsistent weight assignment operations in DQNPolicyGraph (#4502) --- python/ray/rllib/agents/dqn/dqn_policy_graph.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/python/ray/rllib/agents/dqn/dqn_policy_graph.py b/python/ray/rllib/agents/dqn/dqn_policy_graph.py index 318e0758b38a..04eae50480d9 100644 --- a/python/ray/rllib/agents/dqn/dqn_policy_graph.py +++ b/python/ray/rllib/agents/dqn/dqn_policy_graph.py @@ -387,9 +387,7 @@ def __init__(self, observation_space, action_space, config): # update_target_fn will be called periodically to copy Q network to # target Q network update_target_expr = [] - for var, var_target in zip( - sorted(self.q_func_vars, key=lambda v: v.name), - sorted(self.target_q_func_vars, key=lambda v: v.name)): + for var, var_target in zip(self.q_func_vars, self.target_q_func_vars): update_target_expr.append(var_target.assign(var)) self.update_target_expr = tf.group(*update_target_expr) From b634e819a42c898606d0a267339d8d7f665e93a0 Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Thu, 28 Mar 2019 17:01:47 -0700 Subject: [PATCH 2/2] Update dqn_policy_graph.py --- python/ray/rllib/agents/dqn/dqn_policy_graph.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/python/ray/rllib/agents/dqn/dqn_policy_graph.py b/python/ray/rllib/agents/dqn/dqn_policy_graph.py index 04eae50480d9..e52dbcca41d3 100644 --- a/python/ray/rllib/agents/dqn/dqn_policy_graph.py +++ b/python/ray/rllib/agents/dqn/dqn_policy_graph.py @@ -387,6 +387,8 @@ def __init__(self, observation_space, action_space, config): # update_target_fn will be called periodically to copy Q network to # target Q network update_target_expr = [] + assert len(self.q_func_vars) == len(self.target_q_func_vars), \ + (self.q_func_vars, self.target_q_func_vars) for var, var_target in zip(self.q_func_vars, self.target_q_func_vars): update_target_expr.append(var_target.assign(var)) self.update_target_expr = tf.group(*update_target_expr)