Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature(pu): add rope in unizero's transformer #261

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 58 additions & 6 deletions lzero/model/unizero_world_models/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import math
from dataclasses import dataclass
from typing import Optional
from typing import Optional, Tuple

import torch
import torch.nn as nn
Expand Down Expand Up @@ -55,6 +55,15 @@ def __init__(self, config: TransformerConfig) -> None:
self.blocks = nn.ModuleList([Block(config) for _ in range(config.num_layers)])
self.ln_f = nn.LayerNorm(config.embed_dim)

self.config.rope_theta = 500000
self.config.max_seq_len = 2048

self.freqs_cis = precompute_freqs_cis(
self.config.embed_dim // self.config.num_heads,
self.config.max_seq_len * 2,
self.config.rope_theta,
)

def generate_empty_keys_values(self, n: int, max_tokens: int) -> KeysValues:
"""
Generate a placeholder for keys and values.
Expand All @@ -70,7 +79,7 @@ def generate_empty_keys_values(self, n: int, max_tokens: int) -> KeysValues:
return KeysValues(n, self.config.num_heads, max_tokens, self.config.embed_dim, self.config.num_layers, device)

def forward(self, sequences: torch.Tensor, past_keys_values: Optional[KeysValues] = None,
valid_context_lengths: Optional[torch.Tensor] = None) -> torch.Tensor:
valid_context_lengths: Optional[torch.Tensor] = None, start_pos: int = 0) -> torch.Tensor:
"""
Forward pass of the Transformer model.

