Skip to content

Commit

Permalink
polish(pu): optimize kv_caching update()
Browse files Browse the repository at this point in the history
  • Loading branch information
PaParaZz1 committed Aug 22, 2024
1 parent bc5332f commit a6c6a8e
Show file tree
Hide file tree
Showing 7 changed files with 39 additions and 37 deletions.
2 changes: 1 addition & 1 deletion lzero/mcts/buffer/game_buffer_muzero.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ def _prepare_reward_value_context(
action_mask_segment, to_play_segment = [], []

td_steps_list = []
for game_segment, state_index, idx in zip(game_segment_list, pos_in_game_segment_list, batch_index_list):
for game_segment, state_index in zip(game_segment_list, pos_in_game_segment_list):
game_segment_len = len(game_segment)
game_segment_lens.append(game_segment_len)

Expand Down
2 changes: 2 additions & 0 deletions lzero/mcts/buffer/game_buffer_unizero.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,7 @@ def _make_batch(self, batch_size: int, reanalyze_ratio: float) -> Tuple[Any]:
context = reward_value_context, policy_re_context, policy_non_re_context, current_batch
return context


def _prepare_policy_reanalyzed_context(
self, batch_index_list: List[str], game_segment_list: List[Any], pos_in_game_segment_list: List[str]
) -> List[Any]:
Expand Down Expand Up @@ -368,6 +369,7 @@ def _compute_target_policy_reanalyzed(self, policy_re_context: List[Any], model:

return batch_target_policies_re

# 可以直接替换game_buffer_muzero中相应函数
def _compute_target_policy_non_reanalyzed(
self, policy_non_re_context: List[Any], policy_shape: Optional[int]
) -> np.ndarray:
Expand Down
44 changes: 22 additions & 22 deletions lzero/model/unizero_world_models/kv_caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,17 +62,17 @@ def get(self) -> torch.Tensor:
"""
return self._cache[:, :, :self._size, :]

def update(self, x: torch.Tensor) -> None:
def update(self, x: torch.Tensor, tokens: int) -> None:
"""
Overview:
Update the cache with new values.
Arguments:
- x (:obj:`torch.Tensor`): The new values to update the cache with.
"""
assert (x.ndim == self._cache.ndim) and all([x.size(i) == self._cache.size(i) for i in (0, 1, 3)])
assert self._size + x.size(2) <= self._cache.shape[2] # TODO
self._cache = AssignWithoutInplaceCheck.apply(self._cache, x, 2, self._size, self._size + x.size(2))
self._size += x.size(2)
# assert (x.ndim == self._cache.ndim) and all([x.size(i) == self._cache.size(i) for i in (0, 1, 3)])
# assert self._size + tokens <= self._cache.shape[2] # TODO
self._cache = AssignWithoutInplaceCheck.apply(self._cache, x, 2, self._size, self._size + tokens)
self._size += tokens


class KVCache:
Expand Down Expand Up @@ -136,8 +136,8 @@ def update(self, k: torch.Tensor, v: torch.Tensor):
- k (:obj:`torch.Tensor`): The new values to update the key cache with.
- v (:obj:`torch.Tensor`): The new values to update the value cache with.
"""
self._k_cache.update(k)
self._v_cache.update(v)
self._k_cache.update(k, k.size(2))
self._v_cache.update(v, v.size(2))


class KeysValues:
Expand Down Expand Up @@ -203,22 +203,22 @@ def prune(self, mask: np.ndarray) -> None:
for kv_cache in self._keys_values:
kv_cache.prune(mask)

def to_device(self, device: str):
"""
Transfer all KVCache objects within the KeysValues object to a certain device.
Not used in the current implementation.
# def to_device(self, device: str):
# """
# Transfer all KVCache objects within the KeysValues object to a certain device.
# Not used in the current implementation.

Arguments:
- self._keys_values (KeysValues): The KeysValues object to be transferred.
- device (str): The device to transfer to.
Returns:
- keys_values (KeysValues): The KeysValues object with its caches transferred to the specified device.
"""
device = torch.device(device if torch.cuda.is_available() else 'cpu')
for kv_cache in self._keys_values:
kv_cache._k_cache._cache = kv_cache._k_cache._cache.to(device)
kv_cache._v_cache._cache = kv_cache._v_cache._cache.to(device)
return self._keys_values
# Arguments:
# - self._keys_values (KeysValues): The KeysValues object to be transferred.
# - device (str): The device to transfer to.
# Returns:
# - keys_values (KeysValues): The KeysValues object with its caches transferred to the specified device.
# """
# device = torch.device(device if torch.cuda.is_available() else 'cpu')
# for kv_cache in self._keys_values:
# kv_cache._k_cache._cache = kv_cache._k_cache._cache.to(device)
# kv_cache._v_cache._cache = kv_cache._v_cache._cache.to(device)
# return self._keys_values


class AssignWithoutInplaceCheck(torch.autograd.Function):
Expand Down
2 changes: 1 addition & 1 deletion lzero/model/unizero_world_models/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ def forward(self, x: torch.Tensor, kv_cache: Optional[KeysValues] = None,

if kv_cache is not None:
kv_cache.update(k, v) # time 21%
k, v = kv_cache.get()
k, v = kv_cache.get() # time 5%

att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))

Expand Down
2 changes: 1 addition & 1 deletion lzero/policy/unizero.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,7 +372,7 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in
target_reward = target_reward.view(self._cfg.batch_size, -1)
target_value = target_value.view(self._cfg.batch_size, -1)

assert obs_batch.size(0) == self._cfg.batch_size == target_reward.size(0)
# assert obs_batch.size(0) == self._cfg.batch_size == target_reward.size(0)

# Transform rewards and values to their scaled forms
transformed_target_reward = scalar_transform(target_reward)
Expand Down
22 changes: 11 additions & 11 deletions zoo/atari/config/atari_unizero_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,14 @@
infer_context_length = 4

# ====== only for debug =====
# collector_env_num = 2
# n_episode = 2
# evaluator_env_num = 2
# num_simulations = 2
# max_env_step = int(2e5)
# reanalyze_ratio = 0.
# batch_size = 2
# num_unroll_steps = 10
collector_env_num = 2
n_episode = 2
evaluator_env_num = 2
num_simulations = 2
max_env_step = int(2e5)
reanalyze_ratio = 0.
batch_size = 2
num_unroll_steps = 10
# ==============================================================
# end of the most frequently changed config specified by the user
# ==============================================================
Expand All @@ -43,8 +43,8 @@
n_evaluator_episode=evaluator_env_num,
manager=dict(shared_memory=False, ),
# TODO: only for debug
# collect_max_episode_steps=int(50),
# eval_max_episode_steps=int(50),
collect_max_episode_steps=int(50),
eval_max_episode_steps=int(50),
),
policy=dict(
model=dict(
Expand Down Expand Up @@ -104,7 +104,7 @@
for seed in seeds:
# Update exp_name to include the current seed
# main_config.exp_name = f'data_unizero_efficiency/{env_id[:-14]}_stack1_unizero_upc{update_per_collect}-rr{replay_ratio}_H{num_unroll_steps}_bs{batch_size}_seed{seed}_origin'
main_config.exp_name = f'data_unizero_efficiency/{env_id[:-14]}_stack1_unizero_upc{update_per_collect}-rr{replay_ratio}_H{num_unroll_steps}_bs{batch_size}_seed{seed}_nlayer2_optimizehash_custom-deepcopy_targevalue-cuda_optimize-computeloss_optimize-value-lst'
main_config.exp_name = f'data_unizero_efficiency_debug/{env_id[:-14]}_stack1_unizero_upc{update_per_collect}-rr{replay_ratio}_H{num_unroll_steps}_bs{batch_size}_seed{seed}_nlayer2_optimizehash_custom-deepcopy_targevalue-cuda_optimize-computeloss_optimize-value-lst_optimize-targetpolicy-nonrer'

# main_config.exp_name = f'data_unizero_efficiency/{env_id[:-14]}_stack1_unizero_upc{update_per_collect}-rr{replay_ratio}_H{num_unroll_steps}_bs{batch_size}_seed{seed}_optimizehash_1deepcopy-init-infer'
from lzero.entry import train_unizero
Expand Down
2 changes: 1 addition & 1 deletion zoo/atari/config/sco_acp_uz.sh
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ script='source activate cloud-ai-lab && cd /mnt/miaohua/niuyazhe/code/LightZero
echo "The final script is: " $script
sco acp jobs create --workspace-name=miaohua \
--aec2-name=miaohua \
--job-name="unizero-pong-nlayer2-200k-optimizehash_custom-deepcopy_targevalue-cuda_optimize-computeloss_optimize-value-lst_s0" \
--job-name="unizero-pong-nlayer2-200k-optimizehash_custom-deepcopy_targevalue-cuda_opt-computeloss_opt-value-lst_opt-targetpolicy-nonrer_s0" \
--container-image-url='registry.ms-sc-01.maoshanwangtech.com/ccr_2/aicl-ding-v1:20240719-18h48m08s' \
--training-framework=pytorch \
--enable-mpi \
Expand Down

0 comments on commit a6c6a8e

Please sign in to comment.