Skip to content

Commit

Permalink
Apply 'Black' formatter to py/test/correctness and py/test/generators (
Browse files Browse the repository at this point in the history
…#7135)

* Apply 'Black' formatter to py/test/correctness and py/test/generators

Trying to regularize all our Python code to a common style. Should be no functional changes here, just autoformatting + a few tweaks.

* Update complexpy_generator.py
  • Loading branch information
steven-johnson authored Oct 31, 2022
1 parent 0c03ff8 commit bad945f
Show file tree
Hide file tree
Showing 27 changed files with 739 additions and 454 deletions.
143 changes: 112 additions & 31 deletions python_bindings/test/correctness/addconstant_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def test(addconstant_impl_func, offset):
input_float = numpy.array([3.14, 2.718, 1.618], dtype=numpy.float32)
input_double = numpy.array([3.14, 2.718, 1.618], dtype=numpy.float64)
input_half = numpy.array([3.14, 2.718, 1.618], dtype=numpy.float16)
input_2d = numpy.array([[1, 2, 3], [4, 5, 6]], dtype=numpy.int8, order='F')
input_2d = numpy.array([[1, 2, 3], [4, 5, 6]], dtype=numpy.int8, order="F")
input_3d = numpy.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]], dtype=numpy.int8)

output_u8 = numpy.zeros((3,), dtype=numpy.uint8)
Expand All @@ -53,20 +53,47 @@ def test(addconstant_impl_func, offset):
output_float = numpy.zeros((3,), dtype=numpy.float32)
output_double = numpy.zeros((3,), dtype=numpy.float64)
output_half = numpy.zeros((3,), dtype=numpy.float16)
output_2d = numpy.zeros((2, 3), dtype=numpy.int8, order='F')
output_2d = numpy.zeros((2, 3), dtype=numpy.int8, order="F")
output_3d = numpy.zeros((2, 2, 2), dtype=numpy.int8)

addconstant_impl_func(
scalar_u1,
scalar_u8, scalar_u16, scalar_u32, scalar_u64,
scalar_i8, scalar_i16, scalar_i32, scalar_i64,
scalar_float, scalar_double,
input_u8, input_u16, input_u32, input_u64,
input_i8, input_i16, input_i32, input_i64,
input_float, input_double, input_half, input_2d, input_3d,
output_u8, output_u16, output_u32, output_u64,
output_i8, output_i16, output_i32, output_i64,
output_float, output_double, output_half, output_2d, output_3d,
scalar_u8,
scalar_u16,
scalar_u32,
scalar_u64,
scalar_i8,
scalar_i16,
scalar_i32,
scalar_i64,
scalar_float,
scalar_double,
input_u8,
input_u16,
input_u32,
input_u64,
input_i8,
input_i16,
input_i32,
input_i64,
input_float,
input_double,
input_half,
input_2d,
input_3d,
output_u8,
output_u16,
output_u32,
output_u64,
output_i8,
output_i16,
output_i32,
output_i64,
output_float,
output_double,
output_half,
output_2d,
output_3d,
)

