Skip to content

Commit

Permalink
[WIP] Add sharded block-cyclic spilt tensor
Browse files Browse the repository at this point in the history
  • Loading branch information
sogartar committed Sep 10, 2024
1 parent c15b751 commit ecf937e
Show file tree
Hide file tree
Showing 6 changed files with 180 additions and 2 deletions.
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
gguf==0.6.0
numpy==1.26.3
onnx==1.15.0
numpy>=1.15

# Model deps.
huggingface-hub==0.22.2
Expand Down
12 changes: 12 additions & 0 deletions sharktank/sharktank/ops/default_impls.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,13 @@ def equal_default(a, b) -> bool:
return torch.equal(unbox_tensor(a), unbox_tensor(b))


@flatten.override(Tensor)
def flatten_default(
input: Union[PrimitiveTensor, Tensor], start_dim: int, end_dim: int
) -> Tensor:
return torch.flatten(unbox_tensor(input), start_dim, end_dim)


@gemm.override(AllOfType(Tensor, InferenceTensor))
def gemm(
a: AnyTensor,
Expand Down Expand Up @@ -242,6 +249,11 @@ def scaled_dot_product_attention(q, k, v, a) -> Tensor:
)


@reshape.override(Tensor)
def reshape_default(input: Union[PrimitiveTensor, Tensor], shape: List[int]) -> Tensor:
return torch.reshape(unbox_tensor(input), shape)


# RMS norm
@rms_norm.override(Tensor, Tensor)
def rms_norm_default(x, weight, *, epsilon: float) -> Tensor:
Expand Down
11 changes: 11 additions & 0 deletions sharktank/sharktank/ops/shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from typing import Sequence
from ..types.tensors import AnyTensor
import numpy as np


