From 1f1df62493f45370e4fcd8a649424211317bd5c7 Mon Sep 17 00:00:00 2001 From: dyyoungg Date: Thu, 8 Aug 2024 18:45:46 +0800 Subject: [PATCH 1/2] feature(pu): add rope in unizero's transformer --- .../model/unizero_world_models/transformer.py | 56 +++++++++++++++++-- .../model/unizero_world_models/world_model.py | 18 ++++-- zoo/atari/config/atari_unizero_config.py | 21 +++---- 3 files changed, 74 insertions(+), 21 deletions(-) diff --git a/lzero/model/unizero_world_models/transformer.py b/lzero/model/unizero_world_models/transformer.py index 714bc13d6..e3a118f47 100644 --- a/lzero/model/unizero_world_models/transformer.py +++ b/lzero/model/unizero_world_models/transformer.py @@ -4,7 +4,7 @@ import math from dataclasses import dataclass -from typing import Optional +from typing import Optional, Tuple import torch import torch.nn as nn @@ -55,6 +55,15 @@ def __init__(self, config: TransformerConfig) -> None: self.blocks = nn.ModuleList([Block(config) for _ in range(config.num_layers)]) self.ln_f = nn.LayerNorm(config.embed_dim) + self.config.rope_theta = 500000 + self.config.max_seq_len = 2048 + + self.freqs_cis = precompute_freqs_cis( + self.config.embed_dim // self.config.num_heads, + self.config.max_seq_len * 2, + self.config.rope_theta, + ) + def generate_empty_keys_values(self, n: int, max_tokens: int) -> KeysValues: """ Generate a placeholder for keys and values. @@ -70,7 +79,7 @@ def generate_empty_keys_values(self, n: int, max_tokens: int) -> KeysValues: return KeysValues(n, self.config.num_heads, max_tokens, self.config.embed_dim, self.config.num_layers, device) def forward(self, sequences: torch.Tensor, past_keys_values: Optional[KeysValues] = None, - valid_context_lengths: Optional[torch.Tensor] = None) -> torch.Tensor: + valid_context_lengths: Optional[torch.Tensor] = None, start_pos: int = 0) -> torch.Tensor: """ Forward pass of the Transformer model. @@ -82,10 +91,14 @@ def forward(self, sequences: torch.Tensor, past_keys_values: Optional[KeysValues Returns: - torch.Tensor: Output tensor of shape (batch_size, seq_length, embed_dim). """ + seqlen = sequences.shape[1] + self.freqs_cis = self.freqs_cis.to(sequences.device) + freqs_cis = self.freqs_cis[start_pos : start_pos + seqlen] + assert past_keys_values is None or len(past_keys_values) == len(self.blocks) x = self.drop(sequences) for i, block in enumerate(self.blocks): - x = block(x, None if past_keys_values is None else past_keys_values[i], valid_context_lengths) + x = block(x, None if past_keys_values is None else past_keys_values[i], valid_context_lengths, start_pos, freqs_cis) x = self.ln_f(x) return x @@ -129,7 +142,7 @@ def __init__(self, config: TransformerConfig) -> None: ) def forward(self, x: torch.Tensor, past_keys_values: Optional[KeysValues] = None, - valid_context_lengths: Optional[torch.Tensor] = None) -> torch.Tensor: + valid_context_lengths: Optional[torch.Tensor] = None, start_pos: int = 0, freqs_cis: torch.Tensor = None) -> torch.Tensor: """ Forward pass of the Transformer block. @@ -141,7 +154,7 @@ def forward(self, x: torch.Tensor, past_keys_values: Optional[KeysValues] = None Returns: - torch.Tensor: Output tensor of shape (batch_size, seq_length, embed_dim). """ - x_attn = self.attn(self.ln1(x), past_keys_values, valid_context_lengths) + x_attn = self.attn(self.ln1(x), past_keys_values, valid_context_lengths, start_pos, freqs_cis) if self.gru_gating: x = self.gate1(x, x_attn) x = self.gate2(x, self.mlp(self.ln2(x))) @@ -152,6 +165,34 @@ def forward(self, x: torch.Tensor, past_keys_values: Optional[KeysValues] = None return x +def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0): + freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) + t = torch.arange(end, device=freqs.device, dtype=torch.float32) + freqs = torch.outer(t, freqs) + freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 + return freqs_cis + +def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor): + ndim = x.ndim + assert 0 <= 1 < ndim + assert freqs_cis.shape == (x.shape[1], x.shape[-1]) + shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] + return freqs_cis.view(*shape) + + +def apply_rotary_emb( + xq: torch.Tensor, + xk: torch.Tensor, + freqs_cis: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) + xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) + freqs_cis = reshape_for_broadcast(freqs_cis, xq_) + xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) + xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) + return xq_out.type_as(xq), xk_out.type_as(xk) + + class SelfAttention(nn.Module): """ Implements self-attention mechanism for transformers. @@ -189,7 +230,7 @@ def __init__(self, config: TransformerConfig) -> None: self.register_buffer('mask', causal_mask) def forward(self, x: torch.Tensor, kv_cache: Optional[KeysValues] = None, - valid_context_lengths: Optional[torch.Tensor] = None) -> torch.Tensor: + valid_context_lengths: Optional[torch.Tensor] = None, start_pos: int = 0, freqs_cis: torch.Tensor = None) -> torch.Tensor: """ Forward pass for the self-attention mechanism. @@ -212,6 +253,9 @@ def forward(self, x: torch.Tensor, kv_cache: Optional[KeysValues] = None, q = self.query(x).view(B, T, self.num_heads, C // self.num_heads).transpose(1, 2) # (B, num_heads, T, head_size) k = self.key(x).view(B, T, self.num_heads, C // self.num_heads).transpose(1, 2) # (B, num_heads, T, head_size) v = self.value(x).view(B, T, self.num_heads, C // self.num_heads).transpose(1, 2) # (B, num_heads, T, head_size) + + if self.config.rotary_emb: + q, k = apply_rotary_emb(q, k, freqs_cis=freqs_cis) if kv_cache is not None: kv_cache.update(k, v) diff --git a/lzero/model/unizero_world_models/world_model.py b/lzero/model/unizero_world_models/world_model.py index ef31d951c..913581e2f 100644 --- a/lzero/model/unizero_world_models/world_model.py +++ b/lzero/model/unizero_world_models/world_model.py @@ -56,9 +56,11 @@ def __init__(self, config: TransformerConfig, tokenizer) -> None: self._initialize_patterns() # Position embedding - self.pos_emb = nn.Embedding(config.max_tokens, config.embed_dim, device=self.device) - self.precompute_pos_emb_diff_kv() - print(f"self.pos_emb.weight.device: {self.pos_emb.weight.device}") + if not self.config.rotary_emb: + self.pos_emb = nn.Embedding(config.max_tokens, config.embed_dim, device=self.device) + self.precompute_pos_emb_diff_kv() + + print(f"self.pos_emb.weight.device: {self.pos_emb.weight.device}") # Initialize action embedding table self.act_embedding_table = nn.Embedding(config.action_space_size, config.embed_dim, device=self.device) @@ -271,8 +273,11 @@ def forward(self, obs_embeddings_or_act_tokens: Dict[str, Union[torch.Tensor, tu if len(obs_embeddings.shape) == 2: obs_embeddings = obs_embeddings.unsqueeze(1) num_steps = obs_embeddings.size(1) - sequences = self._add_position_embeddings(obs_embeddings, prev_steps, num_steps, kvcache_independent, + if not self.config.rotary_emb: + sequences = self._add_position_embeddings(obs_embeddings, prev_steps, num_steps, kvcache_independent, is_init_infer, valid_context_lengths) + else: + sequences = obs_embeddings # Process action tokens elif 'act_tokens' in obs_embeddings_or_act_tokens: @@ -281,8 +286,11 @@ def forward(self, obs_embeddings_or_act_tokens: Dict[str, Union[torch.Tensor, tu act_tokens = act_tokens.squeeze(1) num_steps = act_tokens.size(1) act_embeddings = self.act_embedding_table(act_tokens) - sequences = self._add_position_embeddings(act_embeddings, prev_steps, num_steps, kvcache_independent, + if not self.config.rotary_emb: + sequences = self._add_position_embeddings(act_embeddings, prev_steps, num_steps, kvcache_independent, is_init_infer, valid_context_lengths) + else: + sequences = act_embeddings # Process combined observation embeddings and action tokens else: diff --git a/zoo/atari/config/atari_unizero_config.py b/zoo/atari/config/atari_unizero_config.py index 1c549010f..86c95438e 100644 --- a/zoo/atari/config/atari_unizero_config.py +++ b/zoo/atari/config/atari_unizero_config.py @@ -20,14 +20,14 @@ infer_context_length = 4 # ====== only for debug ===== -# collector_env_num = 2 -# n_episode = 2 -# evaluator_env_num = 2 -# num_simulations = 5 -# max_env_step = int(5e5) -# reanalyze_ratio = 0. -# batch_size = 2 -# num_unroll_steps = 10 +collector_env_num = 2 +n_episode = 2 +evaluator_env_num = 2 +num_simulations = 5 +max_env_step = int(5e5) +reanalyze_ratio = 0. +batch_size = 2 +num_unroll_steps = 10 # ============================================================== # end of the most frequently changed config specified by the user # ============================================================== @@ -51,13 +51,14 @@ observation_shape=(3, 64, 64), action_space_size=action_space_size, world_model_cfg=dict( + rotary_emb=True, max_blocks=num_unroll_steps, max_tokens=2 * num_unroll_steps, # NOTE: each timestep has 2 tokens: obs and action context_length=2 * infer_context_length, device='cuda', # device='cpu', action_space_size=action_space_size, - num_layers=4, + num_layers=2, num_heads=8, embed_dim=768, obs_type='image', @@ -101,6 +102,6 @@ seeds = [0] # You can add more seed values here for seed in seeds: # Update exp_name to include the current seed - main_config.exp_name = f'data_unizero/{env_id[:-14]}_stack1_unizero_upc{update_per_collect}-rr{replay_ratio}_H{num_unroll_steps}_bs{batch_size}_seed{seed}' + main_config.exp_name = f'data_unizero_debug/{env_id[:-14]}_stack1_unizero_upc{update_per_collect}-rr{replay_ratio}_H{num_unroll_steps}_bs{batch_size}_seed{seed}' from lzero.entry import train_unizero train_unizero([main_config, create_config], seed=seed, model_path=main_config.policy.model_path, max_env_step=max_env_step) From fd0d68deaf67f2b1153ce4159ab250f236d9b4c7 Mon Sep 17 00:00:00 2001 From: dyyoungg Date: Thu, 8 Aug 2024 22:06:49 +0800 Subject: [PATCH 2/2] fix(pu): fix rope in unizero's transformer --- .../model/unizero_world_models/transformer.py | 16 ++++++-- .../model/unizero_world_models/world_model.py | 37 +++++++++++-------- 2 files changed, 34 insertions(+), 19 deletions(-) diff --git a/lzero/model/unizero_world_models/transformer.py b/lzero/model/unizero_world_models/transformer.py index e3a118f47..4f0ab01f1 100644 --- a/lzero/model/unizero_world_models/transformer.py +++ b/lzero/model/unizero_world_models/transformer.py @@ -175,8 +175,12 @@ def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0): def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor): ndim = x.ndim assert 0 <= 1 < ndim - assert freqs_cis.shape == (x.shape[1], x.shape[-1]) - shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] + # assert freqs_cis.shape == (x.shape[1], x.shape[-1]) + # shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] + # TODO + assert freqs_cis.shape == (x.shape[2], x.shape[-1]) + shape = [d if i == 2 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] + return freqs_cis.view(*shape) @@ -188,8 +192,12 @@ def apply_rotary_emb( xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) freqs_cis = reshape_for_broadcast(freqs_cis, xq_) - xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) - xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) + # xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) + # xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) + # TODO + xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(-2) + xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(-2) + return xq_out.type_as(xq), xk_out.type_as(xk) diff --git a/lzero/model/unizero_world_models/world_model.py b/lzero/model/unizero_world_models/world_model.py index 913581e2f..2f844ead4 100644 --- a/lzero/model/unizero_world_models/world_model.py +++ b/lzero/model/unizero_world_models/world_model.py @@ -363,7 +363,12 @@ def _process_obs_act_combined(self, obs_embeddings_or_act_tokens, prev_steps): obs_act = torch.cat([obs, act], dim=1) obs_act_embeddings[:, i * (K + 1):(i + 1) * (K + 1), :] = obs_act - return obs_act_embeddings + self.pos_emb(prev_steps + torch.arange(num_steps, device=self.device)), num_steps + # return obs_act_embeddings + self.pos_emb(prev_steps + torch.arange(num_steps, device=self.device)), num_steps + + return_result = obs_act_embeddings + if not self.config.rotary_emb: + return_result += self.pos_emb(prev_steps + torch.arange(num_steps, device=self.device)) + return return_result, num_steps def _transformer_pass(self, sequences, past_keys_values, kvcache_independent, valid_context_lengths): """ @@ -741,13 +746,14 @@ def update_cache_context(self, latent_state, is_init_infer=True, simulation_inde k_cache_trimmed = k_cache_current[:, :, 2:context_length - 1, :].squeeze(0) v_cache_trimmed = v_cache_current[:, :, 2:context_length - 1, :].squeeze(0) - # Index pre-computed positional encoding differences - pos_emb_diff_k = self.pos_emb_diff_k[layer][(2, context_length - 1)] - pos_emb_diff_v = self.pos_emb_diff_v[layer][(2, context_length - 1)] - # ============ NOTE: Very Important ============ - # Apply positional encoding correction to k and v - k_cache_trimmed += pos_emb_diff_k.squeeze(0) - v_cache_trimmed += pos_emb_diff_v.squeeze(0) + if not self.config.rotary_emb: + # Index pre-computed positional encoding differences + pos_emb_diff_k = self.pos_emb_diff_k[layer][(2, context_length - 1)] + pos_emb_diff_v = self.pos_emb_diff_v[layer][(2, context_length - 1)] + # ============ NOTE: Very Important ============ + # Apply positional encoding correction to k and v + k_cache_trimmed += pos_emb_diff_k.squeeze(0) + v_cache_trimmed += pos_emb_diff_v.squeeze(0) # Pad the last 3 steps along the third dimension with zeros # F.pad parameters (0, 0, 0, 3) specify padding amounts for each dimension: (left, right, top, bottom). For 3D tensor, they correspond to (dim2 left, dim2 right, dim1 left, dim1 right). @@ -782,13 +788,14 @@ def update_cache_context(self, latent_state, is_init_infer=True, simulation_inde k_cache_trimmed = k_cache_current[:, 2:context_length - 1, :] v_cache_trimmed = v_cache_current[:, 2:context_length - 1, :] - # Index pre-computed positional encoding differences - pos_emb_diff_k = self.pos_emb_diff_k[layer][(2, context_length - 1)] - pos_emb_diff_v = self.pos_emb_diff_v[layer][(2, context_length - 1)] - # ============ NOTE: Very Important ============ - # Apply positional encoding correction to k and v - k_cache_trimmed += pos_emb_diff_k.squeeze(0) - v_cache_trimmed += pos_emb_diff_v.squeeze(0) + if not self.config.rotary_emb: + # Index pre-computed positional encoding differences + pos_emb_diff_k = self.pos_emb_diff_k[layer][(2, context_length - 1)] + pos_emb_diff_v = self.pos_emb_diff_v[layer][(2, context_length - 1)] + # ============ NOTE: Very Important ============ + # Apply positional encoding correction to k and v + k_cache_trimmed += pos_emb_diff_k.squeeze(0) + v_cache_trimmed += pos_emb_diff_v.squeeze(0) # Pad the last 3 steps along the third dimension with zeros # F.pad parameters (0, 0, 0, 3) specify padding amounts for each dimension: (left, right, top, bottom). For 3D tensor, they correspond to (dim2 left, dim2 right, dim1 left, dim1 right).