Skip to content

Commit

Permalink
polish(pu): use custom deepcopy for kv_cache
Browse files Browse the repository at this point in the history
  • Loading branch information
PaParaZz1 committed Aug 22, 2024
1 parent 00147f4 commit b40c71b
Show file tree
Hide file tree
Showing 8 changed files with 113 additions and 59 deletions.
26 changes: 16 additions & 10 deletions lzero/mcts/buffer/game_buffer_unizero.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,7 +413,8 @@ def _compute_target_reward_value(self, reward_value_context: List[Any], model: A
m_output = model.initial_inference(m_obs, action_batch)
# ======================================================================

if not model.training:
# if not model.training:
if self._cfg.device == 'cuda':
# if not in training, obtain the scalars of the value/reward
[m_output.latent_state, m_output.value, m_output.policy_logits] = to_detach_cpu_numpy(
[
Expand All @@ -422,6 +423,16 @@ def _compute_target_reward_value(self, reward_value_context: List[Any], model: A
m_output.policy_logits
]
)
elif self._cfg.device == 'cpu':
# TODO
[m_output.latent_state, m_output.value, m_output.policy_logits] = to_detach_cpu_numpy(
[
m_output.latent_state,
inverse_scalar_transform(m_output.value, self._cfg.model.support_scale),
m_output.policy_logits
]
)

network_output.append(m_output)

# concat the output slices after model inference
Expand Down Expand Up @@ -499,19 +510,14 @@ def _compute_target_reward_value(self, reward_value_context: List[Any], model: A
target_values.append(value_list[value_index])
target_rewards.append(reward_list[current_index])
else:
target_values.append(np.array([0.]))
target_rewards.append(np.array([0.]))
target_values.append(np.array(0.))
target_rewards.append(np.array(0.))
value_index += 1

batch_rewards.append(target_rewards)
batch_target_values.append(target_values)

# batch_rewards = np.asarray(batch_rewards)
# batch_target_values = np.asarray(batch_target_values)
# batch_rewards = np.squeeze(batch_rewards, axis=-1)
# batch_target_values = np.squeeze(batch_target_values, axis=-1)

batch_rewards = np.asarray(batch_rewards, dtype=object)
batch_target_values = np.asarray(batch_target_values, dtype=object)
batch_rewards = np.asarray(batch_rewards)
batch_target_values = np.asarray(batch_target_values)

return batch_rewards, batch_target_values
8 changes: 4 additions & 4 deletions lzero/model/unizero_world_models/lpips.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,16 @@ def __init__(self, use_dropout: bool = True):
self.scaling_layer = ScalingLayer()
self.chns = [64, 128, 256, 512, 512] # vg16 features
# Comment out the following line if you don't need perceptual loss
self.net = vgg16(pretrained=True, requires_grad=False)
# self.net = vgg16(pretrained=True, requires_grad=False)
self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout)
self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout)
self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout)
self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout)
self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout)
# Comment out the following line if you don't need perceptual loss
self.load_from_pretrained()
for param in self.parameters():
param.requires_grad = False
# self.load_from_pretrained()
# for param in self.parameters():
# param.requires_grad = False

