Skip to content

Commit

Permalink
Hook the einsum op and the get_index` op into their implementations
Browse files Browse the repository at this point in the history
  • Loading branch information
KyleHerndon committed Sep 25, 2024
1 parent 6315229 commit 8de4a21
Show file tree
Hide file tree
Showing 6 changed files with 117 additions and 6 deletions.
4 changes: 0 additions & 4 deletions sharktank/sharktank/kernels/einsum_2args_q4.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,6 @@ def select(self, ksel: KernelSelection):
m_desc.t.dtype == d_desc.t.dtype and len(m_dims) == len(qs_dims),
lambda: f"einsum_2args_q4 arg 'm': Incorrect dtype (got {m_desc.t.dtype})",
)

# einsum_str
torch._check(
einsum_str.count(",") == 1 and einsum_str.count("->") == 1,
Expand All @@ -139,9 +138,7 @@ def select(self, ksel: KernelSelection):
es_set = set(es_out)

shp = qs_desc.t.shape
print(shp)
b_dims = list(shp[:-2]) + [shp[-2] * block_size]
print(b_dims)
torch._check(
len(es_in0) == len(a_desc.t.shape)
and len(es_in1)
Expand Down Expand Up @@ -262,5 +259,4 @@ def generate(self, ksel: KernelSelection, kb: KernelBuilder):
c_size=len(es_out),
out_dyn_dim_size_str=oddss,
)
print(target_function)
kb.yield_results(*call_function(target_function, *kb.arg_bindings))
3 changes: 2 additions & 1 deletion sharktank/sharktank/kernels/templates/einsum_2args_q4.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ module {
util.func private @sharktank_einsum_2args_q4_{{es_name}}_{{bs}}_{{a_type}}(
%a: !a_tensor_type, %d: !d_tensor_type, %qs_raw: !qs_raw_tensor_type, %m: !m_tensor_type)
-> !c_tensor_type {
%debug = tensor.empty() : tensor<1xf32>
%zero = arith.constant 0.0: !accum_type
// todo: loop
{% for i in range(a_size) %}
Expand All @@ -43,7 +44,7 @@ util.func private @sharktank_einsum_2args_q4_{{es_name}}_{{bs}}_{{a_type}}(
%b_unblocked_dim = arith.muli %b{{b_size-1}}, %bs : index

//%qs = flow.tensor.bitcast %qs_raw : !qs_raw_tensor_type -> !qs_tensor_type
%qs = flow.tensor.bitcast %qs_raw : !qs_raw_tensor_type{{"{"}}{% for i in range(b_size-1) %}%b{{i}},{% endfor %}%b{{b_size-1}}{{"}"}} -> !qs_tensor_type{{"{"}}{% for i in range(b_size-1) %}%b{{i}},{% endfor %}%b_unblocked_dim{{"}"}}
%qs = flow.tensor.bitcast %qs_raw : !qs_raw_tensor_type{{"{"}}{% for i in range(b_size-1) %}%b{{i}},{% endfor %}%b{{b_size-1}}{{"}"}} -> !qs_tensor_type{{"{"}}{% for i in range(b_size-1) %}%b{{i}},{% endfor %}%b{{b_size-1}}{{"}"}}

// Dequantize.
// todo: loop
Expand Down
13 changes: 13 additions & 0 deletions sharktank/sharktank/ops/custom_impls.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import torch.nn.functional as F

from ..kernels import (
einsum_2args_q4,
mmt_block_scaled_offset_q4_unsigned,
mmt_block_scaled_q8,
mmtfp,
Expand Down Expand Up @@ -44,6 +45,18 @@
# return mmtfp(lhs, rhs)


# Einsum


@einsum_2args.override(Tensor, QuantizedTensor)
def einsum_2args_QuantizedTensor(input0, input1, einsum_str):
unpacked = input1.unpack()
layout = input1.layout_type
if not isinstance(unpacked, BlockScaledI4Layout):
return NotImplemented
return einsum_2args_q4(input0, unpacked.d, unpacked._qs, unpacked.m, einsum_str)


# Quantized Matmul


Expand Down
39 changes: 38 additions & 1 deletion sharktank/sharktank/ops/default_impls.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,13 @@
import torch.nn.functional as F
from numbers import Number

from ..types import PrimitiveTensor, QuantizedTensor, InferenceTensor
from ..types import (
PrimitiveTensor,
QuantizedTensor,
InferenceTensor,
PlanarQuantizedTensor,
BlockScaledI4Layout,
)
from ..types.tensors import unbox_tensor, AnyTensor
from ._registry import AllOfType, AllOfExprs, AllOfExprsVariadic, IsOfType
from .signatures import *
Expand Down Expand Up @@ -62,6 +68,12 @@ def conv2d_default(
conv2d.override(Tensor, Tensor, Tensor, auto_dequant=True)(conv2d_default)
conv2d.override(Tensor, Tensor, auto_dequant=True)(conv2d_default)

# Einsum
@einsum_2args.override(AllOfType(Tensor, PrimitiveTensor))
def einsum_2args(x, y, einsum_str):
return torch.einsum(einsum_str, unbox_tensor(x), unbox_tensor(y))


# Elementwise
@elementwise.override(Tensor)
def elementwise_unary(operator, x):
Expand Down Expand Up @@ -133,6 +145,31 @@ def equal_default(a, b) -> bool:
return torch.equal(unbox_tensor(a), unbox_tensor(b))


@get_index.override(AllOfType(Tensor, PrimitiveTensor))
def get_index_default(tensor, key: slice):
return unbox_tensor(tensor).__get_item__(key)


@get_index.override(QuantizedTensor)
def get_index_QuantizedTensor(tensor: QuantizedTensor, key: slice):
unpacked = tensor.unpack()
if isinstance(unpacked, BlockScaledI4Layout):
mul = 2
else:
return NotImplemented
new_d = unpacked._d[key]
new_qs = unpacked._qs[key]
if unpacked.m is not None:
new_m = unpacked.m[key]
dims = new_qs.shape
dims = dims[:-2] + (dims[-2] * dims[-1] * mul,)
layout = BlockScaledI4Layout(shape=dims, d=new_d, qs=new_qs, m=new_m)
return PlanarQuantizedTensor(shape=dims, layout=layout)


# get_index.override(PlanarQuantizedTensor, slice)(get_index_QuantizedTensor)


@gemm.override(AllOfType(Tensor, InferenceTensor))
def gemm(
a: AnyTensor,
Expand Down
59 changes: 59 additions & 0 deletions sharktank/sharktank/ops/signatures.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,11 @@
"all_reduce",
"cat",
"conv2d",
"einsum_2args",
"elementwise",
"embedding_lookup",
"equal",
"get_index",
"gemm",
"group_norm_affine",
"layer_norm",
Expand Down Expand Up @@ -151,6 +153,37 @@ def _conv2d_trampoline(
d.fail(tensors)


@overridable
def einsum_2args(
input0: AnyTensor,
input1: AnyTensor,
einsum_str: str,
*,
accum_dtype: Optional[torch.dtype] = None,
) -> torch.Tensor:
"""Executes a given Einstein summation notation string on the provided tensors.
Equivalent to:
```
y = torch.einsum(einsum_str, input0, input1)
```
"""
raise NotImplementedError


@einsum_2args.trampoline
def _einsum_trampoline(
d: SignatureDispatcher, input0: AnyTensor, input1: AnyTensor, einsum_str: str
):
tensors = (input0, input1)
for override in d.find_overrides(tensors):
result = override(input0, input1, einsum_str)
if result is not NotImplemented:
return override, result
else:
d.fail(tensors)


@overridable
def elementwise(operator, *args: AnyTensor) -> AnyTensor:
"""Applies an elementwise operator against arguments."""
Expand Down Expand Up @@ -232,6 +265,32 @@ def _equal_trampoline(d: SignatureDispatcher, a: AnyTensor, b: AnyTensor):
d.fail(tensors)


@overridable
def get_index(
tensor: AnyTensor,
key: slice,
) -> torch.Tensor:
"""Indexes the tensor using the key.
Equivalent to:
```
out = tensor[key]
```
"""
raise NotImplementedError


@get_index.trampoline
def _get_index_trampoline(d: SignatureDispatcher, tensor: AnyTensor, key: slice):
tensors = (tensor,)
for override in d.find_overrides(tensors):
result = override(tensor, key)
if result is not NotImplemented:
return override, result
else:
d.fail(tensors)


@overridable
def gemm(
a: AnyTensor,
Expand Down
5 changes: 5 additions & 0 deletions sharktank/sharktank/types/tensors.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,11 @@ def __rmul__(self, lhs):
# numbers on the lhs.
return self.__mul__(lhs)

def __getitem__(self, key):
from ..ops import get_index

return get_index(self, key)


REGISTERED_INFERENCE_TENSOR_CLASSES: dict[str, Type[InferenceTensor]] = {}

Expand Down

0 comments on commit 8de4a21

Please sign in to comment.