From ed1fdf69fec98973ee9177c22b61dc5ec8af80e4 Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Fri, 9 Aug 2024 13:00:23 +0200 Subject: [PATCH 1/3] Add lowering of cast_ builtin function --- .../dace_fieldview/gtir_python_codegen.py | 25 ++++++++++++++----- .../runners/dace_fieldview/gtir_to_tasklet.py | 7 ++++-- .../runners_tests/test_dace_fieldview.py | 17 +++++++------ 3 files changed, 34 insertions(+), 15 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_python_codegen.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_python_codegen.py index fcb71e4e6d..98c69cc1b1 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_python_codegen.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_python_codegen.py @@ -14,7 +14,7 @@ from __future__ import annotations -from typing import Any +from typing import Any, Callable import numpy as np @@ -79,12 +79,25 @@ } -def format_builtin(bultin: str, *args: Any) -> str: - if bultin in MATH_BUILTINS_MAPPING: - fmt = MATH_BUILTINS_MAPPING[bultin] +def builtin_cast(*args: Any) -> str: + val, target_type = args + return MATH_BUILTINS_MAPPING[target_type].format(val) + + +GENERAL_BUILTIN_MAPPING: dict[str, Callable[[Any], str]] = { + "cast_": builtin_cast, +} + + +def format_builtin(builtin: str, *args: Any) -> str: + if builtin in MATH_BUILTINS_MAPPING: + fmt = MATH_BUILTINS_MAPPING[builtin] + return fmt.format(*args) + elif builtin in GENERAL_BUILTIN_MAPPING: + expr_func = GENERAL_BUILTIN_MAPPING[builtin] + return expr_func(*args) else: - raise NotImplementedError(f"'{bultin}' not implemented.") - return fmt.format(*args) + raise NotImplementedError(f"'{builtin}' not implemented.") class PythonCodegen(codegen.TemplatedGenerator): diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py index 188b19c577..b782cc4034 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py @@ -802,5 +802,8 @@ def visit_Literal(self, node: gtir.Literal) -> SymbolExpr: def visit_SymRef(self, node: gtir.SymRef) -> IteratorExpr | MemletExpr | SymbolExpr: param = str(node.id) - assert param in self.symbol_map - return self.symbol_map[param] + if param in self.symbol_map: + return self.symbol_map[param] + # if not in the lambda symbol map, this must be a symref to a builtin function + assert param in gtir_python_codegen.MATH_BUILTINS_MAPPING + return SymbolExpr(param, dace.string) diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_dace_fieldview.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_dace_fieldview.py index a4d04511fa..d3eca393da 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_dace_fieldview.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_dace_fieldview.py @@ -102,35 +102,38 @@ def make_mesh_symbols(mesh: MeshDescriptor): ) -def test_gtir_copy(): +def test_gtir_cast(): domain = im.call("cartesian_domain")( im.call("named_range")(gtir.AxisLiteral(value=IDim.value), 0, "size") ) + IFTYPE_INT64 = ts.FieldType(IFTYPE.dims, dtype=ts.ScalarType(kind=ts.ScalarKind.INT64)) testee = gtir.Program( - id="gtir_copy", + id="test_gtir_cast", function_definitions=[], params=[ - gtir.Sym(id="x", type=IFTYPE), + gtir.Sym(id="x", type=IFTYPE_INT64), gtir.Sym(id="y", type=IFTYPE), gtir.Sym(id="size", type=SIZE_TYPE), ], declarations=[], body=[ gtir.SetAt( - expr=im.as_fieldop(im.lambda_("a")(im.deref("a")), domain)("x"), + expr=im.as_fieldop( + im.lambda_("a")(im.call("cast_")(im.deref("a"), "float64")), domain + )("x"), domain=domain, target=gtir.SymRef(id="y"), ) ], ) - a = np.random.rand(N) - b = np.empty_like(a) + a = np.random.randint(-1000, +1000, N) + b = np.empty_like(a, dtype=np.float64) sdfg = dace_backend.build_sdfg_from_gtir(testee, CARTESIAN_OFFSETS) sdfg(x=a, y=b, **FSYMBOLS) - assert np.allclose(a, b) + assert np.allclose(a.astype(np.float64), b) def test_gtir_update(): From 8c57d683d2a71832e3c4ad5c6f130412df20eb6b Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Tue, 13 Aug 2024 17:15:54 +0200 Subject: [PATCH 2/3] Improve test case --- .../runners_tests/test_dace_fieldview.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_dace_fieldview.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_dace_fieldview.py index 276a626731..50294ae294 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_dace_fieldview.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_dace_fieldview.py @@ -113,22 +113,25 @@ def test_gtir_cast(): declarations=[], body=[ gtir.SetAt( - expr=im.as_fieldop( - im.lambda_("a")(im.call("cast_")(im.deref("a"), "float64")), domain - )("x"), + expr=im.op_as_fieldop("divides", domain)( + im.as_fieldop( + im.lambda_("a")(im.call("cast_")(im.deref("a"), "float64")), domain + )("x"), + 2.0, + ), domain=domain, target=gtir.SymRef(id="y"), ) ], ) - a = np.random.randint(-1000, +1000, N) + a = np.ones(N, dtype=np.int64) b = np.empty_like(a, dtype=np.float64) sdfg = dace_backend.build_sdfg_from_gtir(testee, CARTESIAN_OFFSETS) sdfg(x=a, y=b, **FSYMBOLS) - assert np.allclose(a.astype(np.float64), b) + assert np.allclose(b, 0.5) def test_gtir_update(): From f33b6dbca86421454754e43c1c942b3200eece03 Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Wed, 14 Aug 2024 08:35:28 +0200 Subject: [PATCH 3/3] Improve test case (1) --- .../runners_tests/test_dace_fieldview.py | 25 +++++++++++-------- 1 file changed, 14 insertions(+), 11 deletions(-) diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_dace_fieldview.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_dace_fieldview.py index 50294ae294..a8fc9eb33e 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_dace_fieldview.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_dace_fieldview.py @@ -101,37 +101,40 @@ def test_gtir_cast(): domain = im.call("cartesian_domain")( im.call("named_range")(gtir.AxisLiteral(value=IDim.value), 0, "size") ) - IFTYPE_INT64 = ts.FieldType(IFTYPE.dims, dtype=ts.ScalarType(kind=ts.ScalarKind.INT64)) + IFTYPE_FLOAT32 = ts.FieldType(IFTYPE.dims, dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT32)) + IFTYPE_BOOL = ts.FieldType(IFTYPE.dims, dtype=ts.ScalarType(kind=ts.ScalarKind.BOOL)) testee = gtir.Program( id="test_gtir_cast", function_definitions=[], params=[ - gtir.Sym(id="x", type=IFTYPE_INT64), - gtir.Sym(id="y", type=IFTYPE), + gtir.Sym(id="x", type=IFTYPE), + gtir.Sym(id="y", type=IFTYPE_FLOAT32), + gtir.Sym(id="z", type=IFTYPE_BOOL), gtir.Sym(id="size", type=SIZE_TYPE), ], declarations=[], body=[ gtir.SetAt( - expr=im.op_as_fieldop("divides", domain)( + expr=im.op_as_fieldop("eq", domain)( im.as_fieldop( - im.lambda_("a")(im.call("cast_")(im.deref("a"), "float64")), domain + im.lambda_("a")(im.call("cast_")(im.deref("a"), "float32")), domain )("x"), - 2.0, + "y", ), domain=domain, - target=gtir.SymRef(id="y"), + target=gtir.SymRef(id="z"), ) ], ) - a = np.ones(N, dtype=np.int64) - b = np.empty_like(a, dtype=np.float64) + a = np.ones(N, dtype=np.float64) * np.sqrt(2.0) + b = a.astype(np.float32) + c = np.empty_like(a, dtype=np.bool_) sdfg = dace_backend.build_sdfg_from_gtir(testee, CARTESIAN_OFFSETS) - sdfg(x=a, y=b, **FSYMBOLS) - assert np.allclose(b, 0.5) + sdfg(x=a, y=b, z=c, **FSYMBOLS) + np.testing.assert_array_equal(c, True) def test_gtir_update():