diff --git a/python/tvm/relay/prelude.py b/python/tvm/relay/prelude.py index 26f00c5c5e6d..fb8c58bf431e 100644 --- a/python/tvm/relay/prelude.py +++ b/python/tvm/relay/prelude.py @@ -62,6 +62,25 @@ def define_list_nth(self): s_case = Clause(PatternConstructor(self.s, [PatternVar(y)]), self.nth(self.tl(x), y)) self.mod[self.nth] = Function([x, n], Match(n, [z_case, s_case]), a, [a]) + def define_list_update(self): + """Defines a function to update the nth element of a list and return the updated list. + + update(l, i, v) : list[a] -> nat -> a -> list[a] + """ + self.update = GlobalVar("update") + a = TypeVar("a") + l = Var("l", self.l(a)) + n = Var("n", self.nat()) + v = Var("v", a) + + y = Var("y") + + z_case = Clause(PatternConstructor(self.z), self.cons(v, self.tl(l))) + s_case = Clause(PatternConstructor(self.s, [PatternVar(y)]), + self.cons(self.hd(l), self.update(self.tl(l), y, v))) + + self.mod[self.update] = Function([l, n, v], Match(n, [z_case, s_case]), self.l(a), [a]) + def define_list_map(self): """Defines a function for mapping a function over a list's elements. That is, map(f, l) returns a new list where @@ -470,6 +489,7 @@ def __init__(self, mod): self.define_nat_add() self.define_list_length() self.define_list_nth() + self.define_list_update() self.define_list_sum() self.define_tree_adt() diff --git a/tests/python/relay/test_adt.py b/tests/python/relay/test_adt.py index e176194fede6..e9e2915f28a8 100644 --- a/tests/python/relay/test_adt.py +++ b/tests/python/relay/test_adt.py @@ -26,6 +26,7 @@ hd = p.hd tl = p.tl nth = p.nth +update = p.update length = p.length map = p.map foldl = p.foldl @@ -148,6 +149,23 @@ def test_nth(): assert got == expected +def test_update(): + expected = list(range(10)) + l = nil() + # create zero initialized list + for i in range(len(expected)): + l = cons(build_nat(0), l) + + # set value + for i, v in enumerate(expected): + l = update(l, build_nat(i), build_nat(v)) + + got = [] + for i in range(len(expected)): + got.append(count(intrp.evaluate(nth(l, build_nat(i))))) + + assert got == expected + def test_length(): a = relay.TypeVar("a") assert mod[length].checked_type == relay.FuncType([l(a)], nat(), [a])