Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat[next][dace]: Skeleton of GTIR DaCe backend #1538

Merged
merged 84 commits into from
Jul 12, 2024
Merged
Show file tree
Hide file tree
Changes from 48 commits
Commits
Show all changes
84 commits
Select commit Hold shift + click to select a range
b36fcb3
Skeleton for ITIR translation
edopao Apr 18, 2024
2020182
Minor edit
edopao Apr 18, 2024
5c6b6ba
Use Python callstack as a context stack for the ITIR visitor
edopao Apr 19, 2024
60e1c69
Format error
edopao Apr 19, 2024
073a0a4
Refactor tasklet codegen
edopao Apr 19, 2024
50be68f
Code refactoring
edopao Apr 19, 2024
4e2dc15
Add domain to field operator
edopao Apr 22, 2024
ea9da35
Minor edit
edopao Apr 22, 2024
daf7827
Remove hard-coded field shape
edopao Apr 23, 2024
9672b3b
Remove hard-coded target domain
edopao Apr 23, 2024
b6326b8
Merge remote-tracking branch 'origin/main' into dace-fieldview
edopao Apr 23, 2024
26f3790
Refactoring
edopao Apr 24, 2024
1efffa7
Merge remote-tracking branch 'origin/main' into dace-fieldview
edopao Apr 24, 2024
f99fa84
Fix formatting
edopao Apr 24, 2024
d6e1088
More refactoring
edopao Apr 24, 2024
9854497
Minor edit
edopao Apr 24, 2024
37d83d7
Fix formatting
edopao Apr 24, 2024
29986ef
Use callable to build taskgraph
edopao Apr 29, 2024
390f3b4
Add draft of select operator
edopao Apr 29, 2024
de27419
Remove node mapping
edopao Apr 30, 2024
cd900f5
Remove node mapping (fix + test case)
edopao Apr 30, 2024
326cbb5
Add test case for inlined mathematic builtins
edopao Apr 30, 2024
a10b614
Go full functional (remove SDFGState member var)
edopao Apr 30, 2024
aef4265
Minor edit
edopao Apr 30, 2024
9e67dfe
Minor edit (1)
edopao Apr 30, 2024
4b4109e
Fix state handling
edopao Apr 30, 2024
495fd0a
Edit comments based on review
edopao Apr 30, 2024
0085194
Add test case for nested select
edopao Apr 30, 2024
41e2a44
Separate builtin translation from driver logic
edopao May 1, 2024
7148c5f
Improve code comments
edopao May 2, 2024
452399d
Avoid inheritance: pass dataflow builder as arg to builtin translator
edopao May 3, 2024
e404226
Codestyle review changes
edopao May 3, 2024
bb0dfac
Remove circular dependency for builtin translators
edopao May 6, 2024
412cd5d
Fix formatting
edopao May 6, 2024
651de5c
Minor edit
edopao May 6, 2024
dcf3eab
Add support to translate each builtin call to a tasklet node
edopao May 6, 2024
7e6909e
Resolve dace warnings
edopao May 7, 2024
2b07cc5
Remove bultin translator for domain expressions
edopao May 7, 2024
2370fa6
Remove bultin translator for domain expressions (1)
edopao May 7, 2024
8e801df
Refactor
edopao May 7, 2024
812a6e5
Minor edit
edopao May 7, 2024
1d0b50b
Extract ITIR visitor to separate class
edopao May 7, 2024
97a1d22
Code refactoring
edopao May 7, 2024
a30cc7d
Fix formatting
edopao May 7, 2024
e9455e3
Changes in preparation for shift builtin
edopao May 13, 2024
801704b
Merge remote-tracking branch 'origin/main' into dace-fieldview
edopao May 13, 2024
c45c417
Add support for programs without computation (pure memlets)
edopao May 13, 2024
d67518a
Fix test
edopao May 13, 2024
c4385c1
Merge remote-tracking branch 'origin/main' into dace-fieldview
edopao May 16, 2024
ed16fd4
Import updates from branch dace-fieldview-shifts
edopao May 16, 2024
9f7176f
Review comments
edopao May 16, 2024
46febb0
Avoid tasklet-to-tasklet edge connections
edopao May 16, 2024
949bad7
Add support for in-out field parameters
edopao May 16, 2024
8890f95
Refactoring: import modules, not symbols
edopao May 17, 2024
87b71a6
Minor edit
edopao May 17, 2024
665a609
Remove internal package for builtin translators
edopao May 17, 2024
82fdf64
Add wrapper function to build SDFG
edopao May 17, 2024
e4718b0
Merge pull request #4 from edopao/dace-fieldview-refactor_imports
edopao May 17, 2024
6ccecf1
Code changes imported from branch dace-fieldview-shifts
edopao May 17, 2024
3c71efa
Import changes from neighbors branch
edopao May 29, 2024
2f75cfb
Add debuginfo for ir.Program and ir.Stmt nodes
edopao May 29, 2024
085f307
Fix error in debuginfo
edopao May 29, 2024
f19960b
Merge remote-tracking branch 'origin/main' into dace-fieldview
edopao May 29, 2024
dc1434c
Fix error in debuginfo (1)
edopao May 29, 2024
a5b0f41
import changes from neighbors branch
edopao Jun 28, 2024
f7ac3d8
Merge remote-tracking branch 'origin/main' into dace-fieldview
edopao Jun 28, 2024
9318011
Import changes from branch dace-fieldview-neighbors
edopao Jul 4, 2024
11efdeb
Merge remote-tracking branch 'origin/main' into dace-fieldview
edopao Jul 4, 2024
d7312fa
Support field with start offset
edopao Jul 4, 2024
c4f2738
Test IR updated for literal operand
edopao Jul 4, 2024
0fd0b65
Add test coverage to previous commit
edopao Jul 4, 2024
38d2720
Refactor PrimitiveTranslator interface
edopao Jul 4, 2024
e855ef9
Fix formatting
edopao Jul 5, 2024
4cff071
Fix for domain horzontal/vertical dims
edopao Jul 5, 2024
f642e85
Fix for type inference on single value expression
edopao Jul 5, 2024
fc9661c
Import changes from dace-fieldview-shifts
edopao Jul 5, 2024
e424d4e
Minor edit
edopao Jul 5, 2024
66c5fcd
Address review comments
edopao Jul 10, 2024
d5abad4
Merge remote-tracking branch 'origin/main' into dace-fieldview
edopao Jul 10, 2024
1df1bc3
Apply convention for map variables
edopao Jul 10, 2024
abf3918
Import changes from dace-fieldview-shifts
edopao Jul 11, 2024
7f60cfe
Import changes from branch dace-fieldview-shifts
edopao Jul 12, 2024
b3131db
Avoid direct import of symbols from module
edopao Jul 12, 2024
130c877
Address review comments
edopao Jul 12, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/gt4py/next/iterator/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,7 @@ def _output_validator(self: datamodels.DataModelTP, attribute: datamodels.Attrib
GTIR_BUILTINS = {
*BUILTINS,
"as_fieldop", # `as_fieldop(stencil)` creates field_operator from stencil
"select", # `select(cond, field_a, field_b)` creates the field on one branch or the other
}


Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# GT4Py - GridTools Framework
#
# Copyright (c) 2014-2023, ETH Zurich
# All rights reserved.
#
# This file is part of the GT4Py project and the GridTools framework.
# GT4Py is free software: you can redistribute it and/or modify it under
# the terms of the GNU General Public License as published by the
# Free Software Foundation, either version 3 of the License, or any later
# version. See the LICENSE.txt file at the top-level directory of this
# distribution for a copy of the license or check <https://www.gnu.org/licenses/>.
#
# SPDX-License-Identifier: GPL-3.0-or-later
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# GT4Py - GridTools Framework
#
# Copyright (c) 2014-2023, ETH Zurich
# All rights reserved.
#
# This file is part of the GT4Py project and the GridTools framework.
# GT4Py is free software: you can redistribute it and/or modify it under
# the terms of the GNU General Public License as published by the
# Free Software Foundation, either version 3 of the License, or any later
# version. See the LICENSE.txt file at the top-level directory of this
# distribution for a copy of the license or check <https://www.gnu.org/licenses/>.
#
# SPDX-License-Identifier: GPL-3.0-or-later

from gt4py.next.program_processors.runners.dace_fieldview.gtir_builtins.gtir_builtin_field_operator import (
GTIRBuiltinAsFieldOp as AsFieldOp,
)
from gt4py.next.program_processors.runners.dace_fieldview.gtir_builtins.gtir_builtin_select import (
GTIRBuiltinSelect as Select,
)
from gt4py.next.program_processors.runners.dace_fieldview.gtir_builtins.gtir_builtin_symbol_ref import (
GTIRBuiltinSymbolRef as SymbolRef,
)


# export short names of translation classes for GTIR builtin functions
__all__ = [
"AsFieldOp",
"Select",
"SymbolRef",
]
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
# GT4Py - GridTools Framework
#
# Copyright (c) 2014-2023, ETH Zurich
# All rights reserved.
#
# This file is part of the GT4Py project and the GridTools framework.
# GT4Py is free software: you can redistribute it and/or modify it under
# the terms of the GNU General Public License as published by the
# Free Software Foundation, either version 3 of the License, or any later
# version. See the LICENSE.txt file at the top-level directory of this
# distribution for a copy of the license or check <https://www.gnu.org/licenses/>.
#
# SPDX-License-Identifier: GPL-3.0-or-later


from typing import Callable, TypeAlias

import dace
import dace.subsets as sbs

from gt4py.next.common import Connectivity, Dimension
from gt4py.next.iterator import ir as itir
from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm
from gt4py.next.program_processors.runners.dace_fieldview.gtir_builtins.gtir_builtin_translator import (
GTIRBuiltinTranslator,
)
from gt4py.next.program_processors.runners.dace_fieldview.gtir_to_tasklet import (
GTIRToTasklet,
IteratorExpr,
MemletExpr,
SymbolExpr,
TaskletExpr,
)
from gt4py.next.program_processors.runners.dace_fieldview.utility import get_domain, unique_name
from gt4py.next.type_system import type_specifications as ts


# Define type of variables used for field indexing
_INDEX_DTYPE = dace.int64


class GTIRBuiltinAsFieldOp(GTIRBuiltinTranslator):
"""Generates the dataflow subgraph for the `as_field_op` builtin function."""

TaskletConnector: TypeAlias = tuple[dace.nodes.Tasklet, str]

stencil_expr: itir.Lambda
stencil_args: list[Callable]
field_domain: dict[Dimension, tuple[dace.symbolic.SymbolicType, dace.symbolic.SymbolicType]]
field_type: ts.FieldType
offset_provider: dict[str, Connectivity | Dimension]

def __init__(
self,
sdfg: dace.SDFG,
state: dace.SDFGState,
node: itir.FunCall,
stencil_args: list[Callable],
offset_provider: dict[str, Connectivity | Dimension],
):
super().__init__(sdfg, state)
self.offset_provider = offset_provider

assert cpm.is_call_to(node.fun, "as_fieldop")
assert len(node.fun.args) == 2
stencil_expr, domain_expr = node.fun.args
# expect stencil (represented as a lambda function) as first argument
assert isinstance(stencil_expr, itir.Lambda)
# the domain of the field operator is passed as second argument
assert isinstance(domain_expr, itir.FunCall)

domain = get_domain(domain_expr)
# define field domain with all dimensions in alphabetical order
sorted_domain_dims = sorted(domain.keys(), key=lambda x: x.value)

# add local storage to compute the field operator over the given domain
# TODO: use type inference to determine the result type
node_type = ts.ScalarType(kind=ts.ScalarKind.FLOAT64)

self.field_domain = domain
self.field_type = ts.FieldType(sorted_domain_dims, node_type)
self.stencil_expr = stencil_expr
self.stencil_args = stencil_args

def build(self) -> list[tuple[dace.nodes.Node, ts.FieldType | ts.ScalarType]]:
dimension_index_fmt = "i_{dim}"
# first visit the list of arguments and build a symbol map
stencil_args: list[IteratorExpr | MemletExpr] = []
for arg in self.stencil_args:
arg_nodes = arg()
assert len(arg_nodes) == 1
data_node, arg_type = arg_nodes[0]
# require all argument nodes to be data access nodes (no symbols)
assert isinstance(data_node, dace.nodes.AccessNode)

if isinstance(arg_type, ts.ScalarType):
scalar_arg = MemletExpr(data_node, sbs.Indices([0]))
stencil_args.append(scalar_arg)
else:
assert isinstance(arg_type, ts.FieldType)
indices: dict[str, MemletExpr | SymbolExpr | TaskletExpr] = {
dim.value: SymbolExpr(
dace.symbolic.SymExpr(dimension_index_fmt.format(dim=dim.value)),
_INDEX_DTYPE,
)
for dim in self.field_domain.keys()
}
iterator_arg = IteratorExpr(
data_node,
[dim.value for dim in arg_type.dims],
sbs.Indices([0] * len(arg_type.dims)),
indices,
)
stencil_args.append(iterator_arg)

# represent the field operator as a mapped tasklet graph, which will range over the field domain
taskgen = GTIRToTasklet(self.sdfg, self.head_state, self.offset_provider)
input_connections, output_expr = taskgen.visit(self.stencil_expr, args=stencil_args)
assert isinstance(output_expr, TaskletExpr)

# allocate local temporary storage for the result field
field_shape = [
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note to myself: does the data have to flow over a "local storage"

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the final write to an external field, this internal node is redundant, and dace simplify pass will remove it. On the other hand, it is needed when we write back to a field also used as input (in-out field parameters): in this case, the dace simplify pass will correctly keep it. I have added a testcase test_gtir_update to cover this case.

# diff between upper and lower bound
self.field_domain[dim][1] - self.field_domain[dim][0]
for dim in self.field_type.dims
]
field_node = self.add_local_storage(self.field_type, field_shape)

# assume tasklet with single output
output_index = ",".join(
dimension_index_fmt.format(dim=dim.value) for dim in self.field_type.dims
)
output_memlet = dace.Memlet(data=field_node.data, subset=output_index)

# create map range corresponding to the field operator domain
map_ranges = {
dimension_index_fmt.format(dim=dim.value): f"{lb}:{ub}"
for dim, (lb, ub) in self.field_domain.items()
}
me, mx = self.head_state.add_map(unique_name("map"), map_ranges)

for data_node, data_subset, lambda_node, lambda_connector in input_connections:
memlet = dace.Memlet(data=data_node.data, subset=data_subset, volume=1)
self.head_state.add_memlet_path(
data_node,
me,
lambda_node,
dst_conn=lambda_connector,
memlet=memlet,
)
self.head_state.add_memlet_path(
output_expr.node, mx, field_node, src_conn=output_expr.connector, memlet=output_memlet
)

return [(field_node, self.field_type)]
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
# GT4Py - GridTools Framework
#
# Copyright (c) 2014-2023, ETH Zurich
# All rights reserved.
#
# This file is part of the GT4Py project and the GridTools framework.
# GT4Py is free software: you can redistribute it and/or modify it under
# the terms of the GNU General Public License as published by the
# Free Software Foundation, either version 3 of the License, or any later
# version. See the LICENSE.txt file at the top-level directory of this
# distribution for a copy of the license or check <https://www.gnu.org/licenses/>.
#
# SPDX-License-Identifier: GPL-3.0-or-later


from typing import Callable

import dace

from gt4py import eve
from gt4py.next.iterator import ir as itir
from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm
from gt4py.next.program_processors.runners.dace_fieldview.gtir_builtins.gtir_builtin_translator import (
GTIRBuiltinTranslator,
)
from gt4py.next.program_processors.runners.dace_fieldview.utility import get_symbolic_expr
edopao marked this conversation as resolved.
Show resolved Hide resolved
from gt4py.next.type_system import type_specifications as ts


class GTIRBuiltinSelect(GTIRBuiltinTranslator):
"""Generates the dataflow subgraph for the `select` builtin function."""

true_br_builder: Callable
false_br_builder: Callable
edopao marked this conversation as resolved.
Show resolved Hide resolved

def __init__(
self,
sdfg: dace.SDFG,
state: dace.SDFGState,
dataflow_builder: eve.NodeVisitor,
node: itir.FunCall,
):
super().__init__(sdfg, state)

assert cpm.is_call_to(node.fun, "select")
assert len(node.fun.args) == 3
cond_expr, true_expr, false_expr = node.fun.args

# expect condition as first argument
cond = get_symbolic_expr(cond_expr)

# use current head state to terminate the dataflow, and add a entry state
# to connect the true/false branch states as follows:
#
# ------------
# === | select | ===
# || ------------ ||
# \/ \/
# ------------ -------------
# | true | | false |
# ------------ -------------
# || ||
# || ------------ ||
# ==> | head | <==
# ------------
#
select_state = sdfg.add_state_before(state, state.label + "_select")
sdfg.remove_edge(sdfg.out_edges(select_state)[0])

# expect true branch as second argument
true_state = sdfg.add_state(state.label + "_true_branch")
sdfg.add_edge(select_state, true_state, dace.InterstateEdge(condition=cond))
sdfg.add_edge(true_state, state, dace.InterstateEdge())
self.true_br_builder = dataflow_builder.visit(true_expr, sdfg=sdfg, head_state=true_state)

# and false branch as third argument
false_state = sdfg.add_state(state.label + "_false_branch")
sdfg.add_edge(select_state, false_state, dace.InterstateEdge(condition=(f"not {cond}")))
sdfg.add_edge(false_state, state, dace.InterstateEdge())
self.false_br_builder = dataflow_builder.visit(
false_expr, sdfg=sdfg, head_state=false_state
)

def build(self) -> list[tuple[dace.nodes.Node, ts.FieldType | ts.ScalarType]]:
# retrieve true/false states as predecessors of head state
branch_states = tuple(edge.src for edge in self.sdfg.in_edges(self.head_state))
assert len(branch_states) == 2
if branch_states[0].label.endswith("_true_branch"):
true_state, false_state = branch_states
else:
false_state, true_state = branch_states

true_br_args = self.true_br_builder()
false_br_args = self.false_br_builder()

output_nodes = []
for true_br, false_br in zip(true_br_args, false_br_args, strict=True):
true_br_node, true_br_type = true_br
assert isinstance(true_br_node, dace.nodes.AccessNode)
false_br_node, false_br_type = false_br
assert isinstance(false_br_node, dace.nodes.AccessNode)
assert true_br_type == false_br_type
array_type = self.sdfg.arrays[true_br_node.data]
access_node = self.add_local_storage(true_br_type, array_type.shape)
output_nodes.append((access_node, true_br_type))

data_name = access_node.data
true_br_output_node = true_state.add_access(data_name)
true_state.add_nedge(
true_br_node,
true_br_output_node,
dace.Memlet.from_array(
true_br_output_node.data, true_br_output_node.desc(self.sdfg)
),
)

false_br_output_node = false_state.add_access(data_name)
false_state.add_nedge(
false_br_node,
false_br_output_node,
dace.Memlet.from_array(
false_br_output_node.data, false_br_output_node.desc(self.sdfg)
),
)
return output_nodes
Loading
Loading