def broadcast_dim(
Expand Down Expand Up @@ -53,3 +54,13 @@ def broadcast_dims(
ranks = [len(shape) for shape in shaped_or_shape]
broadcast_rank = max(ranks)
return [dim + max(0, broadcast_rank - rank) for dim, rank in zip(dims, ranks)]


def flatten_shape(shape: Sequence[int], start_dim: int = 0, end_dim: int = -1):
end_dim = end_dim if end_dim >= 0 else len(input) - 1
shape = list(shape)
return (
shape[:start_dim]
+ np.prod(shape[start_dim : end_dim + 1])
+ shape[: end_dim + 1]
)
39 changes: 37 additions & 2 deletions sharktank/sharktank/ops/sharded_impls.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,14 @@

import torch
from torch import Tensor
from typing import List, Optional, Sequence
from typing import List, Optional, Sequence, Union
import itertools
from numbers import Number
import numpy as np

from ..types import (
AnyTensor,
BlockCyclicSplitTensor,
DefaultPrimitiveTensor,
InferenceTensor,
ReplicatedTensor,
Expand All @@ -24,7 +26,7 @@
from ..types.tensors import unbox_tensor
from ._registry import AllOfType
from .signatures import *
from .shape import broadcast_dims
from .shape import broadcast_dims, flatten_shape


@all_gather.override(SplitPrimitiveTensor)
Expand Down Expand Up @@ -318,6 +320,25 @@ def equal_split(a: SplitPrimitiveTensor, b: AnyTensor) -> bool:
return a.is_deep_equal(b)


@flatten.override(SplitPrimitiveTensor)
def flatten_sharded_split_tensor(
input: SplitPrimitiveTensor, start_dim: int, end_dim: int
) -> Union[SplitPrimitiveTensor, BlockCyclicSplitTensor]:
end_dim_deduced = end_dim if end_dim >= 0 else len(input.shape) - 1
flatten_dim_len = end_dim_deduced - start_dim + 1
# Is not flattening a the split dim or is the degenerate case of flattening a single dimension.
if (
input.shard_dim < start_dim
or end_dim_deduced > input.shard_dim
or flatten_dim_len == 1
):
shards = [flatten(shard, start_dim, end_dim) for shard in input.shards]
return SplitPrimitiveTensor(shard_dim=input.shard_dim, ts=shards)

flattened_shape = flatten_shape(input.shape, start_dim, end_dim)
assert False, "TODO"


@group_norm_affine.override(
SplitPrimitiveTensor, SplitPrimitiveTensor, SplitPrimitiveTensor
)
Expand Down Expand Up @@ -563,6 +584,20 @@ def replicate_unsharded(input, *, count: int) -> ReplicatedTensor:
return ReplicatedTensor(ts=torch_input, shard_count=count)


@reshape.override(SplitPrimitiveTensor)
def reshape_sharded_split_tensor(
input: SplitPrimitiveTensor, shape: List[int]
) -> Union[SplitPrimitiveTensor, BlockCyclicSplitTensor]:
if (
len(input.shape) == len(shape)
and input.shape[input.shard_dim] == shape[input.shard_dim]
):
shards = [reshape(shard, shape) for shard in input.shards]
return SplitPrimitiveTensor(shard_dim=input.shard_dim, ts=shards, shape=shape)

assert (False, "TODO")


@reshard.override(Tensor, sharding.Split)
def reshard_tensor_split(input: Tensor, spec: sharding.Split) -> AnyTensor:
return reshard_split(input, dim=spec.shard_dim, count=spec.shard_count)
Expand Down
40 changes: 40 additions & 0 deletions sharktank/sharktank/ops/signatures.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
"elementwise",
"embedding_lookup",
"equal",
"flatten",
"gemm",
"group_norm_affine",
"layer_norm",
Expand All @@ -33,6 +34,7 @@
"permute",
"rms_norm",
"replicate",
"reshape",
"reshard",
"reshard_split",
"reshard_like",
Expand Down Expand Up @@ -232,6 +234,25 @@ def _equal_trampoline(d: SignatureDispatcher, a: AnyTensor, b: AnyTensor):
d.fail(tensors)


@overridable
def flatten(input: AnyTensor, start_dim: int = 0, end_dim: int = -1) -> AnyTensor:
"""Flattens input by reshaping it into a one-dimensional tensor."""
...


@flatten.trampoline
def _flatten_trampoline(
d: SignatureDispatcher, input: AnyTensor, start_dim: int, end_dim: int
):
tensors = (input,)
for override in d.find_overrides(tensors):
result = override(input, start_dim, end_dim)
if result is not NotImplemented:
return override, result
else:
d.fail(tensors)


@overridable
def gemm(
a: AnyTensor,
Expand Down Expand Up @@ -529,6 +550,25 @@ def _scaled_dot_product_attention(
d.fail(tensors)


@overridable
def reshape(input: AnyTensor, shape: List[int]) -> AnyTensor:
"""Returns a tensor with the same data and number of elements as input, but with
the specified shape.
"""
...


@reshape.trampoline
def _reshape_trampoline(d: SignatureDispatcher, input, shape) -> AnyTensor:
dispatch_args = input
for override in d.find_overrides(dispatch_args):
result = override(input, shape)
if result is not NotImplemented:
return override, result
else:
d.fail(dispatch_args)


@overridable
def reshard(
input: AnyTensor | Theta,
Expand Down
79 changes: 79 additions & 0 deletions sharktank/sharktank/types/tensors.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

from abc import ABC, abstractmethod
from dataclasses import dataclass
import numpy as np

import torch
from torch import Tensor
Expand All @@ -36,6 +37,7 @@

__all__ = [
"AnyTensor",
"BlockCyclicSplitTensor",
"DefaultPrimitiveTensor",
"flatten_tensor_tree",
"InferenceTensor",
Expand Down Expand Up @@ -850,6 +852,83 @@ def __getitem__(self, key):
return SplitPrimitiveTensor(ts=shards, shard_dim=self.shard_dim)


@register_inference_tensor
class BlockCyclicSplitTensor(ShardedTensor):
"""A tensor that is sharded into blocks of the same size.
The blocks are assigned cyclically to the list of devices.
Example distribution of a 2D tensor consisting of blocks Bij, over 4 devices
arranged in 2D mesh [[D0, D1], [D2, D3]].
```
+--------+--------+--------+--------+
| B00 D0 | B01 D1 | B02 D0 | B03 D1 |
+--------+--------+--------+--------+
| B10 D2 | B11 D3 | B12 D2 | B13 D3 |
+--------+--------+--------+--------+
| B20 D0 | B21 D1 | B22 D0 | B23 D1 |
+--------+--------+--------+--------+
```
Each device will hold a tensor consisting of the blocks assigned to it. They are
concatenated into a single tensor according to the device mesh structure.
In the above example device D0 will have the tensor
```
+-----+-----+
| B00 | B02 |
+-----+-----+
| B20 | B22 |
+-----+-----+
```
If an unsharded tensor dimension has a size that is not divisible by the
corresponding block size dimension, then the last block in will have a reduced
size, such that the sum of block sizes is equal to the unsharded tensor dimension
size.
One usage of this type of sharding is to do reshaping of split tensors without
moving data between devices.
For example if we have a MxN tensor that is split across its 1-th dimension and
distributed across 3 devices. Each device gets a tensor of size [M, N/3].
For simplicity we assume `N mod 3 == 0`.
```
N
+----+----+----+
| D0 | D1 | D2 | M
+----+----+----+
```
Flattening the tensor would result in a 1D tensor with block-cyclic sharding with
blocks of size N/3.
```
MxN
+----+----+----+----+----+----+-----+----+----+----+
| D0 | D1 | D2 | D0 | D1 | D2 | ... | D0 | D1 | D2 |
+----+----+----+----+----+----+-----+----+----+----+
```
"""

def __init__(
self,
*,
shape: List[int],
shards: List[torch.Tensor],
mesh_shape: List[int],
block_shape: List[int],
name: str = UnnamedTensorName,
):
super().__init__(name=name, ts=shards, shape=shape)
assert np.prod(mesh_shape) == len(shards)
self._mesh_shape = mesh_shape
self._block_shape = block_shape

@property
def block_shape(self) -> int:
return self._block_shape

@property
def block_shape(self) -> int:
return self._block_shape


@register_inference_tensor
class ReplicatedTensor(ShardedTensor):
"""A tensor that is replicated across all shards."""
Expand Down

0 comments on commit ecf937e

Please sign in to comment.