Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat[next][dace]: Skeleton of GTIR DaCe backend #1538

Merged
merged 84 commits into from
Jul 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
84 commits
Select commit Hold shift + click to select a range
b36fcb3
Skeleton for ITIR translation
edopao Apr 18, 2024
2020182
Minor edit
edopao Apr 18, 2024
5c6b6ba
Use Python callstack as a context stack for the ITIR visitor
edopao Apr 19, 2024
60e1c69
Format error
edopao Apr 19, 2024
073a0a4
Refactor tasklet codegen
edopao Apr 19, 2024
50be68f
Code refactoring
edopao Apr 19, 2024
4e2dc15
Add domain to field operator
edopao Apr 22, 2024
ea9da35
Minor edit
edopao Apr 22, 2024
daf7827
Remove hard-coded field shape
edopao Apr 23, 2024
9672b3b
Remove hard-coded target domain
edopao Apr 23, 2024
b6326b8
Merge remote-tracking branch 'origin/main' into dace-fieldview
edopao Apr 23, 2024
26f3790
Refactoring
edopao Apr 24, 2024
1efffa7
Merge remote-tracking branch 'origin/main' into dace-fieldview
edopao Apr 24, 2024
f99fa84
Fix formatting
edopao Apr 24, 2024
d6e1088
More refactoring
edopao Apr 24, 2024
9854497
Minor edit
edopao Apr 24, 2024
37d83d7
Fix formatting
edopao Apr 24, 2024
29986ef
Use callable to build taskgraph
edopao Apr 29, 2024
390f3b4
Add draft of select operator
edopao Apr 29, 2024
de27419
Remove node mapping
edopao Apr 30, 2024
cd900f5
Remove node mapping (fix + test case)
edopao Apr 30, 2024
326cbb5
Add test case for inlined mathematic builtins
edopao Apr 30, 2024
a10b614
Go full functional (remove SDFGState member var)
edopao Apr 30, 2024
aef4265
Minor edit
edopao Apr 30, 2024
9e67dfe
Minor edit (1)
edopao Apr 30, 2024
4b4109e
Fix state handling
edopao Apr 30, 2024
495fd0a
Edit comments based on review
edopao Apr 30, 2024
0085194
Add test case for nested select
edopao Apr 30, 2024
41e2a44
Separate builtin translation from driver logic
edopao May 1, 2024
7148c5f
Improve code comments
edopao May 2, 2024
452399d
Avoid inheritance: pass dataflow builder as arg to builtin translator
edopao May 3, 2024
e404226
Codestyle review changes
edopao May 3, 2024
bb0dfac
Remove circular dependency for builtin translators
edopao May 6, 2024
412cd5d
Fix formatting
edopao May 6, 2024
651de5c
Minor edit
edopao May 6, 2024
dcf3eab
Add support to translate each builtin call to a tasklet node
edopao May 6, 2024
7e6909e
Resolve dace warnings
edopao May 7, 2024
2b07cc5
Remove bultin translator for domain expressions
edopao May 7, 2024
2370fa6
Remove bultin translator for domain expressions (1)
edopao May 7, 2024
8e801df
Refactor
edopao May 7, 2024
812a6e5
Minor edit
edopao May 7, 2024
1d0b50b
Extract ITIR visitor to separate class
edopao May 7, 2024
97a1d22
Code refactoring
edopao May 7, 2024
a30cc7d
Fix formatting
edopao May 7, 2024
e9455e3
Changes in preparation for shift builtin
edopao May 13, 2024
801704b
Merge remote-tracking branch 'origin/main' into dace-fieldview
edopao May 13, 2024
c45c417
Add support for programs without computation (pure memlets)
edopao May 13, 2024
d67518a
Fix test
edopao May 13, 2024
c4385c1
Merge remote-tracking branch 'origin/main' into dace-fieldview
edopao May 16, 2024
ed16fd4
Import updates from branch dace-fieldview-shifts
edopao May 16, 2024
9f7176f
Review comments
edopao May 16, 2024
46febb0
Avoid tasklet-to-tasklet edge connections
edopao May 16, 2024
949bad7
Add support for in-out field parameters
edopao May 16, 2024
8890f95
Refactoring: import modules, not symbols
edopao May 17, 2024
87b71a6
Minor edit
edopao May 17, 2024
665a609
Remove internal package for builtin translators
edopao May 17, 2024
82fdf64
Add wrapper function to build SDFG
edopao May 17, 2024
e4718b0
Merge pull request #4 from edopao/dace-fieldview-refactor_imports
edopao May 17, 2024
6ccecf1
Code changes imported from branch dace-fieldview-shifts
edopao May 17, 2024
3c71efa
Import changes from neighbors branch
edopao May 29, 2024
2f75cfb
Add debuginfo for ir.Program and ir.Stmt nodes
edopao May 29, 2024
085f307
Fix error in debuginfo
edopao May 29, 2024
f19960b
Merge remote-tracking branch 'origin/main' into dace-fieldview
edopao May 29, 2024
dc1434c
Fix error in debuginfo (1)
edopao May 29, 2024
a5b0f41
import changes from neighbors branch
edopao Jun 28, 2024
f7ac3d8
Merge remote-tracking branch 'origin/main' into dace-fieldview
edopao Jun 28, 2024
9318011
Import changes from branch dace-fieldview-neighbors
edopao Jul 4, 2024
11efdeb
Merge remote-tracking branch 'origin/main' into dace-fieldview
edopao Jul 4, 2024
d7312fa
Support field with start offset
edopao Jul 4, 2024
c4f2738
Test IR updated for literal operand
edopao Jul 4, 2024
0fd0b65
Add test coverage to previous commit
edopao Jul 4, 2024
38d2720
Refactor PrimitiveTranslator interface
edopao Jul 4, 2024
e855ef9
Fix formatting
edopao Jul 5, 2024
4cff071
Fix for domain horzontal/vertical dims
edopao Jul 5, 2024
f642e85
Fix for type inference on single value expression
edopao Jul 5, 2024
fc9661c
Import changes from dace-fieldview-shifts
edopao Jul 5, 2024
e424d4e
Minor edit
edopao Jul 5, 2024
66c5fcd
Address review comments
edopao Jul 10, 2024
d5abad4
Merge remote-tracking branch 'origin/main' into dace-fieldview
edopao Jul 10, 2024
1df1bc3
Apply convention for map variables
edopao Jul 10, 2024
abf3918
Import changes from dace-fieldview-shifts
edopao Jul 11, 2024
7f60cfe
Import changes from branch dace-fieldview-shifts
edopao Jul 12, 2024
b3131db
Avoid direct import of symbols from module
edopao Jul 12, 2024
130c877
Address review comments
edopao Jul 12, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/gt4py/next/iterator/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,7 @@ def _output_validator(self: datamodels.DataModelTP, attribute: datamodels.Attrib
GTIR_BUILTINS = {
*BUILTINS,
"as_fieldop", # `as_fieldop(stencil, domain)` creates field_operator from stencil (domain is optional, but for now required for embedded execution)
"cond", # `cond(expr, field_a, field_b)` creates the field on one branch or the other
}


Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# GT4Py - GridTools Framework
#
# Copyright (c) 2014-2023, ETH Zurich
# All rights reserved.
#
# This file is part of the GT4Py project and the GridTools framework.
# GT4Py is free software: you can redistribute it and/or modify it under
# the terms of the GNU General Public License as published by the
# Free Software Foundation, either version 3 of the License, or any later
# version. See the LICENSE.txt file at the top-level directory of this
# distribution for a copy of the license or check <https://www.gnu.org/licenses/>.
#
# SPDX-License-Identifier: GPL-3.0-or-later


from gt4py.next.program_processors.runners.dace_fieldview.gtir_to_sdfg import build_sdfg_from_gtir


__all__ = [
"build_sdfg_from_gtir",
]
Original file line number Diff line number Diff line change
@@ -0,0 +1,360 @@
# GT4Py - GridTools Framework
#
# Copyright (c) 2014-2023, ETH Zurich
# All rights reserved.
#
# This file is part of the GT4Py project and the GridTools framework.
# GT4Py is free software: you can redistribute it and/or modify it under
# the terms of the GNU General Public License as published by the
# Free Software Foundation, either version 3 of the License, or any later
# version. See the LICENSE.txt file at the top-level directory of this
# distribution for a copy of the license or check <https://www.gnu.org/licenses/>.
#
# SPDX-License-Identifier: GPL-3.0-or-later


edopao marked this conversation as resolved.
Show resolved Hide resolved
from __future__ import annotations

import abc
from typing import TYPE_CHECKING, Optional, Protocol, TypeAlias

import dace
import dace.subsets as sbs

from gt4py.next import common as gtx_common
from gt4py.next.iterator import ir as gtir
from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm
from gt4py.next.program_processors.runners.dace_fieldview import (
gtir_python_codegen,
gtir_to_tasklet,
utility as dace_fieldview_util,
)
from gt4py.next.type_system import type_specifications as ts


if TYPE_CHECKING:
from gt4py.next.program_processors.runners.dace_fieldview import gtir_to_sdfg


IteratorIndexDType: TypeAlias = dace.int32 # type of iterator indexes
TemporaryData: TypeAlias = tuple[dace.nodes.Node, ts.FieldType | ts.ScalarType]


class PrimitiveTranslator(Protocol):
@abc.abstractmethod
def __call__(
self,
node: gtir.Node,
sdfg: dace.SDFG,
state: dace.SDFGState,
sdfg_builder: gtir_to_sdfg.SDFGBuilder,
) -> list[TemporaryData]:
"""Creates the dataflow subgraph representing a GTIR primitive function.

This method is used by derived classes to build a specialized subgraph
for a specific GTIR primitive function.

Arguments:
node: The GTIR node describing the primitive to be lowered
sdfg: The SDFG where the primitive subgraph should be instantiated
state: The SDFG state where the result of the primitive function should be made available
sdfg_builder: The object responsible for visiting child nodes of the primitive node.

Returns:
A list of data access nodes and the associated GT4Py data type, which provide
access to the result of the primitive subgraph. The GT4Py data type is useful
in the case the returned data is an array, because the type provdes the domain
information (e.g. order of dimensions, dimension types).
"""


def _parse_arg_expr(
node: gtir.Expr,
sdfg: dace.SDFG,
state: dace.SDFGState,
sdfg_builder: gtir_to_sdfg.SDFGBuilder,
domain: list[
tuple[gtx_common.Dimension, dace.symbolic.SymbolicType, dace.symbolic.SymbolicType]
],
) -> gtir_to_tasklet.IteratorExpr | gtir_to_tasklet.MemletExpr:
fields: list[TemporaryData] = sdfg_builder.visit(node, sdfg=sdfg, head_state=state)

assert len(fields) == 1
data_node, arg_type = fields[0]
# require all argument nodes to be data access nodes (no symbols)
assert isinstance(data_node, dace.nodes.AccessNode)

if isinstance(arg_type, ts.ScalarType):
return gtir_to_tasklet.MemletExpr(data_node, sbs.Indices([0]))
else:
assert isinstance(arg_type, ts.FieldType)
indices: dict[gtx_common.Dimension, gtir_to_tasklet.IteratorIndexExpr] = {
dim: gtir_to_tasklet.SymbolExpr(
dace_fieldview_util.get_map_variable(dim),
IteratorIndexDType,
)
for dim, _, _ in domain
}
return gtir_to_tasklet.IteratorExpr(
data_node,
arg_type.dims,
indices,
)


def _create_temporary_field(
sdfg: dace.SDFG,
state: dace.SDFGState,
domain: list[
tuple[gtx_common.Dimension, dace.symbolic.SymbolicType, dace.symbolic.SymbolicType]
],
node_type: ts.ScalarType,
output_desc: dace.data.Data,
output_field_type: ts.DataType,
) -> tuple[dace.nodes.AccessNode, ts.FieldType]:
domain_dims, domain_lbs, domain_ubs = zip(*domain)
field_dims = list(domain_dims)
field_shape = [
# diff between upper and lower bound
(ub - lb)
for lb, ub in zip(domain_lbs, domain_ubs)
]
field_offset: Optional[list[dace.symbolic.SymbolicType]] = None
if any(domain_lbs):
field_offset = [-lb for lb in domain_lbs]

if isinstance(output_desc, dace.data.Array):
# extend the result arrays with the local dimensions added by the field operator e.g. `neighbors`)
assert isinstance(output_field_type, ts.FieldType)
# TODO: enable `assert output_field_type.dtype == node_type`, remove variable `dtype`
node_type = output_field_type.dtype
field_dims.extend(output_field_type.dims)
field_shape.extend(output_desc.shape)
else:
assert isinstance(output_desc, dace.data.Scalar)
assert isinstance(output_field_type, ts.ScalarType)
# TODO: enable `assert output_field_type == node_type`, remove variable `dtype`
node_type = output_field_type

# allocate local temporary storage for the result field
temp_name, _ = sdfg.add_temp_transient(
field_shape, dace_fieldview_util.as_dace_type(node_type), offset=field_offset
)
field_node = state.add_access(temp_name)
field_type = ts.FieldType(field_dims, node_type)

return field_node, field_type


def translate_as_field_op(
node: gtir.Node,
sdfg: dace.SDFG,
state: dace.SDFGState,
sdfg_builder: gtir_to_sdfg.SDFGBuilder,
) -> list[TemporaryData]:
"""Generates the dataflow subgraph for the `as_field_op` builtin function."""
assert isinstance(node, gtir.FunCall)
assert cpm.is_call_to(node.fun, "as_fieldop")

fun_node = node.fun
assert len(fun_node.args) == 2
stencil_expr, domain_expr = fun_node.args
# expect stencil (represented as a lambda function) as first argument
assert isinstance(stencil_expr, gtir.Lambda)
# the domain of the field operator is passed as second argument
assert isinstance(domain_expr, gtir.FunCall)

# add local storage to compute the field operator over the given domain
# TODO: use type inference to determine the result type
node_type = ts.ScalarType(kind=ts.ScalarKind.FLOAT64)
domain = dace_fieldview_util.get_domain(domain_expr)

# first visit the list of arguments and build a symbol map
stencil_args = [_parse_arg_expr(arg, sdfg, state, sdfg_builder, domain) for arg in node.args]

# represent the field operator as a mapped tasklet graph, which will range over the field domain
taskgen = gtir_to_tasklet.LambdaToTasklet(sdfg, state, sdfg_builder)
input_connections, output_expr = taskgen.visit(stencil_expr, args=stencil_args)
assert isinstance(output_expr, gtir_to_tasklet.ValueExpr)
output_desc = output_expr.node.desc(sdfg)

# retrieve the tasklet node which writes the result
last_node = state.in_edges(output_expr.node)[0].src
if isinstance(last_node, dace.nodes.Tasklet):
# the last transient node can be deleted
last_node_connector = state.in_edges(output_expr.node)[0].src_conn
state.remove_node(output_expr.node)
else:
last_node = output_expr.node
last_node_connector = None

# allocate local temporary storage for the result field
field_node, field_type = _create_temporary_field(
sdfg, state, domain, node_type, output_desc, output_expr.field_type
)

# assume tasklet with single output
output_subset = [dace_fieldview_util.get_map_variable(dim) for dim, _, _ in domain]
if isinstance(output_desc, dace.data.Array):
# additional local dimension for neighbors
assert set(output_desc.offset) == {0}
output_subset.extend(f"0:{size}" for size in output_desc.shape)

# create map range corresponding to the field operator domain
map_ranges = {dace_fieldview_util.get_map_variable(dim): f"{lb}:{ub}" for dim, lb, ub in domain}
me, mx = sdfg_builder.add_map("field_op", state, map_ranges)

if len(input_connections) == 0:
# dace requires an empty edge from map entry node to tasklet node, in case there no input memlets
state.add_nedge(me, last_node, dace.Memlet())
else:
for data_node, data_subset, lambda_node, lambda_connector in input_connections:
memlet = dace.Memlet(data=data_node.data, subset=data_subset)
state.add_memlet_path(
data_node,
me,
lambda_node,
dst_conn=lambda_connector,
memlet=memlet,
)
state.add_memlet_path(
last_node,
mx,
field_node,
src_conn=last_node_connector,
memlet=dace.Memlet(data=field_node.data, subset=",".join(output_subset)),
)

return [(field_node, field_type)]


def translate_cond(
node: gtir.Node,
sdfg: dace.SDFG,
state: dace.SDFGState,
sdfg_builder: gtir_to_sdfg.SDFGBuilder,
) -> list[TemporaryData]:
"""Generates the dataflow subgraph for the `cond` builtin function."""
assert isinstance(node, gtir.FunCall)
assert cpm.is_call_to(node.fun, "cond")
assert len(node.args) == 0

fun_node = node.fun
assert len(fun_node.args) == 3
cond_expr, true_expr, false_expr = fun_node.args

# expect condition as first argument
cond = gtir_python_codegen.get_source(cond_expr)

# use current head state to terminate the dataflow, and add a entry state
# to connect the true/false branch states as follows:
#
# ------------
# === | cond | ===
# || ------------ ||
# \/ \/
# ------------ -------------
# | true | | false |
# ------------ -------------
# || ||
# || ------------ ||
# ==> | head | <==
# ------------
#
cond_state = sdfg.add_state_before(state, state.label + "_cond")
sdfg.remove_edge(sdfg.out_edges(cond_state)[0])

# expect true branch as second argument
true_state = sdfg.add_state(state.label + "_true_branch")
sdfg.add_edge(cond_state, true_state, dace.InterstateEdge(condition=f"bool({cond})"))
sdfg.add_edge(true_state, state, dace.InterstateEdge())

# and false branch as third argument
false_state = sdfg.add_state(state.label + "_false_branch")
sdfg.add_edge(cond_state, false_state, dace.InterstateEdge(condition=(f"not bool({cond})")))
sdfg.add_edge(false_state, state, dace.InterstateEdge())

true_br_args = sdfg_builder.visit(true_expr, sdfg=sdfg, head_state=true_state)
false_br_args = sdfg_builder.visit(false_expr, sdfg=sdfg, head_state=false_state)

output_nodes = []
for true_br, false_br in zip(true_br_args, false_br_args, strict=True):
true_br_node, true_br_type = true_br
assert isinstance(true_br_node, dace.nodes.AccessNode)
false_br_node, _ = false_br
assert isinstance(false_br_node, dace.nodes.AccessNode)
desc = true_br_node.desc(sdfg)
assert false_br_node.desc(sdfg) == desc
data_name, _ = sdfg.add_temp_transient_like(desc)
output_nodes.append((state.add_access(data_name), true_br_type))

true_br_output_node = true_state.add_access(data_name)
true_state.add_nedge(
true_br_node,
true_br_output_node,
dace.Memlet.from_array(data_name, desc),
)

false_br_output_node = false_state.add_access(data_name)
false_state.add_nedge(
false_br_node,
false_br_output_node,
dace.Memlet.from_array(data_name, desc),
)

return output_nodes


def translate_symbol_ref(
node: gtir.Node,
sdfg: dace.SDFG,
state: dace.SDFGState,
sdfg_builder: gtir_to_sdfg.SDFGBuilder,
) -> list[TemporaryData]:
"""Generates the dataflow subgraph for a `ir.SymRef` node."""
assert isinstance(node, (gtir.Literal, gtir.SymRef))

data_type: ts.FieldType | ts.ScalarType
if isinstance(node, gtir.Literal):
sym_value = node.value
data_type = node.type
tasklet_name = "get_literal"
else:
sym_value = str(node.id)
data_type = sdfg_builder.get_symbol_type(sym_value)
tasklet_name = f"get_{sym_value}"

if isinstance(data_type, ts.FieldType):
# add access node to current state
sym_node = state.add_access(sym_value)

else:
# scalar symbols are passed to the SDFG as symbols: build tasklet node
# to write the symbol to a scalar access node
tasklet_node = sdfg_builder.add_tasklet(
tasklet_name,
state,
{},
{"__out"},
f"__out = {sym_value}",
)
temp_name, _ = sdfg.add_temp_transient((1,), dace_fieldview_util.as_dace_type(data_type))
sym_node = state.add_access(temp_name)
state.add_edge(
tasklet_node,
"__out",
sym_node,
None,
dace.Memlet(data=sym_node.data, subset="0"),
)

return [(sym_node, data_type)]


if TYPE_CHECKING:
# Use type-checking to assert that all translator functions implement the `PrimitiveTranslator` protocol
__primitive_translators: list[PrimitiveTranslator] = [
translate_as_field_op,
translate_cond,
translate_symbol_ref,
]
Loading
Loading