Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rework embedding table to not use complex numbers #432

Merged
merged 4 commits into from
Nov 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions sharktank/sharktank/layers/paged_llama_attention_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,11 +104,11 @@ def forward(
# Fast path to start_index based embedding lookup if available.
# Falls back to a slower position based index lookup.
if start_index is not None:
xq, xk = embedding.forward(xq=xq, xk=xk, start_index=start_index)
xq = embedding.forward(xt=xq, start_index=start_index)
xk = embedding.forward(xt=xk, start_index=start_index)
else:
xq, xk = embedding.apply_batched_mask(
xq=xq, xk=xk, mask=embedding_batch_mask
)
xq = embedding.apply_batched_mask(xt=xq, mask=embedding_batch_mask)
xk = embedding.apply_batched_mask(xt=xk, mask=embedding_batch_mask)

# Full sequence length.
kv_seq_len = seq_block_ids.shape[1] * self.cache.block_seq_stride
Expand Down
171 changes: 69 additions & 102 deletions sharktank/sharktank/layers/rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,58 +53,46 @@ def rotary_embed_table(self):
return self.static_rotary_embed_table
return self._create_rotary_embed_table()

if self.tensor_parallelism_size == 1:
return None

nt = namedtuple("replicated_tensor", ["shards"])
return nt([None] * self.tensor_parallelism_size)
return None

def forward(
self,
*,
xq: Union[torch.Tensor, SplitPrimitiveTensor],
xk: Union[torch.Tensor, SplitPrimitiveTensor],
xt: Union[torch.Tensor, SplitPrimitiveTensor],
start_index: int,
):
if isinstance(xq, SplitPrimitiveTensor):
assert (
isinstance(xk, SplitPrimitiveTensor)
and xq.shard_count == xk.shard_count
and xk.shard_dim == xq.shard_dim
)
assert (
isinstance(self.rotary_embed_table, ReplicatedTensor)
and xq.shard_count == self.rotary_embed_table.shard_count
)
xqk_shards = [
if isinstance(xt, SplitPrimitiveTensor):
rotary_shards = [None] * xt.shard_count
if self.rotary_embed_table is not None:
assert (
isinstance(self.rotary_embed_table, ReplicatedTensor)
and xt.shard_count == self.rotary_embed_table.shard_count
)
rotary_shards = [
unbox_tensor(shard) for shard in self.rotary_embed_table.shards
]

xt_shards = [
self.forward_unsharded(
xq=unbox_tensor(xq_shard),
xk=unbox_tensor(xk_shard),
xt=unbox_tensor(xt_shard),
start_index=start_index,
rotary_embed_table=unbox_tensor(rotary_embed_table_shard),
)
for xq_shard, xk_shard, rotary_embed_table_shard in zip(
xq.shards, xk.shards, self.rotary_embed_table.shards
rotary_embed_table=rotary_shard,
)
for xt_shard, rotary_shard in zip(xt.shards, rotary_shards)
]
xq_shards = [xqk[0] for xqk in xqk_shards]
xk_shards = [xqk[1] for xqk in xqk_shards]
xq = SplitPrimitiveTensor(ts=xq_shards, shard_dim=xq.shard_dim)
xk = SplitPrimitiveTensor(ts=xk_shards, shard_dim=xk.shard_dim)
return xq, xk
xt = SplitPrimitiveTensor(ts=xt_shards, shard_dim=xt.shard_dim)
return xt
else:
return self.forward_unsharded(
xq=xq,
xk=xk,
xt=xt,
start_index=start_index,
rotary_embed_table=self.rotary_embed_table,
)

def forward_unsharded(
self,
*,
xq: torch.Tensor,
xk: torch.Tensor,
xt: torch.Tensor,
start_index: int,
rotary_embed_table: Optional[torch.Tensor],
):
Expand Down Expand Up @@ -149,44 +137,39 @@ def create_ordering_tensor(dim):
return order_tensor

if self.use_hf:
xq = xq[..., create_interleaved_tensor(xq.shape[-1])]
xk = xk[..., create_interleaved_tensor(xq.shape[-1])]

xq_ = torch.view_as_complex(xq.unflatten(-1, (-1, 2)))
xk_ = torch.view_as_complex(xk.unflatten(-1, (-1, 2)))
_, sl, _, dim = xq_.shape
xt = xt[..., create_interleaved_tensor(xt.shape[-1])]
xt_ = xt.unflatten(-1, (-1, 2))
_, sl, _, dim, _ = xt_.shape

# Offset the table based on starting position.
if self.use_table:
freqs_cis = rotary_embed_table[start_index : start_index + sl, :]
freqs_cis = rotary_embed_table[:, start_index : start_index + sl, :]
else:
freqs_cis = torch.arange(start_index, start_index + sl, device=xq.device)
freqs_cis = torch.arange(start_index, start_index + sl, device=xt.device)
freqs_cis = self._compute_rotary_embed_table(freqs_cis)
freqs_cis = self._replicate(freqs_cis)

assert freqs_cis.shape[-1] == dim
assert (
freqs_cis.shape[0] >= sl
freqs_cis.shape[1] >= sl
), f"Sequence length longer than embedding table ({sl} vs {freqs_cis.shape[0]})"

broadcast_freqs_cis = freqs_cis[None, 0:sl, None, :]
broadcast_freqs_cis = freqs_cis[:, None, 0:sl, None, :]

if self.use_hf:
xq_out = torch.view_as_real(
self.complex_multiply(xq_, broadcast_freqs_cis)
).flatten(3)
xk_out = torch.view_as_real(
self.complex_multiply(xk_, broadcast_freqs_cis)
).flatten(3)
cos = broadcast_freqs_cis[0]
sin = broadcast_freqs_cis[1]
xt_r = xt_[..., 0]
xt_i = xt_[..., 1]

xt_out_r = xt_r * cos - xt_i * sin
xt_out_i = xt_i * cos + xt_r * sin

xq_out = xq_out[..., create_ordering_tensor(xq_out.shape[-1])]
xk_out = xk_out[..., create_ordering_tensor(xq_out.shape[-1])]
xt_out = torch.concatenate((xt_out_r, xt_out_i), dim=-1)

return xq_out.type_as(xq), xk_out.type_as(xk)
if self.use_hf:
xt_out = xt_out[..., create_ordering_tensor(xt_out.shape[-1])]
return xt_out.type_as(xt)

xq_out = torch.view_as_real(xq_ * broadcast_freqs_cis).flatten(3)
xk_out = torch.view_as_real(xk_ * broadcast_freqs_cis).flatten(3)
return xq_out.type_as(xq), xk_out.type_as(xk)
return xt_out.type_as(xt)

def complex_multiply(self, a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
"""Function for elementwise-multiplication of two complex torch tensors.
Expand Down Expand Up @@ -224,11 +207,11 @@ def compute_batch_mask(
self.trace_tensor("rope.positions_seq", positions_seq)

if self.use_table:
freqs_cis = self.rotary_embed_table[positions_seq]
freqs_cis = self.rotary_embed_table[:, positions_seq]
else:
shape = positions_seq.shape
freqs_cis = self._compute_rotary_embed_table(positions_seq.flatten())
freqs_cis = freqs_cis.unflatten(0, shape)
freqs_cis = freqs_cis.unflatten(1, shape)

# Unsqueeze a unit dim for attention heads.
broadcast_freqs_cis = freqs_cis.unsqueeze(2)
Expand All @@ -237,41 +220,24 @@ def compute_batch_mask(
def apply_batched_mask(
self,
*,
xq: Union[torch.Tensor, SplitPrimitiveTensor],
xk: Union[torch.Tensor, SplitPrimitiveTensor],
xt: Union[torch.Tensor, SplitPrimitiveTensor],
mask: Union[torch.Tensor, ReplicatedTensor],
):
if isinstance(xq, SplitPrimitiveTensor):
assert (
isinstance(xk, SplitPrimitiveTensor)
and xq.shard_count == xk.shard_count
and xk.shard_dim == xq.shard_dim
)
assert (
isinstance(mask, ReplicatedTensor)
and mask.shard_count == xq.shard_count
if not isinstance(xt, SplitPrimitiveTensor):
return self.apply_batched_mask_unsharded(xt=xt, mask=mask)

assert isinstance(mask, ReplicatedTensor) and mask.shard_count == xt.shard_count
xt_shards = [
self.apply_batched_mask_unsharded(
xt=unbox_tensor(xt_shard),
mask=unbox_tensor(mask_shard),
)
xqk_shards = [
self.apply_batched_mask_unsharded(
xq=unbox_tensor(xq_shard),
xk=unbox_tensor(xk_shard),
mask=unbox_tensor(mask_shard),
)
for xq_shard, xk_shard, mask_shard in zip(
xq.shards, xk.shards, mask.shards
)
]
xq_shards = [xqk[0] for xqk in xqk_shards]
xk_shards = [xqk[1] for xqk in xqk_shards]
xq = SplitPrimitiveTensor(ts=xq_shards, shard_dim=xq.shard_dim)
xk = SplitPrimitiveTensor(ts=xk_shards, shard_dim=xk.shard_dim)
return xq, xk
else:
return self.apply_batched_mask_unsharded(xq=xq, xk=xk, mask=mask)
for xt_shard, mask_shard in zip(xt.shards, mask.shards)
]
xt = SplitPrimitiveTensor(ts=xt_shards, shard_dim=xt.shard_dim)
return xt

def apply_batched_mask_unsharded(
self, *, xq: torch.Tensor, xk: torch.Tensor, mask: torch.Tensor
):
def apply_batched_mask_unsharded(self, *, xt: torch.Tensor, mask: torch.Tensor):
"""Applies the embedding to a ragged batch of queries and keys.

This does a more complicated indexing operation for cases when the each
Expand All @@ -281,13 +247,17 @@ def apply_batched_mask_unsharded(
"""
# xq_, xk_ shape: bs, sl, _, dim
# freqs_cis shape: max_sl, dim
xq_ = torch.view_as_complex(xq.unflatten(-1, (-1, 2)))
xk_ = torch.view_as_complex(xk.unflatten(-1, (-1, 2)))
_, sl, _, dim = xq_.shape
cos = mask[0]
sin = mask[1]

xq_out = torch.view_as_real(xq_ * mask).flatten(3)
xk_out = torch.view_as_real(xk_ * mask).flatten(3)
return xq_out.type_as(xq), xk_out.type_as(xk)
xt_ = xt.unflatten(-1, (-1, 2))
xt_r = xt_[..., 0]
xt_i = xt_[..., 1]

xt_out_r = xt_r * cos - xt_i * sin
xt_out_i = xt_r * sin + xt_i * cos
xt_out = torch.concatenate((xt_out_r, xt_out_i), dim=-1)
return xt_out.type_as(xt)

def _compute_rotary_embed_table(self, t):
dim = self.rope_dimension_count
Expand All @@ -297,13 +267,10 @@ def _compute_rotary_embed_table(self, t):
)
freqs = torch.outer(t, freqs).float()

freqs_cis = (
torch.complex(torch.cos(freqs), torch.sin(freqs))
if self.use_hf
else torch.polar(torch.ones_like(freqs), freqs)
)
cos = torch.cos(freqs).unsqueeze(0)
sin = torch.sin(freqs).unsqueeze(0)

return freqs_cis
return torch.concatenate((cos, sin), dim=0)

def _create_rotary_embed_table(self):
t = torch.arange(self.max_seqlen, device=self.device)
Expand Down
23 changes: 14 additions & 9 deletions sharktank/sharktank/types/tensors.py
Original file line number Diff line number Diff line change
Expand Up @@ -543,9 +543,10 @@ def _clone_with_globals(
) -> "InferenceTensor":
return DefaultPrimitiveTensor(name=self.name, data=new_globals[self.name])

def __getitem__(self, keys):
if not isinstance(keys, list) and not isinstance(keys, tuple):
keys = [keys]
def __getitem__(self, key):
keys = [key]
if isinstance(key, tuple) or isinstance(key, list):
keys = key

keys = [
unbox_tensor(key) if isinstance(key, PrimitiveTensor) else key
Expand Down Expand Up @@ -1188,15 +1189,19 @@ def create(
raise IOError(f"Missing component tensor '' in {raw_tensors.keys()}") from e
return cls(name=name, ts=ts)

def __getitem__(self, keys):
if not isinstance(keys, list) and not isinstance(keys, tuple):
keys = [keys]
def __getitem__(self, key):
keys = [key]
if isinstance(keys, tuple) or isinstance(keys, list):
keys = key

shards = []
for i, shard in enumerate(self.shards):
shard_keys = [
k.shards[i] if isinstance(k, ReplicatedTensor) else k for k in keys
]
shard_keys = []
for k in keys:
if isinstance(k, ReplicatedTensor):
shard_keys.append(k.shards[i])
else:
shard_keys.append(k)
shards.append(shard[*shard_keys])
return ReplicatedTensor(ts=shards)

Expand Down
6 changes: 4 additions & 2 deletions sharktank/tests/layers/sharded_rotary_embedding_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ def test_sharded_rotary_table():
max_seqlen=max_seqlen,
rope_freq_base=rope_freq_base,
)
oq, ok = default_layer(xq=xq, xk=xk, start_index=0)
oq = default_layer(xt=xq, start_index=0)
ok = default_layer(xt=xk, start_index=0)

# Then we can shard the same inputs and layer
xq = SplitPrimitiveTensor(ts=xq, shard_dim=2, shard_count=4)
Expand All @@ -46,7 +47,8 @@ def test_sharded_rotary_table():
rope_freq_base=rope_freq_base,
tensor_parallelism_size=4,
)
sq, sk = shard_layer(xq=xq, xk=xk, start_index=0)
sq = shard_layer(xt=xq, start_index=0)
sk = shard_layer(xt=xk, start_index=0)

# Gathering and unboxing should yield the same results
sq = ops.unshard(sq)
Expand Down
Loading