-
Notifications
You must be signed in to change notification settings - Fork 118
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feature(rjy): add mamujoco env and related configs #153
base: main
Are you sure you want to change the base?
Changes from 6 commits
f1528d4
de89e54
933f673
1e04545
c54c0a5
f888f6f
27420d4
fc84583
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,6 +3,8 @@ | |
import numpy as np | ||
import torch | ||
from ding.utils import BUFFER_REGISTRY | ||
from ding.utils.data import default_collate, default_decollate | ||
from ding.torch_utils import to_tensor, to_device, to_dtype, to_ndarray | ||
|
||
from lzero.mcts.tree_search.mcts_ctree_sampled import SampledEfficientZeroMCTSCtree as MCTSCtree | ||
from lzero.mcts.tree_search.mcts_ptree_sampled import SampledEfficientZeroMCTSPtree as MCTSPtree | ||
|
@@ -140,7 +142,9 @@ def _make_batch(self, batch_size: int, reanalyze_ratio: float) -> Tuple[Any]: | |
# sampled related core code | ||
# ============================================================== | ||
actions_tmp = game.action_segment[pos_in_game_segment:pos_in_game_segment + | ||
self._cfg.num_unroll_steps].tolist() | ||
self._cfg.num_unroll_steps] | ||
if not isinstance(actions_tmp, list): | ||
actions_tmp = actions_tmp.tolist() | ||
|
||
# NOTE: self._cfg.num_unroll_steps + 1 | ||
root_sampled_actions_tmp = game.root_sampled_actions[pos_in_game_segment:pos_in_game_segment + | ||
|
@@ -152,14 +156,25 @@ def _make_batch(self, batch_size: int, reanalyze_ratio: float) -> Tuple[Any]: | |
|
||
# pad random action | ||
if self._cfg.model.continuous_action_space: | ||
actions_tmp += [ | ||
np.random.randn(self._cfg.model.action_space_size) | ||
if self._multi_agent: | ||
actions_tmp += [ | ||
np.random.randn(self._cfg.model.agent_num, self._cfg.model.action_space_size) | ||
for _ in range(self._cfg.num_unroll_steps - len(actions_tmp)) | ||
] | ||
root_sampled_actions_tmp += [ | ||
root_sampled_actions_tmp += [ | ||
np.random.rand(self._cfg.model.agent_num, self._cfg.model.num_of_sampled_actions, self._cfg.model.action_space_size) | ||
for _ in range(self._cfg.num_unroll_steps + 1 - len(root_sampled_actions_tmp)) | ||
] | ||
else: | ||
actions_tmp += [ | ||
np.random.randn(self._cfg.model.action_space_size) | ||
for _ in range(self._cfg.num_unroll_steps - len(actions_tmp)) | ||
] | ||
root_sampled_actions_tmp += [ | ||
np.random.rand(self._cfg.model.num_of_sampled_actions, self._cfg.model.action_space_size) | ||
for _ in range(self._cfg.num_unroll_steps + 1 - len(root_sampled_actions_tmp)) | ||
] | ||
|
||
else: | ||
# generate random `padded actions_tmp` | ||
actions_tmp += generate_random_actions_discrete( | ||
|
@@ -192,7 +207,8 @@ def _make_batch(self, batch_size: int, reanalyze_ratio: float) -> Tuple[Any]: | |
mask_list.append(mask_tmp) | ||
|
||
# formalize the input observations | ||
obs_list = prepare_observation(obs_list, self._cfg.model.model_type) | ||
if not self._multi_agent: | ||
obs_list = prepare_observation(obs_list, self._cfg.model.model_type) | ||
# ============================================================== | ||
# sampled related core code | ||
# ============================================================== | ||
|
@@ -202,7 +218,7 @@ def _make_batch(self, batch_size: int, reanalyze_ratio: float) -> Tuple[Any]: | |
] | ||
|
||
for i in range(len(current_batch)): | ||
current_batch[i] = np.asarray(current_batch[i]) | ||
current_batch[i] = to_ndarray(current_batch[i]) | ||
|
||
total_transitions = self.get_num_of_transitions() | ||
|
||
|
@@ -272,16 +288,20 @@ def _compute_target_reward_value(self, reward_value_context: List[Any], model: A | |
|
||
batch_target_values, batch_value_prefixs = [], [] | ||
with torch.no_grad(): | ||
value_obs_list = prepare_observation(value_obs_list, self._cfg.model.model_type) | ||
if not self._multi_agent: | ||
value_obs_list = prepare_observation(value_obs_list, self._cfg.model.model_type) | ||
# split a full batch into slices of mini_infer_size: to save the GPU memory for more GPU actors | ||
slices = int(np.ceil(transition_batch_size / self._cfg.mini_infer_size)) | ||
network_output = [] | ||
for i in range(slices): | ||
beg_index = self._cfg.mini_infer_size * i | ||
end_index = self._cfg.mini_infer_size * (i + 1) | ||
m_obs = torch.from_numpy(value_obs_list[beg_index:end_index]).to(self._cfg.device).float() | ||
m_obs = to_dtype(to_device(to_tensor(value_obs_list[beg_index:end_index]), self._cfg.device), torch.float) | ||
|
||
# calculate the target value | ||
m_obs = default_collate(m_obs) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 类似上面的问题 |
||
if self._multi_agent: | ||
m_obs = m_obs[0] | ||
m_output = model.initial_inference(m_obs) | ||
|
||
# TODO(pu) | ||
|
@@ -355,12 +375,20 @@ def _compute_target_reward_value(self, reward_value_context: List[Any], model: A | |
] | ||
) | ||
else: | ||
value_list = value_list.reshape(-1) * ( | ||
np.array([self._cfg.discount_factor for _ in range(transition_batch_size)]) ** td_steps_list | ||
) | ||
if self._multi_agent: | ||
value_list = value_list.reshape(transition_batch_size, self._cfg.model.agent_num) | ||
factor = np.array([self._cfg.discount_factor for _ in range(transition_batch_size)]) ** td_steps_list | ||
value_list = value_list * factor.reshape(transition_batch_size, 1).astype(np.float32) | ||
else: | ||
value_list = value_list.reshape(-1) * ( | ||
np.array([self._cfg.discount_factor for _ in range(transition_batch_size)]) ** td_steps_list | ||
) | ||
|
||
value_list = value_list * np.array(value_mask) | ||
value_list = value_list.tolist() | ||
if self._multi_agent: | ||
value_list = value_list * np.array(value_mask)[:, np.newaxis] | ||
else: | ||
value_list = value_list * np.array(value_mask) | ||
value_list = value_list.tolist() | ||
|
||
horizon_id, value_index = 0, 0 | ||
for game_segment_len_non_re, reward_list, state_index, to_play_list in zip(game_segment_lens, rewards_list, | ||
|
@@ -399,16 +427,17 @@ def _compute_target_reward_value(self, reward_value_context: List[Any], model: A | |
] # * config.discount_factor ** (current_index - base_index) | ||
target_value_prefixs.append(value_prefix) | ||
else: | ||
target_values.append(0) | ||
target_value_prefixs.append(value_prefix) | ||
target_values.append(np.zeros_like(value_list[0])) | ||
target_value_prefixs.append(np.array([0,])) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 单/多智能体运行都是正常的吗?测试一下mamujoco hopper和lunarlander-cont |
||
|
||
value_index += 1 | ||
|
||
batch_value_prefixs.append(target_value_prefixs) | ||
batch_target_values.append(target_values) | ||
|
||
batch_value_prefixs = np.asarray(batch_value_prefixs, dtype=object) | ||
batch_target_values = np.asarray(batch_target_values, dtype=object) | ||
if not self._multi_agent: | ||
batch_value_prefixs = np.asarray(batch_value_prefixs, dtype=np.float32) | ||
batch_target_values = np.asarray(batch_target_values, dtype=np.float32) | ||
|
||
return batch_value_prefixs, batch_target_values | ||
|
||
|
@@ -557,8 +586,8 @@ def _compute_target_policy_reanalyzed(self, policy_re_context: List[Any], model: | |
policy_index += 1 | ||
|
||
batch_target_policies_re.append(target_policies) | ||
|
||
batch_target_policies_re = np.array(batch_target_policies_re) | ||
if not self._multi_agent: | ||
batch_target_policies_re = np.array(batch_target_policies_re) | ||
|
||
return batch_target_policies_re, root_sampled_actions | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,6 +5,7 @@ | |
from easydict import EasyDict | ||
|
||
from ding.utils.compression_helper import jpeg_data_decompressor | ||
from ding.torch_utils import to_ndarray | ||
|
||
|
||
class GameSegment: | ||
|
@@ -96,20 +97,31 @@ def get_unroll_obs(self, timestep: int, num_unroll_steps: int = 0, padding: bool | |
if padding: | ||
pad_len = self.frame_stack_num + num_unroll_steps - len(stacked_obs) | ||
if pad_len > 0: | ||
pad_frames = np.array([stacked_obs[-1] for _ in range(pad_len)]) | ||
stacked_obs = np.concatenate((stacked_obs, pad_frames)) | ||
pad_frames = [stacked_obs[-1] for _ in range(pad_len)] | ||
stacked_obs += pad_frames | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 单/多智能体运行都是正常的吗?测试一下mamujoco hopper和lunarlander-cont |
||
if self.transform2string: | ||
stacked_obs = [jpeg_data_decompressor(obs, self.gray_scale) for obs in stacked_obs] | ||
return stacked_obs | ||
|
||
def _zero_obs(self, input_data): | ||
if isinstance(input_data, dict): | ||
# Process dict | ||
return {k: self._zero_obs(v) for k, v in input_data.items()} | ||
elif isinstance(input_data, (list, np.ndarray)): | ||
# Process arrays or lists | ||
return np.zeros_like(input_data) | ||
else: | ||
# Process other types (e.g. numbers, strings, etc.) | ||
return input_data | ||
|
||
def zero_obs(self) -> List: | ||
""" | ||
Overview: | ||
Return an observation frame filled with zeros. | ||
Returns: | ||
ndarray: An array filled with zeros. | ||
""" | ||
return [np.zeros(self.zero_obs_shape, dtype=np.float32) for _ in range(self.frame_stack_num)] | ||
return [self._zero_obs(self.obs_segment[0]) for _ in range(self.frame_stack_num)] | ||
|
||
def get_obs(self) -> List: | ||
""" | ||
|
@@ -212,9 +224,9 @@ def store_search_stats( | |
Overview: | ||
store the visit count distributions and value of the root node after MCTS. | ||
""" | ||
sum_visits = sum(visit_counts) | ||
sum_visits = np.sum(visit_counts, axis=-1) | ||
if idx is None: | ||
self.child_visit_segment.append([visit_count / sum_visits for visit_count in visit_counts]) | ||
self.child_visit_segment.append([visit_count / sum_visits[i] for i,visit_count in enumerate(visit_counts)]) | ||
self.root_value_segment.append(root_value) | ||
if self.sampled_algo: | ||
self.root_sampled_actions.append(root_sampled_actions) | ||
|
@@ -272,26 +284,26 @@ def game_segment_to_array(self) -> None: | |
For environments with a variable action space, such as board games, the elements in `child_visit_segment` may have | ||
different lengths. In such scenarios, it is necessary to use the object data type for `self.child_visit_segment`. | ||
""" | ||
self.obs_segment = np.array(self.obs_segment) | ||
self.action_segment = np.array(self.action_segment) | ||
self.reward_segment = np.array(self.reward_segment) | ||
self.obs_segment = to_ndarray(self.obs_segment) | ||
self.action_segment = to_ndarray(self.action_segment) | ||
self.reward_segment = to_ndarray(self.reward_segment) | ||
|
||
# Check if all elements in self.child_visit_segment have the same length | ||
if all(len(x) == len(self.child_visit_segment[0]) for x in self.child_visit_segment): | ||
self.child_visit_segment = np.array(self.child_visit_segment) | ||
self.child_visit_segment = to_ndarray(self.child_visit_segment) | ||
else: | ||
# In the case of environments with a variable action space, such as board games, | ||
# the elements in child_visit_segment may have different lengths. | ||
# In such scenarios, it is necessary to use the object data type. | ||
self.child_visit_segment = np.array(self.child_visit_segment, dtype=object) | ||
self.child_visit_segment = to_ndarray(self.child_visit_segment, dtype=object) | ||
|
||
self.root_value_segment = np.array(self.root_value_segment) | ||
self.improved_policy_probs = np.array(self.improved_policy_probs) | ||
self.root_value_segment = to_ndarray(self.root_value_segment) | ||
self.improved_policy_probs = to_ndarray(self.improved_policy_probs) | ||
|
||
self.action_mask_segment = np.array(self.action_mask_segment) | ||
self.to_play_segment = np.array(self.to_play_segment) | ||
self.action_mask_segment = to_ndarray(self.action_mask_segment) | ||
self.to_play_segment = to_ndarray(self.to_play_segment) | ||
if self.use_ture_chance_label_in_chance_encoder: | ||
self.chance_segment = np.array(self.chance_segment) | ||
self.chance_segment = to_ndarray(self.chance_segment) | ||
|
||
def reset(self, init_observations: np.ndarray) -> None: | ||
""" | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
为什么要这样修改呢?之前的方法在多智能体下面会有报错吗?你现在的写法是在单/多智能体下都能与预期一致吗