Skip to content

Commit

Permalink
dialects: (csl) builtin math lib for DSDs (#2686)
Browse files Browse the repository at this point in the history
DSD builtin operations typically come with a variety of overloads. This
PR proposes a base class handling ops and verification. The subclasses
correspond to CSL builtin ops. Their implementation only needs to
specify op name and the valid function signatures. The goal is to keep
everything really simple.

---------

Co-authored-by: n-io <[email protected]>
  • Loading branch information
n-io and n-io authored Jun 6, 2024
1 parent a83b98b commit 6095319
Show file tree
Hide file tree
Showing 3 changed files with 881 additions and 3 deletions.
43 changes: 43 additions & 0 deletions tests/dialects/test_csl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import pytest

from xdsl.dialects.builtin import Float32Type, IntegerType, Signedness, TensorType
from xdsl.dialects.csl import Add16Op, DsdKind, DsdType, GetMemDsdOp
from xdsl.utils.exceptions import VerifyException
from xdsl.utils.test_value import TestSSAValue

tensor = TestSSAValue(TensorType(Float32Type(), [4]))
size_i32 = TestSSAValue(IntegerType(32, Signedness.SIGNED))
dest_dsd = GetMemDsdOp(
operands=[tensor, size_i32], result_types=[DsdType(DsdKind.mem1d_dsd)]
)
src_dsd1 = GetMemDsdOp(
operands=[tensor, size_i32], result_types=[DsdType(DsdKind.mem1d_dsd)]
)
src_dsd2 = GetMemDsdOp(
operands=[tensor, size_i32], result_types=[DsdType(DsdKind.mem1d_dsd)]
)
i16_value = TestSSAValue(IntegerType(16, Signedness.SIGNED))
u16_value = TestSSAValue(IntegerType(16, Signedness.UNSIGNED))


def test_verify_valid_builtin_signature():
Add16Op(operands=[(dest_dsd, src_dsd1, src_dsd2)], result_types=[]).verify_()
Add16Op(operands=[(dest_dsd, i16_value, src_dsd1)], result_types=[]).verify_()
Add16Op(operands=[(dest_dsd, u16_value, src_dsd1)], result_types=[]).verify_()
Add16Op(operands=[(dest_dsd, src_dsd1, i16_value)], result_types=[]).verify_()
Add16Op(operands=[(dest_dsd, src_dsd1, u16_value)], result_types=[]).verify_()


def test_verify_invalid_builtin_signature():
with pytest.raises(VerifyException):
Add16Op(
operands=[(dest_dsd, src_dsd1, src_dsd2, dest_dsd)], result_types=[]
).verify_()
with pytest.raises(VerifyException):
Add16Op(operands=[(dest_dsd, src_dsd1)], result_types=[]).verify_()
with pytest.raises(VerifyException):
Add16Op(operands=[(dest_dsd, i16_value, u16_value)], result_types=[]).verify_()
with pytest.raises(VerifyException):
Add16Op(operands=[(i16_value, src_dsd1, u16_value)], result_types=[]).verify_()
with pytest.raises(VerifyException):
Add16Op(operands=[(dest_dsd, src_dsd1, size_i32)], result_types=[]).verify_()
Loading

0 comments on commit 6095319

Please sign in to comment.