def load_from_pretrained(self) -> None:
ckpt = get_ckpt_path(name="vgg_lpips", root=Path.home() / ".cache/iris/tokenizer_pretrained_vgg") # Download VGG if necessary
Expand Down
2 changes: 1 addition & 1 deletion lzero/model/unizero_world_models/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ def forward(self, x: torch.Tensor, kv_cache: Optional[KeysValues] = None,
v = self.value(x).view(B, T, self.num_heads, C // self.num_heads).transpose(1, 2) # (B, num_heads, T, head_size)

if kv_cache is not None:
kv_cache.update(k, v)
kv_cache.update(k, v) # time 21%
k, v = kv_cache.get()

att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
Expand Down
58 changes: 50 additions & 8 deletions lzero/model/unizero_world_models/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import hashlib
import xxhash
from dataclasses import dataclass

import numpy as np
Expand All @@ -8,6 +9,51 @@
from .kv_caching import KeysValues


def custom_copy_kv_cache_to_dict(src_kv: KeysValues, dst_dict: dict, cache_key: str) -> None:
"""
Overview:
Efficiently copy the contents of a KeysValues object to a new entry in a dictionary.
Arguments:
- src_kv (:obj:`KeysValues`): The source KeysValues object to copy from.
- dst_dict (:obj:`dict`): The destination dictionary to copy to.
- cache_key (:obj:`str`): The key for the new entry in the destination dictionary.
"""
dst_kv = KeysValues(
src_kv._keys_values[0].shape[0], # n
src_kv._keys_values[0].shape[1], # num_heads
src_kv._keys_values[0].shape[2], # max_tokens
src_kv._keys_values[0].shape[3] * src_kv._keys_values[0].shape[1], # embed_dim
len(src_kv), # num_layers
src_kv._keys_values[0]._k_cache._cache.device, # device
)

for src_layer, dst_layer in zip(src_kv._keys_values, dst_kv._keys_values):
dst_layer._k_cache._cache = src_layer._k_cache._cache.detach().clone()
dst_layer._v_cache._cache = src_layer._v_cache._cache.detach().clone()
dst_layer._k_cache._size = src_layer._k_cache._size
dst_layer._v_cache._size = src_layer._v_cache._size

dst_dict[cache_key] = dst_kv

def custom_copy_kv_cache(src_kv: KeysValues) -> None:
dst_kv = KeysValues(
src_kv._keys_values[0].shape[0], # n
src_kv._keys_values[0].shape[1], # num_heads
src_kv._keys_values[0].shape[2], # max_tokens
src_kv._keys_values[0].shape[3] * src_kv._keys_values[0].shape[1], # embed_dim
len(src_kv), # num_layers
src_kv._keys_values[0]._k_cache._cache.device, # device
)

for src_layer, dst_layer in zip(src_kv._keys_values, dst_kv._keys_values):
dst_layer._k_cache._cache = src_layer._k_cache._cache.detach().clone()
dst_layer._v_cache._cache = src_layer._v_cache._cache.detach().clone()
dst_layer._k_cache._size = src_layer._k_cache._size
dst_layer._v_cache._size = src_layer._v_cache._size

return dst_kv


def to_device_for_kvcache(keys_values: KeysValues, device: str) -> KeysValues:
"""
Transfer all KVCache objects within the KeysValues object to a certain device.
Expand All @@ -18,7 +64,7 @@ def to_device_for_kvcache(keys_values: KeysValues, device: str) -> KeysValues:
Returns:
- keys_values (KeysValues): The KeysValues object with its caches transferred to the specified device.
"""
device = torch.device(device if torch.cuda.is_available() else 'cpu')
# device = torch.device(device if torch.cuda.is_available() else 'cpu')

for kv_cache in keys_values:
kv_cache._k_cache._cache = kv_cache._k_cache._cache.to(device)
Expand Down Expand Up @@ -68,7 +114,7 @@ def calculate_cuda_memory_gb(past_keys_values_cache, num_layers: int):
return total_memory_gb


# def hash_state(state, num_buckets=100):
# def hash_state_origin(state, num_buckets=100):
# """
# Quantize the state vector.

Expand All @@ -85,9 +131,6 @@ def calculate_cuda_memory_gb(past_keys_values_cache, num_layers: int):
# hash_object = hashlib.sha256(quantized_state_bytes)
# return hash_object.hexdigest()

import numpy as np
import xxhash


def hash_state(state):
"""
Expand All @@ -99,9 +142,8 @@ def hash_state(state):
The hash value of the state vector.
"""
# Use xxhash for faster hashing
hash_value = xxhash.xxh64(state).hexdigest()

return hash_value
# return xxhash.xxh64(state.view(-1).cpu().numpy()).hexdigest()
return xxhash.xxh64(state).hexdigest()

@dataclass
class WorldModelOutput:
Expand Down
62 changes: 35 additions & 27 deletions lzero/model/unizero_world_models/world_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from .slicer import Head, PolicyHeadCont
from .tokenizer import Tokenizer
from .transformer import Transformer, TransformerConfig
from .utils import LossWithIntermediateLosses, init_weights, to_device_for_kvcache
from .utils import LossWithIntermediateLosses, init_weights, to_device_for_kvcache, custom_copy_kv_cache, custom_copy_kv_cache_to_dict
from .utils import WorldModelOutput, hash_state
from torch.distributions import Categorical, Independent, Normal

Expand Down Expand Up @@ -110,6 +110,10 @@ def __init__(self, config: TransformerConfig, tokenizer) -> None:
# Initialize keys and values for transformer
self._initialize_transformer_keys_values()

# TODO: check
self.latent_recon_loss = torch.tensor(0., device=self.device)
self.perceptual_loss = torch.tensor(0., device=self.device)

def _initialize_config_parameters(self) -> None:
"""Initialize configuration parameters."""
self.policy_entropy_weight = self.config.policy_entropy_weight
Expand Down Expand Up @@ -195,9 +199,14 @@ def _initialize_last_layer(self) -> None:

def _initialize_cache_structures(self) -> None:
"""Initialize cache structures for past keys and values."""
self.past_kv_cache_recurrent_infer = collections.OrderedDict()
# self.past_kv_cache_init_infer = collections.OrderedDict()
self.past_kv_cache_init_infer_envs = [collections.OrderedDict() for _ in range(self.env_num)]
# self.past_kv_cache_recurrent_infer = collections.OrderedDict()
# self.past_kv_cache_init_infer_envs = [collections.OrderedDict() for _ in range(self.env_num)]
# TODO: check
from collections import defaultdict
self.past_kv_cache_recurrent_infer = defaultdict(dict)
self.past_kv_cache_init_infer_envs = [defaultdict(dict) for _ in range(self.env_num)]

self.keys_values_wm_list = []
self.keys_values_wm_size_list = []

Expand Down Expand Up @@ -540,9 +549,9 @@ def refresh_kvs_with_initial_latent_state_for_init_infer(self, latent_state: tor
for i in range(ready_env_num):
# Retrieve latent state for a single environment
state_single_env = latent_state[i]
quantized_state = state_single_env.detach().cpu().numpy()
# Compute hash value using quantized state
cache_key = hash_state(quantized_state)
# Compute hash value using latent state for a single environment
cache_key = hash_state(state_single_env.view(-1).cpu().numpy()) # latent_state[i] is torch.Tensor

# Retrieve cached value
matched_value = self.past_kv_cache_init_infer_envs[i].get(cache_key)

Expand All @@ -552,7 +561,7 @@ def refresh_kvs_with_initial_latent_state_for_init_infer(self, latent_state: tor
self.root_hit_cnt += 1
# deepcopy is needed because forward modifies matched_value in place
# self.keys_values_wm_list.append(copy.deepcopy(to_device_for_kvcache(matched_value, self.device)))
self.keys_values_wm_list.append(to_device_for_kvcache(matched_value, self.device))
self.keys_values_wm_list.append(custom_copy_kv_cache(src_kv=to_device_for_kvcache(matched_value, self.device)))
self.keys_values_wm_size_list.append(matched_value.size)
else:
# Reset using zero values
Expand Down Expand Up @@ -678,7 +687,7 @@ def forward_recurrent_inference(self, state_action_history, simulation_index=0,
# print('recurrent largethan_maxminus7_context_ratio:', length_largethan_maxminus7_context_cnt_ratio)
# print('recurrent largethan_maxminus7_context:', self.length_largethan_maxminus7_context_cnt)

# Trim and pad kv_cache
# Trim and pad kv_cache: modify self.keys_values_wm in-place
self.keys_values_wm_size_list = self.trim_and_pad_kv_cache(is_init_infer=False)
self.keys_values_wm_size_list_current = self.keys_values_wm_size_list

Expand Down Expand Up @@ -718,8 +727,7 @@ def forward_recurrent_inference(self, state_action_history, simulation_index=0,
latent_state_index_in_search_path=latent_state_index_in_search_path
)

return (
outputs_wm.output_sequence, self.latent_state, reward, outputs_wm.logits_policy, outputs_wm.logits_value)
return (outputs_wm.output_sequence, self.latent_state, reward, outputs_wm.logits_policy, outputs_wm.logits_value)

def trim_and_pad_kv_cache(self, is_init_infer=True) -> list:
"""
Expand Down Expand Up @@ -790,11 +798,10 @@ def update_cache_context(self, latent_state, is_init_infer=True, simulation_inde
return
for i in range(latent_state.size(0)):
# ============ Iterate over each environment ============
state_single_env = latent_state[i]
quantized_state = state_single_env.detach().cpu().numpy()
cache_key = hash_state(quantized_state)
cache_key = hash_state(latent_state[i].view(-1).cpu().numpy()) # latent_state[i] is torch.Tensor
context_length = self.context_length


if not is_init_infer:
# ============ Internal Node ============
# Retrieve KV from global KV cache self.keys_values_wm to single environment KV cache self.keys_values_wm_single_env, ensuring correct positional encoding
Expand Down Expand Up @@ -911,12 +918,12 @@ def update_cache_context(self, latent_state, is_init_infer=True, simulation_inde
# Store the latest key-value cache for initial inference
# self.past_kv_cache_init_infer_envs[i][cache_key] = copy.deepcopy(
# to_device_for_kvcache(self.keys_values_wm_single_env, 'cpu'))
self.past_kv_cache_init_infer_envs[i][cache_key] = to_device_for_kvcache(self.keys_values_wm_single_env, 'cpu')
custom_copy_kv_cache_to_dict(to_device_for_kvcache(self.keys_values_wm_single_env, 'cpu'), self.past_kv_cache_init_infer_envs[i], cache_key)
else:
# Store the latest key-value cache for recurrent inference
# self.past_kv_cache_recurrent_infer[cache_key] = copy.deepcopy(
# to_device_for_kvcache(self.keys_values_wm_single_env, 'cpu'))
self.past_kv_cache_recurrent_infer[cache_key] = to_device_for_kvcache(self.keys_values_wm_single_env, 'cpu')
custom_copy_kv_cache_to_dict(to_device_for_kvcache(self.keys_values_wm_single_env, 'cpu'), self.past_kv_cache_recurrent_infer, cache_key)

def retrieve_or_generate_kvcache(self, latent_state: list, ready_env_num: int,
simulation_index: int = 0) -> list:
Expand All @@ -936,8 +943,8 @@ def retrieve_or_generate_kvcache(self, latent_state: list, ready_env_num: int,
"""
for i in range(ready_env_num):
self.total_query_count += 1
state_single_env = latent_state[i] # Get the latent state for a single environment
cache_key = hash_state(state_single_env) # Compute the hash value using the quantized state
state_single_env = latent_state[i] # latent_state[i] is np.array
cache_key = hash_state(state_single_env)

# Try to retrieve the cached value from past_kv_cache_init_infer_envs
matched_value = self.past_kv_cache_init_infer_envs[i].get(cache_key)
Expand All @@ -951,7 +958,7 @@ def retrieve_or_generate_kvcache(self, latent_state: list, ready_env_num: int,
self.hit_count += 1
# Perform a deep copy because the transformer's forward pass might modify matched_value in-place
# self.keys_values_wm_list.append(copy.deepcopy(to_device_for_kvcache(matched_value, self.device)))
self.keys_values_wm_list.append(to_device_for_kvcache(matched_value, self.device))
self.keys_values_wm_list.append(custom_copy_kv_cache(src_kv=to_device_for_kvcache(matched_value, self.device)))
self.keys_values_wm_size_list.append(matched_value.size)
else:
# If no matching cache is found, generate a new one using zero reset
Expand Down Expand Up @@ -999,7 +1006,7 @@ def compute_loss(self, batch, target_tokenizer: Tokenizer = None, inverse_scalar

if self.obs_type == 'image':
# Reconstruct observations from latent state representations
reconstructed_images = self.tokenizer.decode_to_obs(obs_embeddings)
# reconstructed_images = self.tokenizer.decode_to_obs(obs_embeddings)

# ========== for visualization ==========
# Uncomment the lines below for visual analysis
Expand All @@ -1014,10 +1021,9 @@ def compute_loss(self, batch, target_tokenizer: Tokenizer = None, inverse_scalar
# Calculate reconstruction loss and perceptual loss
# latent_recon_loss = self.tokenizer.reconstruction_loss(batch['observations'].reshape(-1, 3, 64, 64), reconstructed_images) # NOTE: for stack=1
# perceptual_loss = self.tokenizer.perceptual_loss(batch['observations'].reshape(-1, 3, 64, 64), reconstructed_images) # NOTE: for stack=1
latent_recon_loss = torch.tensor(0., device=batch['observations'].device,
dtype=batch['observations'].dtype)
perceptual_loss = torch.tensor(0., device=batch['observations'].device,
dtype=batch['observations'].dtype)

latent_recon_loss = self.latent_recon_loss
perceptual_loss = self.perceptual_loss

elif self.obs_type == 'vector':
perceptual_loss = torch.tensor(0., device=batch['observations'].device,
Expand Down Expand Up @@ -1091,7 +1097,7 @@ def compute_loss(self, batch, target_tokenizer: Tokenizer = None, inverse_scalar
target_obs_embeddings = target_tokenizer.encode_to_obs_embeddings(batch['observations'])

# Compute labels for observations, rewards, and ends
labels_observations, labels_rewards, labels_ends = self.compute_labels_world_model(target_obs_embeddings,
labels_observations, labels_rewards, _ = self.compute_labels_world_model(target_obs_embeddings,
batch['rewards'],
batch['ends'],
batch['mask_padding'])
Expand Down Expand Up @@ -1341,7 +1347,7 @@ def compute_policy_entropy_loss(self, logits, mask):

def compute_labels_world_model(self, obs_embeddings: torch.Tensor, rewards: torch.Tensor, ends: torch.Tensor,
mask_padding: torch.BoolTensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
assert torch.all(ends.sum(dim=1) <= 1) # Each sequence sample should have at most one 'done' flag
# assert torch.all(ends.sum(dim=1) <= 1) # Each sequence sample should have at most one 'done' flag
mask_fill = torch.logical_not(mask_padding)

# Prepare observation labels
Expand All @@ -1352,9 +1358,11 @@ def compute_labels_world_model(self, obs_embeddings: torch.Tensor, rewards: torc
labels_rewards = rewards.masked_fill(mask_fill_rewards, -100)

# Fill the masked areas of ends
labels_ends = ends.masked_fill(mask_fill, -100)
# labels_ends = ends.masked_fill(mask_fill, -100)

# return labels_observations, labels_rewards.reshape(-1, self.support_size), labels_ends.reshape(-1)
return labels_observations, labels_rewards.view(-1, self.support_size), None # TODO

return labels_observations, labels_rewards.reshape(-1, self.support_size), labels_ends.reshape(-1)

def compute_labels_world_model_value_policy(self, target_value: torch.Tensor, target_policy: torch.Tensor,
mask_padding: torch.BoolTensor) -> Tuple[torch.Tensor, torch.Tensor]:
Expand Down
Loading

0 comments on commit b40c71b

Please sign in to comment.