Expand All @@ -82,10 +91,14 @@ def forward(self, sequences: torch.Tensor, past_keys_values: Optional[KeysValues
Returns:
- torch.Tensor: Output tensor of shape (batch_size, seq_length, embed_dim).
"""
seqlen = sequences.shape[1]
self.freqs_cis = self.freqs_cis.to(sequences.device)
freqs_cis = self.freqs_cis[start_pos : start_pos + seqlen]

assert past_keys_values is None or len(past_keys_values) == len(self.blocks)
x = self.drop(sequences)
for i, block in enumerate(self.blocks):
x = block(x, None if past_keys_values is None else past_keys_values[i], valid_context_lengths)
x = block(x, None if past_keys_values is None else past_keys_values[i], valid_context_lengths, start_pos, freqs_cis)

x = self.ln_f(x)
return x
Expand Down Expand Up @@ -129,7 +142,7 @@ def __init__(self, config: TransformerConfig) -> None:
)

def forward(self, x: torch.Tensor, past_keys_values: Optional[KeysValues] = None,
valid_context_lengths: Optional[torch.Tensor] = None) -> torch.Tensor:
valid_context_lengths: Optional[torch.Tensor] = None, start_pos: int = 0, freqs_cis: torch.Tensor = None) -> torch.Tensor:
"""
Forward pass of the Transformer block.

Expand All @@ -141,7 +154,7 @@ def forward(self, x: torch.Tensor, past_keys_values: Optional[KeysValues] = None
Returns:
- torch.Tensor: Output tensor of shape (batch_size, seq_length, embed_dim).
"""
x_attn = self.attn(self.ln1(x), past_keys_values, valid_context_lengths)
x_attn = self.attn(self.ln1(x), past_keys_values, valid_context_lengths, start_pos, freqs_cis)
if self.gru_gating:
x = self.gate1(x, x_attn)
x = self.gate2(x, self.mlp(self.ln2(x)))
Expand All @@ -152,6 +165,42 @@ def forward(self, x: torch.Tensor, past_keys_values: Optional[KeysValues] = None
return x


def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
t = torch.arange(end, device=freqs.device, dtype=torch.float32)
freqs = torch.outer(t, freqs)
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
return freqs_cis

def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
ndim = x.ndim
assert 0 <= 1 < ndim
# assert freqs_cis.shape == (x.shape[1], x.shape[-1])
# shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
# TODO
assert freqs_cis.shape == (x.shape[2], x.shape[-1])
shape = [d if i == 2 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]

return freqs_cis.view(*shape)


def apply_rotary_emb(
xq: torch.Tensor,
xk: torch.Tensor,
freqs_cis: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
# xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
# xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
# TODO
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(-2)
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(-2)

return xq_out.type_as(xq), xk_out.type_as(xk)


class SelfAttention(nn.Module):
"""
Implements self-attention mechanism for transformers.
Expand Down Expand Up @@ -189,7 +238,7 @@ def __init__(self, config: TransformerConfig) -> None:
self.register_buffer('mask', causal_mask)

def forward(self, x: torch.Tensor, kv_cache: Optional[KeysValues] = None,
valid_context_lengths: Optional[torch.Tensor] = None) -> torch.Tensor:
valid_context_lengths: Optional[torch.Tensor] = None, start_pos: int = 0, freqs_cis: torch.Tensor = None) -> torch.Tensor:
"""
Forward pass for the self-attention mechanism.

Expand All @@ -212,6 +261,9 @@ def forward(self, x: torch.Tensor, kv_cache: Optional[KeysValues] = None,
q = self.query(x).view(B, T, self.num_heads, C // self.num_heads).transpose(1, 2) # (B, num_heads, T, head_size)
k = self.key(x).view(B, T, self.num_heads, C // self.num_heads).transpose(1, 2) # (B, num_heads, T, head_size)
v = self.value(x).view(B, T, self.num_heads, C // self.num_heads).transpose(1, 2) # (B, num_heads, T, head_size)

if self.config.rotary_emb:
q, k = apply_rotary_emb(q, k, freqs_cis=freqs_cis)

if kv_cache is not None:
kv_cache.update(k, v)
Expand Down
55 changes: 35 additions & 20 deletions lzero/model/unizero_world_models/world_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,11 @@ def __init__(self, config: TransformerConfig, tokenizer) -> None:
self._initialize_patterns()

# Position embedding
self.pos_emb = nn.Embedding(config.max_tokens, config.embed_dim, device=self.device)
self.precompute_pos_emb_diff_kv()
print(f"self.pos_emb.weight.device: {self.pos_emb.weight.device}")
if not self.config.rotary_emb:
self.pos_emb = nn.Embedding(config.max_tokens, config.embed_dim, device=self.device)
self.precompute_pos_emb_diff_kv()

print(f"self.pos_emb.weight.device: {self.pos_emb.weight.device}")

# Initialize action embedding table
self.act_embedding_table = nn.Embedding(config.action_space_size, config.embed_dim, device=self.device)
Expand Down Expand Up @@ -271,8 +273,11 @@ def forward(self, obs_embeddings_or_act_tokens: Dict[str, Union[torch.Tensor, tu
if len(obs_embeddings.shape) == 2:
obs_embeddings = obs_embeddings.unsqueeze(1)
num_steps = obs_embeddings.size(1)
sequences = self._add_position_embeddings(obs_embeddings, prev_steps, num_steps, kvcache_independent,
if not self.config.rotary_emb:
sequences = self._add_position_embeddings(obs_embeddings, prev_steps, num_steps, kvcache_independent,
is_init_infer, valid_context_lengths)
else:
sequences = obs_embeddings

# Process action tokens
elif 'act_tokens' in obs_embeddings_or_act_tokens:
Expand All @@ -281,8 +286,11 @@ def forward(self, obs_embeddings_or_act_tokens: Dict[str, Union[torch.Tensor, tu
act_tokens = act_tokens.squeeze(1)
num_steps = act_tokens.size(1)
act_embeddings = self.act_embedding_table(act_tokens)
sequences = self._add_position_embeddings(act_embeddings, prev_steps, num_steps, kvcache_independent,
if not self.config.rotary_emb:
sequences = self._add_position_embeddings(act_embeddings, prev_steps, num_steps, kvcache_independent,
is_init_infer, valid_context_lengths)
else:
sequences = act_embeddings

# Process combined observation embeddings and action tokens
else:
Expand Down Expand Up @@ -355,7 +363,12 @@ def _process_obs_act_combined(self, obs_embeddings_or_act_tokens, prev_steps):
obs_act = torch.cat([obs, act], dim=1)
obs_act_embeddings[:, i * (K + 1):(i + 1) * (K + 1), :] = obs_act

return obs_act_embeddings + self.pos_emb(prev_steps + torch.arange(num_steps, device=self.device)), num_steps
# return obs_act_embeddings + self.pos_emb(prev_steps + torch.arange(num_steps, device=self.device)), num_steps

return_result = obs_act_embeddings
if not self.config.rotary_emb:
return_result += self.pos_emb(prev_steps + torch.arange(num_steps, device=self.device))
return return_result, num_steps

def _transformer_pass(self, sequences, past_keys_values, kvcache_independent, valid_context_lengths):
"""
Expand Down Expand Up @@ -733,13 +746,14 @@ def update_cache_context(self, latent_state, is_init_infer=True, simulation_inde
k_cache_trimmed = k_cache_current[:, :, 2:context_length - 1, :].squeeze(0)
v_cache_trimmed = v_cache_current[:, :, 2:context_length - 1, :].squeeze(0)

# Index pre-computed positional encoding differences
pos_emb_diff_k = self.pos_emb_diff_k[layer][(2, context_length - 1)]
pos_emb_diff_v = self.pos_emb_diff_v[layer][(2, context_length - 1)]
# ============ NOTE: Very Important ============
# Apply positional encoding correction to k and v
k_cache_trimmed += pos_emb_diff_k.squeeze(0)
v_cache_trimmed += pos_emb_diff_v.squeeze(0)
if not self.config.rotary_emb:
# Index pre-computed positional encoding differences
pos_emb_diff_k = self.pos_emb_diff_k[layer][(2, context_length - 1)]
pos_emb_diff_v = self.pos_emb_diff_v[layer][(2, context_length - 1)]
# ============ NOTE: Very Important ============
# Apply positional encoding correction to k and v
k_cache_trimmed += pos_emb_diff_k.squeeze(0)
v_cache_trimmed += pos_emb_diff_v.squeeze(0)

# Pad the last 3 steps along the third dimension with zeros
# F.pad parameters (0, 0, 0, 3) specify padding amounts for each dimension: (left, right, top, bottom). For 3D tensor, they correspond to (dim2 left, dim2 right, dim1 left, dim1 right).
Expand Down Expand Up @@ -774,13 +788,14 @@ def update_cache_context(self, latent_state, is_init_infer=True, simulation_inde
k_cache_trimmed = k_cache_current[:, 2:context_length - 1, :]
v_cache_trimmed = v_cache_current[:, 2:context_length - 1, :]

# Index pre-computed positional encoding differences
pos_emb_diff_k = self.pos_emb_diff_k[layer][(2, context_length - 1)]
pos_emb_diff_v = self.pos_emb_diff_v[layer][(2, context_length - 1)]
# ============ NOTE: Very Important ============
# Apply positional encoding correction to k and v
k_cache_trimmed += pos_emb_diff_k.squeeze(0)
v_cache_trimmed += pos_emb_diff_v.squeeze(0)
if not self.config.rotary_emb:
# Index pre-computed positional encoding differences
pos_emb_diff_k = self.pos_emb_diff_k[layer][(2, context_length - 1)]
pos_emb_diff_v = self.pos_emb_diff_v[layer][(2, context_length - 1)]
# ============ NOTE: Very Important ============
# Apply positional encoding correction to k and v
k_cache_trimmed += pos_emb_diff_k.squeeze(0)
v_cache_trimmed += pos_emb_diff_v.squeeze(0)

# Pad the last 3 steps along the third dimension with zeros
# F.pad parameters (0, 0, 0, 3) specify padding amounts for each dimension: (left, right, top, bottom). For 3D tensor, they correspond to (dim2 left, dim2 right, dim1 left, dim1 right).
Expand Down
21 changes: 11 additions & 10 deletions zoo/atari/config/atari_unizero_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,14 @@
infer_context_length = 4

# ====== only for debug =====
# collector_env_num = 2
# n_episode = 2
# evaluator_env_num = 2
# num_simulations = 5
# max_env_step = int(5e5)
# reanalyze_ratio = 0.
# batch_size = 2
# num_unroll_steps = 10
collector_env_num = 2
n_episode = 2
evaluator_env_num = 2
num_simulations = 5
max_env_step = int(5e5)
reanalyze_ratio = 0.
batch_size = 2
num_unroll_steps = 10
# ==============================================================
# end of the most frequently changed config specified by the user
# ==============================================================
Expand All @@ -51,13 +51,14 @@
observation_shape=(3, 64, 64),
action_space_size=action_space_size,
world_model_cfg=dict(
rotary_emb=True,
max_blocks=num_unroll_steps,
max_tokens=2 * num_unroll_steps, # NOTE: each timestep has 2 tokens: obs and action
context_length=2 * infer_context_length,
device='cuda',
# device='cpu',
action_space_size=action_space_size,
num_layers=4,
num_layers=2,
num_heads=8,
embed_dim=768,
obs_type='image',
Expand Down Expand Up @@ -101,6 +102,6 @@
seeds = [0] # You can add more seed values here
for seed in seeds:
# Update exp_name to include the current seed
main_config.exp_name = f'data_unizero/{env_id[:-14]}_stack1_unizero_upc{update_per_collect}-rr{replay_ratio}_H{num_unroll_steps}_bs{batch_size}_seed{seed}'
main_config.exp_name = f'data_unizero_debug/{env_id[:-14]}_stack1_unizero_upc{update_per_collect}-rr{replay_ratio}_H{num_unroll_steps}_bs{batch_size}_seed{seed}'
from lzero.entry import train_unizero
train_unizero([main_config, create_config], seed=seed, model_path=main_config.policy.model_path, max_env_step=max_env_step)
Loading