Skip to content

Commit

Permalink
fix i386 fp16 cases
Browse files Browse the repository at this point in the history
  • Loading branch information
wrongtest-intellif committed Aug 29, 2022
1 parent 59f7db4 commit 13aa737
Showing 1 changed file with 33 additions and 18 deletions.
51 changes: 33 additions & 18 deletions tests/python/unittest/test_tir_imm_values.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def test_tir_intimm_overflow():

def compare_float_value(value, expect, msg):
if math.isfinite(value):
assert value == expect, f"{value} vs {expect}, {msg}"
assert np.abs(value - expect) < 1e-5, f"{value} vs {expect}, {msg}"
elif math.isnan(value):
assert math.isnan(expect), f"{value} vs {expect}, {msg}"
elif math.isinf(value):
Expand Down Expand Up @@ -209,7 +209,7 @@ def check_tir_const_fold(
if isinstance(x_range, (int, float)):
x = x_range
elif dtype.startswith("int") or dtype.startswith("uint"):
x = np.random.randint(x_range[0], x_range[1] + 1)
x = np.random.randint(x_range[0], x_range[1] + 1, dtype=dtype)
else:
x = np.random.uniform(x_range[0], x_range[1])

Expand All @@ -218,7 +218,7 @@ def check_tir_const_fold(
if isinstance(y_range, (int, float)):
y = y_range
elif dtype.startswith("int") or dtype.startswith("uint"):
y = np.random.randint(y_range[0], y_range[1] + 1)
y = np.random.randint(y_range[0], y_range[1] + 1, dtype=dtype)
else:
y = np.random.uniform(y_range[0], y_range[1])

Expand All @@ -239,35 +239,50 @@ def check_tir_const_fold(
+ "This test is intentionally non-deterministic, "
+ f"if it fails please report it in github issue together with this seed {seed}\n"
)
compare_float_value(calc_res, fold_res.value, flaky_msg)
if expect:
compare_float_value(expect, calc_res, flaky_msg)
if dtype.startswith("float"):
compare_float_value(calc_res, fold_res.value, flaky_msg)
if expect:
compare_float_value(expect, calc_res, flaky_msg)
else:
assert calc_res == fold_res.value, flaky_msg
if expect:
assert expect == calc_res, flaky_msg


@tvm.testing.requires_llvm()
def test_tir_floatimm_const_fold():
"""Behavior check: folding fp32 match platform f32 arithmetic"""

@T.prim_func
def float_imm_multiply(x: T.float32, y: T.float32) -> T.float32:
T.evaluate(T.ret(x * y, dtype="float32"))
def float_imm_multiply(x: T.float32, y: T.float32, z: T.Buffer[(), "float32"]):
z[()] = x * y

@T.prim_func
def float_imm_add(x: T.float32, y: T.float32) -> T.float32:
T.evaluate(T.ret(x + y, dtype="float32"))
def float_imm_add(x: T.float32, y: T.float32, z: T.Buffer[(), "float32"]):
z[()] = x + y

@T.prim_func
def float_imm_sub(x: T.float32, y: T.float32) -> T.float32:
T.evaluate(T.ret(x - y, dtype="float32"))
def float_imm_sub(x: T.float32, y: T.float32, z: T.Buffer[(), "float32"]):
z[()] = x - y

@T.prim_func
def float_imm_div(x: T.float32, y: T.float32) -> T.float32:
T.evaluate(T.ret(x / y, dtype="float32"))
def float_imm_div(x: T.float32, y: T.float32, z: T.Buffer[(), "float32"]):
z[()] = x / y

def __wrap_build(f):
lib = tvm.build(f, target="llvm")
z = tvm.nd.array(np.zeros([]).astype("float32"))

def _func(x, y):
lib(x, y, z)
return z.numpy()

return _func

fmul = tvm.build(float_imm_multiply, target="llvm")
fadd = tvm.build(float_imm_add, target="llvm")
fsub = tvm.build(float_imm_sub, target="llvm")
fdiv = tvm.build(float_imm_div, target="llvm")
fmul = __wrap_build(float_imm_multiply)
fadd = __wrap_build(float_imm_add)
fsub = __wrap_build(float_imm_sub)
fdiv = __wrap_build(float_imm_div)

# overflow
check_tir_const_fold("float32", lambda x, y: x * y, fmul, 3.0e30, 3.0e30, np.inf)
Expand Down

0 comments on commit 13aa737

Please sign in to comment.