Skip to content

Commit

Permalink
fix(pu): fix np.asarray in sampled related buffer/policy
Browse files Browse the repository at this point in the history
  • Loading branch information
puyuan1996 committed Aug 19, 2024
1 parent 8300a52 commit 9e372fd
Show file tree
Hide file tree
Showing 11 changed files with 33 additions and 43 deletions.
6 changes: 2 additions & 4 deletions lzero/mcts/buffer/game_buffer_muzero.py
Original file line number Diff line number Diff line change
Expand Up @@ -504,17 +504,15 @@ def _compute_target_reward_value(self, reward_value_context: List[Any], model: A
target_values.append(value_list[value_index])
target_rewards.append(reward_list[current_index])
else:
target_values.append(np.array([0.]))
target_rewards.append(np.array([0.]))
target_values.append(np.array(0.))
target_rewards.append(np.array(0.))
value_index += 1

batch_rewards.append(target_rewards)
batch_target_values.append(target_values)

batch_rewards = np.asarray(batch_rewards)
batch_target_values = np.asarray(batch_target_values)
batch_rewards = np.squeeze(batch_rewards, axis=-1)
batch_target_values = np.squeeze(batch_target_values, axis=-1)

return batch_rewards, batch_target_values

Expand Down
15 changes: 6 additions & 9 deletions lzero/mcts/buffer/game_buffer_sampled_efficientzero.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,7 +375,7 @@ def _compute_target_reward_value(self, reward_value_context: List[Any], model: A
target_values = []
target_value_prefixs = []

value_prefix = np.array([0.])
value_prefix = np.array(0.)
base_index = state_index
for current_index in range(state_index, state_index + self._cfg.num_unroll_steps + 1):
bootstrap_index = current_index + td_steps_list[value_index]
Expand All @@ -393,20 +393,19 @@ def _compute_target_reward_value(self, reward_value_context: List[Any], model: A

# reset every lstm_horizon_len
if horizon_id % self._cfg.lstm_horizon_len == 0:
value_prefix = np.array([0.])
value_prefix = np.array(0.)
base_index = current_index
horizon_id += 1

if current_index < game_segment_len_non_re:
target_values.append(value_list[value_index])
# Since the horizon is small and the discount_factor is close to 1.
# Compute the reward sum to approximate the value prefix for simplification
value_prefix += reward_list[current_index
] # * config.discount_factor ** (current_index - base_index)
target_value_prefixs.append(value_prefix)
value_prefix += reward_list[current_index].item() # * config.discount_factor ** (current_index - base_index)
target_value_prefixs.append(value_prefix.item())
else:
target_values.append(np.array([0.]))
target_value_prefixs.append(value_prefix)
target_values.append(np.array(0.))
target_value_prefixs.append(value_prefix.item())

value_index += 1

Expand All @@ -415,8 +414,6 @@ def _compute_target_reward_value(self, reward_value_context: List[Any], model: A

batch_value_prefixs = np.asarray(batch_value_prefixs)
batch_target_values = np.asarray(batch_target_values)
batch_value_prefixs = np.squeeze(batch_value_prefixs, axis=-1)
batch_target_values = np.squeeze(batch_target_values, axis=-1)

return batch_value_prefixs, batch_target_values

Expand Down
12 changes: 4 additions & 8 deletions lzero/mcts/buffer/game_buffer_sampled_muzero.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,11 +371,9 @@ def _compute_target_reward_value(self, reward_value_context: List[Any], model: A
target_values = []
target_rewards = []

reward = np.array([0.])
base_index = state_index
for current_index in range(state_index, state_index + self._cfg.num_unroll_steps + 1):
bootstrap_index = current_index + td_steps_list[value_index]
# for i, reward in enumerate(game.rewards[current_index:bootstrap_index]):
for i, reward in enumerate(reward_list[current_index:bootstrap_index]):
if self._cfg.env_type == 'board_games' and to_play_segment[0][0] in [1, 2]:
# TODO(pu): for board_games, very important, to check
Expand All @@ -390,11 +388,11 @@ def _compute_target_reward_value(self, reward_value_context: List[Any], model: A
horizon_id += 1

if current_index < game_segment_len_non_re:
target_values.append(value_list[value_index])
target_rewards.append(reward_list[current_index])
target_values.append(value_list[value_index].item())
target_rewards.append(reward_list[current_index].item())
else:
target_values.append(np.array([0.]))
target_rewards.append(np.array([0.]))
target_values.append(np.array(0.))
target_rewards.append(np.array(0.))

value_index += 1

Expand All @@ -403,8 +401,6 @@ def _compute_target_reward_value(self, reward_value_context: List[Any], model: A

batch_rewards = np.asarray(batch_rewards)
batch_target_values = np.asarray(batch_target_values)
batch_rewards = np.squeeze(batch_rewards, axis=-1)
batch_target_values = np.squeeze(batch_target_values, axis=-1)

return batch_rewards, batch_target_values

Expand Down
6 changes: 2 additions & 4 deletions lzero/mcts/buffer/game_buffer_unizero.py
Original file line number Diff line number Diff line change
Expand Up @@ -499,16 +499,14 @@ def _compute_target_reward_value(self, reward_value_context: List[Any], model: A
target_values.append(value_list[value_index])
target_rewards.append(reward_list[current_index])
else:
target_values.append(np.array([0.]))
target_rewards.append(np.array([0.]))
target_values.append(np.array(0.))
target_rewards.append(np.array(0.))
value_index += 1

batch_rewards.append(target_rewards)
batch_target_values.append(target_values)

batch_rewards = np.asarray(batch_rewards)
batch_target_values = np.asarray(batch_target_values)
batch_rewards = np.squeeze(batch_rewards, axis=-1)
batch_target_values = np.squeeze(batch_target_values, axis=-1)

return batch_rewards, batch_target_values
2 changes: 0 additions & 2 deletions lzero/model/sampled_muzero_model_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,8 +330,6 @@ def _dynamics(self, latent_state: torch.Tensor, action: torch.Tensor) -> Tuple[
# (batch_size, action_dim, 1) -> (batch_size, action_dim)
# e.g., torch.Size([8, 2, 1]) -> torch.Size([8, 2])
action = action.squeeze(-1)
else:
raise ValueError("The shape of action is not supported.")

action_encoding = action

Expand Down
10 changes: 4 additions & 6 deletions lzero/policy/muzero.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ class MuZeroPolicy(Policy):
# (bool) Whether to analyze dormant ratio.
analysis_dormant_ratio=False,
# (bool) Whether to use HarmonyDream to balance weights between different losses. Default to False.
# More details can be found in https://arxiv.org/abs/2310.00344
# More details can be found in https://arxiv.org/abs/2310.00344.
harmony_balance=False
),
# ****** common ******
Expand Down Expand Up @@ -367,12 +367,10 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in
obs_target_batch = self.image_transforms.transform(obs_target_batch)

# shape: (batch_size, num_unroll_steps, action_dim)
# NOTE: .long(), in discrete action space.
# NOTE: .long() is only for discrete action space.
action_batch = torch.from_numpy(action_batch).to(self._cfg.device).unsqueeze(-1).long()
data_list = [
mask_batch,
target_reward.astype('float32'),
target_value.astype('float32'), target_policy, weights
data_list = [mask_batch, target_reward,
target_value, target_policy, weights
]
[mask_batch, target_reward, target_value, target_policy,
weights] = to_torch_float_tensor(data_list, self._cfg.device)
Expand Down
10 changes: 5 additions & 5 deletions lzero/policy/sampled_efficientzero.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,12 +340,12 @@ def _forward_learn(self, data: torch.Tensor) -> Dict[str, Union[float, int]]:
obs_target_batch = self.image_transforms.transform(obs_target_batch)

# shape: (batch_size, num_unroll_steps, action_dim)
# NOTE: .float(), in continuous action space.
action_batch = torch.from_numpy(action_batch).to(self._cfg.device).float()
# NOTE: .float() in continuous action space.
action_batch = torch.from_numpy(action_batch).to(self._cfg.device)
data_list = [
mask_batch,
target_value_prefix.astype('float32'),
target_value.astype('float32'), target_policy, weights
target_value_prefix,
target_value, target_policy, weights
]
[mask_batch, target_value_prefix, target_value, target_policy,
weights] = to_torch_float_tensor(data_list, self._cfg.device)
Expand Down Expand Up @@ -535,7 +535,7 @@ def _forward_learn(self, data: torch.Tensor) -> Dict[str, Union[float, int]]:
'total_loss': loss.mean().item(),
'policy_loss': policy_loss.mean().item(),
'policy_entropy': policy_entropy.item() / (self._cfg.num_unroll_steps + 1),
'target_policy_entropy': target_policy_entropy.item() / (self._cfg.num_unroll_steps + 1),
'target_policy_entropy': target_policy_entropy / (self._cfg.num_unroll_steps + 1),
'value_prefix_loss': value_prefix_loss.mean().item(),
'value_loss': value_loss.mean().item(),
'consistency_loss': consistency_loss.mean().item() / self._cfg.num_unroll_steps,
Expand Down
4 changes: 2 additions & 2 deletions lzero/policy/sampled_muzero.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,8 +343,8 @@ def _forward_learn(self, data: torch.Tensor) -> Dict[str, Union[float, int]]:
action_batch = torch.from_numpy(action_batch).to(self._cfg.device).float()
data_list = [
mask_batch,
target_reward.astype('float32'),
target_value.astype('float32'), target_policy, weights
target_reward,
target_value, target_policy, weights
]
[mask_batch, target_reward, target_value, target_policy,
weights] = to_torch_float_tensor(data_list, self._cfg.device)
Expand Down
4 changes: 3 additions & 1 deletion lzero/policy/unizero.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@ class UniZeroPolicy(MuZeroPolicy):
# (int) The save interval of the model.
learn=dict(learner=dict(hook=dict(save_ckpt_after_iter=10000, ), ), ),
world_model_cfg=dict(
# (bool) If True, the action space of the environment is continuous, otherwise discrete.
continuous_action_space=False,
# (int) The number of tokens per block.
tokens_per_block=2,
# (int) The maximum number of blocks.
Expand All @@ -86,7 +88,7 @@ class UniZeroPolicy(MuZeroPolicy):
# (str) The type of attention mechanism used. Options could be ['causal'].
attention='causal',
# (int) The number of layers in the model.
num_layers=4,
num_layers=2,
# (int) The number of attention heads.
num_heads=8,
# (int) The dimension of the embedding.
Expand Down
3 changes: 1 addition & 2 deletions zoo/atari/config/atari_unizero_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,8 @@
max_tokens=2 * num_unroll_steps, # NOTE: each timestep has 2 tokens: obs and action
context_length=2 * infer_context_length,
device='cuda',
# device='cpu',
action_space_size=action_space_size,
num_layers=4,
num_layers=2,
num_heads=8,
embed_dim=768,
obs_type='image',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@
max_env_step = int(5e5)
reanalyze_ratio = 0.
norm_type = 'LN'

# only for debug
# num_simulations = 5
# batch_size = 2
# ==============================================================
# end of the most frequently changed config specified by the user
# ==============================================================
Expand Down

0 comments on commit 9e372fd

Please sign in to comment.