combinations = [
Expand Down Expand Up @@ -102,20 +129,47 @@ def test(addconstant_impl_func, offset):
scalar_i32 = 0
addconstant_impl_func(
scalar_u1,
scalar_u8, scalar_u16, scalar_u32, scalar_u64,
scalar_i8, scalar_i16, scalar_i32, scalar_i64,
scalar_float, scalar_double,
input_u8, input_u16, input_u32, input_u64,
input_i8, input_i16, input_i32, input_i64,
input_float, input_double, input_half, input_2d, input_3d,
output_u8, output_u16, output_u32, output_u64,
output_i8, output_i16, output_i32, output_i64,
output_float, output_double, output_half, output_2d, output_3d,
scalar_u8,
scalar_u16,
scalar_u32,
scalar_u64,
scalar_i8,
scalar_i16,
scalar_i32,
scalar_i64,
scalar_float,
scalar_double,
input_u8,
input_u16,
input_u32,
input_u64,
input_i8,
input_i16,
input_i32,
input_i64,
input_float,
input_double,
input_half,
input_2d,
input_3d,
output_u8,
output_u16,
output_u32,
output_u64,
output_i8,
output_i16,
output_i32,
output_i64,
output_float,
output_double,
output_half,
output_2d,
output_3d,
)
except RuntimeError as e:
assert str(e) == "Halide Runtime Error: -27", e
else:
assert False, 'Did not see expected exception!'
assert False, "Did not see expected exception!"

try:
# Expected requirement failure #2 -- note that for AOT-compiled
Expand All @@ -124,20 +178,47 @@ def test(addconstant_impl_func, offset):
scalar_i32 = -1
addconstant_impl_func(
scalar_u1,
scalar_u8, scalar_u16, scalar_u32, scalar_u64,
scalar_i8, scalar_i16, scalar_i32, scalar_i64,
scalar_float, scalar_double,
input_u8, input_u16, input_u32, input_u64,
input_i8, input_i16, input_i32, input_i64,
input_float, input_double, input_half, input_2d, input_3d,
output_u8, output_u16, output_u32, output_u64,
output_i8, output_i16, output_i32, output_i64,
output_float, output_double, output_half, output_2d, output_3d,
scalar_u8,
scalar_u16,
scalar_u32,
scalar_u64,
scalar_i8,
scalar_i16,
scalar_i32,
scalar_i64,
scalar_float,
scalar_double,
input_u8,
input_u16,
input_u32,
input_u64,
input_i8,
input_i16,
input_i32,
input_i64,
input_float,
input_double,
input_half,
input_2d,
input_3d,
output_u8,
output_u16,
output_u32,
output_u64,
output_i8,
output_i16,
output_i32,
output_i64,
output_float,
output_double,
output_half,
output_2d,
output_3d,
)
except RuntimeError as e:
assert str(e) == "Halide Runtime Error: -27", e
else:
assert False, 'Did not see expected exception!'
assert False, "Did not see expected exception!"


if __name__ == "__main__":
Expand Down
10 changes: 6 additions & 4 deletions python_bindings/test/correctness/atomics.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import halide as hl


def test_atomics():
x = hl.Var('x')
im = hl.Func('im')
f = hl.Func('f')
x = hl.Var("x")
im = hl.Func("im")
f = hl.Func("f")
im[x] = (x * x) % 5
r = hl.RDom([(0, 100)])
f[x] = 0
Expand All @@ -16,7 +17,8 @@ def test_atomics():
idx = (i * i) % 5
ref[idx] += 1
for i in range(5):
assert(b[i] == ref[i])
assert b[i] == ref[i]


if __name__ == "__main__":
test_atomics()
50 changes: 26 additions & 24 deletions python_bindings/test/correctness/autodiff.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
import halide as hl


def test_autodiff():
x = hl.Var('x')
x = hl.Var("x")
b = hl.Buffer(hl.Float(32), [3])
p = hl.Param(hl.Float(32), 'p', 1)
p = hl.Param(hl.Float(32), "p", 1)
b[0] = 1.0
b[1] = 2.0
b[2] = 3.0
f, g, h = hl.Func('f'), hl.Func('g'), hl.Func('h')
f, g, h = hl.Func("f"), hl.Func("g"), hl.Func("h")
f[x] = b[x]
f[0] = 4.0
g[x] = f[x] * 5.0 * p
Expand All @@ -20,47 +21,48 @@ def test_autodiff():
# gradient w.r.t. the initialization of f
d_f_init = d[f]
d_f_init_buf = d_f_init.realize([3])
assert(d_f_init_buf[0] == 0.0)
assert(d_f_init_buf[1] == 5.0)
assert(d_f_init_buf[2] == 5.0)
d_f_init = d[f]# test different interface
assert d_f_init_buf[0] == 0.0
assert d_f_init_buf[1] == 5.0
assert d_f_init_buf[2] == 5.0
d_f_init = d[f] # test different interface
d_f_init_buf = d_f_init.realize([3])
assert(d_f_init_buf[0] == 0.0)
assert(d_f_init_buf[1] == 5.0)
assert(d_f_init_buf[2] == 5.0)
assert d_f_init_buf[0] == 0.0
assert d_f_init_buf[1] == 5.0
assert d_f_init_buf[2] == 5.0

# gradient w.r.t. the updated f
d_f_update_0 = d[f, 0]
d_f_update_0_buf = d_f_update_0.realize([3])
assert(d_f_update_0_buf[0] == 5.0)
assert(d_f_update_0_buf[1] == 5.0)
assert(d_f_update_0_buf[2] == 5.0)
assert d_f_update_0_buf[0] == 5.0
assert d_f_update_0_buf[1] == 5.0
assert d_f_update_0_buf[2] == 5.0
d_f_update_0 = d[f, 0]
d_f_update_0_buf = d_f_update_0.realize([3])
assert(d_f_update_0_buf[0] == 5.0)
assert(d_f_update_0_buf[1] == 5.0)
assert(d_f_update_0_buf[2] == 5.0)
assert d_f_update_0_buf[0] == 5.0
assert d_f_update_0_buf[1] == 5.0
assert d_f_update_0_buf[2] == 5.0

# gradient w.r.t. the buffer
d_b = d[b]
d_b_buf = d_b.realize([3])
assert(d_b_buf[0] == 0.0)
assert(d_b_buf[1] == 5.0)
assert(d_b_buf[2] == 5.0)
assert d_b_buf[0] == 0.0
assert d_b_buf[1] == 5.0
assert d_b_buf[2] == 5.0
d_b = d[b]
d_b_buf = d_b.realize([3])
assert(d_b_buf[0] == 0.0)
assert(d_b_buf[1] == 5.0)
assert(d_b_buf[2] == 5.0)
assert d_b_buf[0] == 0.0
assert d_b_buf[1] == 5.0
assert d_b_buf[2] == 5.0

# gradient w.r.t. the param
d_p = d[p]
d_p_buf = d_p.realize()
# 5 * (4 + 2 + 3)
assert(abs(d_p_buf[()] - 45.0) < 1e-6)
assert abs(d_p_buf[()] - 45.0) < 1e-6
d_p = d[p]
d_p_buf = d_p.realize()
assert(abs(d_p_buf[()] - 45.0) < 1e-6)
assert abs(d_p_buf[()] - 45.0) < 1e-6


if __name__ == "__main__":
test_autodiff()
Loading

0 comments on commit bad945f

Please sign in to comment.