From 9e372fd53e940ab3840a6d149bf79b298471c71d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=92=B2=E6=BA=90?= <2402552459@qq.com> Date: Mon, 19 Aug 2024 19:06:23 +0800 Subject: [PATCH] fix(pu): fix np.asarray in sampled related buffer/policy --- lzero/mcts/buffer/game_buffer_muzero.py | 6 ++---- .../buffer/game_buffer_sampled_efficientzero.py | 15 ++++++--------- lzero/mcts/buffer/game_buffer_sampled_muzero.py | 12 ++++-------- lzero/mcts/buffer/game_buffer_unizero.py | 6 ++---- lzero/model/sampled_muzero_model_mlp.py | 2 -- lzero/policy/muzero.py | 10 ++++------ lzero/policy/sampled_efficientzero.py | 10 +++++----- lzero/policy/sampled_muzero.py | 4 ++-- lzero/policy/unizero.py | 4 +++- zoo/atari/config/atari_unizero_config.py | 3 +-- ...narlander_cont_sampled_efficientzero_config.py | 4 ++++ 11 files changed, 33 insertions(+), 43 deletions(-) diff --git a/lzero/mcts/buffer/game_buffer_muzero.py b/lzero/mcts/buffer/game_buffer_muzero.py index 5c956ad14..c965e49b0 100644 --- a/lzero/mcts/buffer/game_buffer_muzero.py +++ b/lzero/mcts/buffer/game_buffer_muzero.py @@ -504,8 +504,8 @@ def _compute_target_reward_value(self, reward_value_context: List[Any], model: A target_values.append(value_list[value_index]) target_rewards.append(reward_list[current_index]) else: - target_values.append(np.array([0.])) - target_rewards.append(np.array([0.])) + target_values.append(np.array(0.)) + target_rewards.append(np.array(0.)) value_index += 1 batch_rewards.append(target_rewards) @@ -513,8 +513,6 @@ def _compute_target_reward_value(self, reward_value_context: List[Any], model: A batch_rewards = np.asarray(batch_rewards) batch_target_values = np.asarray(batch_target_values) - batch_rewards = np.squeeze(batch_rewards, axis=-1) - batch_target_values = np.squeeze(batch_target_values, axis=-1) return batch_rewards, batch_target_values diff --git a/lzero/mcts/buffer/game_buffer_sampled_efficientzero.py b/lzero/mcts/buffer/game_buffer_sampled_efficientzero.py index 1c91e5e2e..1821f7a2e 100644 --- a/lzero/mcts/buffer/game_buffer_sampled_efficientzero.py +++ b/lzero/mcts/buffer/game_buffer_sampled_efficientzero.py @@ -375,7 +375,7 @@ def _compute_target_reward_value(self, reward_value_context: List[Any], model: A target_values = [] target_value_prefixs = [] - value_prefix = np.array([0.]) + value_prefix = np.array(0.) base_index = state_index for current_index in range(state_index, state_index + self._cfg.num_unroll_steps + 1): bootstrap_index = current_index + td_steps_list[value_index] @@ -393,7 +393,7 @@ def _compute_target_reward_value(self, reward_value_context: List[Any], model: A # reset every lstm_horizon_len if horizon_id % self._cfg.lstm_horizon_len == 0: - value_prefix = np.array([0.]) + value_prefix = np.array(0.) base_index = current_index horizon_id += 1 @@ -401,12 +401,11 @@ def _compute_target_reward_value(self, reward_value_context: List[Any], model: A target_values.append(value_list[value_index]) # Since the horizon is small and the discount_factor is close to 1. # Compute the reward sum to approximate the value prefix for simplification - value_prefix += reward_list[current_index - ] # * config.discount_factor ** (current_index - base_index) - target_value_prefixs.append(value_prefix) + value_prefix += reward_list[current_index].item() # * config.discount_factor ** (current_index - base_index) + target_value_prefixs.append(value_prefix.item()) else: - target_values.append(np.array([0.])) - target_value_prefixs.append(value_prefix) + target_values.append(np.array(0.)) + target_value_prefixs.append(value_prefix.item()) value_index += 1 @@ -415,8 +414,6 @@ def _compute_target_reward_value(self, reward_value_context: List[Any], model: A batch_value_prefixs = np.asarray(batch_value_prefixs) batch_target_values = np.asarray(batch_target_values) - batch_value_prefixs = np.squeeze(batch_value_prefixs, axis=-1) - batch_target_values = np.squeeze(batch_target_values, axis=-1) return batch_value_prefixs, batch_target_values diff --git a/lzero/mcts/buffer/game_buffer_sampled_muzero.py b/lzero/mcts/buffer/game_buffer_sampled_muzero.py index f01664996..977a81daf 100644 --- a/lzero/mcts/buffer/game_buffer_sampled_muzero.py +++ b/lzero/mcts/buffer/game_buffer_sampled_muzero.py @@ -371,11 +371,9 @@ def _compute_target_reward_value(self, reward_value_context: List[Any], model: A target_values = [] target_rewards = [] - reward = np.array([0.]) base_index = state_index for current_index in range(state_index, state_index + self._cfg.num_unroll_steps + 1): bootstrap_index = current_index + td_steps_list[value_index] - # for i, reward in enumerate(game.rewards[current_index:bootstrap_index]): for i, reward in enumerate(reward_list[current_index:bootstrap_index]): if self._cfg.env_type == 'board_games' and to_play_segment[0][0] in [1, 2]: # TODO(pu): for board_games, very important, to check @@ -390,11 +388,11 @@ def _compute_target_reward_value(self, reward_value_context: List[Any], model: A horizon_id += 1 if current_index < game_segment_len_non_re: - target_values.append(value_list[value_index]) - target_rewards.append(reward_list[current_index]) + target_values.append(value_list[value_index].item()) + target_rewards.append(reward_list[current_index].item()) else: - target_values.append(np.array([0.])) - target_rewards.append(np.array([0.])) + target_values.append(np.array(0.)) + target_rewards.append(np.array(0.)) value_index += 1 @@ -403,8 +401,6 @@ def _compute_target_reward_value(self, reward_value_context: List[Any], model: A batch_rewards = np.asarray(batch_rewards) batch_target_values = np.asarray(batch_target_values) - batch_rewards = np.squeeze(batch_rewards, axis=-1) - batch_target_values = np.squeeze(batch_target_values, axis=-1) return batch_rewards, batch_target_values diff --git a/lzero/mcts/buffer/game_buffer_unizero.py b/lzero/mcts/buffer/game_buffer_unizero.py index fe57bebf0..15a14cb78 100644 --- a/lzero/mcts/buffer/game_buffer_unizero.py +++ b/lzero/mcts/buffer/game_buffer_unizero.py @@ -499,8 +499,8 @@ def _compute_target_reward_value(self, reward_value_context: List[Any], model: A target_values.append(value_list[value_index]) target_rewards.append(reward_list[current_index]) else: - target_values.append(np.array([0.])) - target_rewards.append(np.array([0.])) + target_values.append(np.array(0.)) + target_rewards.append(np.array(0.)) value_index += 1 batch_rewards.append(target_rewards) @@ -508,7 +508,5 @@ def _compute_target_reward_value(self, reward_value_context: List[Any], model: A batch_rewards = np.asarray(batch_rewards) batch_target_values = np.asarray(batch_target_values) - batch_rewards = np.squeeze(batch_rewards, axis=-1) - batch_target_values = np.squeeze(batch_target_values, axis=-1) return batch_rewards, batch_target_values diff --git a/lzero/model/sampled_muzero_model_mlp.py b/lzero/model/sampled_muzero_model_mlp.py index 611efe149..c119c2198 100644 --- a/lzero/model/sampled_muzero_model_mlp.py +++ b/lzero/model/sampled_muzero_model_mlp.py @@ -330,8 +330,6 @@ def _dynamics(self, latent_state: torch.Tensor, action: torch.Tensor) -> Tuple[ # (batch_size, action_dim, 1) -> (batch_size, action_dim) # e.g., torch.Size([8, 2, 1]) -> torch.Size([8, 2]) action = action.squeeze(-1) - else: - raise ValueError("The shape of action is not supported.") action_encoding = action diff --git a/lzero/policy/muzero.py b/lzero/policy/muzero.py index 94e007edd..9ae0578f5 100644 --- a/lzero/policy/muzero.py +++ b/lzero/policy/muzero.py @@ -71,7 +71,7 @@ class MuZeroPolicy(Policy): # (bool) Whether to analyze dormant ratio. analysis_dormant_ratio=False, # (bool) Whether to use HarmonyDream to balance weights between different losses. Default to False. - # More details can be found in https://arxiv.org/abs/2310.00344 + # More details can be found in https://arxiv.org/abs/2310.00344. harmony_balance=False ), # ****** common ****** @@ -367,12 +367,10 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in obs_target_batch = self.image_transforms.transform(obs_target_batch) # shape: (batch_size, num_unroll_steps, action_dim) - # NOTE: .long(), in discrete action space. + # NOTE: .long() is only for discrete action space. action_batch = torch.from_numpy(action_batch).to(self._cfg.device).unsqueeze(-1).long() - data_list = [ - mask_batch, - target_reward.astype('float32'), - target_value.astype('float32'), target_policy, weights + data_list = [mask_batch, target_reward, + target_value, target_policy, weights ] [mask_batch, target_reward, target_value, target_policy, weights] = to_torch_float_tensor(data_list, self._cfg.device) diff --git a/lzero/policy/sampled_efficientzero.py b/lzero/policy/sampled_efficientzero.py index 98945542b..f74e2ef11 100644 --- a/lzero/policy/sampled_efficientzero.py +++ b/lzero/policy/sampled_efficientzero.py @@ -340,12 +340,12 @@ def _forward_learn(self, data: torch.Tensor) -> Dict[str, Union[float, int]]: obs_target_batch = self.image_transforms.transform(obs_target_batch) # shape: (batch_size, num_unroll_steps, action_dim) - # NOTE: .float(), in continuous action space. - action_batch = torch.from_numpy(action_batch).to(self._cfg.device).float() + # NOTE: .float() in continuous action space. + action_batch = torch.from_numpy(action_batch).to(self._cfg.device) data_list = [ mask_batch, - target_value_prefix.astype('float32'), - target_value.astype('float32'), target_policy, weights + target_value_prefix, + target_value, target_policy, weights ] [mask_batch, target_value_prefix, target_value, target_policy, weights] = to_torch_float_tensor(data_list, self._cfg.device) @@ -535,7 +535,7 @@ def _forward_learn(self, data: torch.Tensor) -> Dict[str, Union[float, int]]: 'total_loss': loss.mean().item(), 'policy_loss': policy_loss.mean().item(), 'policy_entropy': policy_entropy.item() / (self._cfg.num_unroll_steps + 1), - 'target_policy_entropy': target_policy_entropy.item() / (self._cfg.num_unroll_steps + 1), + 'target_policy_entropy': target_policy_entropy / (self._cfg.num_unroll_steps + 1), 'value_prefix_loss': value_prefix_loss.mean().item(), 'value_loss': value_loss.mean().item(), 'consistency_loss': consistency_loss.mean().item() / self._cfg.num_unroll_steps, diff --git a/lzero/policy/sampled_muzero.py b/lzero/policy/sampled_muzero.py index 721c9620e..982866248 100644 --- a/lzero/policy/sampled_muzero.py +++ b/lzero/policy/sampled_muzero.py @@ -343,8 +343,8 @@ def _forward_learn(self, data: torch.Tensor) -> Dict[str, Union[float, int]]: action_batch = torch.from_numpy(action_batch).to(self._cfg.device).float() data_list = [ mask_batch, - target_reward.astype('float32'), - target_value.astype('float32'), target_policy, weights + target_reward, + target_value, target_policy, weights ] [mask_batch, target_reward, target_value, target_policy, weights] = to_torch_float_tensor(data_list, self._cfg.device) diff --git a/lzero/policy/unizero.py b/lzero/policy/unizero.py index a8f318c8b..68ba75d12 100644 --- a/lzero/policy/unizero.py +++ b/lzero/policy/unizero.py @@ -63,6 +63,8 @@ class UniZeroPolicy(MuZeroPolicy): # (int) The save interval of the model. learn=dict(learner=dict(hook=dict(save_ckpt_after_iter=10000, ), ), ), world_model_cfg=dict( + # (bool) If True, the action space of the environment is continuous, otherwise discrete. + continuous_action_space=False, # (int) The number of tokens per block. tokens_per_block=2, # (int) The maximum number of blocks. @@ -86,7 +88,7 @@ class UniZeroPolicy(MuZeroPolicy): # (str) The type of attention mechanism used. Options could be ['causal']. attention='causal', # (int) The number of layers in the model. - num_layers=4, + num_layers=2, # (int) The number of attention heads. num_heads=8, # (int) The dimension of the embedding. diff --git a/zoo/atari/config/atari_unizero_config.py b/zoo/atari/config/atari_unizero_config.py index 1c549010f..af83988a0 100644 --- a/zoo/atari/config/atari_unizero_config.py +++ b/zoo/atari/config/atari_unizero_config.py @@ -55,9 +55,8 @@ max_tokens=2 * num_unroll_steps, # NOTE: each timestep has 2 tokens: obs and action context_length=2 * infer_context_length, device='cuda', - # device='cpu', action_space_size=action_space_size, - num_layers=4, + num_layers=2, num_heads=8, embed_dim=768, obs_type='image', diff --git a/zoo/box2d/lunarlander/config/lunarlander_cont_sampled_efficientzero_config.py b/zoo/box2d/lunarlander/config/lunarlander_cont_sampled_efficientzero_config.py index 517417d2d..a9a7f2267 100644 --- a/zoo/box2d/lunarlander/config/lunarlander_cont_sampled_efficientzero_config.py +++ b/zoo/box2d/lunarlander/config/lunarlander_cont_sampled_efficientzero_config.py @@ -15,6 +15,10 @@ max_env_step = int(5e5) reanalyze_ratio = 0. norm_type = 'LN' + +# only for debug +# num_simulations = 5 +# batch_size = 2 # ============================================================== # end of the most frequently changed config specified by the user # ==============================================================