-
Notifications
You must be signed in to change notification settings - Fork 8
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Orchestrated] Signature & empty code issues #70
Comments
7 tasks
Test file to be dropped in from ndsl.stencils.corners import fill_corners_dgrid_defn
from ndsl.boilerplate import get_factories_single_tile_orchestrated_cpu
from ndsl.constants import X_DIM, Y_DIM, Z_DIM, X_INTERFACE_DIM, Y_INTERFACE_DIM
from ndsl.dsl.typing import Float, FloatField
from ndsl import orchestrate, StencilFactory, DaceConfig
from gt4py.cartesian.gtscript import computation, PARALLEL, interval
class OrchestratedCorner:
def __init__(self, stencil_factory: StencilFactory) -> None:
orchestrate(
obj=self,
config=stencil_factory.config.dace_config
or DaceConfig(communicator=None, backend=stencil_factory.backend),
)
origin, domain = stencil_factory.grid_indexing.get_origin_domain(
dims=[X_DIM, Y_DIM, Z_DIM]
)
axes_offsets = stencil_factory.grid_indexing.axis_offsets(origin, domain)
self.corner_stencil = stencil_factory.from_origin_domain(
fill_corners_dgrid_defn,
externals=axes_offsets,
origin=origin,
domain=domain,
)
def __call__(self, x, y):
self.corner_stencil(x, x, y, y, 1.0)
def test_empty_corners():
stencil_factory, quantity_factory = get_factories_single_tile_orchestrated_cpu(
12, 12, 5, 0
)
# Make the
stencil_factory.grid_indexing.south_edge = False
stencil_factory.grid_indexing.north_edge = False
stencil_factory.grid_indexing.west_edge = False
stencil_factory.grid_indexing.east_edge = False
stencil_factory.grid_indexing.axis_offsets
x = quantity_factory.empty(dims=[X_INTERFACE_DIM, Y_DIM, Z_DIM], units="n/a")
y = quantity_factory.empty(dims=[X_DIM, Y_INTERFACE_DIM, Z_DIM], units="n/a")
orch_corner = OrchestratedCorner(stencil_factory)
orch_corner(x, y)
def unusued_parameter_stencil(
field: FloatField, # type: ignore
result: FloatField, # type: ignore
weight: Float, # type: ignore
):
with computation(PARALLEL), interval(...):
result = field[1, 0, 0] + field[0, 1, 0] + field[-1, 0, 0] + field[0, -1, 0]
class OrchestratedUnusedParameter:
def __init__(self, stencil_factory: StencilFactory):
orchestrate(
obj=self,
config=stencil_factory.config.dace_config
or DaceConfig(communicator=None, backend=stencil_factory.backend),
)
origin, domain = stencil_factory.grid_indexing.get_origin_domain(
dims=[X_DIM, Y_DIM, Z_DIM]
)
self.unused_stencil = stencil_factory.from_origin_domain(
unusued_parameter_stencil,
origin=origin,
domain=domain,
)
def __call__(self, x, y):
self.unused_stencil(x, y, 1.0)
def test_unused_parameters():
stencil_factory, quantity_factory = get_factories_single_tile_orchestrated_cpu(
12, 12, 5, 2
)
x = quantity_factory.empty(dims=[X_INTERFACE_DIM, Y_DIM, Z_DIM], units="n/a")
y = quantity_factory.empty(dims=[X_DIM, Y_INTERFACE_DIM, Z_DIM], units="n/a")
orch_unused = OrchestratedUnusedParameter(stencil_factory)
orch_unused(x, y)
def unusued_field_stencil(
field: FloatField, # type: ignore
other_field: FloatField, # type: ignore
result: FloatField, # type: ignore
):
with computation(PARALLEL), interval(...):
result = field[1, 0, 0] + field[0, 1, 0] + field[-1, 0, 0] + field[0, -1, 0]
class OrchestratedunusedField:
def __init__(self, stencil_factory: StencilFactory):
orchestrate(
obj=self,
config=stencil_factory.config.dace_config
or DaceConfig(communicator=None, backend=stencil_factory.backend),
)
origin, domain = stencil_factory.grid_indexing.get_origin_domain(
dims=[X_DIM, Y_DIM, Z_DIM]
)
self.unused_stencil = stencil_factory.from_origin_domain(
unusued_field_stencil,
origin=origin,
domain=domain,
)
def __call__(self, x, unused_field, y):
self.unused_stencil(
x,
unused_field,
y,
)
def test_unused_field():
stencil_factory, quantity_factory = get_factories_single_tile_orchestrated_cpu(
12, 12, 5, 2
)
x = quantity_factory.empty(dims=[X_INTERFACE_DIM, Y_DIM, Z_DIM], units="n/a")
x_unused = quantity_factory.empty(dims=[X_INTERFACE_DIM, Y_DIM, Z_DIM], units="n/a")
y = quantity_factory.empty(dims=[X_DIM, Y_INTERFACE_DIM, Z_DIM], units="n/a")
orch_unused = OrchestratedunusedField(stencil_factory)
orch_unused(x, x_unused, y)
if __name__ == "__main__":
test_unused_parameters()
test_empty_corners()
test_unused_field() |
Patches to be applied diff --git a/src/gt4py/cartesian/backend/dace_backend.py b/src/gt4py/cartesian/backend/dace_backend.py
index 7608fcd5..e4699b3b 100644
--- a/src/gt4py/cartesian/backend/dace_backend.py
+++ b/src/gt4py/cartesian/backend/dace_backend.py
@@ -234,6 +234,24 @@ def _sdfg_add_arrays_and_edges(
None,
dace.Memlet(name, subset=dace.subsets.Range(ranges)),
)
+ elif isinstance(array, dace.data.Scalar):
+ wrapper_sdfg.add_scalar(name, dtype=array.dtype, storage=array.storage)
+ if name in inputs:
+ state.add_edge(
+ state.add_read(name),
+ None,
+ nsdfg,
+ name,
+ dace.Memlet(name),
+ )
+ if name in outputs:
+ state.add_edge(
+ nsdfg,
+ name,
+ state.add_write(name),
+ None,
+ dace.Memlet(name),
+ )
def _sdfg_specialize_symbols(wrapper_sdfg, domain: Tuple[int, ...]):
diff --git a/src/gt4py/cartesian/backend/dace_lazy_stencil.py b/src/gt4py/cartesian/backend/dace_lazy_stencil.py
index 2b3cf6fe..0c614ad8 100644
--- a/src/gt4py/cartesian/backend/dace_lazy_stencil.py
+++ b/src/gt4py/cartesian/backend/dace_lazy_stencil.py
@@ -15,6 +15,7 @@ from gt4py.cartesian.backend.dace_backend import SDFGManager
from gt4py.cartesian.backend.dace_stencil_object import DaCeStencilObject, add_optional_fields
from gt4py.cartesian.backend.module_generator import make_args_data_from_gtir
from gt4py.cartesian.lazy_stencil import LazyStencil
+from gt4py.cartesian.gtc.passes.gtir_prune_unused_parameters import prune_unused_parameters
if TYPE_CHECKING:
@@ -26,6 +27,7 @@ class DaCeLazyStencil(LazyStencil, SDFGConvertible):
if "dace" not in builder.backend.name:
raise ValueError("Trying to build a DaCeLazyStencil for non-dace backend.")
super().__init__(builder=builder)
+ self.signature = []
@property
def field_info(self) -> Dict[str, Any]:
@@ -47,7 +49,8 @@ class DaCeLazyStencil(LazyStencil, SDFGConvertible):
def __sdfg__(self, *args, **kwargs) -> dace.SDFG:
sdfg_manager = SDFGManager(self.builder)
args_data = make_args_data_from_gtir(self.builder.gtir_pipeline)
- arg_names = [arg.name for arg in self.builder.gtir.api_signature]
+ assert self.signature != []
+ arg_names = self.signature
assert args_data.domain_info is not None
norm_kwargs = DaCeStencilObject.normalize_args(
*args,
@@ -69,5 +72,9 @@ class DaCeLazyStencil(LazyStencil, SDFGConvertible):
return {}
def __sdfg_signature__(self) -> Tuple[Sequence[str], Sequence[str]]:
- args = [arg.name for arg in self.builder.gtir.api_signature]
- return (args, [])
+ if self.signature == []:
+ self.signature = [
+ str(p)
+ for p in self.builder.gtir_pipeline.apply([prune_unused_parameters]).param_names
+ ]
+ return (self.signature, [])
diff --git a/src/gt4py/cartesian/gtc/dace/oir_to_dace.py b/src/gt4py/cartesian/gtc/dace/oir_to_dace.py
index dba6c5a7..9dd57290 100644
--- a/src/gt4py/cartesian/gtc/dace/oir_to_dace.py
+++ b/src/gt4py/cartesian/gtc/dace/oir_to_dace.py
@@ -150,7 +150,7 @@ class OirSDFGBuilder(eve.NodeVisitor):
debuginfo=dace.DebugInfo(0),
)
else:
- ctx.sdfg.add_symbol(param.name, stype=data_type_to_dace_typeclass(param.dtype))
+ ctx.sdfg.add_scalar(param.name, dtype=data_type_to_dace_typeclass(param.dtype))
for decl in node.declarations:
dim_strs = [d for i, d in enumerate("IJK") if decl.dimensions[i]] + [ |
Working solution seems to be updating the Branch under test: https://github.com/FlorianDeconinck/gt4py/tree/cartesian/fix/missing_parameter |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Description
There's a series of bug related to signature at GT level and declared symbols/scalar/fields at DaCe wrapper SDFG. All of those bugs live in the bridge between gt4py & dace.
Those can be classified in 3 groups:
Most of those behavior are linked to the
prune_unused_argument
pass of GT4Py which is called at the very beginning of GTIR. While this is clearly not the design (passes should be pushed down to OIR or backend IR) this was done to deal with some of those issues. Plain removing the prune pass (which could be done considering it gives little to no performance improvement) does not lead to fixing.In the comments below we will put down 3 examples that showcase the issues (either plain
ndsl
or relying onpyfv3
) and some patches that fixes some bugs but creates other.To Reproduce
See comment.
The text was updated successfully, but these errors were encountered: