From b0f7e3a77e58963ee4befed63bba07fededd188c Mon Sep 17 00:00:00 2001 From: DropD Date: Tue, 26 Mar 2024 16:57:04 +0100 Subject: [PATCH 01/30] workflowify past linting and args injection --- src/gt4py/next/backend.py | 24 +++++++- src/gt4py/next/ffront/decorator.py | 25 +-------- src/gt4py/next/ffront/past_passes/linters.py | 59 ++++++++++++++++++++ src/gt4py/next/ffront/stages.py | 5 ++ src/gt4py/next/otf/recipes.py | 27 ++------- 5 files changed, 95 insertions(+), 45 deletions(-) create mode 100644 src/gt4py/next/ffront/past_passes/linters.py diff --git a/src/gt4py/next/backend.py b/src/gt4py/next/backend.py index cfa4911b57..6e821b92c8 100644 --- a/src/gt4py/next/backend.py +++ b/src/gt4py/next/backend.py @@ -20,12 +20,30 @@ from gt4py._core import definitions as core_defs from gt4py.next import allocators as next_allocators from gt4py.next.ffront import func_to_past, past_process_args, past_to_itir, stages as ffront_stages -from gt4py.next.otf import recipes +from gt4py.next.ffront.past_passes import linters as past_linters +from gt4py.next.otf import recipes, workflow from gt4py.next.program_processors import processor_interface as ppi +@dataclasses.dataclass(frozen=True) +class ArgsInjector(workflow.Workflow): + args: tuple[Any, ...] = dataclasses.field(default_factory=tuple) + kwargs: dict[str, Any] = dataclasses.field(default_factory=dict) + + def __call__(self, inp: ffront_stages.PastProgramDefinition) -> ffront_stages.PastClosure: + return ffront_stages.PastClosure( + past_node=inp.past_node, + closure_vars=inp.closure_vars, + grid_type=inp.grid_type, + args=self.args, + kwargs=self.kwargs, + ) + + DEFAULT_TRANSFORMS = recipes.ProgramTransformWorkflow( func_to_past=func_to_past.OptionalFuncToPastFactory(cached=True), + past_lint=past_linters.LinterFactory(), + past_inject_args=ArgsInjector(), past_transform_args=past_process_args.past_process_args, past_to_itir=past_to_itir.PastToItirFactory(), ) @@ -40,7 +58,9 @@ class Backend(Generic[core_defs.DeviceTypeT]): def __call__( self, program: ffront_stages.ProgramDefinition, *args: tuple[Any], **kwargs: dict[str, Any] ) -> None: - transformer = self.transformer.replace(args=args, kwargs=kwargs) + transformer = self.transformer.replace( + past_inject_args=ArgsInjector(args=args, kwargs=kwargs) + ) program_call = transformer(program) self.executor(program_call.program, *program_call.args, **program_call.kwargs) diff --git a/src/gt4py/next/ffront/decorator.py b/src/gt4py/next/ffront/decorator.py index 36503989f0..2f7ccd69d4 100644 --- a/src/gt4py/next/ffront/decorator.py +++ b/src/gt4py/next/ffront/decorator.py @@ -119,28 +119,9 @@ def past_stage(self): return next_backend.DEFAULT_TRANSFORMS.func_to_past(self.definition_stage) def __post_init__(self): - function_closure_vars = transform_utils._filter_closure_vars_by_type( - self.past_stage.closure_vars, GTCallable - ) - misnamed_functions = [ - f"{name} vs. {func.id}" - for name, func in function_closure_vars.items() - if name != func.__gt_itir__().id - ] - if misnamed_functions: - raise RuntimeError( - f"The following symbols resolve to a function with a mismatching name: {','.join(misnamed_functions)}." - ) - - undefined_symbols = [ - symbol.id - for symbol in self.past_stage.past_node.closure_vars - if symbol.id not in self.past_stage.closure_vars - ] - if undefined_symbols: - raise RuntimeError( - f"The following closure variables are undefined: {', '.join(undefined_symbols)}." - ) + if self.backend is not None and self.backend.transformer is not None: + self.backend.transformer.past_lint(self.past_stage) + return next_backend.DEFAULT_TRANSFORMS.past_lint(self.past_stage) @property def __name__(self) -> str: diff --git a/src/gt4py/next/ffront/past_passes/linters.py b/src/gt4py/next/ffront/past_passes/linters.py new file mode 100644 index 0000000000..c05981affa --- /dev/null +++ b/src/gt4py/next/ffront/past_passes/linters.py @@ -0,0 +1,59 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2023, ETH Zurich +# All rights reserved. +# +# This file is part of the GT4Py project and the GridTools framework. +# GT4Py is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later + +import factory + +from gt4py.next.ffront import gtcallable, stages as ffront_stages, transform_utils +from gt4py.next.otf import workflow + + +@workflow.make_step +def lint_misnamed_functions( + inp: ffront_stages.PastProgramDefinition, +) -> ffront_stages.PastProgramDefinition: + function_closure_vars = transform_utils._filter_closure_vars_by_type( + inp.closure_vars, gtcallable.GTCallable + ) + misnamed_functions = [ + f"{name} vs. {func.id}" + for name, func in function_closure_vars.items() + if name != func.__gt_itir__().id + ] + if misnamed_functions: + raise RuntimeError( + f"The following symbols resolve to a function with a mismatching name: {','.join(misnamed_functions)}." + ) + return inp + + +@workflow.make_step +def lint_undefined_symbols( + inp: ffront_stages.PastProgramDefinition, +) -> ffront_stages.PastProgramDefinition: + undefined_symbols = [ + symbol.id for symbol in inp.past_node.closure_vars if symbol.id not in inp.closure_vars + ] + if undefined_symbols: + raise RuntimeError( + f"The following closure variables are undefined: {', '.join(undefined_symbols)}." + ) + return inp + + +class LinterFactory(factory.Factory): + class Meta: + model = workflow.CachedStep + + step = lint_misnamed_functions.chain(lint_undefined_symbols) + hash_function = ffront_stages.hash_past_program_definition diff --git a/src/gt4py/next/ffront/stages.py b/src/gt4py/next/ffront/stages.py index ed7c65c0af..a79cd79d97 100644 --- a/src/gt4py/next/ffront/stages.py +++ b/src/gt4py/next/ffront/stages.py @@ -18,6 +18,7 @@ import types from typing import Any, Optional +from gt4py.eve import utils as eve_utils from gt4py.next import common from gt4py.next.ffront import program_ast as past @@ -42,3 +43,7 @@ class PastClosure: grid_type: Optional[common.GridType] args: tuple[Any, ...] kwargs: dict[str, Any] + + +def hash_past_program_definition(past_definition: PastProgramDefinition) -> str: + return eve_utils.content_hash(past_definition.past_node, past_definition.grid_type) diff --git a/src/gt4py/next/otf/recipes.py b/src/gt4py/next/otf/recipes.py index 702d0ebb9d..cebbd3fe2d 100644 --- a/src/gt4py/next/otf/recipes.py +++ b/src/gt4py/next/otf/recipes.py @@ -14,7 +14,6 @@ from __future__ import annotations import dataclasses -from typing import Any from gt4py.next.ffront import stages as ffront_stages from gt4py.next.otf import stages, step_types, workflow @@ -28,29 +27,15 @@ class ProgramTransformWorkflow(workflow.NamedStepSequence): ffront_stages.ProgramDefinition | ffront_stages.PastProgramDefinition, ffront_stages.PastProgramDefinition, ] + past_lint: workflow.Workflow[ + ffront_stages.PastProgramDefinition, ffront_stages.PastProgramDefinition + ] + past_inject_args: workflow.Workflow[ + ffront_stages.PastProgramDefinition, ffront_stages.PastClosure + ] past_transform_args: workflow.Workflow[ffront_stages.PastClosure, ffront_stages.PastClosure] past_to_itir: workflow.Workflow[ffront_stages.PastClosure, stages.ProgramCall] - args: tuple[Any, ...] = dataclasses.field(default_factory=tuple) - kwargs: dict[str, Any] = dataclasses.field(default_factory=dict) - - def __call__( - self, - inp: ffront_stages.ProgramDefinition | ffront_stages.PastProgramDefinition, - ) -> stages.ProgramCall: - past_stage = self.func_to_past(inp) - return self.past_to_itir( - self.past_transform_args( - ffront_stages.PastClosure( - past_node=past_stage.past_node, - closure_vars=past_stage.closure_vars, - grid_type=past_stage.grid_type, - args=self.args, - kwargs=self.kwargs, - ) - ) - ) - @dataclasses.dataclass(frozen=True) class OTFCompileWorkflow(workflow.NamedStepSequence): From 445b35d03965c1217065a97e96d078459b3f6488 Mon Sep 17 00:00:00 2001 From: DropD Date: Tue, 2 Apr 2024 13:40:42 +0200 Subject: [PATCH 02/30] workflowify func -> FOAST --- src/gt4py/next/backend.py | 21 ++- src/gt4py/next/ffront/decorator.py | 131 +++++++++--------- .../ffront/foast_passes/type_deduction.py | 7 +- src/gt4py/next/ffront/foast_to_itir.py | 5 + src/gt4py/next/ffront/func_to_foast.py | 81 ++++++++++- src/gt4py/next/ffront/past_passes/linters.py | 2 +- src/gt4py/next/ffront/stages.py | 49 ++++++- src/gt4py/next/otf/recipes.py | 12 ++ .../ffront_tests/test_foast_pretty_printer.py | 4 +- .../test_math_builtin_execution.py | 22 +-- 10 files changed, 243 insertions(+), 91 deletions(-) diff --git a/src/gt4py/next/backend.py b/src/gt4py/next/backend.py index 6e821b92c8..c0838cc425 100644 --- a/src/gt4py/next/backend.py +++ b/src/gt4py/next/backend.py @@ -19,7 +19,14 @@ from gt4py._core import definitions as core_defs from gt4py.next import allocators as next_allocators -from gt4py.next.ffront import func_to_past, past_process_args, past_to_itir, stages as ffront_stages +from gt4py.next.ffront import ( + foast_to_itir, + func_to_foast, + func_to_past, + past_process_args, + past_to_itir, + stages as ffront_stages, +) from gt4py.next.ffront.past_passes import linters as past_linters from gt4py.next.otf import recipes, workflow from gt4py.next.program_processors import processor_interface as ppi @@ -40,6 +47,14 @@ def __call__(self, inp: ffront_stages.PastProgramDefinition) -> ffront_stages.Pa ) +DEFAULT_FIELDOP_TRANSFORMS = recipes.FieldopTransformWorkflow( + func_to_foast=func_to_foast.OptionalFuncToFoastFactory(cached=True), + foast_to_itir=workflow.CachedStep( + step=foast_to_itir.foast_to_itir, hash_function=ffront_stages.hash_foast_operator_definition + ), +) + + DEFAULT_TRANSFORMS = recipes.ProgramTransformWorkflow( func_to_past=func_to_past.OptionalFuncToPastFactory(cached=True), past_lint=past_linters.LinterFactory(), @@ -53,12 +68,12 @@ def __call__(self, inp: ffront_stages.PastProgramDefinition) -> ffront_stages.Pa class Backend(Generic[core_defs.DeviceTypeT]): executor: ppi.ProgramExecutor allocator: next_allocators.FieldBufferAllocatorProtocol[core_defs.DeviceTypeT] - transformer: recipes.ProgramTransformWorkflow = DEFAULT_TRANSFORMS + transforms_prog: recipes.ProgramTransformWorkflow = DEFAULT_TRANSFORMS def __call__( self, program: ffront_stages.ProgramDefinition, *args: tuple[Any], **kwargs: dict[str, Any] ) -> None: - transformer = self.transformer.replace( + transformer = self.transforms_prog.replace( past_inject_args=ArgsInjector(args=args, kwargs=kwargs) ) program_call = transformer(program) diff --git a/src/gt4py/next/ffront/decorator.py b/src/gt4py/next/ffront/decorator.py index 2f7ccd69d4..5d9c205e86 100644 --- a/src/gt4py/next/ffront/decorator.py +++ b/src/gt4py/next/ffront/decorator.py @@ -48,15 +48,11 @@ transform_utils, type_specifications as ts_ffront, ) -from gt4py.next.ffront.foast_passes.type_deduction import FieldOperatorTypeDeduction -from gt4py.next.ffront.foast_to_itir import FieldOperatorLowering -from gt4py.next.ffront.func_to_foast import FieldOperatorParser from gt4py.next.ffront.gtcallable import GTCallable from gt4py.next.ffront.past_passes.closure_var_type_deduction import ( ClosureVarTypeDeduction as ProgramClosureVarTypeDeduction, ) from gt4py.next.ffront.past_passes.type_deduction import ProgramTypeDeduction -from gt4py.next.ffront.source_utils import SourceDefinition, get_closure_vars_from_function from gt4py.next.iterator import ir as itir from gt4py.next.iterator.ir_utils.ir_makers import ( literal_from_value, @@ -114,13 +110,13 @@ def definition(self): @functools.cached_property def past_stage(self): - if self.backend is not None and self.backend.transformer is not None: - return self.backend.transformer.func_to_past(self.definition_stage) + if self.backend is not None and self.backend.transforms_prog is not None: + return self.backend.transforms_prog.func_to_past(self.definition_stage) return next_backend.DEFAULT_TRANSFORMS.func_to_past(self.definition_stage) def __post_init__(self): - if self.backend is not None and self.backend.transformer is not None: - self.backend.transformer.past_lint(self.past_stage) + if self.backend is not None and self.backend.transforms_prog is not None: + self.backend.transforms_prog.past_lint(self.past_stage) return next_backend.DEFAULT_TRANSFORMS.past_lint(self.past_stage) @property @@ -191,8 +187,8 @@ def itir(self) -> itir.FencilDefinition: args=[], kwargs={}, ) - if self.backend is not None and self.backend.transformer is not None: - return self.backend.transformer.past_to_itir(no_args_past) + if self.backend is not None and self.backend.transforms_prog is not None: + return self.backend.transforms_prog.past_to_itir(no_args_past) return past_to_itir.PastToItirFactory()(no_args_past).program def __call__(self, *args, offset_provider: dict[str, Dimension], **kwargs: Any) -> None: @@ -233,6 +229,11 @@ def __call__(self, *args, offset_provider: dict[str, Dimension], **kwargs): ppi.ensure_processor_kind(self.backend.executor, ppi.ProgramExecutor) self.backend(self.past_stage, *args, **(kwargs | {"offset_provider": offset_provider})) + def __post_init__(self): + if self.backend is not None and self.backend.transforms_prog is not None: + self.backend.transforms_prog.past_lint(self.past_stage) + return next_backend.DEFAULT_TRANSFORMS.past_lint(self.past_stage) + @dataclasses.dataclass(frozen=True) class ProgramWithBoundArgs(Program): @@ -385,12 +386,8 @@ class FieldOperator(GTCallable, Generic[OperatorNodeT]): it will be deduced from actually occurring dimensions. """ - foast_node: OperatorNodeT - closure_vars: dict[str, Any] - definition: Optional[types.FunctionType] + definition_stage: ffront_stages.FieldOperatorDefinition backend: Optional[ppi.ProgramExecutor] - grid_type: Optional[GridType] - operator_attributes: Optional[dict[str, Any]] = None _program_cache: dict = dataclasses.field( init=False, default_factory=dict ) # init=False ensure the cache is not copied in calls to replace @@ -405,39 +402,33 @@ def from_function( operator_node_cls: type[OperatorNodeT] = foast.FieldOperator, operator_attributes: Optional[dict[str, Any]] = None, ) -> FieldOperator[OperatorNodeT]: - operator_attributes = operator_attributes or {} - - source_def = SourceDefinition.from_function(definition) - closure_vars = get_closure_vars_from_function(definition) - annotations = typing.get_type_hints(definition) - foast_definition_node = FieldOperatorParser.apply(source_def, closure_vars, annotations) - loc = foast_definition_node.location - operator_attribute_nodes = { - key: foast.Constant(value=value, type=type_translation.from_value(value), location=loc) - for key, value in operator_attributes.items() - } - untyped_foast_node = operator_node_cls( - id=foast_definition_node.id, - definition=foast_definition_node, - location=loc, - **operator_attribute_nodes, - ) - foast_node = FieldOperatorTypeDeduction.apply(untyped_foast_node) return cls( - foast_node=foast_node, - closure_vars=closure_vars, - definition=definition, + definition_stage=ffront_stages.FieldOperatorDefinition( + definition=definition, + grid_type=grid_type, + node_class=operator_node_cls, + attributes=operator_attributes or {}, + ), backend=backend, - grid_type=grid_type, - operator_attributes=operator_attributes, ) + def __post_init__(self): + _ = self.foast_stage + + @functools.cached_property + def foast_stage(self) -> ffront_stages.FoastOperatorDefinition: + return next_backend.DEFAULT_FIELDOP_TRANSFORMS.func_to_foast(self.definition_stage) + @property def __name__(self) -> str: - return self.definition.__name__ + return self.definition_stage.definition.__name__ + + @property + def definition(self) -> str: + return self.definition_stage.definition def __gt_type__(self) -> ts.CallableType: - type_ = self.foast_node.type + type_ = self.foast_stage.foast_node.type assert isinstance(type_, ts.CallableType) return type_ @@ -445,20 +436,15 @@ def with_backend(self, backend: ppi.ProgramExecutor) -> FieldOperator: return dataclasses.replace(self, backend=backend) def with_grid_type(self, grid_type: GridType) -> FieldOperator: - return dataclasses.replace(self, grid_type=grid_type) + return dataclasses.replace( + self, definition_stage=dataclasses.replace(self.definition_stage, grid_type=grid_type) + ) def __gt_itir__(self) -> itir.FunctionDefinition: - if hasattr(self, "__cached_itir"): - return getattr(self, "__cached_itir") - - itir_node: itir.FunctionDefinition = FieldOperatorLowering.apply(self.foast_node) - - object.__setattr__(self, "__cached_itir", itir_node) - - return itir_node + return next_backend.DEFAULT_FIELDOP_TRANSFORMS.foast_to_itir(self.foast_stage) def __gt_closure_vars__(self) -> dict[str, Any]: - return self.closure_vars + return self.foast_stage.closure_vars def as_program( self, arg_types: list[ts.TypeSpec], kwarg_types: dict[str, ts.TypeSpec] @@ -476,7 +462,7 @@ def as_program( except KeyError: pass - loc = self.foast_node.location + loc = self.foast_stage.foast_node.location # use a new UID generator to allow caching param_sym_uids = eve_utils.UIDGenerator() @@ -499,12 +485,12 @@ def as_program( ) out_ref = past.Name(id="out", location=loc) - if self.foast_node.id in self.closure_vars: + if self.foast_stage.foast_node.id in self.foast_stage.closure_vars: raise RuntimeError("A closure variable has the same name as the field operator itself.") - closure_vars = {self.foast_node.id: self} + closure_vars = {self.foast_stage.foast_node.id: self} closure_symbols = [ past.Symbol( - id=self.foast_node.id, + id=self.foast_stage.foast_node.id, type=ts.DeferredType(constraint=None), namespace=dialect_ast_enums.Namespace.CLOSURE, location=loc, @@ -512,12 +498,12 @@ def as_program( ] untyped_past_node = past.Program( - id=f"__field_operator_{self.foast_node.id}", + id=f"__field_operator_{self.foast_stage.foast_node.id}", type=ts.DeferredType(constraint=ts_ffront.ProgramType), params=[*params_decl, out_sym], body=[ past.Call( - func=past.Name(id=self.foast_node.id, location=loc), + func=past.Name(id=self.foast_stage.foast_node.id, location=loc), args=params_ref, kwargs={"out": out_ref}, location=loc, @@ -534,7 +520,7 @@ def as_program( past_stage=ffront_stages.PastProgramDefinition( past_node=past_node, closure_vars=closure_vars, - grid_type=self.grid_type, + grid_type=self.foast_stage.grid_type, ), backend=self.backend, ) @@ -555,7 +541,9 @@ def __call__( if "out" not in kwargs: raise errors.MissingArgumentError(None, "out", True) out = kwargs.pop("out") - args, kwargs = type_info.canonicalize_arguments(self.foast_node.type, args, kwargs) + args, kwargs = type_info.canonicalize_arguments( + self.foast_stage.foast_node.type, args, kwargs + ) # TODO(tehrengruber): check all offset providers are given # deduce argument types arg_types = [] @@ -569,22 +557,33 @@ def __call__( *args, out, offset_provider=offset_provider, **kwargs ) else: - if self.operator_attributes is not None and any( + attributes = ( + self.definition_stage.attributes + if self.definition_stage + else self.foast_stage.attributes + ) + if attributes and any( has_scan_op_attribute := [ - attribute in self.operator_attributes - for attribute in ["init", "axis", "forward"] + attribute in attributes for attribute in ["init", "axis", "forward"] ] ): assert all(has_scan_op_attribute) - forward = self.operator_attributes["forward"] - init = self.operator_attributes["init"] - axis = self.operator_attributes["axis"] - op = embedded_operators.ScanOperator(self.definition, forward, init, axis) + forward = attributes["forward"] + init = attributes["init"] + axis = attributes["axis"] + op = embedded_operators.ScanOperator( + self.definition_stage.definition, forward, init, axis + ) else: - op = embedded_operators.EmbeddedOperator(self.definition) + op = embedded_operators.EmbeddedOperator(self.definition_stage.definition) return embedded_operators.field_operator_call(op, args, kwargs) +@dataclasses.dataclass(frozen=True) +class FieldOperatorFromFoast(FieldOperator): + foast_stage: ffront_stages.FoastOperatorDefinition + + @typing.overload def field_operator( definition: types.FunctionType, *, backend: Optional[ppi.ProgramExecutor] diff --git a/src/gt4py/next/ffront/foast_passes/type_deduction.py b/src/gt4py/next/ffront/foast_passes/type_deduction.py index 6044b41421..7f79769bf1 100644 --- a/src/gt4py/next/ffront/foast_passes/type_deduction.py +++ b/src/gt4py/next/ffront/foast_passes/type_deduction.py @@ -12,7 +12,7 @@ # # SPDX-License-Identifier: GPL-3.0-or-later -from typing import Any, Optional, cast +from typing import Any, Optional, TypeVar, cast import gt4py.next.ffront.field_operator_ast as foast from gt4py.eve import NodeTranslator, NodeVisitor, traits @@ -29,6 +29,9 @@ from gt4py.next.type_system import type_info, type_specifications as ts, type_translation +OperatorNodeT = TypeVar("OperatorNodeT", bound=foast.LocatedNode) + + def with_altered_scalar_kind( type_spec: ts.TypeSpec, new_scalar_kind: ts.ScalarKind ) -> ts.ScalarType | ts.FieldType: @@ -250,7 +253,7 @@ class FieldOperatorTypeDeduction(traits.VisitorWithSymbolTableTrait, NodeTransla """ @classmethod - def apply(cls, node: foast.FunctionDefinition) -> foast.FunctionDefinition: + def apply(cls, node: OperatorNodeT) -> OperatorNodeT: typed_foast_node = cls().visit(node) FieldOperatorTypeDeductionCompletnessValidator.apply(typed_foast_node) diff --git a/src/gt4py/next/ffront/foast_to_itir.py b/src/gt4py/next/ffront/foast_to_itir.py index 0e39853a3c..8bc2796a67 100644 --- a/src/gt4py/next/ffront/foast_to_itir.py +++ b/src/gt4py/next/ffront/foast_to_itir.py @@ -23,6 +23,7 @@ fbuiltins, field_operator_ast as foast, lowering_utils, + stages as ffront_stages, type_specifications as ts_ffront, ) from gt4py.next.ffront.experimental import EXPERIMENTAL_FUN_BUILTIN_NAMES @@ -33,6 +34,10 @@ from gt4py.next.type_system import type_info, type_specifications as ts +def foast_to_itir(inp: ffront_stages.FoastOperatorDefinition) -> itir.Expr: + return FieldOperatorLowering.apply(inp.foast_node) + + def promote_to_list( node: foast.Symbol | foast.Expr, ) -> Callable[[itir.Expr], itir.Expr]: diff --git a/src/gt4py/next/ffront/func_to_foast.py b/src/gt4py/next/ffront/func_to_foast.py index ceac9902cf..b6c1c30069 100644 --- a/src/gt4py/next/ffront/func_to_foast.py +++ b/src/gt4py/next/ffront/func_to_foast.py @@ -16,11 +16,21 @@ import ast import builtins -from typing import Any, Callable, Iterable, Mapping, Type, cast +import dataclasses +import typing +from typing import Any, Callable, Iterable, Mapping, Type + +import factory import gt4py.eve as eve from gt4py.next import errors -from gt4py.next.ffront import dialect_ast_enums, fbuiltins, field_operator_ast as foast +from gt4py.next.ffront import ( + dialect_ast_enums, + fbuiltins, + field_operator_ast as foast, + source_utils, + stages as ffront_stages, +) from gt4py.next.ffront.ast_passes import ( SingleAssignTargetPass, SingleStaticAssignPass, @@ -35,9 +45,74 @@ from gt4py.next.ffront.foast_passes.iterable_unpack import UnpackedAssignPass from gt4py.next.ffront.foast_passes.type_alias_replacement import TypeAliasReplacement from gt4py.next.ffront.foast_passes.type_deduction import FieldOperatorTypeDeduction +from gt4py.next.otf import workflow from gt4py.next.type_system import type_info, type_specifications as ts, type_translation +@workflow.make_step +def func_to_foast( + inp: ffront_stages.FieldOperatorDefinition[ffront_stages.OperatorNodeT], +) -> ffront_stages.FoastOperatorDefinition[ffront_stages.OperatorNodeT]: + source_def = source_utils.SourceDefinition.from_function(inp.definition) + closure_vars = source_utils.get_closure_vars_from_function(inp.definition) + annotations = typing.get_type_hints(inp.definition) + foast_definition_node = FieldOperatorParser.apply(source_def, closure_vars, annotations) + loc = foast_definition_node.location + operator_attribute_nodes = { + key: foast.Constant(value=value, type=type_translation.from_value(value), location=loc) + for key, value in inp.attributes.items() + } + untyped_foast_node = inp.node_class( + id=foast_definition_node.id, + definition=foast_definition_node, + location=loc, + **operator_attribute_nodes, + ) + foast_node = FieldOperatorTypeDeduction.apply(untyped_foast_node) + return ffront_stages.FoastOperatorDefinition( + foast_node=foast_node, + closure_vars=closure_vars, + grid_type=inp.grid_type, + attributes=inp.attributes, + ) + + +@dataclasses.dataclass(frozen=True) +class OptionalFuncToFoast(workflow.SkippableStep): + step: workflow.Workflow[ + ffront_stages.FieldOperatorDefinition, ffront_stages.FoastOperatorDefinition + ] = func_to_foast + + def skip_condition( + self, inp: ffront_stages.FieldOperatorDefinition | ffront_stages.FoastOperatorDefinition + ) -> bool: + match inp: + case ffront_stages.FieldOperatorDefinition(): + return False + case ffront_stages.FoastOperatorDefinition(): + return True + + +@dataclasses.dataclass(frozen=True) +class OptionalFuncToFoastFactory(factory.Factory): + class Meta: + model = OptionalFuncToFoast + + class Params: + workflow: workflow.Workflow[ + ffront_stages.FieldOperatorDefinition, ffront_stages.FoastOperatorDefinition + ] = func_to_foast + cached = factory.Trait( + step=factory.LazyAttribute( + lambda o: workflow.CachedStep( + step=o.workflow, hash_function=ffront_stages.hash_field_operator_definition + ) + ) + ) + + step = factory.LazyAttribute(lambda o: o.workflow) + + class FieldOperatorParser(DialectParser[foast.FunctionDefinition]): """ Parse field operator function definition from source code into FOAST. @@ -141,7 +216,7 @@ def _builtin_type_constructor_symbols( ], # this is a constraint type that will not be inferred (as the function is polymorphic) pos_or_kw_args={}, kw_only_args={}, - returns=cast(ts.DataType, type_translation.from_type_hint(value)), + returns=typing.cast(ts.DataType, type_translation.from_type_hint(value)), ), namespace=dialect_ast_enums.Namespace.CLOSURE, location=location, diff --git a/src/gt4py/next/ffront/past_passes/linters.py b/src/gt4py/next/ffront/past_passes/linters.py index c05981affa..c4e7933fc6 100644 --- a/src/gt4py/next/ffront/past_passes/linters.py +++ b/src/gt4py/next/ffront/past_passes/linters.py @@ -26,7 +26,7 @@ def lint_misnamed_functions( inp.closure_vars, gtcallable.GTCallable ) misnamed_functions = [ - f"{name} vs. {func.id}" + f"{name} vs. {func.__gt_itir__().id}" for name, func in function_closure_vars.items() if name != func.__gt_itir__().id ] diff --git a/src/gt4py/next/ffront/stages.py b/src/gt4py/next/ffront/stages.py index a79cd79d97..51e610de3b 100644 --- a/src/gt4py/next/ffront/stages.py +++ b/src/gt4py/next/ffront/stages.py @@ -15,12 +15,49 @@ from __future__ import annotations import dataclasses +import inspect import types -from typing import Any, Optional +from typing import Any, Generic, Optional, TypeVar from gt4py.eve import utils as eve_utils from gt4py.next import common -from gt4py.next.ffront import program_ast as past +from gt4py.next.ffront import field_operator_ast as foast, program_ast as past + + +OperatorNodeT = TypeVar("OperatorNodeT", bound=foast.LocatedNode) + + +@dataclasses.dataclass(frozen=True) +class FieldOperatorDefinition(Generic[OperatorNodeT]): + definition: types.FunctionType + grid_type: Optional[common.GridType] = None + node_class: type[OperatorNodeT] = dataclasses.field(default=foast.FieldOperator) # type: ignore[assignment] # TODO(ricoh): understand why mypy complains + attributes: dict[str, Any] = dataclasses.field(default_factory=dict) + + +def hash_field_operator_definition(fieldop_definition: FieldOperatorDefinition) -> str: + return eve_utils.content_hash( + fieldop_definition.definition.__name__, + hash(fieldop_definition.definition), + inspect.getsourcelines(fieldop_definition.definition), + fieldop_definition.grid_type, + fieldop_definition.node_class, + fieldop_definition.attributes, + ) + + +@dataclasses.dataclass(frozen=True) +class FoastOperatorDefinition(Generic[OperatorNodeT]): + foast_node: OperatorNodeT + closure_vars: dict[str, Any] + grid_type: Optional[common.GridType] = None + attributes: dict[str, Any] = dataclasses.field(default_factory=dict) + + +def hash_foast_operator_definition(foast_definition: FoastOperatorDefinition) -> str: + return eve_utils.content_hash( + foast_definition.foast_node, foast_definition.grid_type, foast_definition.attributes + ) @dataclasses.dataclass(frozen=True) @@ -36,6 +73,10 @@ class PastProgramDefinition: grid_type: Optional[common.GridType] = None +def hash_past_program_definition(past_definition: PastProgramDefinition) -> str: + return eve_utils.content_hash(past_definition.past_node, past_definition.grid_type) + + @dataclasses.dataclass(frozen=True) class PastClosure: closure_vars: dict[str, Any] @@ -43,7 +84,3 @@ class PastClosure: grid_type: Optional[common.GridType] args: tuple[Any, ...] kwargs: dict[str, Any] - - -def hash_past_program_definition(past_definition: PastProgramDefinition) -> str: - return eve_utils.content_hash(past_definition.past_node, past_definition.grid_type) diff --git a/src/gt4py/next/otf/recipes.py b/src/gt4py/next/otf/recipes.py index cebbd3fe2d..126bf48d8e 100644 --- a/src/gt4py/next/otf/recipes.py +++ b/src/gt4py/next/otf/recipes.py @@ -16,9 +16,21 @@ import dataclasses from gt4py.next.ffront import stages as ffront_stages +from gt4py.next.iterator import ir as itir from gt4py.next.otf import stages, step_types, workflow +@dataclasses.dataclass(frozen=True) +class FieldopTransformWorkflow(workflow.NamedStepSequence): + """Modular workflow for transformations with access to intermediates.""" + + func_to_foast: workflow.SkippableStep[ + ffront_stages.FieldOperatorDefinition | ffront_stages.FoastOperatorDefinition, + ffront_stages.FoastOperatorDefinition, + ] + foast_to_itir: workflow.Workflow[ffront_stages.FoastOperatorDefinition, itir.Expr] + + @dataclasses.dataclass(frozen=True) class ProgramTransformWorkflow(workflow.NamedStepSequence): """Modular workflow for transformations with access to intermediates.""" diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_foast_pretty_printer.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_foast_pretty_printer.py index c1bee4fa2f..77ae302efa 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_foast_pretty_printer.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_foast_pretty_printer.py @@ -69,7 +69,7 @@ def bar(inp1: Field[[I], int64], inp2: Field[[I], int64]) -> Field[[I], int64]: """ ).strip() - assert pretty_format(bar.foast_node) == expected + assert pretty_format(bar.foast_stage.foast_node) == expected def test_scanop(): @@ -89,4 +89,4 @@ def scan(inp: int32) -> int32: """ ).strip() - assert pretty_format(scan.foast_node) == expected + assert pretty_format(scan.foast_stage.foast_node) == expected diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_math_builtin_execution.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_math_builtin_execution.py index e076ec4227..3419930588 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_math_builtin_execution.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_math_builtin_execution.py @@ -18,9 +18,13 @@ import numpy as np import pytest -import gt4py.next as gtx -from gt4py.next.ffront import dialect_ast_enums, fbuiltins, field_operator_ast as foast -from gt4py.next.ffront.decorator import FieldOperator +from gt4py.next.ffront import ( + decorator, + dialect_ast_enums, + fbuiltins, + field_operator_ast as foast, + stages as ffront_stages, +) from gt4py.next.ffront.foast_passes.type_deduction import FieldOperatorTypeDeduction from gt4py.next.program_processors import processor_interface as ppi from gt4py.next.type_system import type_translation @@ -107,12 +111,14 @@ def make_builtin_field_operator(builtin_name: str, backend: Optional[ppi.Program ) typed_foast_node = FieldOperatorTypeDeduction.apply(foast_node) - return FieldOperator( - foast_node=typed_foast_node, - closure_vars=closure_vars, - definition=None, + return decorator.FieldOperatorFromFoast( + definition_stage=None, + foast_stage=ffront_stages.FoastOperatorDefinition( + foast_node=typed_foast_node, + closure_vars=closure_vars, + grid_type=None, + ), backend=backend, - grid_type=None, ) From 3561a0cfc59fe6d6605ff358badb551f1ff55863 Mon Sep 17 00:00:00 2001 From: DropD Date: Tue, 2 Apr 2024 16:21:42 +0200 Subject: [PATCH 03/30] fix missing attribute rename. --- .../feature_tests/ffront_tests/ffront_test_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py index 04fbb58305..2a93828e6e 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py @@ -49,7 +49,7 @@ def __call__(self, program, *args, **kwargs) -> None: raise ValueError("No backend selected! Backend selection is mandatory in tests.") -no_backend = NoBackend(executor=no_exec, transformer=None, allocator=None) +no_backend = NoBackend(executor=no_exec, transforms_prog=None, allocator=None) OPTIONAL_PROCESSORS = [] From 2842cadd95e1e36989d1ae2114ac7826d5d48cc8 Mon Sep 17 00:00:00 2001 From: DropD Date: Tue, 2 Apr 2024 16:34:22 +0200 Subject: [PATCH 04/30] workflowify `FieldOperator.as_program` --- src/gt4py/next/ffront/decorator.py | 90 +++---------------------- src/gt4py/next/ffront/foast_to_past.py | 92 ++++++++++++++++++++++++++ src/gt4py/next/ffront/stages.py | 17 +++++ 3 files changed, 117 insertions(+), 82 deletions(-) create mode 100644 src/gt4py/next/ffront/foast_to_past.py diff --git a/src/gt4py/next/ffront/decorator.py b/src/gt4py/next/ffront/decorator.py index 5d9c205e86..1dd846200a 100644 --- a/src/gt4py/next/ffront/decorator.py +++ b/src/gt4py/next/ffront/decorator.py @@ -28,7 +28,6 @@ from gt4py import eve from gt4py._core import definitions as core_defs -from gt4py.eve import utils as eve_utils from gt4py.eve.extended_typing import Any, Optional from gt4py.next import ( allocators as next_allocators, @@ -39,20 +38,15 @@ from gt4py.next.common import Dimension, GridType from gt4py.next.embedded import operators as embedded_operators from gt4py.next.ffront import ( - dialect_ast_enums, field_operator_ast as foast, + foast_to_past, past_process_args, past_to_itir, - program_ast as past, stages as ffront_stages, transform_utils, type_specifications as ts_ffront, ) from gt4py.next.ffront.gtcallable import GTCallable -from gt4py.next.ffront.past_passes.closure_var_type_deduction import ( - ClosureVarTypeDeduction as ProgramClosureVarTypeDeduction, -) -from gt4py.next.ffront.past_passes.type_deduction import ProgramTypeDeduction from gt4py.next.iterator import ir as itir from gt4py.next.iterator.ir_utils.ir_makers import ( literal_from_value, @@ -449,83 +443,15 @@ def __gt_closure_vars__(self) -> dict[str, Any]: def as_program( self, arg_types: list[ts.TypeSpec], kwarg_types: dict[str, ts.TypeSpec] ) -> Program: - # TODO(tehrengruber): implement mechanism to deduce default values - # of arg and kwarg types - # TODO(tehrengruber): check foast operator has no out argument that clashes - # with the out argument of the program we generate here. - hash_ = eve_utils.content_hash(( - tuple(arg_types), - tuple((name, arg) for name, arg in kwarg_types.items()), - )) - try: - return self._program_cache[hash_] - except KeyError: - pass - - loc = self.foast_stage.foast_node.location - # use a new UID generator to allow caching - param_sym_uids = eve_utils.UIDGenerator() - - type_ = self.__gt_type__() - params_decl: list[past.Symbol] = [ - past.DataSymbol( - id=param_sym_uids.sequential_id(prefix="__sym"), - type=arg_type, - namespace=dialect_ast_enums.Namespace.LOCAL, - location=loc, - ) - for arg_type in arg_types - ] - params_ref = [past.Name(id=pdecl.id, location=loc) for pdecl in params_decl] - out_sym: past.Symbol = past.DataSymbol( - id="out", - type=type_info.return_type(type_, with_args=arg_types, with_kwargs=kwarg_types), - namespace=dialect_ast_enums.Namespace.LOCAL, - location=loc, - ) - out_ref = past.Name(id="out", location=loc) - - if self.foast_stage.foast_node.id in self.foast_stage.closure_vars: - raise RuntimeError("A closure variable has the same name as the field operator itself.") - closure_vars = {self.foast_stage.foast_node.id: self} - closure_symbols = [ - past.Symbol( - id=self.foast_stage.foast_node.id, - type=ts.DeferredType(constraint=None), - namespace=dialect_ast_enums.Namespace.CLOSURE, - location=loc, + past_stage = foast_to_past.foast_to_past( + ffront_stages.FoastWithTypes( + foast_op_def=self.foast_stage, + arg_types=tuple(arg_types), + kwarg_types=kwarg_types, + closure_vars={self.foast_stage.foast_node.id: self}, ), - ] - - untyped_past_node = past.Program( - id=f"__field_operator_{self.foast_stage.foast_node.id}", - type=ts.DeferredType(constraint=ts_ffront.ProgramType), - params=[*params_decl, out_sym], - body=[ - past.Call( - func=past.Name(id=self.foast_stage.foast_node.id, location=loc), - args=params_ref, - kwargs={"out": out_ref}, - location=loc, - ) - ], - closure_vars=closure_symbols, - location=loc, - ) - untyped_past_node = ProgramClosureVarTypeDeduction.apply(untyped_past_node, closure_vars) - past_node = ProgramTypeDeduction.apply(untyped_past_node) - - self._program_cache[hash_] = ProgramFromPast( - definition_stage=None, - past_stage=ffront_stages.PastProgramDefinition( - past_node=past_node, - closure_vars=closure_vars, - grid_type=self.foast_stage.grid_type, - ), - backend=self.backend, ) - - return self._program_cache[hash_] + return ProgramFromPast(definition_stage=None, past_stage=past_stage, backend=self.backend) def __call__( self, diff --git a/src/gt4py/next/ffront/foast_to_past.py b/src/gt4py/next/ffront/foast_to_past.py new file mode 100644 index 0000000000..2ba3104541 --- /dev/null +++ b/src/gt4py/next/ffront/foast_to_past.py @@ -0,0 +1,92 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2023, ETH Zurich +# All rights reserved. +# +# This file is part of the GT4Py project and the GridTools framework. +# GT4Py is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later + +from gt4py.eve import utils as eve_utils +from gt4py.next.ffront import ( + dialect_ast_enums, + program_ast as past, + stages as ffront_stages, + type_specifications as ts_ffront, +) +from gt4py.next.ffront.past_passes import closure_var_type_deduction, type_deduction +from gt4py.next.type_system import type_info, type_specifications as ts + + +def foast_to_past(inp: ffront_stages.FoastWithTypes) -> ffront_stages.PastProgramDefinition: + # TODO(tehrengruber): implement mechanism to deduce default values + # of arg and kwarg types + # TODO(tehrengruber): check foast operator has no out argument that clashes + # with the out argument of the program we generate here. + + loc = inp.foast_op_def.foast_node.location + # use a new UID generator to allow caching + param_sym_uids = eve_utils.UIDGenerator() + + type_ = inp.foast_op_def.foast_node.type + params_decl: list[past.Symbol] = [ + past.DataSymbol( + id=param_sym_uids.sequential_id(prefix="__sym"), + type=arg_type, + namespace=dialect_ast_enums.Namespace.LOCAL, + location=loc, + ) + for arg_type in inp.arg_types + ] + params_ref = [past.Name(id=pdecl.id, location=loc) for pdecl in params_decl] + out_sym: past.Symbol = past.DataSymbol( + id="out", + type=type_info.return_type( + type_, with_args=list(inp.arg_types), with_kwargs=inp.kwarg_types + ), + namespace=dialect_ast_enums.Namespace.LOCAL, + location=loc, + ) + out_ref = past.Name(id="out", location=loc) + + if inp.foast_op_def.foast_node.id in inp.foast_op_def.closure_vars: + raise RuntimeError("A closure variable has the same name as the field operator itself.") + closure_symbols: list[past.Symbol] = [ + past.Symbol( + id=inp.foast_op_def.foast_node.id, + type=ts.DeferredType(constraint=None), + namespace=dialect_ast_enums.Namespace.CLOSURE, + location=loc, + ), + ] + + untyped_past_node = past.Program( + id=f"__field_operator_{inp.foast_op_def.foast_node.id}", + type=ts.DeferredType(constraint=ts_ffront.ProgramType), + params=[*params_decl, out_sym], + body=[ + past.Call( + func=past.Name(id=inp.foast_op_def.foast_node.id, location=loc), + args=params_ref, + kwargs={"out": out_ref}, + location=loc, + ) + ], + closure_vars=closure_symbols, + location=loc, + ) + untyped_past_node = closure_var_type_deduction.ClosureVarTypeDeduction.apply( + untyped_past_node, inp.closure_vars + ) + past_node = type_deduction.ProgramTypeDeduction.apply(untyped_past_node) + + return ffront_stages.PastProgramDefinition( + past_node=past_node, + closure_vars=inp.closure_vars, + grid_type=inp.foast_op_def.grid_type, + ) diff --git a/src/gt4py/next/ffront/stages.py b/src/gt4py/next/ffront/stages.py index 51e610de3b..5e800b3989 100644 --- a/src/gt4py/next/ffront/stages.py +++ b/src/gt4py/next/ffront/stages.py @@ -22,6 +22,7 @@ from gt4py.eve import utils as eve_utils from gt4py.next import common from gt4py.next.ffront import field_operator_ast as foast, program_ast as past +from gt4py.next.type_system import type_specifications as ts OperatorNodeT = TypeVar("OperatorNodeT", bound=foast.LocatedNode) @@ -60,6 +61,22 @@ def hash_foast_operator_definition(foast_definition: FoastOperatorDefinition) -> ) +@dataclasses.dataclass(frozen=True) +class FoastWithTypes(Generic[OperatorNodeT]): + foast_op_def: FoastOperatorDefinition[OperatorNodeT] + arg_types: tuple[ts.TypeSpec] + kwarg_types: dict[str, ts.TypeSpec] + closure_vars: dict[str, Any] + + +def hash_foast_with_types(foast_with_types: FoastWithTypes) -> str: + return eve_utils.content_hash(( + foast_with_types.foast_op_def, + foast_with_types.arg_types, + tuple((name, arg) for name, arg in foast_with_types.kwarg_types.items()), + )) + + @dataclasses.dataclass(frozen=True) class ProgramDefinition: definition: types.FunctionType From 3357d115c812bf8d4012b7b7c8c8db188869ce5e Mon Sep 17 00:00:00 2001 From: DropD Date: Wed, 3 Apr 2024 11:24:05 +0200 Subject: [PATCH 05/30] make sure scan operator attributes are properly hashed --- src/gt4py/next/ffront/stages.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/gt4py/next/ffront/stages.py b/src/gt4py/next/ffront/stages.py index 5e800b3989..487cc54867 100644 --- a/src/gt4py/next/ffront/stages.py +++ b/src/gt4py/next/ffront/stages.py @@ -39,11 +39,10 @@ class FieldOperatorDefinition(Generic[OperatorNodeT]): def hash_field_operator_definition(fieldop_definition: FieldOperatorDefinition) -> str: return eve_utils.content_hash( fieldop_definition.definition.__name__, - hash(fieldop_definition.definition), inspect.getsourcelines(fieldop_definition.definition), fieldop_definition.grid_type, fieldop_definition.node_class, - fieldop_definition.attributes, + tuple(fieldop_definition.attributes.items()), ) @@ -57,7 +56,9 @@ class FoastOperatorDefinition(Generic[OperatorNodeT]): def hash_foast_operator_definition(foast_definition: FoastOperatorDefinition) -> str: return eve_utils.content_hash( - foast_definition.foast_node, foast_definition.grid_type, foast_definition.attributes + foast_definition.foast_node, + foast_definition.grid_type, + tuple(foast_definition.attributes.items()), ) From ebf164cdc304afbe1dcb4b0661d8be353d6e5d2d Mon Sep 17 00:00:00 2001 From: DropD Date: Thu, 4 Apr 2024 15:23:21 +0200 Subject: [PATCH 06/30] add closure vars to hash for program definition --- src/gt4py/next/ffront/stages.py | 38 ++++++++++++++++++++++++++------- 1 file changed, 30 insertions(+), 8 deletions(-) diff --git a/src/gt4py/next/ffront/stages.py b/src/gt4py/next/ffront/stages.py index 43ef33358c..54472d214d 100644 --- a/src/gt4py/next/ffront/stages.py +++ b/src/gt4py/next/ffront/stages.py @@ -35,15 +35,18 @@ class FieldOperatorDefinition(Generic[OperatorNodeT]): node_class: type[OperatorNodeT] = dataclasses.field(default=foast.FieldOperator) # type: ignore[assignment] # TODO(ricoh): understand why mypy complains attributes: dict[str, Any] = dataclasses.field(default_factory=dict) + def __getstate__(self) -> dict[str, Any]: + """Make the stage pickleable (but not unpickleable) for use with content_hash.""" + state = self.__dict__.copy() + state["name"] = self.definition.__name__ + state["source"] = inspect.getsource(self.definition) + state |= self.attributes + del state["definition"] + return state + def hash_field_operator_definition(fieldop_definition: FieldOperatorDefinition) -> str: - return eve_utils.content_hash( - fieldop_definition.definition.__name__, - inspect.getsourcelines(fieldop_definition.definition), - fieldop_definition.grid_type, - fieldop_definition.node_class, - tuple(fieldop_definition.attributes.items()), - ) + return eve_utils.content_hash(fieldop_definition) @dataclasses.dataclass(frozen=True) @@ -85,6 +88,14 @@ class ProgramDefinition: definition: types.FunctionType grid_type: Optional[common.GridType] = None + def __getstate__(self) -> dict[str, Any]: + """Make the stage pickleable (but not unpickleable) for use with content_hash.""" + state = self.__dict__.copy() + state["name"] = self.definition.__name__ + state["source"] = inspect.getsource(self.definition) + del state["definition"] + return state + @dataclasses.dataclass(frozen=True) class PastProgramDefinition: @@ -92,9 +103,20 @@ class PastProgramDefinition: closure_vars: dict[str, Any] grid_type: Optional[common.GridType] = None + def __getstate__(self) -> dict[str, Any]: + """Make the stage pickleable (but not unpickleable) for use with content_hash.""" + hashable_closure_vars = self.closure_vars.copy() + for k, v in self.closure_vars.items(): + if hasattr(v, "definition_stage"): + hashable_closure_vars[k] = v.definition_stage + hashable_closure_vars[f"{k}_backend"] = v.backend.__name__ if v.backend else "None" + state = self.__dict__.copy() + state["closure_vars"] = hashable_closure_vars + return state + def hash_past_program_definition(past_definition: PastProgramDefinition) -> str: - return eve_utils.content_hash(past_definition.past_node, past_definition.grid_type) + return eve_utils.content_hash(past_definition) @dataclasses.dataclass(frozen=True) From d3a78518b1a95fcfeecf0484f6217b04587375fc Mon Sep 17 00:00:00 2001 From: DropD Date: Thu, 4 Apr 2024 17:44:44 +0200 Subject: [PATCH 07/30] [wip] integrating fieldop workflows --- src/gt4py/next/backend.py | 63 ++++++++++++++++++++++---- src/gt4py/next/ffront/decorator.py | 57 +++++++++++++++++------ src/gt4py/next/ffront/foast_to_past.py | 39 +++++++++++++++- src/gt4py/next/ffront/stages.py | 58 ++++++++++++++++++++---- src/gt4py/next/otf/recipes.py | 16 +++++++ 5 files changed, 202 insertions(+), 31 deletions(-) diff --git a/src/gt4py/next/backend.py b/src/gt4py/next/backend.py index c0838cc425..206b4539c1 100644 --- a/src/gt4py/next/backend.py +++ b/src/gt4py/next/backend.py @@ -18,9 +18,11 @@ from typing import Any, Generic from gt4py._core import definitions as core_defs +from gt4py.eve import utils as eve_utils from gt4py.next import allocators as next_allocators from gt4py.next.ffront import ( foast_to_itir, + foast_to_past, func_to_foast, func_to_past, past_process_args, @@ -33,7 +35,7 @@ @dataclasses.dataclass(frozen=True) -class ArgsInjector(workflow.Workflow): +class ProgArgsInjector(workflow.Workflow): args: tuple[Any, ...] = dataclasses.field(default_factory=tuple) kwargs: dict[str, Any] = dataclasses.field(default_factory=dict) @@ -47,18 +49,42 @@ def __call__(self, inp: ffront_stages.PastProgramDefinition) -> ffront_stages.Pa ) +@dataclasses.dataclass(frozen=True) +class FopArgsInjector(workflow.Workflow): + args: tuple[Any, ...] = dataclasses.field(default_factory=tuple) + kwargs: dict[str, Any] = dataclasses.field(default_factory=dict) + from_fieldop: Any = None + + def __call__(self, inp: ffront_stages.FoastOperatorDefinition) -> ffront_stages.FoastClosure: + return ffront_stages.FoastClosure( + foast_op_def=inp, + args=self.args, + kwargs=self.kwargs, + closure_vars={inp.foast_node.id: self.from_fieldop}, + ) + + DEFAULT_FIELDOP_TRANSFORMS = recipes.FieldopTransformWorkflow( func_to_foast=func_to_foast.OptionalFuncToFoastFactory(cached=True), + foast_inject_args=FopArgsInjector(), + foast_to_past_closure=foast_to_past.FoastToPastClosure( + foast_to_past=workflow.CachedStep( + foast_to_past.foast_to_past, + hash_function=eve_utils.content_hash, + ) + ), + past_transform_args=past_process_args.past_process_args, + past_to_itir=past_to_itir.PastToItirFactory(), foast_to_itir=workflow.CachedStep( step=foast_to_itir.foast_to_itir, hash_function=ffront_stages.hash_foast_operator_definition ), ) -DEFAULT_TRANSFORMS = recipes.ProgramTransformWorkflow( +DEFAULT_PROG_TRANSFORMS = recipes.ProgramTransformWorkflow( func_to_past=func_to_past.OptionalFuncToPastFactory(cached=True), past_lint=past_linters.LinterFactory(), - past_inject_args=ArgsInjector(), + past_inject_args=ProgArgsInjector(), past_transform_args=past_process_args.past_process_args, past_to_itir=past_to_itir.PastToItirFactory(), ) @@ -68,15 +94,34 @@ def __call__(self, inp: ffront_stages.PastProgramDefinition) -> ffront_stages.Pa class Backend(Generic[core_defs.DeviceTypeT]): executor: ppi.ProgramExecutor allocator: next_allocators.FieldBufferAllocatorProtocol[core_defs.DeviceTypeT] - transforms_prog: recipes.ProgramTransformWorkflow = DEFAULT_TRANSFORMS + transforms_fop: recipes.FieldopTransformWorkflow = DEFAULT_FIELDOP_TRANSFORMS + transforms_prog: recipes.ProgramTransformWorkflow = DEFAULT_PROG_TRANSFORMS def __call__( - self, program: ffront_stages.ProgramDefinition, *args: tuple[Any], **kwargs: dict[str, Any] + self, + program: ffront_stages.ProgramDefinition | ffront_stages.FieldOperatorDefinition, + *args: tuple[Any], + **kwargs: dict[str, Any], ) -> None: - transformer = self.transforms_prog.replace( - past_inject_args=ArgsInjector(args=args, kwargs=kwargs) - ) - program_call = transformer(program) + if isinstance( + program, (ffront_stages.FieldOperatorDefinition, ffront_stages.FoastOperatorDefinition) + ): + offset_provider = kwargs.pop("offset_provider") + from_fieldop = kwargs.pop("from_fieldop") + transforms_fop = self.transforms_fop.replace( + foast_inject_args=FopArgsInjector( + args=args, kwargs=kwargs, from_fieldop=from_fieldop + ) + ) + program_call = transforms_fop(program) + program_call = dataclasses.replace( + program_call, kwargs=program_call.kwargs | {"offset_provider": offset_provider} + ) + else: + transforms_prog = self.transforms_prog.replace( + past_inject_args=ProgArgsInjector(args=args, kwargs=kwargs) + ) + program_call = transforms_prog(program) self.executor(program_call.program, *program_call.args, **program_call.kwargs) @property diff --git a/src/gt4py/next/ffront/decorator.py b/src/gt4py/next/ffront/decorator.py index 10fb86546c..19a3f22dea 100644 --- a/src/gt4py/next/ffront/decorator.py +++ b/src/gt4py/next/ffront/decorator.py @@ -103,12 +103,12 @@ def definition(self): def past_stage(self): if self.backend is not None and self.backend.transforms_prog is not None: return self.backend.transforms_prog.func_to_past(self.definition_stage) - return next_backend.DEFAULT_TRANSFORMS.func_to_past(self.definition_stage) + return next_backend.DEFAULT_PROG_TRANSFORMS.func_to_past(self.definition_stage) def __post_init__(self): if self.backend is not None and self.backend.transforms_prog is not None: self.backend.transforms_prog.past_lint(self.past_stage) - return next_backend.DEFAULT_TRANSFORMS.past_lint(self.past_stage) + return next_backend.DEFAULT_PROG_TRANSFORMS.past_lint(self.past_stage) @property def __name__(self) -> str: @@ -221,7 +221,7 @@ def __call__(self, *args, offset_provider: dict[str, Dimension], **kwargs): def __post_init__(self): if self.backend is not None and self.backend.transforms_prog is not None: self.backend.transforms_prog.past_lint(self.past_stage) - return next_backend.DEFAULT_TRANSFORMS.past_lint(self.past_stage) + return next_backend.DEFAULT_PROG_TRANSFORMS.past_lint(self.past_stage) @dataclasses.dataclass(frozen=True) @@ -399,6 +399,8 @@ def __post_init__(self): @functools.cached_property def foast_stage(self) -> ffront_stages.FoastOperatorDefinition: + if self.backend is not None and self.backend.transforms_fop is not None: + return self.backend.transforms_fop.func_to_foast(self.definition_stage) return next_backend.DEFAULT_FIELDOP_TRANSFORMS.func_to_foast(self.definition_stage) @property @@ -423,6 +425,8 @@ def with_grid_type(self, grid_type: GridType) -> FieldOperator: ) def __gt_itir__(self) -> itir.FunctionDefinition: + if self.backend is not None and self.backend.transforms_fop is not None: + return self.backend.transforms_fop.foast_to_itir(self.foast_stage) return next_backend.DEFAULT_FIELDOP_TRANSFORMS.foast_to_itir(self.foast_stage) def __gt_closure_vars__(self) -> dict[str, Any]: @@ -431,7 +435,7 @@ def __gt_closure_vars__(self) -> dict[str, Any]: def as_program( self, arg_types: list[ts.TypeSpec], kwarg_types: dict[str, ts.TypeSpec] ) -> Program: - past_stage = foast_to_past.foast_to_past( + foast_with_types = ( ffront_stages.FoastWithTypes( foast_op_def=self.foast_stage, arg_types=tuple(arg_types), @@ -439,6 +443,22 @@ def as_program( closure_vars={self.foast_stage.foast_node.id: self}, ), ) + past_stage = None + if self.backend is not None and self.backend.transforms_fop is not None: + past_stage = self.backend.transforms_fop.foast_to_past_closure.foast_to_past( + foast_with_types + ) + else: + past_stage = ( + next_backend.DEFAULT_FIELDOP_TRANSFORMS.foast_to_past_closure.foast_to_past( + ffront_stages.FoastWithTypes( + foast_op_def=self.foast_stage, + arg_types=tuple(arg_types), + kwarg_types=kwarg_types, + closure_vars={self.foast_stage.foast_node.id: self}, + ), + ) + ) return ProgramFromPast(definition_stage=None, past_stage=past_stage, backend=self.backend) def __call__(self, *args, **kwargs) -> None: @@ -456,15 +476,23 @@ def __call__(self, *args, **kwargs) -> None: ) # TODO(tehrengruber): check all offset providers are given # deduce argument types - arg_types = [] - for arg in args: - arg_types.append(type_translation.from_value(arg)) - kwarg_types = {} - for name, arg in kwargs.items(): - kwarg_types[name] = type_translation.from_value(arg) - - return self.as_program(arg_types, kwarg_types)( - *args, out, offset_provider=offset_provider, **kwargs + # arg_types = [] + # for arg in args: + # arg_types.append(type_translation.from_value(arg)) + # kwarg_types = {} + # for name, arg in kwargs.items(): + # kwarg_types[name] = type_translation.from_value(arg) + + # return self.as_program(arg_types, kwarg_types)( + # *args, out, offset_provider=offset_provider, **kwargs + # ) + return self.backend( + self.definition_stage, + *args, + out=out, + offset_provider=offset_provider, + from_fieldop=self, + **kwargs, ) else: attributes = ( @@ -493,6 +521,9 @@ def __call__(self, *args, **kwargs) -> None: class FieldOperatorFromFoast(FieldOperator): foast_stage: ffront_stages.FoastOperatorDefinition + def __call__(self, *args, **kwargs) -> None: + return self.backend(self.foast_stage, *args, from_fieldop=self, **kwargs) + @typing.overload def field_operator( diff --git a/src/gt4py/next/ffront/foast_to_past.py b/src/gt4py/next/ffront/foast_to_past.py index 2ba3104541..b3f06db92b 100644 --- a/src/gt4py/next/ffront/foast_to_past.py +++ b/src/gt4py/next/ffront/foast_to_past.py @@ -12,6 +12,8 @@ # # SPDX-License-Identifier: GPL-3.0-or-later +import dataclasses + from gt4py.eve import utils as eve_utils from gt4py.next.ffront import ( dialect_ast_enums, @@ -20,7 +22,8 @@ type_specifications as ts_ffront, ) from gt4py.next.ffront.past_passes import closure_var_type_deduction, type_deduction -from gt4py.next.type_system import type_info, type_specifications as ts +from gt4py.next.type_system import type_info, type_specifications as ts, type_translation +from gt4py.next.otf import workflow def foast_to_past(inp: ffront_stages.FoastWithTypes) -> ffront_stages.PastProgramDefinition: @@ -90,3 +93,37 @@ def foast_to_past(inp: ffront_stages.FoastWithTypes) -> ffront_stages.PastProgra closure_vars=inp.closure_vars, grid_type=inp.foast_op_def.grid_type, ) + + +@dataclasses.dataclass(frozen=True) +class FoastToPastClosure(workflow.NamedStepSequence): + foast_to_past: workflow.Workflow[ + ffront_stages.FoastWithTypes, ffront_stages.PastProgramDefinition + ] + + def __call__(self, inp: ffront_stages.FoastClosure) -> ffront_stages.PastClosure: + # TODO(tehrengruber): check all offset providers are given + # deduce argument types + arg_types = [] + for arg in inp.args: + arg_types.append(type_translation.from_value(arg)) + kwarg_types = {} + for name, arg in inp.kwargs.items(): + kwarg_types[name] = type_translation.from_value(arg) + + past_def = super().__call__( + ffront_stages.FoastWithTypes( + foast_op_def=inp.foast_op_def, + arg_types=arg_types, + kwarg_types=kwarg_types, + closure_vars=inp.closure_vars, + ) + ) + + return ffront_stages.PastClosure( + past_node=past_def.past_node, + closure_vars=past_def.closure_vars, + grid_type=past_def.grid_type, + args=inp.args, + kwargs=inp.kwargs, + ) diff --git a/src/gt4py/next/ffront/stages.py b/src/gt4py/next/ffront/stages.py index 54472d214d..9cceb91457 100644 --- a/src/gt4py/next/ffront/stages.py +++ b/src/gt4py/next/ffront/stages.py @@ -56,6 +56,17 @@ class FoastOperatorDefinition(Generic[OperatorNodeT]): grid_type: Optional[common.GridType] = None attributes: dict[str, Any] = dataclasses.field(default_factory=dict) + def __getstate__(self) -> dict[str, Any]: + """Make the stage pickleable (but not unpickleable) for use with content_hash.""" + hashable_closure_vars = {} + for k, v in self.closure_vars.items(): + if hasattr(v, "definition_stage"): + hashable_closure_vars[k] = v.definition_stage + hashable_closure_vars[f"{k}_backend"] = v.backend.__name__ if v.backend else "None" + state = self.__dict__.copy() + state["closure_vars"] = hashable_closure_vars + return state + def hash_foast_operator_definition(foast_definition: FoastOperatorDefinition) -> str: return eve_utils.content_hash( @@ -72,15 +83,46 @@ class FoastWithTypes(Generic[OperatorNodeT]): kwarg_types: dict[str, ts.TypeSpec] closure_vars: dict[str, Any] + def __getstate__(self) -> dict[str, Any]: + """Make the stage pickleable (but not unpickleable) for use with content_hash.""" + hashable_closure_vars = {} + for k, v in self.closure_vars.items(): + if hasattr(v, "definition_stage"): + hashable_closure_vars[k] = v.definition_stage + hashable_closure_vars[f"{k}_backend"] = v.backend.__name__ if v.backend else "None" + state = self.__dict__.copy() + state["closure_vars"] = hashable_closure_vars + return state + + +@dataclasses.dataclass(frozen=True) +class FoastClosure(Generic[OperatorNodeT]): + foast_op_def: FoastOperatorDefinition[OperatorNodeT] + args: tuple[Any, ...] + kwargs: dict[str, Any] + closure_vars: dict[str, Any] + + def __getstate__(self) -> dict[str, Any]: + """Make the stage pickleable (but not unpickleable) for use with content_hash.""" + hashable_closure_vars = {} + for k, v in self.closure_vars.items(): + if hasattr(v, "definition_stage"): + hashable_closure_vars[k] = v.definition_stage + hashable_closure_vars[f"{k}_backend"] = v.backend.__name__ if v.backend else "None" + state = self.__dict__.copy() + state["closure_vars"] = hashable_closure_vars + return state + def hash_foast_with_types(foast_with_types: FoastWithTypes) -> str: - return eve_utils.content_hash( - ( - foast_with_types.foast_op_def, - foast_with_types.arg_types, - tuple((name, arg) for name, arg in foast_with_types.kwarg_types.items()), - ) - ) + return eve_utils.content_hash(foast_with_types) + # return eve_utils.content_hash( + # ( + # foast_with_types.foast_op_def, + # foast_with_types.arg_types, + # tuple((name, arg) for name, arg in foast_with_types.kwarg_types.items()), + # ) + # ) @dataclasses.dataclass(frozen=True) @@ -105,7 +147,7 @@ class PastProgramDefinition: def __getstate__(self) -> dict[str, Any]: """Make the stage pickleable (but not unpickleable) for use with content_hash.""" - hashable_closure_vars = self.closure_vars.copy() + hashable_closure_vars = {} for k, v in self.closure_vars.items(): if hasattr(v, "definition_stage"): hashable_closure_vars[k] = v.definition_stage diff --git a/src/gt4py/next/otf/recipes.py b/src/gt4py/next/otf/recipes.py index 46f542206e..370e7ce9ab 100644 --- a/src/gt4py/next/otf/recipes.py +++ b/src/gt4py/next/otf/recipes.py @@ -28,8 +28,24 @@ class FieldopTransformWorkflow(workflow.NamedStepSequence): ffront_stages.FieldOperatorDefinition | ffront_stages.FoastOperatorDefinition, ffront_stages.FoastOperatorDefinition, ] + foast_inject_args: workflow.Workflow[ + ffront_stages.FoastOperatorDefinition, ffront_stages.FoastClosure + ] + foast_to_past_closure: workflow.Workflow[ffront_stages.FoastClosure, ffront_stages.PastClosure] + past_transform_args: workflow.Workflow[ffront_stages.PastClosure, ffront_stages.PastClosure] + past_to_itir: workflow.Workflow[ffront_stages.PastClosure, stages.ProgramCall] foast_to_itir: workflow.Workflow[ffront_stages.FoastOperatorDefinition, itir.Expr] + @property + def step_order(self): + return [ + "func_to_foast", + "foast_inject_args", + "foast_to_past_closure", + "past_transform_args", + "past_to_itir", + ] + @dataclasses.dataclass(frozen=True) class ProgramTransformWorkflow(workflow.NamedStepSequence): From a99fc485d56923182b6ea99ff323a14c6406ab75 Mon Sep 17 00:00:00 2001 From: DropD Date: Fri, 5 Apr 2024 09:11:48 +0200 Subject: [PATCH 08/30] update foast pretty printer doctest --- src/gt4py/next/ffront/foast_pretty_printer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/gt4py/next/ffront/foast_pretty_printer.py b/src/gt4py/next/ffront/foast_pretty_printer.py index e589ecb601..4fa80c4892 100644 --- a/src/gt4py/next/ffront/foast_pretty_printer.py +++ b/src/gt4py/next/ffront/foast_pretty_printer.py @@ -234,7 +234,7 @@ def pretty_format(node: foast.LocatedNode) -> str: >>> @field_operator ... def field_op(a: Field[[IDim], float64]) -> Field[[IDim], float64]: ... return a + 1.0 - >>> print(pretty_format(field_op.foast_node)) + >>> print(pretty_format(field_op.foast_stage.foast_node)) @field_operator def field_op(a: Field[[IDim], float64]) -> Field[[IDim], float64]: return a + 1.0 From c71e0bf25806eb50f4a7afe69f1972bb36ca8a58 Mon Sep 17 00:00:00 2001 From: DropD Date: Fri, 5 Apr 2024 09:25:04 +0200 Subject: [PATCH 09/30] fix code quality issues --- src/gt4py/next/backend.py | 2 +- src/gt4py/next/ffront/decorator.py | 13 -------- src/gt4py/next/ffront/foast_to_past.py | 4 +-- src/gt4py/next/ffront/func_to_foast.py | 4 +-- src/gt4py/next/ffront/past_passes/linters.py | 3 +- src/gt4py/next/ffront/stages.py | 30 +------------------ src/gt4py/next/otf/recipes.py | 2 +- .../unit_tests/otf_tests/test_workflow.py | 2 +- 8 files changed, 9 insertions(+), 51 deletions(-) diff --git a/src/gt4py/next/backend.py b/src/gt4py/next/backend.py index 206b4539c1..46c1d9acec 100644 --- a/src/gt4py/next/backend.py +++ b/src/gt4py/next/backend.py @@ -76,7 +76,7 @@ def __call__(self, inp: ffront_stages.FoastOperatorDefinition) -> ffront_stages. past_transform_args=past_process_args.past_process_args, past_to_itir=past_to_itir.PastToItirFactory(), foast_to_itir=workflow.CachedStep( - step=foast_to_itir.foast_to_itir, hash_function=ffront_stages.hash_foast_operator_definition + step=foast_to_itir.foast_to_itir, hash_function=eve_utils.content_hash ), ) diff --git a/src/gt4py/next/ffront/decorator.py b/src/gt4py/next/ffront/decorator.py index 19a3f22dea..d94b6d1fa4 100644 --- a/src/gt4py/next/ffront/decorator.py +++ b/src/gt4py/next/ffront/decorator.py @@ -39,7 +39,6 @@ from gt4py.next.embedded import operators as embedded_operators from gt4py.next.ffront import ( field_operator_ast as foast, - foast_to_past, past_process_args, past_to_itir, stages as ffront_stages, @@ -474,18 +473,6 @@ def __call__(self, *args, **kwargs) -> None: args, kwargs = type_info.canonicalize_arguments( self.foast_stage.foast_node.type, args, kwargs ) - # TODO(tehrengruber): check all offset providers are given - # deduce argument types - # arg_types = [] - # for arg in args: - # arg_types.append(type_translation.from_value(arg)) - # kwarg_types = {} - # for name, arg in kwargs.items(): - # kwarg_types[name] = type_translation.from_value(arg) - - # return self.as_program(arg_types, kwarg_types)( - # *args, out, offset_provider=offset_provider, **kwargs - # ) return self.backend( self.definition_stage, *args, diff --git a/src/gt4py/next/ffront/foast_to_past.py b/src/gt4py/next/ffront/foast_to_past.py index b3f06db92b..b2e6324860 100644 --- a/src/gt4py/next/ffront/foast_to_past.py +++ b/src/gt4py/next/ffront/foast_to_past.py @@ -22,8 +22,8 @@ type_specifications as ts_ffront, ) from gt4py.next.ffront.past_passes import closure_var_type_deduction, type_deduction -from gt4py.next.type_system import type_info, type_specifications as ts, type_translation from gt4py.next.otf import workflow +from gt4py.next.type_system import type_info, type_specifications as ts, type_translation def foast_to_past(inp: ffront_stages.FoastWithTypes) -> ffront_stages.PastProgramDefinition: @@ -114,7 +114,7 @@ def __call__(self, inp: ffront_stages.FoastClosure) -> ffront_stages.PastClosure past_def = super().__call__( ffront_stages.FoastWithTypes( foast_op_def=inp.foast_op_def, - arg_types=arg_types, + arg_types=tuple(arg_types), kwarg_types=kwarg_types, closure_vars=inp.closure_vars, ) diff --git a/src/gt4py/next/ffront/func_to_foast.py b/src/gt4py/next/ffront/func_to_foast.py index 1655560a3c..bcb400df4a 100644 --- a/src/gt4py/next/ffront/func_to_foast.py +++ b/src/gt4py/next/ffront/func_to_foast.py @@ -104,9 +104,7 @@ class Params: ] = func_to_foast cached = factory.Trait( step=factory.LazyAttribute( - lambda o: workflow.CachedStep( - step=o.workflow, hash_function=ffront_stages.hash_field_operator_definition - ) + lambda o: workflow.CachedStep(step=o.workflow, hash_function=eve.utils.content_hash) ) ) diff --git a/src/gt4py/next/ffront/past_passes/linters.py b/src/gt4py/next/ffront/past_passes/linters.py index c4e7933fc6..a1a734da66 100644 --- a/src/gt4py/next/ffront/past_passes/linters.py +++ b/src/gt4py/next/ffront/past_passes/linters.py @@ -14,6 +14,7 @@ import factory +from gt4py import eve from gt4py.next.ffront import gtcallable, stages as ffront_stages, transform_utils from gt4py.next.otf import workflow @@ -56,4 +57,4 @@ class Meta: model = workflow.CachedStep step = lint_misnamed_functions.chain(lint_undefined_symbols) - hash_function = ffront_stages.hash_past_program_definition + hash_function = eve.utils.content_hash diff --git a/src/gt4py/next/ffront/stages.py b/src/gt4py/next/ffront/stages.py index 9cceb91457..29b8a9c9ed 100644 --- a/src/gt4py/next/ffront/stages.py +++ b/src/gt4py/next/ffront/stages.py @@ -19,7 +19,6 @@ import types from typing import Any, Generic, Optional, TypeVar -from gt4py.eve import utils as eve_utils from gt4py.next import common from gt4py.next.ffront import field_operator_ast as foast, program_ast as past from gt4py.next.type_system import type_specifications as ts @@ -45,10 +44,6 @@ def __getstate__(self) -> dict[str, Any]: return state -def hash_field_operator_definition(fieldop_definition: FieldOperatorDefinition) -> str: - return eve_utils.content_hash(fieldop_definition) - - @dataclasses.dataclass(frozen=True) class FoastOperatorDefinition(Generic[OperatorNodeT]): foast_node: OperatorNodeT @@ -68,18 +63,10 @@ def __getstate__(self) -> dict[str, Any]: return state -def hash_foast_operator_definition(foast_definition: FoastOperatorDefinition) -> str: - return eve_utils.content_hash( - foast_definition.foast_node, - foast_definition.grid_type, - tuple(foast_definition.attributes.items()), - ) - - @dataclasses.dataclass(frozen=True) class FoastWithTypes(Generic[OperatorNodeT]): foast_op_def: FoastOperatorDefinition[OperatorNodeT] - arg_types: tuple[ts.TypeSpec] + arg_types: tuple[ts.TypeSpec, ...] kwarg_types: dict[str, ts.TypeSpec] closure_vars: dict[str, Any] @@ -114,17 +101,6 @@ def __getstate__(self) -> dict[str, Any]: return state -def hash_foast_with_types(foast_with_types: FoastWithTypes) -> str: - return eve_utils.content_hash(foast_with_types) - # return eve_utils.content_hash( - # ( - # foast_with_types.foast_op_def, - # foast_with_types.arg_types, - # tuple((name, arg) for name, arg in foast_with_types.kwarg_types.items()), - # ) - # ) - - @dataclasses.dataclass(frozen=True) class ProgramDefinition: definition: types.FunctionType @@ -157,10 +133,6 @@ def __getstate__(self) -> dict[str, Any]: return state -def hash_past_program_definition(past_definition: PastProgramDefinition) -> str: - return eve_utils.content_hash(past_definition) - - @dataclasses.dataclass(frozen=True) class PastClosure: closure_vars: dict[str, Any] diff --git a/src/gt4py/next/otf/recipes.py b/src/gt4py/next/otf/recipes.py index 370e7ce9ab..87188ac9ba 100644 --- a/src/gt4py/next/otf/recipes.py +++ b/src/gt4py/next/otf/recipes.py @@ -37,7 +37,7 @@ class FieldopTransformWorkflow(workflow.NamedStepSequence): foast_to_itir: workflow.Workflow[ffront_stages.FoastOperatorDefinition, itir.Expr] @property - def step_order(self): + def step_order(self) -> list[str]: return [ "func_to_foast", "foast_inject_args", diff --git a/tests/next_tests/unit_tests/otf_tests/test_workflow.py b/tests/next_tests/unit_tests/otf_tests/test_workflow.py index 9e1e14edaf..2a274c0110 100644 --- a/tests/next_tests/unit_tests/otf_tests/test_workflow.py +++ b/tests/next_tests/unit_tests/otf_tests/test_workflow.py @@ -77,7 +77,7 @@ def test_cached_with_hashing(): def hashing(inp: list[int]) -> int: return hash(sum(inp)) - wf = workflow.CachedStep(step=lambda inp: inp + [1], hash_function=hashing) + wf = workflow.CachedStep(step=lambda inp: [*inp, 1], hash_function=hashing) assert wf([1, 2, 3]) == [1, 2, 3, 1] assert wf([3, 2, 1]) == [1, 2, 3, 1] From df648e8b211bb5ec065a1866d5ba69471e337eea Mon Sep 17 00:00:00 2001 From: DropD Date: Fri, 5 Apr 2024 10:18:02 +0200 Subject: [PATCH 10/30] reuse content hashing code in ffront.stages --- src/gt4py/next/ffront/stages.py | 159 +++++++++++++++++--------------- 1 file changed, 83 insertions(+), 76 deletions(-) diff --git a/src/gt4py/next/ffront/stages.py b/src/gt4py/next/ffront/stages.py index 29b8a9c9ed..a220ba995f 100644 --- a/src/gt4py/next/ffront/stages.py +++ b/src/gt4py/next/ffront/stages.py @@ -27,115 +27,122 @@ OperatorNodeT = TypeVar("OperatorNodeT", bound=foast.LocatedNode) -@dataclasses.dataclass(frozen=True) -class FieldOperatorDefinition(Generic[OperatorNodeT]): - definition: types.FunctionType - grid_type: Optional[common.GridType] = None - node_class: type[OperatorNodeT] = dataclasses.field(default=foast.FieldOperator) # type: ignore[assignment] # TODO(ricoh): understand why mypy complains - attributes: dict[str, Any] = dataclasses.field(default_factory=dict) +class ContentHashableMixin: + """ + Allows deriving dataclasses to modify what goes into the content hash per-field. + + Warning: Using this will modify how the class gets pickled. If unpickling is desired, + extra care has to be taken. The hasher must not remove crucial data and any modifications + have to be undone while loading in the __setstate__ method (which needs to be implemented). + + In fact, when unpickling is required, it is probably best to implement both + __setstate__ and __getstate__ by hand for the entire class rather than per-field. + """ def __getstate__(self) -> dict[str, Any]: - """Make the stage pickleable (but not unpickleable) for use with content_hash.""" + if not dataclasses.is_dataclass(self): + raise TypeError(f"'{self.__class__}' is not a dataclass.") state = self.__dict__.copy() - state["name"] = self.definition.__name__ - state["source"] = inspect.getsource(self.definition) - state |= self.attributes - del state["definition"] + for field in dataclasses.fields(self): + if "content_hasher" in field.metadata: + field.metadata["content_hasher"](state, getattr(self, field.name), field.name) return state +def function_type_hasher(state: dict[str, Any], value: types.FunctionType, name: str) -> None: + state[f"{name}__name"] = value.__name__ + state[f"{name}__source"] = inspect.getsource(value) + del state[name] + + +def simple_dict_hasher(state: dict[str, Any], value: dict[str, Any], name: str) -> None: + for k, v in value.items(): + state[f"{name}__{k}"] = v + + +def closure_vars_hasher(state: dict[str, Any], value: dict[str, Any], name: str) -> None: + hashable_closure_vars = {} + for k, v in value.items(): + # replace the decorator with the earliest canonical dsl representation available + if hasattr(v, "definition_stage"): + hashable_closure_vars[k] = v.definition_stage + elif hasattr(v, "foast_stage"): + hashable_closure_vars[k] = v.foast_stage + elif hasattr(v, "past_stage"): + hashable_closure_vars[k] = v.past_stage + # put the backend into the hash because it may influence the toolchain + # TODO(ricoh): This is not perfect, since backend names are allowed to clash (low priority). + if hasattr(v, "backend"): + hashable_closure_vars[f"{k}_backend"] = v.backend.__name__ if v.backend else "None" + state[name] = hashable_closure_vars + + @dataclasses.dataclass(frozen=True) -class FoastOperatorDefinition(Generic[OperatorNodeT]): - foast_node: OperatorNodeT - closure_vars: dict[str, Any] +class FieldOperatorDefinition(ContentHashableMixin, Generic[OperatorNodeT]): + definition: types.FunctionType = dataclasses.field( + metadata={"content_hasher": function_type_hasher} + ) grid_type: Optional[common.GridType] = None - attributes: dict[str, Any] = dataclasses.field(default_factory=dict) + node_class: type[OperatorNodeT] = dataclasses.field(default=foast.FieldOperator) # type: ignore[assignment] # TODO(ricoh): understand why mypy complains + attributes: dict[str, Any] = dataclasses.field( + default_factory=dict, metadata={"content_hasher": simple_dict_hasher} + ) - def __getstate__(self) -> dict[str, Any]: - """Make the stage pickleable (but not unpickleable) for use with content_hash.""" - hashable_closure_vars = {} - for k, v in self.closure_vars.items(): - if hasattr(v, "definition_stage"): - hashable_closure_vars[k] = v.definition_stage - hashable_closure_vars[f"{k}_backend"] = v.backend.__name__ if v.backend else "None" - state = self.__dict__.copy() - state["closure_vars"] = hashable_closure_vars - return state + +@dataclasses.dataclass(frozen=True) +class FoastOperatorDefinition(ContentHashableMixin, Generic[OperatorNodeT]): + foast_node: OperatorNodeT + closure_vars: dict[str, Any] = dataclasses.field( + metadata={"content_hasher": closure_vars_hasher} + ) + grid_type: Optional[common.GridType] = None + attributes: dict[str, Any] = dataclasses.field( + default_factory=dict, metadata={"content_hasher": simple_dict_hasher} + ) @dataclasses.dataclass(frozen=True) -class FoastWithTypes(Generic[OperatorNodeT]): +class FoastWithTypes(ContentHashableMixin, Generic[OperatorNodeT]): foast_op_def: FoastOperatorDefinition[OperatorNodeT] arg_types: tuple[ts.TypeSpec, ...] kwarg_types: dict[str, ts.TypeSpec] - closure_vars: dict[str, Any] - - def __getstate__(self) -> dict[str, Any]: - """Make the stage pickleable (but not unpickleable) for use with content_hash.""" - hashable_closure_vars = {} - for k, v in self.closure_vars.items(): - if hasattr(v, "definition_stage"): - hashable_closure_vars[k] = v.definition_stage - hashable_closure_vars[f"{k}_backend"] = v.backend.__name__ if v.backend else "None" - state = self.__dict__.copy() - state["closure_vars"] = hashable_closure_vars - return state + closure_vars: dict[str, Any] = dataclasses.field( + metadata={"content_hasher": closure_vars_hasher} + ) @dataclasses.dataclass(frozen=True) -class FoastClosure(Generic[OperatorNodeT]): +class FoastClosure(ContentHashableMixin, Generic[OperatorNodeT]): foast_op_def: FoastOperatorDefinition[OperatorNodeT] args: tuple[Any, ...] kwargs: dict[str, Any] - closure_vars: dict[str, Any] - - def __getstate__(self) -> dict[str, Any]: - """Make the stage pickleable (but not unpickleable) for use with content_hash.""" - hashable_closure_vars = {} - for k, v in self.closure_vars.items(): - if hasattr(v, "definition_stage"): - hashable_closure_vars[k] = v.definition_stage - hashable_closure_vars[f"{k}_backend"] = v.backend.__name__ if v.backend else "None" - state = self.__dict__.copy() - state["closure_vars"] = hashable_closure_vars - return state + closure_vars: dict[str, Any] = dataclasses.field( + metadata={"content_hasher": closure_vars_hasher} + ) @dataclasses.dataclass(frozen=True) -class ProgramDefinition: - definition: types.FunctionType +class ProgramDefinition(ContentHashableMixin): + definition: types.FunctionType = dataclasses.field( + metadata={"content_hasher": function_type_hasher} + ) grid_type: Optional[common.GridType] = None - def __getstate__(self) -> dict[str, Any]: - """Make the stage pickleable (but not unpickleable) for use with content_hash.""" - state = self.__dict__.copy() - state["name"] = self.definition.__name__ - state["source"] = inspect.getsource(self.definition) - del state["definition"] - return state - @dataclasses.dataclass(frozen=True) -class PastProgramDefinition: +class PastProgramDefinition(ContentHashableMixin): past_node: past.Program - closure_vars: dict[str, Any] + closure_vars: dict[str, Any] = dataclasses.field( + metadata={"content_hasher": closure_vars_hasher} + ) grid_type: Optional[common.GridType] = None - def __getstate__(self) -> dict[str, Any]: - """Make the stage pickleable (but not unpickleable) for use with content_hash.""" - hashable_closure_vars = {} - for k, v in self.closure_vars.items(): - if hasattr(v, "definition_stage"): - hashable_closure_vars[k] = v.definition_stage - hashable_closure_vars[f"{k}_backend"] = v.backend.__name__ if v.backend else "None" - state = self.__dict__.copy() - state["closure_vars"] = hashable_closure_vars - return state - @dataclasses.dataclass(frozen=True) -class PastClosure: - closure_vars: dict[str, Any] +class PastClosure(ContentHashableMixin): + closure_vars: dict[str, Any] = dataclasses.field( + metadata={"content_hasher": closure_vars_hasher} + ) past_node: past.Program grid_type: Optional[common.GridType] args: tuple[Any, ...] From 1f5c336c84988416afa33178eb3849c8cfd6a8ea Mon Sep 17 00:00:00 2001 From: DropD Date: Fri, 5 Apr 2024 10:24:33 +0200 Subject: [PATCH 11/30] remove erroneously committed .python-version --- .python-version | 2 -- 1 file changed, 2 deletions(-) delete mode 100644 .python-version diff --git a/.python-version b/.python-version deleted file mode 100644 index 6d1a8d63b5..0000000000 --- a/.python-version +++ /dev/null @@ -1,2 +0,0 @@ -3.11 -3.10 From 0a54c3d6036bcab40369eaa726f2e5f34ccf48ae Mon Sep 17 00:00:00 2001 From: DropD Date: Tue, 9 Apr 2024 10:09:32 +0200 Subject: [PATCH 12/30] re-apply Program.itir fix after merge --- src/gt4py/next/ffront/decorator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/gt4py/next/ffront/decorator.py b/src/gt4py/next/ffront/decorator.py index d94b6d1fa4..820a45e70e 100644 --- a/src/gt4py/next/ffront/decorator.py +++ b/src/gt4py/next/ffront/decorator.py @@ -178,7 +178,7 @@ def itir(self) -> itir.FencilDefinition: kwargs={}, ) if self.backend is not None and self.backend.transforms_prog is not None: - return self.backend.transforms_prog.past_to_itir(no_args_past) + return self.backend.transforms_prog.past_to_itir(no_args_past).program return past_to_itir.PastToItirFactory()(no_args_past).program def __call__(self, *args, offset_provider: dict[str, Dimension], **kwargs: Any) -> None: From e1ed8d20a6ab0b4caf76e14d57dd570b513db2f9 Mon Sep 17 00:00:00 2001 From: DropD Date: Thu, 11 Apr 2024 11:53:04 +0200 Subject: [PATCH 13/30] move backend transforms to `next.backend` --- src/gt4py/next/backend.py | 56 +++++++++++++++++++++++++++++++---- src/gt4py/next/otf/recipes.py | 47 ----------------------------- 2 files changed, 51 insertions(+), 52 deletions(-) diff --git a/src/gt4py/next/backend.py b/src/gt4py/next/backend.py index 46c1d9acec..4842bca515 100644 --- a/src/gt4py/next/backend.py +++ b/src/gt4py/next/backend.py @@ -30,10 +30,56 @@ stages as ffront_stages, ) from gt4py.next.ffront.past_passes import linters as past_linters -from gt4py.next.otf import recipes, workflow +from gt4py.next.iterator import ir as itir +from gt4py.next.otf import stages, workflow from gt4py.next.program_processors import processor_interface as ppi +@dataclasses.dataclass(frozen=True) +class FieldopTransformWorkflow(workflow.NamedStepSequence): + """Modular workflow for transformations with access to intermediates.""" + + func_to_foast: workflow.SkippableStep[ + ffront_stages.FieldOperatorDefinition | ffront_stages.FoastOperatorDefinition, + ffront_stages.FoastOperatorDefinition, + ] + foast_inject_args: workflow.Workflow[ + ffront_stages.FoastOperatorDefinition, ffront_stages.FoastClosure + ] + foast_to_past_closure: workflow.Workflow[ffront_stages.FoastClosure, ffront_stages.PastClosure] + past_transform_args: workflow.Workflow[ffront_stages.PastClosure, ffront_stages.PastClosure] + past_to_itir: workflow.Workflow[ffront_stages.PastClosure, stages.ProgramCall] + foast_to_itir: workflow.Workflow[ffront_stages.FoastOperatorDefinition, itir.Expr] + + @property + def step_order(self) -> list[str]: + return [ + "func_to_foast", + "foast_inject_args", + "foast_to_past_closure", + "past_transform_args", + "past_to_itir", + ] + + +@dataclasses.dataclass(frozen=True) +class ProgramTransformWorkflow(workflow.NamedStepSequence): + """Modular workflow for transformations with access to intermediates.""" + + func_to_past: workflow.SkippableStep[ + ffront_stages.ProgramDefinition | ffront_stages.PastProgramDefinition, + ffront_stages.PastProgramDefinition, + ] + past_lint: workflow.Workflow[ + ffront_stages.PastProgramDefinition, ffront_stages.PastProgramDefinition + ] + past_inject_args: workflow.Workflow[ + ffront_stages.PastProgramDefinition, ffront_stages.PastClosure + ] + past_transform_args: workflow.Workflow[ffront_stages.PastClosure, ffront_stages.PastClosure] + past_to_itir: workflow.Workflow[ffront_stages.PastClosure, stages.ProgramCall] + + @dataclasses.dataclass(frozen=True) class ProgArgsInjector(workflow.Workflow): args: tuple[Any, ...] = dataclasses.field(default_factory=tuple) @@ -64,7 +110,7 @@ def __call__(self, inp: ffront_stages.FoastOperatorDefinition) -> ffront_stages. ) -DEFAULT_FIELDOP_TRANSFORMS = recipes.FieldopTransformWorkflow( +DEFAULT_FIELDOP_TRANSFORMS = FieldopTransformWorkflow( func_to_foast=func_to_foast.OptionalFuncToFoastFactory(cached=True), foast_inject_args=FopArgsInjector(), foast_to_past_closure=foast_to_past.FoastToPastClosure( @@ -81,7 +127,7 @@ def __call__(self, inp: ffront_stages.FoastOperatorDefinition) -> ffront_stages. ) -DEFAULT_PROG_TRANSFORMS = recipes.ProgramTransformWorkflow( +DEFAULT_PROG_TRANSFORMS = ProgramTransformWorkflow( func_to_past=func_to_past.OptionalFuncToPastFactory(cached=True), past_lint=past_linters.LinterFactory(), past_inject_args=ProgArgsInjector(), @@ -94,8 +140,8 @@ def __call__(self, inp: ffront_stages.FoastOperatorDefinition) -> ffront_stages. class Backend(Generic[core_defs.DeviceTypeT]): executor: ppi.ProgramExecutor allocator: next_allocators.FieldBufferAllocatorProtocol[core_defs.DeviceTypeT] - transforms_fop: recipes.FieldopTransformWorkflow = DEFAULT_FIELDOP_TRANSFORMS - transforms_prog: recipes.ProgramTransformWorkflow = DEFAULT_PROG_TRANSFORMS + transforms_fop: FieldopTransformWorkflow = DEFAULT_FIELDOP_TRANSFORMS + transforms_prog: ProgramTransformWorkflow = DEFAULT_PROG_TRANSFORMS def __call__( self, diff --git a/src/gt4py/next/otf/recipes.py b/src/gt4py/next/otf/recipes.py index 87188ac9ba..982e2e9b7b 100644 --- a/src/gt4py/next/otf/recipes.py +++ b/src/gt4py/next/otf/recipes.py @@ -15,56 +15,9 @@ import dataclasses -from gt4py.next.ffront import stages as ffront_stages -from gt4py.next.iterator import ir as itir from gt4py.next.otf import stages, step_types, workflow -@dataclasses.dataclass(frozen=True) -class FieldopTransformWorkflow(workflow.NamedStepSequence): - """Modular workflow for transformations with access to intermediates.""" - - func_to_foast: workflow.SkippableStep[ - ffront_stages.FieldOperatorDefinition | ffront_stages.FoastOperatorDefinition, - ffront_stages.FoastOperatorDefinition, - ] - foast_inject_args: workflow.Workflow[ - ffront_stages.FoastOperatorDefinition, ffront_stages.FoastClosure - ] - foast_to_past_closure: workflow.Workflow[ffront_stages.FoastClosure, ffront_stages.PastClosure] - past_transform_args: workflow.Workflow[ffront_stages.PastClosure, ffront_stages.PastClosure] - past_to_itir: workflow.Workflow[ffront_stages.PastClosure, stages.ProgramCall] - foast_to_itir: workflow.Workflow[ffront_stages.FoastOperatorDefinition, itir.Expr] - - @property - def step_order(self) -> list[str]: - return [ - "func_to_foast", - "foast_inject_args", - "foast_to_past_closure", - "past_transform_args", - "past_to_itir", - ] - - -@dataclasses.dataclass(frozen=True) -class ProgramTransformWorkflow(workflow.NamedStepSequence): - """Modular workflow for transformations with access to intermediates.""" - - func_to_past: workflow.SkippableStep[ - ffront_stages.ProgramDefinition | ffront_stages.PastProgramDefinition, - ffront_stages.PastProgramDefinition, - ] - past_lint: workflow.Workflow[ - ffront_stages.PastProgramDefinition, ffront_stages.PastProgramDefinition - ] - past_inject_args: workflow.Workflow[ - ffront_stages.PastProgramDefinition, ffront_stages.PastClosure - ] - past_transform_args: workflow.Workflow[ffront_stages.PastClosure, ffront_stages.PastClosure] - past_to_itir: workflow.Workflow[ffront_stages.PastClosure, stages.ProgramCall] - - @dataclasses.dataclass(frozen=True) class OTFCompileWorkflow(workflow.NamedStepSequence): """The typical compiled backend steps composed into a workflow.""" From 3e30ca2c5397a57fb139dcab6ad9d213315906e8 Mon Sep 17 00:00:00 2001 From: DropD Date: Fri, 12 Apr 2024 10:23:12 +0200 Subject: [PATCH 14/30] add tested toolchain workthrough notebook --- .../next/Advanced_ToolchainWalkthrough.md | 714 ++++++++++++++++++ tox.ini | 2 + 2 files changed, 716 insertions(+) create mode 100644 docs/user/next/Advanced_ToolchainWalkthrough.md diff --git a/docs/user/next/Advanced_ToolchainWalkthrough.md b/docs/user/next/Advanced_ToolchainWalkthrough.md new file mode 100644 index 0000000000..bb3c5d5bbc --- /dev/null +++ b/docs/user/next/Advanced_ToolchainWalkthrough.md @@ -0,0 +1,714 @@ +```python +import dataclasses +import inspect + +import gt4py.next as gtx +from gt4py.next import backend + +import devtools +``` + + + + +```python +I = gtx.Dimension("I") +Ioff = gtx.FieldOffset("Ioff", source=I, target=(I,)) +OFFSET_PROVIDER = {"Ioff": I} +``` + +# Toolchain Overview + +```mermaid +graph LR + +fdef(FieldOperatorDefinition) -->|func_to_foast| foast(FoastOperatorDefinition) +foast -->|foast_to_itir| itir_expr(itir.Expr) +foast -->|foast_inject_args| fclos(FoastClosure) +fclos -->|foast_to_past_closure| pclos(PastClosure) +pclos -->|past_process_args| pclos +pclos -->|past_to_itir| pcall(ProgramCall) + +pdef(ProgramDefinition) -->|func_to_past| past(PastProgramDefinition) +past -->|past_lint| past +past -->|past_inject_args| pclos(ProgramClosure) +``` + +# Walkthrough from Field Operator + +## Starting Out + +```python +@gtx.field_operator +def example_fo(a: gtx.Field[[I], gtx.float64]) -> gtx.Field[[I], gtx.float64]: + return a + 1.0 +``` + +```python +start = example_fo.definition_stage +``` + +```python +gtx.ffront.stages.FieldOperatorDefinition? +``` + + Init signature: + gtx.ffront.stages.FieldOperatorDefinition( +  definition: 'types.FunctionType', +  grid_type: 'Optional[common.GridType]' = None, +  node_class: 'type[OperatorNodeT]' = <class 'gt4py.next.ffront.field_operator_ast.FieldOperator'>, +  attributes: 'dict[str, Any]' = <factory>, + ) -> None + Docstring: FieldOperatorDefinition(definition: 'types.FunctionType', grid_type: 'Optional[common.GridType]' = None, node_class: 'type[OperatorNodeT]' = , attributes: 'dict[str, Any]' = ) + File: ~/Code/gt4py/src/gt4py/next/ffront/stages.py + Type: type + Subclasses: + +## DSL -> FOAST + +```mermaid +graph LR + +fdef(FieldOperatorDefinition) -->|func_to_foast| foast(FoastOperatorDefinition) +foast -->|foast_to_itir| itir_expr(itir.Expr) +foast -->|foast_inject_args| fclos(FoastClosure) +fclos -->|foast_to_past_closure| pclos(PastClosure) +pclos -->|past_process_args| pclos +pclos -->|past_to_itir| pcall(ProgramCall) + +pdef(ProgramDefinition) -->|func_to_past| past(PastProgramDefinition) +past -->|past_lint| past +past -->|past_inject_args| pclos(ProgramClosure) + +style fdef fill:red +style foast fill:red +linkStyle 0 stroke:red,stroke-width:4px,color:pink +``` + +```python +foast = backend.DEFAULT_FIELDOP_TRANSFORMS.func_to_foast(start) +``` + +```python +gtx.ffront.stages.FoastOperatorDefinition? +``` + + Init signature: + gtx.ffront.stages.FoastOperatorDefinition( +  foast_node: 'OperatorNodeT', +  closure_vars: 'dict[str, Any]', +  grid_type: 'Optional[common.GridType]' = None, +  attributes: 'dict[str, Any]' = <factory>, + ) -> None + Docstring: FoastOperatorDefinition(foast_node: 'OperatorNodeT', closure_vars: 'dict[str, Any]', grid_type: 'Optional[common.GridType]' = None, attributes: 'dict[str, Any]' = ) + File: ~/Code/gt4py/src/gt4py/next/ffront/stages.py + Type: type + Subclasses: + +## FOAST -> ITIR + +This also happens inside the `decorator.FieldOperator.__gt_itir__` method during the lowering from calling Programs to ITIR + +```mermaid +graph LR + +fdef(FieldOperatorDefinition) -->|func_to_foast| foast(FoastOperatorDefinition) +foast -->|foast_to_itir| itir_expr(itir.Expr) +foast -->|foast_inject_args| fclos(FoastClosure) +fclos -->|foast_to_past_closure| pclos(PastClosure) +pclos -->|past_process_args| pclos +pclos -->|past_to_itir| pcall(ProgramCall) + +pdef(ProgramDefinition) -->|func_to_past| past(PastProgramDefinition) +past -->|past_lint| past +past -->|past_inject_args| pclos(ProgramClosure) + +style foast fill:red +style itir_expr fill:red +linkStyle 1 stroke:red,stroke-width:4px,color:pink +``` + +```python +fitir = backend.DEFAULT_FIELDOP_TRANSFORMS.foast_to_itir(foast) +``` + +```python +fitir.__class__ +``` + + gt4py.next.iterator.ir.FunctionDefinition + +## FOAST -> FOAST closure + +This is preparation for "directly calling" a field operator. + +```mermaid +graph LR + +fdef(FieldOperatorDefinition) -->|func_to_foast| foast(FoastOperatorDefinition) +foast -->|foast_to_itir| itir_expr(itir.Expr) +foast -->|foast_inject_args| fclos(FoastClosure) +fclos -->|foast_to_past_closure| pclos(PastClosure) +pclos -->|past_process_args| pclos +pclos -->|past_to_itir| pcall(ProgramCall) + +pdef(ProgramDefinition) -->|func_to_past| past(PastProgramDefinition) +past -->|past_lint| past +past -->|past_inject_args| pclos(ProgramClosure) + +style foast fill:red +style fclos fill:red +linkStyle 2 stroke:red,stroke-width:4px,color:pink +``` + +Here we have to dynamically generate a workflow step, because the arguments were not known before. + +```python +fclos = backend.DEFAULT_FIELDOP_TRANSFORMS.foast_inject_args.__class__( + args=(gtx.ones(domain={I: 10}, dtype=gtx.float64),), + kwargs={ + "out": gtx.zeros(domain={I: 10}, dtype=gtx.float64) + }, + from_fieldop=example_fo +)(foast) +``` + +```python +gtx.ffront.stages.FoastClosure? +``` + + Init signature: + gtx.ffront.stages.FoastClosure( +  foast_op_def: 'FoastOperatorDefinition[OperatorNodeT]', +  args: 'tuple[Any, ...]', +  kwargs: 'dict[str, Any]', +  closure_vars: 'dict[str, Any]', + ) -> None + Docstring: FoastClosure(foast_op_def: 'FoastOperatorDefinition[OperatorNodeT]', args: 'tuple[Any, ...]', kwargs: 'dict[str, Any]', closure_vars: 'dict[str, Any]') + File: ~/Code/gt4py/src/gt4py/next/ffront/stages.py + Type: type + Subclasses: + +## FOAST with args -> PAST closure + +This auto-generates a program for us, directly in PAST representation and forwards the call arguments to it + +```mermaid +graph LR + +fdef(FieldOperatorDefinition) -->|func_to_foast| foast(FoastOperatorDefinition) +foast -->|foast_to_itir| itir_expr(itir.Expr) +foast -->|foast_inject_args| fclos(FoastClosure) +fclos -->|foast_to_past_closure| pclos(PastClosure) +pclos -->|past_process_args| pclos +pclos -->|past_to_itir| pcall(ProgramCall) + +pdef(ProgramDefinition) -->|func_to_past| past(PastProgramDefinition) +past -->|past_lint| past +past -->|past_inject_args| pclos(ProgramClosure) + +style fclos fill:red +style pclos fill:red +linkStyle 3 stroke:red,stroke-width:4px,color:pink +``` + +```python +pclos = backend.DEFAULT_FIELDOP_TRANSFORMS.foast_to_past_closure(fclos) +``` + +```python +gtx.ffront.stages.PastClosure? +``` + + Init signature: + gtx.ffront.stages.PastClosure( +  closure_vars: 'dict[str, Any]', +  past_node: 'past.Program', +  grid_type: 'Optional[common.GridType]', +  args: 'tuple[Any, ...]', +  kwargs: 'dict[str, Any]', + ) -> None + Docstring: PastClosure(closure_vars: 'dict[str, Any]', past_node: 'past.Program', grid_type: 'Optional[common.GridType]', args: 'tuple[Any, ...]', kwargs: 'dict[str, Any]') + File: ~/Code/gt4py/src/gt4py/next/ffront/stages.py + Type: type + Subclasses: + +## Transform PAST closure arguments + +Don't ask me, seems to be necessary though + +```mermaid +graph LR + +fdef(FieldOperatorDefinition) -->|func_to_foast| foast(FoastOperatorDefinition) +foast -->|foast_to_itir| itir_expr(itir.Expr) +foast -->|foast_inject_args| fclos(FoastClosure) +fclos -->|foast_to_past_closure| pclos(PastClosure) +pclos -->|past_process_args| pclos +pclos -->|past_to_itir| pcall(ProgramCall) + +pdef(ProgramDefinition) -->|func_to_past| past(PastProgramDefinition) +past -->|past_lint| past +past -->|past_inject_args| pclos(ProgramClosure) + +style pclos fill:red +%%style pclos fill:red +linkStyle 4 stroke:red,stroke-width:4px,color:pink +``` + +```python +pclost = backend.DEFAULT_PROG_TRANSFORMS.past_transform_args(pclos) +``` + +## Lower PAST -> ITIR + +still forwarding the call arguments + +```mermaid +graph LR + +fdef(FieldOperatorDefinition) -->|func_to_foast| foast(FoastOperatorDefinition) +foast -->|foast_to_itir| itir_expr(itir.Expr) +foast -->|foast_inject_args| fclos(FoastClosure) +fclos -->|foast_to_past_closure| pclos(PastClosure) +pclos -->|past_process_args| pclos +pclos -->|past_to_itir| pcall(ProgramCall) + +pdef(ProgramDefinition) -->|func_to_past| past(PastProgramDefinition) +past -->|past_lint| past +past -->|past_inject_args| pclos(ProgramClosure) + +style pclos fill:red +style pcall fill:red +linkStyle 5 stroke:red,stroke-width:4px,color:pink +``` + +```python +pitir = backend.DEFAULT_PROG_TRANSFORMS.past_to_itir(pclost) +``` + +```python +gtx.otf.stages.ProgramCall? +``` + + Init signature: + gtx.otf.stages.ProgramCall( +  program: 'itir.FencilDefinition', +  args: 'tuple[Any, ...]', +  kwargs: 'dict[str, Any]', + ) -> None + Docstring: Iterator IR representaion of a program together with arguments to be passed to it. + File: ~/Code/gt4py/src/gt4py/next/otf/stages.py + Type: type + Subclasses: + +## Executing The Result + +```python +gtx.gtfn_cpu.executor(pitir.program, *pitir.args, offset_provider=OFFSET_PROVIDER, **pitir.kwargs) +``` + +```python +pitir.args +``` + + (NumPyArrayField(_domain=Domain(dims=(Dimension(value='I', kind=),), ranges=(UnitRange(0, 10),)), _ndarray=array([1., 1., 1., 1., 1., 1., 1., 1., 1., 1.])), + NumPyArrayField(_domain=Domain(dims=(Dimension(value='I', kind=),), ranges=(UnitRange(0, 10),)), _ndarray=array([2., 2., 2., 2., 2., 2., 2., 2., 2., 2.])), + 10, + 10) + +## Full Field Operator Toolchain + +using the default step order + +```mermaid +graph LR + +fdef(FieldOperatorDefinition) -->|func_to_foast| foast(FoastOperatorDefinition) +foast -->|foast_to_itir| itir_expr(itir.Expr) +foast -->|foast_inject_args| fclos(FoastClosure) +fclos -->|foast_to_past_closure| pclos(PastClosure) +pclos -->|past_process_args| pclos +pclos -->|past_to_itir| pcall(ProgramCall) + +pdef(ProgramDefinition) -->|func_to_past| past(PastProgramDefinition) +past -->|past_lint| past +past -->|past_inject_args| pclos(ProgramClosure) + +style fdef fill:red +style foast fill:red +style fclos fill:red +style pclos fill:red +style pcall fill:red +linkStyle 0,2,3,4,5 stroke:red,stroke-width:4px,color:pink +``` + +### Starting from DSL + +```python +foast_toolchain = backend.DEFAULT_FIELDOP_TRANSFORMS.replace( + foast_inject_args=backend.FopArgsInjector(args=fclos.args, kwargs=fclos.kwargs, from_fieldop=example_fo) +) +pitir2 = foast_toolchain(start) +assert pitir2 == pitir +``` + +#### Pass The result to the compile workflow and execute + +```python +example_compiled = gtx.gtfn_cpu.executor.otf_workflow( + dataclasses.replace(pitir2, kwargs=pitir2.kwargs | {"offset_provider": OFFSET_PROVIDER}) +) +``` + +```python +example_compiled(*pitir2.args, offset_provider=OFFSET_PROVIDER) +``` + +```python +example_compiled(pitir2.args[1], *pitir2.args[1:], offset_provider=OFFSET_PROVIDER) +``` + +```python +pitir2.args[1].asnumpy() +``` + + array([3., 3., 3., 3., 3., 3., 3., 3., 3., 3.]) + +### Starting from FOAST + +Note that it is the exact same call but with a different input stage + +```python +pitir3 = foast_toolchain(foast) +assert pitir3 == pitir +``` + +# Walkthrough starting from Program + +## Starting Out + +```python +@gtx.program +def example_prog(a: gtx.Field[[I], gtx.float64], out: gtx.Field[[I], gtx.float64]) -> None: + example_fo(a, out=out) +``` + +```python +p_start = example_prog.definition_stage +``` + +```python +gtx.ffront.stages.ProgramDefinition? +``` + + Init signature: + gtx.ffront.stages.ProgramDefinition( +  definition: 'types.FunctionType', +  grid_type: 'Optional[common.GridType]' = None, + ) -> None + Docstring: ProgramDefinition(definition: 'types.FunctionType', grid_type: 'Optional[common.GridType]' = None) + File: ~/Code/gt4py/src/gt4py/next/ffront/stages.py + Type: type + Subclasses: + +## DSL -> PAST + +```mermaid +graph LR + +fdef(FieldOperatorDefinition) -->|func_to_foast| foast(FoastOperatorDefinition) +foast -->|foast_to_itir| itir_expr(itir.Expr) +foast -->|foast_inject_args| fclos(FoastClosure) +fclos -->|foast_to_past_closure| pclos(PastClosure) +pclos -->|past_process_args| pclos +pclos -->|past_to_itir| pcall(ProgramCall) + +pdef(ProgramDefinition) -->|func_to_past| past(PastProgramDefinition) +past -->|past_lint| past +past -->|past_inject_args| pclos(ProgramClosure) + +style pdef fill:red +style past fill:red +linkStyle 6 stroke:red,stroke-width:4px,color:pink +``` + +```python +p_past = backend.DEFAULT_PROG_TRANSFORMS.func_to_past(p_start) +``` + +## PAST -> Closure + +```mermaid +graph LR + +fdef(FieldOperatorDefinition) -->|func_to_foast| foast(FoastOperatorDefinition) +foast -->|foast_to_itir| itir_expr(itir.Expr) +foast -->|foast_inject_args| fclos(FoastClosure) +fclos -->|foast_to_past_closure| pclos(PastClosure) +pclos -->|past_process_args| pclos +pclos -->|past_to_itir| pcall(ProgramCall) + +pdef(ProgramDefinition) -->|func_to_past| past(PastProgramDefinition) +past -->|past_lint| past +past -->|past_inject_args| pclos(ProgramClosure) + +style past fill:red +style pclos fill:red +linkStyle 7 stroke:red,stroke-width:4px,color:pink +``` + +```python +pclos = backend.DEFAULT_PROG_TRANSFORMS.replace( + past_inject_args=backend.ProgArgsInjector( + args=fclos.args, + kwargs=fclos.kwargs + ) +)(p_past) +``` + +## Full Program Toolchain + +```mermaid +graph LR + +fdef(FieldOperatorDefinition) -->|func_to_foast| foast(FoastOperatorDefinition) +foast -->|foast_to_itir| itir_expr(itir.Expr) +foast -->|foast_inject_args| fclos(FoastClosure) +fclos -->|foast_to_past_closure| pclos(PastClosure) +pclos -->|past_process_args| pclos +pclos -->|past_to_itir| pcall(ProgramCall) + +pdef(ProgramDefinition) -->|func_to_past| past(PastProgramDefinition) +past -->|past_lint| past +past -->|past_inject_args| pclos(ProgramClosure) + +style pdef fill:red +style past fill:red +style pclos fill:red +style pcall fill:red +linkStyle 4,5,6,7 stroke:red,stroke-width:4px,color:pink +``` + +### Starting from DSL + +```python +toolchain = backend.DEFAULT_PROG_TRANSFORMS.replace( + past_inject_args=backend.ProgArgsInjector( + args=fclos.args, + kwargs=fclos.kwargs + ) +) +``` + +```python +p_itir1 = toolchain(p_start) +``` + +```python +p_itir2 = toolchain(p_past) +``` + +```python +assert p_itir1 == p_itir2 +``` + +```python +!jupyter nbconvert WorkflowPatterns.ipynb --to slides +``` + + [NbConvertApp] WARNING | pattern 'WorkflowPatterns.ipynb' matched no files + This application is used to convert notebook files (*.ipynb) + to various other formats. + + WARNING: THE COMMANDLINE INTERFACE MAY CHANGE IN FUTURE RELEASES. + + Options + ======= + The options below are convenience aliases to configurable class-options, + as listed in the "Equivalent to" description-line of the aliases. + To see all configurable class-options for some , use: + --help-all + + --debug + set log level to logging.DEBUG (maximize logging output) + Equivalent to: [--Application.log_level=10] + --show-config + Show the application's configuration (human-readable format) + Equivalent to: [--Application.show_config=True] + --show-config-json + Show the application's configuration (json format) + Equivalent to: [--Application.show_config_json=True] + --generate-config + generate default config file + Equivalent to: [--JupyterApp.generate_config=True] + -y + Answer yes to any questions instead of prompting. + Equivalent to: [--JupyterApp.answer_yes=True] + --execute + Execute the notebook prior to export. + Equivalent to: [--ExecutePreprocessor.enabled=True] + --allow-errors + Continue notebook execution even if one of the cells throws an error and include the error message in the cell output (the default behaviour is to abort conversion). This flag is only relevant if '--execute' was specified, too. + Equivalent to: [--ExecutePreprocessor.allow_errors=True] + --stdin + read a single notebook file from stdin. Write the resulting notebook with default basename 'notebook.*' + Equivalent to: [--NbConvertApp.from_stdin=True] + --stdout + Write notebook output to stdout instead of files. + Equivalent to: [--NbConvertApp.writer_class=StdoutWriter] + --inplace + Run nbconvert in place, overwriting the existing notebook (only + relevant when converting to notebook format) + Equivalent to: [--NbConvertApp.use_output_suffix=False --NbConvertApp.export_format=notebook --FilesWriter.build_directory=] + --clear-output + Clear output of current file and save in place, + overwriting the existing notebook. + Equivalent to: [--NbConvertApp.use_output_suffix=False --NbConvertApp.export_format=notebook --FilesWriter.build_directory= --ClearOutputPreprocessor.enabled=True] + --coalesce-streams + Coalesce consecutive stdout and stderr outputs into one stream (within each cell). + Equivalent to: [--NbConvertApp.use_output_suffix=False --NbConvertApp.export_format=notebook --FilesWriter.build_directory= --CoalesceStreamsPreprocessor.enabled=True] + --no-prompt + Exclude input and output prompts from converted document. + Equivalent to: [--TemplateExporter.exclude_input_prompt=True --TemplateExporter.exclude_output_prompt=True] + --no-input + Exclude input cells and output prompts from converted document. + This mode is ideal for generating code-free reports. + Equivalent to: [--TemplateExporter.exclude_output_prompt=True --TemplateExporter.exclude_input=True --TemplateExporter.exclude_input_prompt=True] + --allow-chromium-download + Whether to allow downloading chromium if no suitable version is found on the system. + Equivalent to: [--WebPDFExporter.allow_chromium_download=True] + --disable-chromium-sandbox + Disable chromium security sandbox when converting to PDF.. + Equivalent to: [--WebPDFExporter.disable_sandbox=True] + --show-input + Shows code input. This flag is only useful for dejavu users. + Equivalent to: [--TemplateExporter.exclude_input=False] + --embed-images + Embed the images as base64 dataurls in the output. This flag is only useful for the HTML/WebPDF/Slides exports. + Equivalent to: [--HTMLExporter.embed_images=True] + --sanitize-html + Whether the HTML in Markdown cells and cell outputs should be sanitized.. + Equivalent to: [--HTMLExporter.sanitize_html=True] + --log-level= + Set the log level by value or name. + Choices: any of [0, 10, 20, 30, 40, 50, 'DEBUG', 'INFO', 'WARN', 'ERROR', 'CRITICAL'] + Default: 30 + Equivalent to: [--Application.log_level] + --config= + Full path of a config file. + Default: '' + Equivalent to: [--JupyterApp.config_file] + --to= + The export format to be used, either one of the built-in formats + ['asciidoc', 'custom', 'html', 'latex', 'markdown', 'notebook', 'pdf', 'python', 'qtpdf', 'qtpng', 'rst', 'script', 'slides', 'webpdf'] + or a dotted object name that represents the import path for an + ``Exporter`` class + Default: '' + Equivalent to: [--NbConvertApp.export_format] + --template= + Name of the template to use + Default: '' + Equivalent to: [--TemplateExporter.template_name] + --template-file= + Name of the template file to use + Default: None + Equivalent to: [--TemplateExporter.template_file] + --theme= + Template specific theme(e.g. the name of a JupyterLab CSS theme distributed + as prebuilt extension for the lab template) + Default: 'light' + Equivalent to: [--HTMLExporter.theme] + --sanitize_html= + Whether the HTML in Markdown cells and cell outputs should be sanitized.This + should be set to True by nbviewer or similar tools. + Default: False + Equivalent to: [--HTMLExporter.sanitize_html] + --writer= + Writer class used to write the + results of the conversion + Default: 'FilesWriter' + Equivalent to: [--NbConvertApp.writer_class] + --post= + PostProcessor class used to write the + results of the conversion + Default: '' + Equivalent to: [--NbConvertApp.postprocessor_class] + --output= + Overwrite base name use for output files. + Supports pattern replacements '{notebook_name}'. + Default: '{notebook_name}' + Equivalent to: [--NbConvertApp.output_base] + --output-dir= + Directory to write output(s) to. Defaults + to output to the directory of each notebook. To recover + previous default behaviour (outputting to the current + working directory) use . as the flag value. + Default: '' + Equivalent to: [--FilesWriter.build_directory] + --reveal-prefix= + The URL prefix for reveal.js (version 3.x). + This defaults to the reveal CDN, but can be any url pointing to a copy + of reveal.js. + For speaker notes to work, this must be a relative path to a local + copy of reveal.js: e.g., "reveal.js". + If a relative path is given, it must be a subdirectory of the + current directory (from which the server is run). + See the usage documentation + (https://nbconvert.readthedocs.io/en/latest/usage.html#reveal-js-html-slideshow) + for more details. + Default: '' + Equivalent to: [--SlidesExporter.reveal_url_prefix] + --nbformat= + The nbformat version to write. + Use this to downgrade notebooks. + Choices: any of [1, 2, 3, 4] + Default: 4 + Equivalent to: [--NotebookExporter.nbformat_version] + + Examples + -------- + + The simplest way to use nbconvert is + + > jupyter nbconvert mynotebook.ipynb --to html + + Options include ['asciidoc', 'custom', 'html', 'latex', 'markdown', 'notebook', 'pdf', 'python', 'qtpdf', 'qtpng', 'rst', 'script', 'slides', 'webpdf']. + + > jupyter nbconvert --to latex mynotebook.ipynb + + Both HTML and LaTeX support multiple output templates. LaTeX includes + 'base', 'article' and 'report'. HTML includes 'basic', 'lab' and + 'classic'. You can specify the flavor of the format used. + + > jupyter nbconvert --to html --template lab mynotebook.ipynb + + You can also pipe the output to stdout, rather than a file + + > jupyter nbconvert mynotebook.ipynb --stdout + + PDF is generated via latex + + > jupyter nbconvert mynotebook.ipynb --to pdf + + You can get (and serve) a Reveal.js-powered slideshow + + > jupyter nbconvert myslides.ipynb --to slides --post serve + + Multiple notebooks can be given at the command line in a couple of + different ways: + + > jupyter nbconvert notebook*.ipynb + > jupyter nbconvert notebook1.ipynb notebook2.ipynb + + or you can specify the notebooks list in a config file, containing:: + + c.NbConvertApp.notebooks = ["my_notebook.ipynb"] + + > jupyter nbconvert --config mycfg.py + + To see all available configurables, use `--help-all`. + +```python + +``` diff --git a/tox.ini b/tox.ini index d4418c4ebc..8479e4c52c 100644 --- a/tox.ini +++ b/tox.ini @@ -109,10 +109,12 @@ commands = description = Run notebooks commands_pre = jupytext docs/user/next/QuickstartGuide.md --to .ipynb + jupytext docs/user/next/Advanced_ToolchainWalkthrough.md --to .ipynb commands = python -m pytest --nbmake docs/user/next/workshop/slides -v -n {env:NUM_PROCESSES:1} python -m pytest --nbmake docs/user/next/workshop/exercises -k 'solutions' -v -n {env:NUM_PROCESSES:1} python -m pytest --nbmake docs/user/next/QuickstartGuide.ipynb -v -n {env:NUM_PROCESSES:1} + python -m pytest --nbmake docs/user/next/Advanced_ToolchainWalkthrough.ipynb -v -n {env:NUM_PROCESSES:1} python -m pytest --nbmake examples -v -n {env:NUM_PROCESSES:1} # -- Other artefacts -- From 9e50bd836520075a38d2b4934b1a779a90c34a9d Mon Sep 17 00:00:00 2001 From: DropD Date: Fri, 12 Apr 2024 13:20:43 +0200 Subject: [PATCH 15/30] fix toolchain walkthrough notebook --- .../next/Advanced_ToolchainWalkthrough.md | 204 +----------------- 1 file changed, 2 insertions(+), 202 deletions(-) diff --git a/docs/user/next/Advanced_ToolchainWalkthrough.md b/docs/user/next/Advanced_ToolchainWalkthrough.md index bb3c5d5bbc..94a7bfa7e2 100644 --- a/docs/user/next/Advanced_ToolchainWalkthrough.md +++ b/docs/user/next/Advanced_ToolchainWalkthrough.md @@ -305,7 +305,7 @@ gtx.otf.stages.ProgramCall? ## Executing The Result ```python -gtx.gtfn_cpu.executor(pitir.program, *pitir.args, offset_provider=OFFSET_PROVIDER, **pitir.kwargs) +gtx.program_processors.runners.roundtrip.executor(pitir.program, *pitir.args, offset_provider=OFFSET_PROVIDER, **pitir.kwargs) ``` ```python @@ -356,7 +356,7 @@ assert pitir2 == pitir #### Pass The result to the compile workflow and execute ```python -example_compiled = gtx.gtfn_cpu.executor.otf_workflow( +example_compiled = gtx.program_processors.runners.roundtrip.executor.otf_workflow( dataclasses.replace(pitir2, kwargs=pitir2.kwargs | {"offset_provider": OFFSET_PROVIDER}) ) ``` @@ -512,203 +512,3 @@ p_itir2 = toolchain(p_past) ```python assert p_itir1 == p_itir2 ``` - -```python -!jupyter nbconvert WorkflowPatterns.ipynb --to slides -``` - - [NbConvertApp] WARNING | pattern 'WorkflowPatterns.ipynb' matched no files - This application is used to convert notebook files (*.ipynb) - to various other formats. - - WARNING: THE COMMANDLINE INTERFACE MAY CHANGE IN FUTURE RELEASES. - - Options - ======= - The options below are convenience aliases to configurable class-options, - as listed in the "Equivalent to" description-line of the aliases. - To see all configurable class-options for some , use: - --help-all - - --debug - set log level to logging.DEBUG (maximize logging output) - Equivalent to: [--Application.log_level=10] - --show-config - Show the application's configuration (human-readable format) - Equivalent to: [--Application.show_config=True] - --show-config-json - Show the application's configuration (json format) - Equivalent to: [--Application.show_config_json=True] - --generate-config - generate default config file - Equivalent to: [--JupyterApp.generate_config=True] - -y - Answer yes to any questions instead of prompting. - Equivalent to: [--JupyterApp.answer_yes=True] - --execute - Execute the notebook prior to export. - Equivalent to: [--ExecutePreprocessor.enabled=True] - --allow-errors - Continue notebook execution even if one of the cells throws an error and include the error message in the cell output (the default behaviour is to abort conversion). This flag is only relevant if '--execute' was specified, too. - Equivalent to: [--ExecutePreprocessor.allow_errors=True] - --stdin - read a single notebook file from stdin. Write the resulting notebook with default basename 'notebook.*' - Equivalent to: [--NbConvertApp.from_stdin=True] - --stdout - Write notebook output to stdout instead of files. - Equivalent to: [--NbConvertApp.writer_class=StdoutWriter] - --inplace - Run nbconvert in place, overwriting the existing notebook (only - relevant when converting to notebook format) - Equivalent to: [--NbConvertApp.use_output_suffix=False --NbConvertApp.export_format=notebook --FilesWriter.build_directory=] - --clear-output - Clear output of current file and save in place, - overwriting the existing notebook. - Equivalent to: [--NbConvertApp.use_output_suffix=False --NbConvertApp.export_format=notebook --FilesWriter.build_directory= --ClearOutputPreprocessor.enabled=True] - --coalesce-streams - Coalesce consecutive stdout and stderr outputs into one stream (within each cell). - Equivalent to: [--NbConvertApp.use_output_suffix=False --NbConvertApp.export_format=notebook --FilesWriter.build_directory= --CoalesceStreamsPreprocessor.enabled=True] - --no-prompt - Exclude input and output prompts from converted document. - Equivalent to: [--TemplateExporter.exclude_input_prompt=True --TemplateExporter.exclude_output_prompt=True] - --no-input - Exclude input cells and output prompts from converted document. - This mode is ideal for generating code-free reports. - Equivalent to: [--TemplateExporter.exclude_output_prompt=True --TemplateExporter.exclude_input=True --TemplateExporter.exclude_input_prompt=True] - --allow-chromium-download - Whether to allow downloading chromium if no suitable version is found on the system. - Equivalent to: [--WebPDFExporter.allow_chromium_download=True] - --disable-chromium-sandbox - Disable chromium security sandbox when converting to PDF.. - Equivalent to: [--WebPDFExporter.disable_sandbox=True] - --show-input - Shows code input. This flag is only useful for dejavu users. - Equivalent to: [--TemplateExporter.exclude_input=False] - --embed-images - Embed the images as base64 dataurls in the output. This flag is only useful for the HTML/WebPDF/Slides exports. - Equivalent to: [--HTMLExporter.embed_images=True] - --sanitize-html - Whether the HTML in Markdown cells and cell outputs should be sanitized.. - Equivalent to: [--HTMLExporter.sanitize_html=True] - --log-level= - Set the log level by value or name. - Choices: any of [0, 10, 20, 30, 40, 50, 'DEBUG', 'INFO', 'WARN', 'ERROR', 'CRITICAL'] - Default: 30 - Equivalent to: [--Application.log_level] - --config= - Full path of a config file. - Default: '' - Equivalent to: [--JupyterApp.config_file] - --to= - The export format to be used, either one of the built-in formats - ['asciidoc', 'custom', 'html', 'latex', 'markdown', 'notebook', 'pdf', 'python', 'qtpdf', 'qtpng', 'rst', 'script', 'slides', 'webpdf'] - or a dotted object name that represents the import path for an - ``Exporter`` class - Default: '' - Equivalent to: [--NbConvertApp.export_format] - --template= - Name of the template to use - Default: '' - Equivalent to: [--TemplateExporter.template_name] - --template-file= - Name of the template file to use - Default: None - Equivalent to: [--TemplateExporter.template_file] - --theme= - Template specific theme(e.g. the name of a JupyterLab CSS theme distributed - as prebuilt extension for the lab template) - Default: 'light' - Equivalent to: [--HTMLExporter.theme] - --sanitize_html= - Whether the HTML in Markdown cells and cell outputs should be sanitized.This - should be set to True by nbviewer or similar tools. - Default: False - Equivalent to: [--HTMLExporter.sanitize_html] - --writer= - Writer class used to write the - results of the conversion - Default: 'FilesWriter' - Equivalent to: [--NbConvertApp.writer_class] - --post= - PostProcessor class used to write the - results of the conversion - Default: '' - Equivalent to: [--NbConvertApp.postprocessor_class] - --output= - Overwrite base name use for output files. - Supports pattern replacements '{notebook_name}'. - Default: '{notebook_name}' - Equivalent to: [--NbConvertApp.output_base] - --output-dir= - Directory to write output(s) to. Defaults - to output to the directory of each notebook. To recover - previous default behaviour (outputting to the current - working directory) use . as the flag value. - Default: '' - Equivalent to: [--FilesWriter.build_directory] - --reveal-prefix= - The URL prefix for reveal.js (version 3.x). - This defaults to the reveal CDN, but can be any url pointing to a copy - of reveal.js. - For speaker notes to work, this must be a relative path to a local - copy of reveal.js: e.g., "reveal.js". - If a relative path is given, it must be a subdirectory of the - current directory (from which the server is run). - See the usage documentation - (https://nbconvert.readthedocs.io/en/latest/usage.html#reveal-js-html-slideshow) - for more details. - Default: '' - Equivalent to: [--SlidesExporter.reveal_url_prefix] - --nbformat= - The nbformat version to write. - Use this to downgrade notebooks. - Choices: any of [1, 2, 3, 4] - Default: 4 - Equivalent to: [--NotebookExporter.nbformat_version] - - Examples - -------- - - The simplest way to use nbconvert is - - > jupyter nbconvert mynotebook.ipynb --to html - - Options include ['asciidoc', 'custom', 'html', 'latex', 'markdown', 'notebook', 'pdf', 'python', 'qtpdf', 'qtpng', 'rst', 'script', 'slides', 'webpdf']. - - > jupyter nbconvert --to latex mynotebook.ipynb - - Both HTML and LaTeX support multiple output templates. LaTeX includes - 'base', 'article' and 'report'. HTML includes 'basic', 'lab' and - 'classic'. You can specify the flavor of the format used. - - > jupyter nbconvert --to html --template lab mynotebook.ipynb - - You can also pipe the output to stdout, rather than a file - - > jupyter nbconvert mynotebook.ipynb --stdout - - PDF is generated via latex - - > jupyter nbconvert mynotebook.ipynb --to pdf - - You can get (and serve) a Reveal.js-powered slideshow - - > jupyter nbconvert myslides.ipynb --to slides --post serve - - Multiple notebooks can be given at the command line in a couple of - different ways: - - > jupyter nbconvert notebook*.ipynb - > jupyter nbconvert notebook1.ipynb notebook2.ipynb - - or you can specify the notebooks list in a config file, containing:: - - c.NbConvertApp.notebooks = ["my_notebook.ipynb"] - - > jupyter nbconvert --config mycfg.py - - To see all available configurables, use `--help-all`. - -```python - -``` From 678bd6796297fe652be5b5d4761b85df1dd0c476 Mon Sep 17 00:00:00 2001 From: DropD Date: Thu, 18 Apr 2024 10:16:31 +0200 Subject: [PATCH 16/30] put default toolchain steps into definitions --- src/gt4py/next/backend.py | 121 +++++++++++++++++---------------- src/gt4py/next/otf/workflow.py | 6 +- 2 files changed, 68 insertions(+), 59 deletions(-) diff --git a/src/gt4py/next/backend.py b/src/gt4py/next/backend.py index 4842bca515..a50c9cefb1 100644 --- a/src/gt4py/next/backend.py +++ b/src/gt4py/next/backend.py @@ -35,6 +35,21 @@ from gt4py.next.program_processors import processor_interface as ppi +@dataclasses.dataclass(frozen=True) +class FopArgsInjector(workflow.Workflow): + args: tuple[Any, ...] = dataclasses.field(default_factory=tuple) + kwargs: dict[str, Any] = dataclasses.field(default_factory=dict) + from_fieldop: Any = None + + def __call__(self, inp: ffront_stages.FoastOperatorDefinition) -> ffront_stages.FoastClosure: + return ffront_stages.FoastClosure( + foast_op_def=inp, + args=self.args, + kwargs=self.kwargs, + closure_vars={inp.foast_node.id: self.from_fieldop}, + ) + + @dataclasses.dataclass(frozen=True) class FieldopTransformWorkflow(workflow.NamedStepSequence): """Modular workflow for transformations with access to intermediates.""" @@ -42,14 +57,35 @@ class FieldopTransformWorkflow(workflow.NamedStepSequence): func_to_foast: workflow.SkippableStep[ ffront_stages.FieldOperatorDefinition | ffront_stages.FoastOperatorDefinition, ffront_stages.FoastOperatorDefinition, - ] + ] = dataclasses.field( + default_factory=lambda: func_to_foast.OptionalFuncToFoastFactory(cached=True) + ) foast_inject_args: workflow.Workflow[ ffront_stages.FoastOperatorDefinition, ffront_stages.FoastClosure - ] - foast_to_past_closure: workflow.Workflow[ffront_stages.FoastClosure, ffront_stages.PastClosure] - past_transform_args: workflow.Workflow[ffront_stages.PastClosure, ffront_stages.PastClosure] - past_to_itir: workflow.Workflow[ffront_stages.PastClosure, stages.ProgramCall] - foast_to_itir: workflow.Workflow[ffront_stages.FoastOperatorDefinition, itir.Expr] + ] = dataclasses.field(default_factory=FopArgsInjector) + foast_to_past_closure: workflow.Workflow[ + ffront_stages.FoastClosure, ffront_stages.PastClosure + ] = dataclasses.field( + default_factory=lambda: foast_to_past.FoastToPastClosure( + foast_to_past=workflow.CachedStep( + foast_to_past.foast_to_past, hash_function=eve_utils.content_hash + ) + ) + ) + past_transform_args: workflow.Workflow[ffront_stages.PastClosure, ffront_stages.PastClosure] = ( + dataclasses.field(default=past_process_args.past_process_args) + ) + past_to_itir: workflow.Workflow[ffront_stages.PastClosure, stages.ProgramCall] = ( + dataclasses.field(default_factory=past_to_itir.PastToItirFactory) + ) + + foast_to_itir: workflow.Workflow[ffront_stages.FoastOperatorDefinition, itir.Expr] = ( + dataclasses.field( + default_factory=lambda: workflow.CachedStep( + step=foast_to_itir.foast_to_itir, hash_function=eve_utils.content_hash + ) + ) + ) @property def step_order(self) -> list[str]: @@ -62,22 +98,7 @@ def step_order(self) -> list[str]: ] -@dataclasses.dataclass(frozen=True) -class ProgramTransformWorkflow(workflow.NamedStepSequence): - """Modular workflow for transformations with access to intermediates.""" - - func_to_past: workflow.SkippableStep[ - ffront_stages.ProgramDefinition | ffront_stages.PastProgramDefinition, - ffront_stages.PastProgramDefinition, - ] - past_lint: workflow.Workflow[ - ffront_stages.PastProgramDefinition, ffront_stages.PastProgramDefinition - ] - past_inject_args: workflow.Workflow[ - ffront_stages.PastProgramDefinition, ffront_stages.PastClosure - ] - past_transform_args: workflow.Workflow[ffront_stages.PastClosure, ffront_stages.PastClosure] - past_to_itir: workflow.Workflow[ffront_stages.PastClosure, stages.ProgramCall] +DEFAULT_FIELDOP_TRANSFORMS = FieldopTransformWorkflow() @dataclasses.dataclass(frozen=True) @@ -96,44 +117,30 @@ def __call__(self, inp: ffront_stages.PastProgramDefinition) -> ffront_stages.Pa @dataclasses.dataclass(frozen=True) -class FopArgsInjector(workflow.Workflow): - args: tuple[Any, ...] = dataclasses.field(default_factory=tuple) - kwargs: dict[str, Any] = dataclasses.field(default_factory=dict) - from_fieldop: Any = None - - def __call__(self, inp: ffront_stages.FoastOperatorDefinition) -> ffront_stages.FoastClosure: - return ffront_stages.FoastClosure( - foast_op_def=inp, - args=self.args, - kwargs=self.kwargs, - closure_vars={inp.foast_node.id: self.from_fieldop}, - ) - +class ProgramTransformWorkflow(workflow.NamedStepSequence): + """Modular workflow for transformations with access to intermediates.""" -DEFAULT_FIELDOP_TRANSFORMS = FieldopTransformWorkflow( - func_to_foast=func_to_foast.OptionalFuncToFoastFactory(cached=True), - foast_inject_args=FopArgsInjector(), - foast_to_past_closure=foast_to_past.FoastToPastClosure( - foast_to_past=workflow.CachedStep( - foast_to_past.foast_to_past, - hash_function=eve_utils.content_hash, - ) - ), - past_transform_args=past_process_args.past_process_args, - past_to_itir=past_to_itir.PastToItirFactory(), - foast_to_itir=workflow.CachedStep( - step=foast_to_itir.foast_to_itir, hash_function=eve_utils.content_hash - ), -) + func_to_past: workflow.SkippableStep[ + ffront_stages.ProgramDefinition | ffront_stages.PastProgramDefinition, + ffront_stages.PastProgramDefinition, + ] = dataclasses.field( + default_factory=lambda: func_to_past.OptionalFuncToPastFactory(cached=True) + ) + past_lint: workflow.Workflow[ + ffront_stages.PastProgramDefinition, ffront_stages.PastProgramDefinition + ] = dataclasses.field(default_factory=past_linters.LinterFactory) + past_inject_args: workflow.Workflow[ + ffront_stages.PastProgramDefinition, ffront_stages.PastClosure + ] = dataclasses.field(default_factory=ProgArgsInjector) + past_transform_args: workflow.Workflow[ffront_stages.PastClosure, ffront_stages.PastClosure] = ( + dataclasses.field(default=past_process_args.past_process_args) + ) + past_to_itir: workflow.Workflow[ffront_stages.PastClosure, stages.ProgramCall] = ( + dataclasses.field(default_factory=past_to_itir.PastToItirFactory) + ) -DEFAULT_PROG_TRANSFORMS = ProgramTransformWorkflow( - func_to_past=func_to_past.OptionalFuncToPastFactory(cached=True), - past_lint=past_linters.LinterFactory(), - past_inject_args=ProgArgsInjector(), - past_transform_args=past_process_args.past_process_args, - past_to_itir=past_to_itir.PastToItirFactory(), -) +DEFAULT_PROG_TRANSFORMS = ProgramTransformWorkflow() @dataclasses.dataclass(frozen=True) diff --git a/src/gt4py/next/otf/workflow.py b/src/gt4py/next/otf/workflow.py index 8ae741195f..c83748dece 100644 --- a/src/gt4py/next/otf/workflow.py +++ b/src/gt4py/next/otf/workflow.py @@ -84,8 +84,10 @@ def replace(self, **kwargs: Any) -> Self: return dataclasses.replace(self, **kwargs) -class ChainableWorkflowMixin(Workflow[StartT, EndT]): - def chain(self, next_step: Workflow[EndT, NewEndT]) -> ChainableWorkflowMixin[StartT, NewEndT]: +class ChainableWorkflowMixin(Workflow[StartT, EndT_co], Protocol[StartT, EndT_co]): + def chain( + self, next_step: Workflow[EndT_co, NewEndT] + ) -> ChainableWorkflowMixin[StartT, NewEndT]: return make_step(self).chain(next_step) From fad81c0e30c25fab80e3a64fb988b6ed651b6f4f Mon Sep 17 00:00:00 2001 From: DropD Date: Thu, 18 Apr 2024 16:31:25 +0200 Subject: [PATCH 17/30] replace content_hash with dedicated cache key gen for ffront stages --- src/gt4py/next/backend.py | 5 +- src/gt4py/next/ffront/func_to_foast.py | 4 +- src/gt4py/next/ffront/func_to_past.py | 5 +- src/gt4py/next/ffront/past_passes/linters.py | 3 +- src/gt4py/next/ffront/stages.py | 201 ++++++++++-------- .../unit_tests/ffront_tests/test_stages.py | 162 ++++++++++++++ 6 files changed, 287 insertions(+), 93 deletions(-) create mode 100644 tests/next_tests/unit_tests/ffront_tests/test_stages.py diff --git a/src/gt4py/next/backend.py b/src/gt4py/next/backend.py index a50c9cefb1..e541351b89 100644 --- a/src/gt4py/next/backend.py +++ b/src/gt4py/next/backend.py @@ -18,7 +18,6 @@ from typing import Any, Generic from gt4py._core import definitions as core_defs -from gt4py.eve import utils as eve_utils from gt4py.next import allocators as next_allocators from gt4py.next.ffront import ( foast_to_itir, @@ -68,7 +67,7 @@ class FieldopTransformWorkflow(workflow.NamedStepSequence): ] = dataclasses.field( default_factory=lambda: foast_to_past.FoastToPastClosure( foast_to_past=workflow.CachedStep( - foast_to_past.foast_to_past, hash_function=eve_utils.content_hash + foast_to_past.foast_to_past, hash_function=ffront_stages.cache_key ) ) ) @@ -82,7 +81,7 @@ class FieldopTransformWorkflow(workflow.NamedStepSequence): foast_to_itir: workflow.Workflow[ffront_stages.FoastOperatorDefinition, itir.Expr] = ( dataclasses.field( default_factory=lambda: workflow.CachedStep( - step=foast_to_itir.foast_to_itir, hash_function=eve_utils.content_hash + step=foast_to_itir.foast_to_itir, hash_function=ffront_stages.cache_key ) ) ) diff --git a/src/gt4py/next/ffront/func_to_foast.py b/src/gt4py/next/ffront/func_to_foast.py index bcb400df4a..0c1b084a1a 100644 --- a/src/gt4py/next/ffront/func_to_foast.py +++ b/src/gt4py/next/ffront/func_to_foast.py @@ -104,7 +104,9 @@ class Params: ] = func_to_foast cached = factory.Trait( step=factory.LazyAttribute( - lambda o: workflow.CachedStep(step=o.workflow, hash_function=eve.utils.content_hash) + lambda o: workflow.CachedStep( + step=o.workflow, hash_function=ffront_stages.cache_key + ) ) ) diff --git a/src/gt4py/next/ffront/func_to_past.py b/src/gt4py/next/ffront/func_to_past.py index 6864993f4c..606202f923 100644 --- a/src/gt4py/next/ffront/func_to_past.py +++ b/src/gt4py/next/ffront/func_to_past.py @@ -21,7 +21,6 @@ import factory -from gt4py.eve import utils as eve_utils from gt4py.next import errors from gt4py.next.ffront import ( dialect_ast_enums, @@ -73,7 +72,9 @@ class Params: workflow = func_to_past cached = factory.Trait( step=factory.LazyAttribute( - lambda o: workflow.CachedStep(step=o.workflow, hash_function=eve_utils.content_hash) + lambda o: workflow.CachedStep( + step=o.workflow, hash_function=ffront_stages.cache_key + ) ) ) diff --git a/src/gt4py/next/ffront/past_passes/linters.py b/src/gt4py/next/ffront/past_passes/linters.py index a1a734da66..693a436f1d 100644 --- a/src/gt4py/next/ffront/past_passes/linters.py +++ b/src/gt4py/next/ffront/past_passes/linters.py @@ -14,7 +14,6 @@ import factory -from gt4py import eve from gt4py.next.ffront import gtcallable, stages as ffront_stages, transform_utils from gt4py.next.otf import workflow @@ -57,4 +56,4 @@ class Meta: model = workflow.CachedStep step = lint_misnamed_functions.chain(lint_undefined_symbols) - hash_function = eve.utils.content_hash + hash_function = ffront_stages.cache_key diff --git a/src/gt4py/next/ffront/stages.py b/src/gt4py/next/ffront/stages.py index a220ba995f..f6e0796ef4 100644 --- a/src/gt4py/next/ffront/stages.py +++ b/src/gt4py/next/ffront/stages.py @@ -14,136 +14,167 @@ from __future__ import annotations +import collections import dataclasses -import inspect +import functools +import hashlib import types +import typing from typing import Any, Generic, Optional, TypeVar +import xxhash + +from gt4py import eve from gt4py.next import common -from gt4py.next.ffront import field_operator_ast as foast, program_ast as past +from gt4py.next.ffront import field_operator_ast as foast, program_ast as past, source_utils from gt4py.next.type_system import type_specifications as ts -OperatorNodeT = TypeVar("OperatorNodeT", bound=foast.LocatedNode) - - -class ContentHashableMixin: - """ - Allows deriving dataclasses to modify what goes into the content hash per-field. - - Warning: Using this will modify how the class gets pickled. If unpickling is desired, - extra care has to be taken. The hasher must not remove crucial data and any modifications - have to be undone while loading in the __setstate__ method (which needs to be implemented). - - In fact, when unpickling is required, it is probably best to implement both - __setstate__ and __getstate__ by hand for the entire class rather than per-field. - """ - - def __getstate__(self) -> dict[str, Any]: - if not dataclasses.is_dataclass(self): - raise TypeError(f"'{self.__class__}' is not a dataclass.") - state = self.__dict__.copy() - for field in dataclasses.fields(self): - if "content_hasher" in field.metadata: - field.metadata["content_hasher"](state, getattr(self, field.name), field.name) - return state - - -def function_type_hasher(state: dict[str, Any], value: types.FunctionType, name: str) -> None: - state[f"{name}__name"] = value.__name__ - state[f"{name}__source"] = inspect.getsource(value) - del state[name] +if typing.TYPE_CHECKING: + from gt4py.next.ffront import decorator -def simple_dict_hasher(state: dict[str, Any], value: dict[str, Any], name: str) -> None: - for k, v in value.items(): - state[f"{name}__{k}"] = v - - -def closure_vars_hasher(state: dict[str, Any], value: dict[str, Any], name: str) -> None: - hashable_closure_vars = {} - for k, v in value.items(): - # replace the decorator with the earliest canonical dsl representation available - if hasattr(v, "definition_stage"): - hashable_closure_vars[k] = v.definition_stage - elif hasattr(v, "foast_stage"): - hashable_closure_vars[k] = v.foast_stage - elif hasattr(v, "past_stage"): - hashable_closure_vars[k] = v.past_stage - # put the backend into the hash because it may influence the toolchain - # TODO(ricoh): This is not perfect, since backend names are allowed to clash (low priority). - if hasattr(v, "backend"): - hashable_closure_vars[f"{k}_backend"] = v.backend.__name__ if v.backend else "None" - state[name] = hashable_closure_vars +OperatorNodeT = TypeVar("OperatorNodeT", bound=foast.LocatedNode) @dataclasses.dataclass(frozen=True) -class FieldOperatorDefinition(ContentHashableMixin, Generic[OperatorNodeT]): - definition: types.FunctionType = dataclasses.field( - metadata={"content_hasher": function_type_hasher} - ) +class FieldOperatorDefinition(Generic[OperatorNodeT]): + definition: types.FunctionType grid_type: Optional[common.GridType] = None node_class: type[OperatorNodeT] = dataclasses.field(default=foast.FieldOperator) # type: ignore[assignment] # TODO(ricoh): understand why mypy complains - attributes: dict[str, Any] = dataclasses.field( - default_factory=dict, metadata={"content_hasher": simple_dict_hasher} - ) + attributes: dict[str, Any] = dataclasses.field(default_factory=dict) @dataclasses.dataclass(frozen=True) -class FoastOperatorDefinition(ContentHashableMixin, Generic[OperatorNodeT]): +class FoastOperatorDefinition(Generic[OperatorNodeT]): foast_node: OperatorNodeT - closure_vars: dict[str, Any] = dataclasses.field( - metadata={"content_hasher": closure_vars_hasher} - ) + closure_vars: dict[str, Any] grid_type: Optional[common.GridType] = None - attributes: dict[str, Any] = dataclasses.field( - default_factory=dict, metadata={"content_hasher": simple_dict_hasher} - ) + attributes: dict[str, Any] = dataclasses.field(default_factory=dict) @dataclasses.dataclass(frozen=True) -class FoastWithTypes(ContentHashableMixin, Generic[OperatorNodeT]): +class FoastWithTypes(Generic[OperatorNodeT]): foast_op_def: FoastOperatorDefinition[OperatorNodeT] arg_types: tuple[ts.TypeSpec, ...] kwarg_types: dict[str, ts.TypeSpec] - closure_vars: dict[str, Any] = dataclasses.field( - metadata={"content_hasher": closure_vars_hasher} - ) + closure_vars: dict[str, Any] @dataclasses.dataclass(frozen=True) -class FoastClosure(ContentHashableMixin, Generic[OperatorNodeT]): +class FoastClosure(Generic[OperatorNodeT]): foast_op_def: FoastOperatorDefinition[OperatorNodeT] args: tuple[Any, ...] kwargs: dict[str, Any] - closure_vars: dict[str, Any] = dataclasses.field( - metadata={"content_hasher": closure_vars_hasher} - ) + closure_vars: dict[str, Any] @dataclasses.dataclass(frozen=True) -class ProgramDefinition(ContentHashableMixin): - definition: types.FunctionType = dataclasses.field( - metadata={"content_hasher": function_type_hasher} - ) +class ProgramDefinition: + definition: types.FunctionType grid_type: Optional[common.GridType] = None @dataclasses.dataclass(frozen=True) -class PastProgramDefinition(ContentHashableMixin): +class PastProgramDefinition: past_node: past.Program - closure_vars: dict[str, Any] = dataclasses.field( - metadata={"content_hasher": closure_vars_hasher} - ) + closure_vars: dict[str, Any] grid_type: Optional[common.GridType] = None @dataclasses.dataclass(frozen=True) -class PastClosure(ContentHashableMixin): - closure_vars: dict[str, Any] = dataclasses.field( - metadata={"content_hasher": closure_vars_hasher} - ) +class PastClosure: + closure_vars: dict[str, Any] past_node: past.Program grid_type: Optional[common.GridType] args: tuple[Any, ...] kwargs: dict[str, Any] + + +Hasher_T: typing.TypeAlias = eve.extended_typing.HashlibAlgorithm | xxhash.xxh64 | hashlib._Hash + + +def cache_key(obj: Any, algorithm: Optional[str | Hasher_T] = None) -> str: + hasher: Hasher_T + if not algorithm: + hasher = xxhash.xxh64() + elif isinstance(algorithm, str): + hasher = hashlib.new(algorithm) + else: + hasher = algorithm + + update_cache_key(obj, hasher) + return hasher.hexdigest() + + +@functools.singledispatch +def update_cache_key(obj: Any, hasher: Hasher_T) -> None: + if dataclasses.is_dataclass(obj): + update_cache_key(obj.__class__, hasher) + for field in dataclasses.fields(obj): + update_cache_key(getattr(obj, field.name), hasher) + # the following is to avoid circular dependencies + elif hasattr(obj, "backend"): # assume it is a decorator wrapper + update_cache_key_fielop(obj, hasher) + else: + hasher.update(str(obj).encode()) + + +@update_cache_key.register +def update_cache_key_str(obj: str, hasher: Hasher_T) -> None: + hasher.update(str(obj).encode()) + + +@update_cache_key.register +def update_cache_key_builtins( + obj: str | None | bool | int | float, + hasher: Hasher_T, +) -> None: + hasher.update(str(obj).encode()) + + +@update_cache_key.register +def update_cache_key_func(obj: types.FunctionType, hasher: Hasher_T) -> None: + sourcedef = source_utils.SourceDefinition.from_function(obj) + for item in sourcedef: + update_cache_key(item, hasher) + + +@update_cache_key.register +def update_cache_key_dict(obj: dict, hasher: Hasher_T) -> None: + for key, value in obj.items(): + update_cache_key(key, hasher) + update_cache_key(value, hasher) + + +@update_cache_key.register +def update_cache_key_type(obj: type, hasher: Hasher_T) -> None: + hasher.update(obj.__name__.encode()) + + +@update_cache_key.register +def update_cache_key_sequence( + obj: tuple | list | collections.abc.Iterable, hasher: Hasher_T +) -> None: + for item in obj: + update_cache_key(item, hasher) + + +@update_cache_key.register +def update_cache_key_foast(obj: foast.LocatedNode, hasher: Hasher_T) -> None: + update_cache_key(obj.location, hasher) + update_cache_key(str(obj), hasher) + + +# not registered to avoid circular dependencies +def update_cache_key_fielop( + obj: decorator.FieldOperator | decorator.Program, + hasher: Hasher_T, +) -> None: + if hasattr(obj, "definition_stage"): + update_cache_key(obj.definition_stage, hasher) + elif hasattr(obj, "foast_stage"): + update_cache_key(obj.foast_stage, hasher) + elif hasattr(obj, "past_stage"): + update_cache_key(obj.past_stage, hasher) + update_cache_key(obj.backend, hasher) diff --git a/tests/next_tests/unit_tests/ffront_tests/test_stages.py b/tests/next_tests/unit_tests/ffront_tests/test_stages.py new file mode 100644 index 0000000000..871a0a18d8 --- /dev/null +++ b/tests/next_tests/unit_tests/ffront_tests/test_stages.py @@ -0,0 +1,162 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2023, ETH Zurich +# All rights reserved. +# +# This file is part of the GT4Py project and the GridTools framework. +# GT4Py is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later + +import pytest +from gt4py import next as gtx +from gt4py.next.ffront import stages + + +@pytest.fixture +def idim(): + yield gtx.Dimension("I") + + +@pytest.fixture +def jdim(): + yield gtx.Dimension("J") + + +@pytest.fixture +def fieldop(idim): + @gtx.field_operator + def copy(a: gtx.Field[[idim], gtx.int32]) -> gtx.Field[[idim], gtx.int32]: + return a + + yield copy + + +@pytest.fixture +def samecode_fieldop(idim): + @gtx.field_operator + def copy(a: gtx.Field[[idim], gtx.int32]) -> gtx.Field[[idim], gtx.int32]: + return a + + yield copy + + +@pytest.fixture +def different_fieldop(jdim): + @gtx.field_operator + def copy(a: gtx.Field[[jdim], gtx.int32]) -> gtx.Field[[jdim], gtx.int32]: + return a + + yield copy + + +@pytest.fixture +def program(fieldop, idim): + copy = fieldop + + @gtx.program + def copy_program(a: gtx.Field[[idim], gtx.int32], out: gtx.Field[[idim], gtx.int32]): + copy(a, out=out) + + yield copy_program + + +@pytest.fixture +def samecode_program(samecode_fieldop, idim): + copy = samecode_fieldop + + @gtx.program + def copy_program(a: gtx.Field[[idim], gtx.int32], out: gtx.Field[[idim], gtx.int32]): + copy(a, out=out) + + yield copy_program + + +@pytest.fixture +def different_program(different_fieldop, jdim): + copy = different_fieldop + + @gtx.program + def copy_program(a: gtx.Field[[jdim], gtx.int32], out: gtx.Field[[jdim], gtx.int32]): + copy(a, out=out) + + yield copy_program + + +def test_cache_key_field_op_def(fieldop, samecode_fieldop, different_fieldop): + assert stages.cache_key(samecode_fieldop.definition_stage) != stages.cache_key( + fieldop.definition_stage + ) + assert stages.cache_key(different_fieldop.definition_stage) != stages.cache_key( + fieldop.definition_stage + ) + + +def test_cache_key_foast_op_def(fieldop, samecode_fieldop, different_fieldop): + foast = gtx.backend.DEFAULT_FIELDOP_TRANSFORMS.func_to_foast(fieldop.definition_stage) + samecode = gtx.backend.DEFAULT_FIELDOP_TRANSFORMS.func_to_foast( + samecode_fieldop.definition_stage + ) + different = gtx.backend.DEFAULT_FIELDOP_TRANSFORMS.func_to_foast( + different_fieldop.definition_stage + ) + + assert stages.cache_key(samecode) != stages.cache_key(foast) + assert stages.cache_key(different) != stages.cache_key(foast) + + +def test_cache_key_foast_closure(fieldop, samecode_fieldop, different_fieldop, idim, jdim): + foast_closure = gtx.backend.DEFAULT_FIELDOP_TRANSFORMS.func_to_foast.chain( + gtx.backend.FopArgsInjector( + args=(gtx.zeros({idim: 10}, gtx.int32),), + kwargs={"out": gtx.zeros({idim: 10}, gtx.int32)}, + from_fieldop=fieldop, + ), + )(fieldop.definition_stage) + samecode = gtx.backend.DEFAULT_FIELDOP_TRANSFORMS.func_to_foast.chain( + gtx.backend.FopArgsInjector( + args=(gtx.zeros({idim: 10}, gtx.int32),), + kwargs={"out": gtx.zeros({idim: 10}, gtx.int32)}, + from_fieldop=samecode_fieldop, + ) + )(samecode_fieldop.definition_stage) + different = gtx.backend.DEFAULT_FIELDOP_TRANSFORMS.func_to_foast.chain( + gtx.backend.FopArgsInjector( + args=(gtx.zeros({jdim: 10}, gtx.int32),), + kwargs={"out": gtx.zeros({jdim: 10}, gtx.int32)}, + from_fieldop=different_fieldop, + ) + )(different_fieldop.definition_stage) + different_args = gtx.backend.DEFAULT_FIELDOP_TRANSFORMS.func_to_foast.chain( + gtx.backend.FopArgsInjector( + args=(gtx.zeros({idim: 11}, gtx.int32),), + kwargs={"out": gtx.zeros({idim: 11}, gtx.int32)}, + from_fieldop=fieldop, + ) + )(fieldop.definition_stage) + + assert stages.cache_key(samecode) != stages.cache_key(foast_closure) + assert stages.cache_key(different) != stages.cache_key(foast_closure) + assert stages.cache_key(different_args) != stages.cache_key(foast_closure) + + +def test_cache_key_program_def(program, samecode_program, different_program): + assert stages.cache_key(samecode_program.definition_stage) != stages.cache_key( + program.definition_stage + ) + assert stages.cache_key(different_program.definition_stage) != stages.cache_key( + program.definition_stage + ) + + +def test_cache_key_past_def(program, samecode_program, different_program): + past = gtx.backend.DEFAULT_PROG_TRANSFORMS.func_to_past(program.definition_stage) + samecode = gtx.backend.DEFAULT_PROG_TRANSFORMS.func_to_past(samecode_program.definition_stage) + different = gtx.backend.DEFAULT_PROG_TRANSFORMS.func_to_past(different_program.definition_stage) + + assert stages.cache_key(samecode) != stages.cache_key(past) + assert stages.cache_key(different) != stages.cache_key(past) From af1a0165aec1ac83235afc81878b2813dd4fca45 Mon Sep 17 00:00:00 2001 From: DropD Date: Thu, 18 Apr 2024 16:40:23 +0200 Subject: [PATCH 18/30] add typeignores for hash algorithms --- src/gt4py/next/ffront/stages.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/gt4py/next/ffront/stages.py b/src/gt4py/next/ffront/stages.py index f6e0796ef4..f9de2e9bc1 100644 --- a/src/gt4py/next/ffront/stages.py +++ b/src/gt4py/next/ffront/stages.py @@ -91,15 +91,17 @@ class PastClosure: kwargs: dict[str, Any] -Hasher_T: typing.TypeAlias = eve.extended_typing.HashlibAlgorithm | xxhash.xxh64 | hashlib._Hash +# TODO(ricoh): This type seems to not really catch the relevant types +# which leads to the ignores below +Hasher_T: typing.TypeAlias = eve.extended_typing.HashlibAlgorithm def cache_key(obj: Any, algorithm: Optional[str | Hasher_T] = None) -> str: hasher: Hasher_T if not algorithm: - hasher = xxhash.xxh64() + hasher = xxhash.xxh64() # type: ignore[assignment] # see todo above elif isinstance(algorithm, str): - hasher = hashlib.new(algorithm) + hasher = hashlib.new(algorithm) # type: ignore[assignment] # see todo above else: hasher = algorithm From 556e8c530bd0855ecfbbc557c94cd4caf90911ac Mon Sep 17 00:00:00 2001 From: DropD Date: Fri, 19 Apr 2024 09:23:13 +0200 Subject: [PATCH 19/30] downgrade singledispatch type hints for py < 310 --- src/gt4py/next/ffront/decorator.py | 8 ++++++++ src/gt4py/next/ffront/stages.py | 10 +++++----- 2 files changed, 13 insertions(+), 5 deletions(-) diff --git a/src/gt4py/next/ffront/decorator.py b/src/gt4py/next/ffront/decorator.py index 820a45e70e..0c490a152b 100644 --- a/src/gt4py/next/ffront/decorator.py +++ b/src/gt4py/next/ffront/decorator.py @@ -506,6 +506,14 @@ def __call__(self, *args, **kwargs) -> None: @dataclasses.dataclass(frozen=True) class FieldOperatorFromFoast(FieldOperator): + """ + This version of the field operator does not have a DSL definition. + + Instead, it is defined from a FieldOperator AST directly. + Current main use case is for tests that programmatically build FOAST + trees with specific features to be tested. + """ + foast_stage: ffront_stages.FoastOperatorDefinition def __call__(self, *args, **kwargs) -> None: diff --git a/src/gt4py/next/ffront/stages.py b/src/gt4py/next/ffront/stages.py index f9de2e9bc1..30a38dc118 100644 --- a/src/gt4py/next/ffront/stages.py +++ b/src/gt4py/next/ffront/stages.py @@ -127,9 +127,11 @@ def update_cache_key_str(obj: str, hasher: Hasher_T) -> None: hasher.update(str(obj).encode()) -@update_cache_key.register +@update_cache_key.register(int) +@update_cache_key.register(bool) +@update_cache_key.register(float) def update_cache_key_builtins( - obj: str | None | bool | int | float, + obj: None, hasher: Hasher_T, ) -> None: hasher.update(str(obj).encode()) @@ -155,9 +157,7 @@ def update_cache_key_type(obj: type, hasher: Hasher_T) -> None: @update_cache_key.register -def update_cache_key_sequence( - obj: tuple | list | collections.abc.Iterable, hasher: Hasher_T -) -> None: +def update_cache_key_sequence(obj: collections.abc.Iterable, hasher: Hasher_T) -> None: for item in obj: update_cache_key(item, hasher) From fba07f1ef3e38fbcfcd6bcdf5688768c53cd82fb Mon Sep 17 00:00:00 2001 From: DropD Date: Fri, 19 Apr 2024 09:45:52 +0200 Subject: [PATCH 20/30] docstrings for AST based decorator wrappers --- src/gt4py/next/ffront/decorator.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/src/gt4py/next/ffront/decorator.py b/src/gt4py/next/ffront/decorator.py index 0c490a152b..bab7123a80 100644 --- a/src/gt4py/next/ffront/decorator.py +++ b/src/gt4py/next/ffront/decorator.py @@ -206,6 +206,13 @@ def __call__(self, *args, offset_provider: dict[str, Dimension], **kwargs: Any) @dataclasses.dataclass(frozen=True) class ProgramFromPast(Program): + """ + This version of program has no DSL definition associated with it. + + PAST nodes can be built programmatically from field operators or from scratch. + This wrapper provides the appropriate toolchain entry points. + """ + past_stage: ffront_stages.PastProgramDefinition def __call__(self, *args, offset_provider: dict[str, Dimension], **kwargs): @@ -509,9 +516,9 @@ class FieldOperatorFromFoast(FieldOperator): """ This version of the field operator does not have a DSL definition. - Instead, it is defined from a FieldOperator AST directly. - Current main use case is for tests that programmatically build FOAST - trees with specific features to be tested. + FieldOperator AST nodes can be programmatically built, which may be + particularly useful in testing and debugging. + This class provides the appropriate toolchain entry points. """ foast_stage: ffront_stages.FoastOperatorDefinition From 50bdea2bfaf4d4b34cdd03c88e5490eb94904e05 Mon Sep 17 00:00:00 2001 From: DropD Date: Fri, 19 Apr 2024 09:51:22 +0200 Subject: [PATCH 21/30] todos for linting step calls in decorator wrappers --- src/gt4py/next/ffront/decorator.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/gt4py/next/ffront/decorator.py b/src/gt4py/next/ffront/decorator.py index bab7123a80..860e71a2a8 100644 --- a/src/gt4py/next/ffront/decorator.py +++ b/src/gt4py/next/ffront/decorator.py @@ -104,6 +104,7 @@ def past_stage(self): return self.backend.transforms_prog.func_to_past(self.definition_stage) return next_backend.DEFAULT_PROG_TRANSFORMS.func_to_past(self.definition_stage) + # TODO(ricoh): linting should become optional, up to the backend. def __post_init__(self): if self.backend is not None and self.backend.transforms_prog is not None: self.backend.transforms_prog.past_lint(self.past_stage) @@ -224,6 +225,7 @@ def __call__(self, *args, offset_provider: dict[str, Dimension], **kwargs): ppi.ensure_processor_kind(self.backend.executor, ppi.ProgramExecutor) self.backend(self.past_stage, *args, **(kwargs | {"offset_provider": offset_provider})) + # TODO(ricoh): linting should become optional, up to the backend. def __post_init__(self): if self.backend is not None and self.backend.transforms_prog is not None: self.backend.transforms_prog.past_lint(self.past_stage) @@ -399,6 +401,7 @@ def from_function( backend=backend, ) + # TODO(ricoh): linting should become optional, up to the backend. def __post_init__(self): """This ensures that DSL linting occurs at decoration time.""" _ = self.foast_stage From 63121fd8852afa115cbb2974407f65098be2e19f Mon Sep 17 00:00:00 2001 From: DropD Date: Fri, 19 Apr 2024 10:00:30 +0200 Subject: [PATCH 22/30] comment first occurrence of backwards compat backend pattern --- src/gt4py/next/ffront/decorator.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/gt4py/next/ffront/decorator.py b/src/gt4py/next/ffront/decorator.py index 860e71a2a8..11c9f322a0 100644 --- a/src/gt4py/next/ffront/decorator.py +++ b/src/gt4py/next/ffront/decorator.py @@ -100,6 +100,7 @@ def definition(self): @functools.cached_property def past_stage(self): + # backwards compatibility for backends that do not support the full toolchain if self.backend is not None and self.backend.transforms_prog is not None: return self.backend.transforms_prog.func_to_past(self.definition_stage) return next_backend.DEFAULT_PROG_TRANSFORMS.func_to_past(self.definition_stage) From 04e407ba53c671c99dc48b6cbad98820fb5e3125 Mon Sep 17 00:00:00 2001 From: DropD Date: Fri, 19 Apr 2024 14:59:06 +0200 Subject: [PATCH 23/30] stages hasher: avoid recursing into non-stage dataclasses --- src/gt4py/next/ffront/stages.py | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/src/gt4py/next/ffront/stages.py b/src/gt4py/next/ffront/stages.py index 30a38dc118..86279de137 100644 --- a/src/gt4py/next/ffront/stages.py +++ b/src/gt4py/next/ffront/stages.py @@ -111,17 +111,26 @@ def cache_key(obj: Any, algorithm: Optional[str | Hasher_T] = None) -> str: @functools.singledispatch def update_cache_key(obj: Any, hasher: Hasher_T) -> None: - if dataclasses.is_dataclass(obj): - update_cache_key(obj.__class__, hasher) - for field in dataclasses.fields(obj): - update_cache_key(getattr(obj, field.name), hasher) # the following is to avoid circular dependencies - elif hasattr(obj, "backend"): # assume it is a decorator wrapper + if hasattr(obj, "backend"): # assume it is a decorator wrapper update_cache_key_fielop(obj, hasher) else: hasher.update(str(obj).encode()) +@update_cache_key.register(FieldOperatorDefinition) +@update_cache_key.register(FoastOperatorDefinition) +@update_cache_key.register(FoastWithTypes) +@update_cache_key.register(FoastClosure) +@update_cache_key.register(ProgramDefinition) +@update_cache_key.register(PastProgramDefinition) +@update_cache_key.register(PastClosure) +def update_cache_key_stages(obj: Any, hasher: Hasher_T) -> None: + update_cache_key(obj.__class__, hasher) + for field in dataclasses.fields(obj): + update_cache_key(getattr(obj, field.name), hasher) + + @update_cache_key.register def update_cache_key_str(obj: str, hasher: Hasher_T) -> None: hasher.update(str(obj).encode()) From 9bb341aacc1a296f148cac5c40c0df64433a5952 Mon Sep 17 00:00:00 2001 From: DropD Date: Mon, 22 Apr 2024 10:35:25 +0200 Subject: [PATCH 24/30] rename `ffront.stages.cache_key` -> `ffront.stages.fingerprint_stage` --- src/gt4py/next/backend.py | 4 +- src/gt4py/next/ffront/func_to_foast.py | 2 +- src/gt4py/next/ffront/func_to_past.py | 2 +- src/gt4py/next/ffront/past_passes/linters.py | 2 +- src/gt4py/next/ffront/stages.py | 82 +++++++++---------- .../unit_tests/ffront_tests/test_stages.py | 22 ++--- 6 files changed, 57 insertions(+), 57 deletions(-) diff --git a/src/gt4py/next/backend.py b/src/gt4py/next/backend.py index e541351b89..3d3c7a27e1 100644 --- a/src/gt4py/next/backend.py +++ b/src/gt4py/next/backend.py @@ -67,7 +67,7 @@ class FieldopTransformWorkflow(workflow.NamedStepSequence): ] = dataclasses.field( default_factory=lambda: foast_to_past.FoastToPastClosure( foast_to_past=workflow.CachedStep( - foast_to_past.foast_to_past, hash_function=ffront_stages.cache_key + foast_to_past.foast_to_past, hash_function=ffront_stages.fingerprint_stage ) ) ) @@ -81,7 +81,7 @@ class FieldopTransformWorkflow(workflow.NamedStepSequence): foast_to_itir: workflow.Workflow[ffront_stages.FoastOperatorDefinition, itir.Expr] = ( dataclasses.field( default_factory=lambda: workflow.CachedStep( - step=foast_to_itir.foast_to_itir, hash_function=ffront_stages.cache_key + step=foast_to_itir.foast_to_itir, hash_function=ffront_stages.fingerprint_stage ) ) ) diff --git a/src/gt4py/next/ffront/func_to_foast.py b/src/gt4py/next/ffront/func_to_foast.py index 0c1b084a1a..6127cbdef5 100644 --- a/src/gt4py/next/ffront/func_to_foast.py +++ b/src/gt4py/next/ffront/func_to_foast.py @@ -105,7 +105,7 @@ class Params: cached = factory.Trait( step=factory.LazyAttribute( lambda o: workflow.CachedStep( - step=o.workflow, hash_function=ffront_stages.cache_key + step=o.workflow, hash_function=ffront_stages.fingerprint_stage ) ) ) diff --git a/src/gt4py/next/ffront/func_to_past.py b/src/gt4py/next/ffront/func_to_past.py index 606202f923..372386aaf4 100644 --- a/src/gt4py/next/ffront/func_to_past.py +++ b/src/gt4py/next/ffront/func_to_past.py @@ -73,7 +73,7 @@ class Params: cached = factory.Trait( step=factory.LazyAttribute( lambda o: workflow.CachedStep( - step=o.workflow, hash_function=ffront_stages.cache_key + step=o.workflow, hash_function=ffront_stages.fingerprint_stage ) ) ) diff --git a/src/gt4py/next/ffront/past_passes/linters.py b/src/gt4py/next/ffront/past_passes/linters.py index 693a436f1d..6e77262fd1 100644 --- a/src/gt4py/next/ffront/past_passes/linters.py +++ b/src/gt4py/next/ffront/past_passes/linters.py @@ -56,4 +56,4 @@ class Meta: model = workflow.CachedStep step = lint_misnamed_functions.chain(lint_undefined_symbols) - hash_function = ffront_stages.cache_key + hash_function = ffront_stages.fingerprint_stage diff --git a/src/gt4py/next/ffront/stages.py b/src/gt4py/next/ffront/stages.py index 86279de137..430b068932 100644 --- a/src/gt4py/next/ffront/stages.py +++ b/src/gt4py/next/ffront/stages.py @@ -96,7 +96,7 @@ class PastClosure: Hasher_T: typing.TypeAlias = eve.extended_typing.HashlibAlgorithm -def cache_key(obj: Any, algorithm: Optional[str | Hasher_T] = None) -> str: +def fingerprint_stage(obj: Any, algorithm: Optional[str | Hasher_T] = None) -> str: hasher: Hasher_T if not algorithm: hasher = xxhash.xxh64() # type: ignore[assignment] # see todo above @@ -105,87 +105,87 @@ def cache_key(obj: Any, algorithm: Optional[str | Hasher_T] = None) -> str: else: hasher = algorithm - update_cache_key(obj, hasher) + add_content_to_fingerprint(obj, hasher) return hasher.hexdigest() @functools.singledispatch -def update_cache_key(obj: Any, hasher: Hasher_T) -> None: +def add_content_to_fingerprint(obj: Any, hasher: Hasher_T) -> None: # the following is to avoid circular dependencies if hasattr(obj, "backend"): # assume it is a decorator wrapper - update_cache_key_fielop(obj, hasher) + add_content_to_fingerprint_fielop(obj, hasher) else: hasher.update(str(obj).encode()) -@update_cache_key.register(FieldOperatorDefinition) -@update_cache_key.register(FoastOperatorDefinition) -@update_cache_key.register(FoastWithTypes) -@update_cache_key.register(FoastClosure) -@update_cache_key.register(ProgramDefinition) -@update_cache_key.register(PastProgramDefinition) -@update_cache_key.register(PastClosure) -def update_cache_key_stages(obj: Any, hasher: Hasher_T) -> None: - update_cache_key(obj.__class__, hasher) +@add_content_to_fingerprint.register(FieldOperatorDefinition) +@add_content_to_fingerprint.register(FoastOperatorDefinition) +@add_content_to_fingerprint.register(FoastWithTypes) +@add_content_to_fingerprint.register(FoastClosure) +@add_content_to_fingerprint.register(ProgramDefinition) +@add_content_to_fingerprint.register(PastProgramDefinition) +@add_content_to_fingerprint.register(PastClosure) +def add_content_to_fingerprint_stages(obj: Any, hasher: Hasher_T) -> None: + add_content_to_fingerprint(obj.__class__, hasher) for field in dataclasses.fields(obj): - update_cache_key(getattr(obj, field.name), hasher) + add_content_to_fingerprint(getattr(obj, field.name), hasher) -@update_cache_key.register -def update_cache_key_str(obj: str, hasher: Hasher_T) -> None: +@add_content_to_fingerprint.register +def add_content_to_fingerprint_str(obj: str, hasher: Hasher_T) -> None: hasher.update(str(obj).encode()) -@update_cache_key.register(int) -@update_cache_key.register(bool) -@update_cache_key.register(float) -def update_cache_key_builtins( +@add_content_to_fingerprint.register(int) +@add_content_to_fingerprint.register(bool) +@add_content_to_fingerprint.register(float) +def add_content_to_fingerprint_builtins( obj: None, hasher: Hasher_T, ) -> None: hasher.update(str(obj).encode()) -@update_cache_key.register -def update_cache_key_func(obj: types.FunctionType, hasher: Hasher_T) -> None: +@add_content_to_fingerprint.register +def add_content_to_fingerprint_func(obj: types.FunctionType, hasher: Hasher_T) -> None: sourcedef = source_utils.SourceDefinition.from_function(obj) for item in sourcedef: - update_cache_key(item, hasher) + add_content_to_fingerprint(item, hasher) -@update_cache_key.register -def update_cache_key_dict(obj: dict, hasher: Hasher_T) -> None: +@add_content_to_fingerprint.register +def add_content_to_fingerprint_dict(obj: dict, hasher: Hasher_T) -> None: for key, value in obj.items(): - update_cache_key(key, hasher) - update_cache_key(value, hasher) + add_content_to_fingerprint(key, hasher) + add_content_to_fingerprint(value, hasher) -@update_cache_key.register -def update_cache_key_type(obj: type, hasher: Hasher_T) -> None: +@add_content_to_fingerprint.register +def add_content_to_fingerprint_type(obj: type, hasher: Hasher_T) -> None: hasher.update(obj.__name__.encode()) -@update_cache_key.register -def update_cache_key_sequence(obj: collections.abc.Iterable, hasher: Hasher_T) -> None: +@add_content_to_fingerprint.register +def add_content_to_fingerprint_sequence(obj: collections.abc.Iterable, hasher: Hasher_T) -> None: for item in obj: - update_cache_key(item, hasher) + add_content_to_fingerprint(item, hasher) -@update_cache_key.register -def update_cache_key_foast(obj: foast.LocatedNode, hasher: Hasher_T) -> None: - update_cache_key(obj.location, hasher) - update_cache_key(str(obj), hasher) +@add_content_to_fingerprint.register +def add_content_to_fingerprint_foast(obj: foast.LocatedNode, hasher: Hasher_T) -> None: + add_content_to_fingerprint(obj.location, hasher) + add_content_to_fingerprint(str(obj), hasher) # not registered to avoid circular dependencies -def update_cache_key_fielop( +def add_content_to_fingerprint_fielop( obj: decorator.FieldOperator | decorator.Program, hasher: Hasher_T, ) -> None: if hasattr(obj, "definition_stage"): - update_cache_key(obj.definition_stage, hasher) + add_content_to_fingerprint(obj.definition_stage, hasher) elif hasattr(obj, "foast_stage"): - update_cache_key(obj.foast_stage, hasher) + add_content_to_fingerprint(obj.foast_stage, hasher) elif hasattr(obj, "past_stage"): - update_cache_key(obj.past_stage, hasher) - update_cache_key(obj.backend, hasher) + add_content_to_fingerprint(obj.past_stage, hasher) + add_content_to_fingerprint(obj.backend, hasher) diff --git a/tests/next_tests/unit_tests/ffront_tests/test_stages.py b/tests/next_tests/unit_tests/ffront_tests/test_stages.py index 871a0a18d8..67ac96d653 100644 --- a/tests/next_tests/unit_tests/ffront_tests/test_stages.py +++ b/tests/next_tests/unit_tests/ffront_tests/test_stages.py @@ -88,10 +88,10 @@ def copy_program(a: gtx.Field[[jdim], gtx.int32], out: gtx.Field[[jdim], gtx.int def test_cache_key_field_op_def(fieldop, samecode_fieldop, different_fieldop): - assert stages.cache_key(samecode_fieldop.definition_stage) != stages.cache_key( + assert stages.fingerprint_stage(samecode_fieldop.definition_stage) != stages.fingerprint_stage( fieldop.definition_stage ) - assert stages.cache_key(different_fieldop.definition_stage) != stages.cache_key( + assert stages.fingerprint_stage(different_fieldop.definition_stage) != stages.fingerprint_stage( fieldop.definition_stage ) @@ -105,8 +105,8 @@ def test_cache_key_foast_op_def(fieldop, samecode_fieldop, different_fieldop): different_fieldop.definition_stage ) - assert stages.cache_key(samecode) != stages.cache_key(foast) - assert stages.cache_key(different) != stages.cache_key(foast) + assert stages.fingerprint_stage(samecode) != stages.fingerprint_stage(foast) + assert stages.fingerprint_stage(different) != stages.fingerprint_stage(foast) def test_cache_key_foast_closure(fieldop, samecode_fieldop, different_fieldop, idim, jdim): @@ -139,16 +139,16 @@ def test_cache_key_foast_closure(fieldop, samecode_fieldop, different_fieldop, i ) )(fieldop.definition_stage) - assert stages.cache_key(samecode) != stages.cache_key(foast_closure) - assert stages.cache_key(different) != stages.cache_key(foast_closure) - assert stages.cache_key(different_args) != stages.cache_key(foast_closure) + assert stages.fingerprint_stage(samecode) != stages.fingerprint_stage(foast_closure) + assert stages.fingerprint_stage(different) != stages.fingerprint_stage(foast_closure) + assert stages.fingerprint_stage(different_args) != stages.fingerprint_stage(foast_closure) def test_cache_key_program_def(program, samecode_program, different_program): - assert stages.cache_key(samecode_program.definition_stage) != stages.cache_key( + assert stages.fingerprint_stage(samecode_program.definition_stage) != stages.fingerprint_stage( program.definition_stage ) - assert stages.cache_key(different_program.definition_stage) != stages.cache_key( + assert stages.fingerprint_stage(different_program.definition_stage) != stages.fingerprint_stage( program.definition_stage ) @@ -158,5 +158,5 @@ def test_cache_key_past_def(program, samecode_program, different_program): samecode = gtx.backend.DEFAULT_PROG_TRANSFORMS.func_to_past(samecode_program.definition_stage) different = gtx.backend.DEFAULT_PROG_TRANSFORMS.func_to_past(different_program.definition_stage) - assert stages.cache_key(samecode) != stages.cache_key(past) - assert stages.cache_key(different) != stages.cache_key(past) + assert stages.fingerprint_stage(samecode) != stages.fingerprint_stage(past) + assert stages.fingerprint_stage(different) != stages.fingerprint_stage(past) From 888017d9a867ac1e8cc9ee1ee7c7e1965019bec7 Mon Sep 17 00:00:00 2001 From: DropD Date: Mon, 22 Apr 2024 11:10:34 +0200 Subject: [PATCH 25/30] improve ffront.stage fingerprinting --- src/gt4py/next/ffront/decorator.py | 28 +++++++ src/gt4py/next/ffront/foast_pretty_printer.py | 11 +++ src/gt4py/next/ffront/stages.py | 74 +++++++++---------- 3 files changed, 76 insertions(+), 37 deletions(-) diff --git a/src/gt4py/next/ffront/decorator.py b/src/gt4py/next/ffront/decorator.py index a0e3494c90..d20636ca2e 100644 --- a/src/gt4py/next/ffront/decorator.py +++ b/src/gt4py/next/ffront/decorator.py @@ -642,3 +642,31 @@ def scan_operator_inner(definition: types.FunctionType) -> FieldOperator: ) return scan_operator_inner if definition is None else scan_operator_inner(definition) + + +@ffront_stages.add_content_to_fingerprint.register +def add_fieldop_to_fingerprint(obj: FieldOperator, hasher: ffront_stages.HashlibAlgorithm) -> None: + ffront_stages.add_content_to_fingerprint(obj.definition_stage, hasher) + ffront_stages.add_content_to_fingerprint(obj.backend, hasher) + + +@ffront_stages.add_content_to_fingerprint.register +def add_foast_fieldop_to_fingerprint( + obj: FieldOperatorFromFoast, hasher: ffront_stages.HashlibAlgorithm +) -> None: + ffront_stages.add_content_to_fingerprint(obj.foast_stage, hasher) + ffront_stages.add_content_to_fingerprint(obj.backend, hasher) + + +@ffront_stages.add_content_to_fingerprint.register +def add_program_to_fingerprint(obj: Program, hasher: ffront_stages.HashlibAlgorithm) -> None: + ffront_stages.add_content_to_fingerprint(obj.definition_stage, hasher) + ffront_stages.add_content_to_fingerprint(obj.backend, hasher) + + +@ffront_stages.add_content_to_fingerprint.register +def add_past_program_to_fingerprint( + obj: ProgramFromPast, hasher: ffront_stages.HashlibAlgorithm +) -> None: + ffront_stages.add_content_to_fingerprint(obj.past_stage, hasher) + ffront_stages.add_content_to_fingerprint(obj.backend, hasher) diff --git a/src/gt4py/next/ffront/foast_pretty_printer.py b/src/gt4py/next/ffront/foast_pretty_printer.py index 4fa80c4892..6194647e1f 100644 --- a/src/gt4py/next/ffront/foast_pretty_printer.py +++ b/src/gt4py/next/ffront/foast_pretty_printer.py @@ -126,6 +126,17 @@ def apply(cls, node: foast.LocatedNode, **kwargs: Any) -> str: # type: ignore[o UnaryOp = as_fmt("{op}{operand}") + IfStmt = as_fmt( + textwrap.dedent( + """ + if {condition}: + {true_branch} + else: + {false_branch} + """ + ).strip() + ) + def visit_UnaryOp(self, node: foast.UnaryOp, **kwargs: Any) -> str: if node.op is dialect_ast_enums.UnaryOperator.NOT: op = "not " diff --git a/src/gt4py/next/ffront/stages.py b/src/gt4py/next/ffront/stages.py index 430b068932..ae1096ca60 100644 --- a/src/gt4py/next/ffront/stages.py +++ b/src/gt4py/next/ffront/stages.py @@ -31,7 +31,7 @@ if typing.TYPE_CHECKING: - from gt4py.next.ffront import decorator + pass OperatorNodeT = TypeVar("OperatorNodeT", bound=foast.LocatedNode) @@ -91,17 +91,35 @@ class PastClosure: kwargs: dict[str, Any] -# TODO(ricoh): This type seems to not really catch the relevant types -# which leads to the ignores below -Hasher_T: typing.TypeAlias = eve.extended_typing.HashlibAlgorithm +class HashlibAlgorithm(typing.Protocol): + """Used in the hashlib module of the standard library.""" + @property + def block_size(self) -> int: ... -def fingerprint_stage(obj: Any, algorithm: Optional[str | Hasher_T] = None) -> str: - hasher: Hasher_T + @property + def digest_size(self) -> int: ... + + @property + def name(self) -> str: ... + + def __init__(self, data: eve.extended_typing.ReadableBuffer = ...) -> None: ... + + def copy(self) -> eve.extended_typing.Self: ... + + def update(self, data: eve.extended_typing.Buffer, /) -> None: ... + + def digest(self) -> bytes: ... + + def hexdigest(self) -> str: ... + + +def fingerprint_stage(obj: Any, algorithm: Optional[str | HashlibAlgorithm] = None) -> str: + hasher: HashlibAlgorithm if not algorithm: - hasher = xxhash.xxh64() # type: ignore[assignment] # see todo above + hasher = xxhash.xxh64() elif isinstance(algorithm, str): - hasher = hashlib.new(algorithm) # type: ignore[assignment] # see todo above + hasher = hashlib.new(algorithm) else: hasher = algorithm @@ -110,12 +128,8 @@ def fingerprint_stage(obj: Any, algorithm: Optional[str | Hasher_T] = None) -> s @functools.singledispatch -def add_content_to_fingerprint(obj: Any, hasher: Hasher_T) -> None: - # the following is to avoid circular dependencies - if hasattr(obj, "backend"): # assume it is a decorator wrapper - add_content_to_fingerprint_fielop(obj, hasher) - else: - hasher.update(str(obj).encode()) +def add_content_to_fingerprint(obj: Any, hasher: HashlibAlgorithm) -> None: + hasher.update(str(obj).encode()) @add_content_to_fingerprint.register(FieldOperatorDefinition) @@ -125,67 +139,53 @@ def add_content_to_fingerprint(obj: Any, hasher: Hasher_T) -> None: @add_content_to_fingerprint.register(ProgramDefinition) @add_content_to_fingerprint.register(PastProgramDefinition) @add_content_to_fingerprint.register(PastClosure) -def add_content_to_fingerprint_stages(obj: Any, hasher: Hasher_T) -> None: +def add_content_to_fingerprint_stages(obj: Any, hasher: HashlibAlgorithm) -> None: add_content_to_fingerprint(obj.__class__, hasher) for field in dataclasses.fields(obj): add_content_to_fingerprint(getattr(obj, field.name), hasher) @add_content_to_fingerprint.register -def add_content_to_fingerprint_str(obj: str, hasher: Hasher_T) -> None: +def add_str_to_fingerprint(obj: str, hasher: HashlibAlgorithm) -> None: hasher.update(str(obj).encode()) @add_content_to_fingerprint.register(int) @add_content_to_fingerprint.register(bool) @add_content_to_fingerprint.register(float) -def add_content_to_fingerprint_builtins( +def add_builtin_to_fingerprint( obj: None, - hasher: Hasher_T, + hasher: HashlibAlgorithm, ) -> None: hasher.update(str(obj).encode()) @add_content_to_fingerprint.register -def add_content_to_fingerprint_func(obj: types.FunctionType, hasher: Hasher_T) -> None: +def add_func_to_fingerprint(obj: types.FunctionType, hasher: HashlibAlgorithm) -> None: sourcedef = source_utils.SourceDefinition.from_function(obj) for item in sourcedef: add_content_to_fingerprint(item, hasher) @add_content_to_fingerprint.register -def add_content_to_fingerprint_dict(obj: dict, hasher: Hasher_T) -> None: +def add_dict_to_fingerprint(obj: dict, hasher: HashlibAlgorithm) -> None: for key, value in obj.items(): add_content_to_fingerprint(key, hasher) add_content_to_fingerprint(value, hasher) @add_content_to_fingerprint.register -def add_content_to_fingerprint_type(obj: type, hasher: Hasher_T) -> None: +def add_type_to_fingerprint(obj: type, hasher: HashlibAlgorithm) -> None: hasher.update(obj.__name__.encode()) @add_content_to_fingerprint.register -def add_content_to_fingerprint_sequence(obj: collections.abc.Iterable, hasher: Hasher_T) -> None: +def add_sequence_to_fingerprint(obj: collections.abc.Iterable, hasher: HashlibAlgorithm) -> None: for item in obj: add_content_to_fingerprint(item, hasher) @add_content_to_fingerprint.register -def add_content_to_fingerprint_foast(obj: foast.LocatedNode, hasher: Hasher_T) -> None: +def add_foast_located_node_to_fingerprint(obj: foast.LocatedNode, hasher: HashlibAlgorithm) -> None: add_content_to_fingerprint(obj.location, hasher) add_content_to_fingerprint(str(obj), hasher) - - -# not registered to avoid circular dependencies -def add_content_to_fingerprint_fielop( - obj: decorator.FieldOperator | decorator.Program, - hasher: Hasher_T, -) -> None: - if hasattr(obj, "definition_stage"): - add_content_to_fingerprint(obj.definition_stage, hasher) - elif hasattr(obj, "foast_stage"): - add_content_to_fingerprint(obj.foast_stage, hasher) - elif hasattr(obj, "past_stage"): - add_content_to_fingerprint(obj.past_stage, hasher) - add_content_to_fingerprint(obj.backend, hasher) From 345ce8efd34465a3b9abd852fe77acb509a88787 Mon Sep 17 00:00:00 2001 From: DropD Date: Mon, 22 Apr 2024 13:31:54 +0200 Subject: [PATCH 26/30] update HashlibAlgorithm in eve --- src/gt4py/eve/extended_typing.py | 15 ++++++--- src/gt4py/eve/utils.py | 8 ++--- src/gt4py/next/ffront/decorator.py | 14 ++++---- src/gt4py/next/ffront/stages.py | 51 ++++++++++-------------------- 4 files changed, 36 insertions(+), 52 deletions(-) diff --git a/src/gt4py/eve/extended_typing.py b/src/gt4py/eve/extended_typing.py index 42473bea63..e406a5f097 100644 --- a/src/gt4py/eve/extended_typing.py +++ b/src/gt4py/eve/extended_typing.py @@ -207,15 +207,20 @@ def __delete__(self, _instance: _C) -> None: ... class HashlibAlgorithm(Protocol): """Used in the hashlib module of the standard library.""" - digest_size: int - block_size: int - name: str + @property + def block_size(self) -> int: ... + + @property + def digest_size(self) -> int: ... + + @property + def name(self) -> str: ... def __init__(self, data: ReadableBuffer = ...) -> None: ... - def copy(self) -> HashlibAlgorithm: ... + def copy(self) -> Self: ... - def update(self, data: ReadableBuffer) -> None: ... + def update(self, data: Buffer, /) -> None: ... def digest(self) -> bytes: ... diff --git a/src/gt4py/eve/utils.py b/src/gt4py/eve/utils.py index 01c066ca91..d1f9d0f7d5 100644 --- a/src/gt4py/eve/utils.py +++ b/src/gt4py/eve/utils.py @@ -401,12 +401,12 @@ def content_hash(*args: Any, hash_algorithm: str | xtyping.HashlibAlgorithm | No """ if hash_algorithm is None: - hash_algorithm = xxhash.xxh64() # type: ignore[assignment] + hash_algorithm = xxhash.xxh64() elif isinstance(hash_algorithm, str): - hash_algorithm = hashlib.new(hash_algorithm) # type: ignore[assignment] + hash_algorithm = hashlib.new(hash_algorithm) - hash_algorithm.update(pickle.dumps(args)) # type: ignore[union-attr] - result = hash_algorithm.hexdigest() # type: ignore[union-attr] + hash_algorithm.update(pickle.dumps(args)) + result = hash_algorithm.hexdigest() assert isinstance(result, str) return result diff --git a/src/gt4py/next/ffront/decorator.py b/src/gt4py/next/ffront/decorator.py index d20636ca2e..be1b3c1fa8 100644 --- a/src/gt4py/next/ffront/decorator.py +++ b/src/gt4py/next/ffront/decorator.py @@ -24,11 +24,11 @@ import typing import warnings from collections.abc import Callable -from typing import Generic, TypeVar +from typing import Any, Generic, Optional, TypeVar from gt4py import eve from gt4py._core import definitions as core_defs -from gt4py.eve.extended_typing import Any, Optional +from gt4py.eve import extended_typing as xtyping from gt4py.next import ( allocators as next_allocators, backend as next_backend, @@ -645,28 +645,26 @@ def scan_operator_inner(definition: types.FunctionType) -> FieldOperator: @ffront_stages.add_content_to_fingerprint.register -def add_fieldop_to_fingerprint(obj: FieldOperator, hasher: ffront_stages.HashlibAlgorithm) -> None: +def add_fieldop_to_fingerprint(obj: FieldOperator, hasher: xtyping.HashlibAlgorithm) -> None: ffront_stages.add_content_to_fingerprint(obj.definition_stage, hasher) ffront_stages.add_content_to_fingerprint(obj.backend, hasher) @ffront_stages.add_content_to_fingerprint.register def add_foast_fieldop_to_fingerprint( - obj: FieldOperatorFromFoast, hasher: ffront_stages.HashlibAlgorithm + obj: FieldOperatorFromFoast, hasher: xtyping.HashlibAlgorithm ) -> None: ffront_stages.add_content_to_fingerprint(obj.foast_stage, hasher) ffront_stages.add_content_to_fingerprint(obj.backend, hasher) @ffront_stages.add_content_to_fingerprint.register -def add_program_to_fingerprint(obj: Program, hasher: ffront_stages.HashlibAlgorithm) -> None: +def add_program_to_fingerprint(obj: Program, hasher: xtyping.HashlibAlgorithm) -> None: ffront_stages.add_content_to_fingerprint(obj.definition_stage, hasher) ffront_stages.add_content_to_fingerprint(obj.backend, hasher) @ffront_stages.add_content_to_fingerprint.register -def add_past_program_to_fingerprint( - obj: ProgramFromPast, hasher: ffront_stages.HashlibAlgorithm -) -> None: +def add_past_program_to_fingerprint(obj: ProgramFromPast, hasher: xtyping.HashlibAlgorithm) -> None: ffront_stages.add_content_to_fingerprint(obj.past_stage, hasher) ffront_stages.add_content_to_fingerprint(obj.backend, hasher) diff --git a/src/gt4py/next/ffront/stages.py b/src/gt4py/next/ffront/stages.py index ae1096ca60..559e0cabd1 100644 --- a/src/gt4py/next/ffront/stages.py +++ b/src/gt4py/next/ffront/stages.py @@ -24,7 +24,7 @@ import xxhash -from gt4py import eve +from gt4py.eve import extended_typing as xtyping from gt4py.next import common from gt4py.next.ffront import field_operator_ast as foast, program_ast as past, source_utils from gt4py.next.type_system import type_specifications as ts @@ -91,31 +91,8 @@ class PastClosure: kwargs: dict[str, Any] -class HashlibAlgorithm(typing.Protocol): - """Used in the hashlib module of the standard library.""" - - @property - def block_size(self) -> int: ... - - @property - def digest_size(self) -> int: ... - - @property - def name(self) -> str: ... - - def __init__(self, data: eve.extended_typing.ReadableBuffer = ...) -> None: ... - - def copy(self) -> eve.extended_typing.Self: ... - - def update(self, data: eve.extended_typing.Buffer, /) -> None: ... - - def digest(self) -> bytes: ... - - def hexdigest(self) -> str: ... - - -def fingerprint_stage(obj: Any, algorithm: Optional[str | HashlibAlgorithm] = None) -> str: - hasher: HashlibAlgorithm +def fingerprint_stage(obj: Any, algorithm: Optional[str | xtyping.HashlibAlgorithm] = None) -> str: + hasher: xtyping.HashlibAlgorithm if not algorithm: hasher = xxhash.xxh64() elif isinstance(algorithm, str): @@ -128,7 +105,7 @@ def fingerprint_stage(obj: Any, algorithm: Optional[str | HashlibAlgorithm] = No @functools.singledispatch -def add_content_to_fingerprint(obj: Any, hasher: HashlibAlgorithm) -> None: +def add_content_to_fingerprint(obj: Any, hasher: xtyping.HashlibAlgorithm) -> None: hasher.update(str(obj).encode()) @@ -139,14 +116,14 @@ def add_content_to_fingerprint(obj: Any, hasher: HashlibAlgorithm) -> None: @add_content_to_fingerprint.register(ProgramDefinition) @add_content_to_fingerprint.register(PastProgramDefinition) @add_content_to_fingerprint.register(PastClosure) -def add_content_to_fingerprint_stages(obj: Any, hasher: HashlibAlgorithm) -> None: +def add_content_to_fingerprint_stages(obj: Any, hasher: xtyping.HashlibAlgorithm) -> None: add_content_to_fingerprint(obj.__class__, hasher) for field in dataclasses.fields(obj): add_content_to_fingerprint(getattr(obj, field.name), hasher) @add_content_to_fingerprint.register -def add_str_to_fingerprint(obj: str, hasher: HashlibAlgorithm) -> None: +def add_str_to_fingerprint(obj: str, hasher: xtyping.HashlibAlgorithm) -> None: hasher.update(str(obj).encode()) @@ -155,37 +132,41 @@ def add_str_to_fingerprint(obj: str, hasher: HashlibAlgorithm) -> None: @add_content_to_fingerprint.register(float) def add_builtin_to_fingerprint( obj: None, - hasher: HashlibAlgorithm, + hasher: xtyping.HashlibAlgorithm, ) -> None: hasher.update(str(obj).encode()) @add_content_to_fingerprint.register -def add_func_to_fingerprint(obj: types.FunctionType, hasher: HashlibAlgorithm) -> None: +def add_func_to_fingerprint(obj: types.FunctionType, hasher: xtyping.HashlibAlgorithm) -> None: sourcedef = source_utils.SourceDefinition.from_function(obj) for item in sourcedef: add_content_to_fingerprint(item, hasher) @add_content_to_fingerprint.register -def add_dict_to_fingerprint(obj: dict, hasher: HashlibAlgorithm) -> None: +def add_dict_to_fingerprint(obj: dict, hasher: xtyping.HashlibAlgorithm) -> None: for key, value in obj.items(): add_content_to_fingerprint(key, hasher) add_content_to_fingerprint(value, hasher) @add_content_to_fingerprint.register -def add_type_to_fingerprint(obj: type, hasher: HashlibAlgorithm) -> None: +def add_type_to_fingerprint(obj: type, hasher: xtyping.HashlibAlgorithm) -> None: hasher.update(obj.__name__.encode()) @add_content_to_fingerprint.register -def add_sequence_to_fingerprint(obj: collections.abc.Iterable, hasher: HashlibAlgorithm) -> None: +def add_sequence_to_fingerprint( + obj: collections.abc.Iterable, hasher: xtyping.HashlibAlgorithm +) -> None: for item in obj: add_content_to_fingerprint(item, hasher) @add_content_to_fingerprint.register -def add_foast_located_node_to_fingerprint(obj: foast.LocatedNode, hasher: HashlibAlgorithm) -> None: +def add_foast_located_node_to_fingerprint( + obj: foast.LocatedNode, hasher: xtyping.HashlibAlgorithm +) -> None: add_content_to_fingerprint(obj.location, hasher) add_content_to_fingerprint(str(obj), hasher) From 7be23de231051ed0381bb7069161659b450d2d5a Mon Sep 17 00:00:00 2001 From: DropD Date: Tue, 23 Apr 2024 13:28:55 +0200 Subject: [PATCH 27/30] remove redundant singledispatch methods --- src/gt4py/next/ffront/stages.py | 15 --------------- 1 file changed, 15 deletions(-) diff --git a/src/gt4py/next/ffront/stages.py b/src/gt4py/next/ffront/stages.py index 559e0cabd1..f3fcad0219 100644 --- a/src/gt4py/next/ffront/stages.py +++ b/src/gt4py/next/ffront/stages.py @@ -122,21 +122,6 @@ def add_content_to_fingerprint_stages(obj: Any, hasher: xtyping.HashlibAlgorithm add_content_to_fingerprint(getattr(obj, field.name), hasher) -@add_content_to_fingerprint.register -def add_str_to_fingerprint(obj: str, hasher: xtyping.HashlibAlgorithm) -> None: - hasher.update(str(obj).encode()) - - -@add_content_to_fingerprint.register(int) -@add_content_to_fingerprint.register(bool) -@add_content_to_fingerprint.register(float) -def add_builtin_to_fingerprint( - obj: None, - hasher: xtyping.HashlibAlgorithm, -) -> None: - hasher.update(str(obj).encode()) - - @add_content_to_fingerprint.register def add_func_to_fingerprint(obj: types.FunctionType, hasher: xtyping.HashlibAlgorithm) -> None: sourcedef = source_utils.SourceDefinition.from_function(obj) From 250b26d8562b255703f0f8575bbb585e7790b7f9 Mon Sep 17 00:00:00 2001 From: DropD Date: Tue, 23 Apr 2024 15:43:39 +0200 Subject: [PATCH 28/30] [wip] make toolchains static with args as inputs --- src/gt4py/next/backend.py | 88 ++++++++++++++++----------------- src/gt4py/next/ffront/stages.py | 12 +++++ src/gt4py/next/otf/workflow.py | 23 +++++++++ 3 files changed, 77 insertions(+), 46 deletions(-) diff --git a/src/gt4py/next/backend.py b/src/gt4py/next/backend.py index 3d3c7a27e1..95cead8fef 100644 --- a/src/gt4py/next/backend.py +++ b/src/gt4py/next/backend.py @@ -34,23 +34,21 @@ from gt4py.next.program_processors import processor_interface as ppi -@dataclasses.dataclass(frozen=True) -class FopArgsInjector(workflow.Workflow): - args: tuple[Any, ...] = dataclasses.field(default_factory=tuple) - kwargs: dict[str, Any] = dataclasses.field(default_factory=dict) - from_fieldop: Any = None - - def __call__(self, inp: ffront_stages.FoastOperatorDefinition) -> ffront_stages.FoastClosure: - return ffront_stages.FoastClosure( - foast_op_def=inp, - args=self.args, - kwargs=self.kwargs, - closure_vars={inp.foast_node.id: self.from_fieldop}, - ) +@workflow.make_step +def foast_to_foast_closure( + inp: workflow.InputWithArgs[ffront_stages.FoastOperatorDefinition], +) -> ffront_stages.FoastClosure: + from_fieldop = inp.kwargs.pop("from_fieldop") + return ffront_stages.FoastClosure( + foast_op_def=inp.data, + args=inp.args, + kwargs=inp.kwargs, + closure_vars={inp.data.foast_node.id: from_fieldop}, + ) @dataclasses.dataclass(frozen=True) -class FieldopTransformWorkflow(workflow.NamedStepSequence): +class FieldopTransformWorkflow(workflow.NamedStepSequenceWithArgs): """Modular workflow for transformations with access to intermediates.""" func_to_foast: workflow.SkippableStep[ @@ -59,9 +57,9 @@ class FieldopTransformWorkflow(workflow.NamedStepSequence): ] = dataclasses.field( default_factory=lambda: func_to_foast.OptionalFuncToFoastFactory(cached=True) ) - foast_inject_args: workflow.Workflow[ - ffront_stages.FoastOperatorDefinition, ffront_stages.FoastClosure - ] = dataclasses.field(default_factory=FopArgsInjector) + foast_to_foast_closure: workflow.Workflow[ + workflow.InputWithArgs[ffront_stages.FoastOperatorDefinition], ffront_stages.FoastClosure + ] = dataclasses.field(default=foast_to_foast_closure, metadata={"takes_args": True}) foast_to_past_closure: workflow.Workflow[ ffront_stages.FoastClosure, ffront_stages.PastClosure ] = dataclasses.field( @@ -90,7 +88,7 @@ class FieldopTransformWorkflow(workflow.NamedStepSequence): def step_order(self) -> list[str]: return [ "func_to_foast", - "foast_inject_args", + "foast_to_foast_closure", "foast_to_past_closure", "past_transform_args", "past_to_itir", @@ -101,22 +99,7 @@ def step_order(self) -> list[str]: @dataclasses.dataclass(frozen=True) -class ProgArgsInjector(workflow.Workflow): - args: tuple[Any, ...] = dataclasses.field(default_factory=tuple) - kwargs: dict[str, Any] = dataclasses.field(default_factory=dict) - - def __call__(self, inp: ffront_stages.PastProgramDefinition) -> ffront_stages.PastClosure: - return ffront_stages.PastClosure( - past_node=inp.past_node, - closure_vars=inp.closure_vars, - grid_type=inp.grid_type, - args=self.args, - kwargs=self.kwargs, - ) - - -@dataclasses.dataclass(frozen=True) -class ProgramTransformWorkflow(workflow.NamedStepSequence): +class ProgramTransformWorkflow(workflow.NamedStepSequenceWithArgs): """Modular workflow for transformations with access to intermediates.""" func_to_past: workflow.SkippableStep[ @@ -128,11 +111,22 @@ class ProgramTransformWorkflow(workflow.NamedStepSequence): past_lint: workflow.Workflow[ ffront_stages.PastProgramDefinition, ffront_stages.PastProgramDefinition ] = dataclasses.field(default_factory=past_linters.LinterFactory) - past_inject_args: workflow.Workflow[ + past_to_past_closure: workflow.Workflow[ ffront_stages.PastProgramDefinition, ffront_stages.PastClosure - ] = dataclasses.field(default_factory=ProgArgsInjector) + ] = dataclasses.field( + default=lambda inp: ffront_stages.PastClosure( + past_node=inp.data.past_node, + closure_vars=inp.data.closure_vars, + grid_type=inp.data.grid_type, + args=inp.args, + kwargs=inp.kwargs, + ), + metadata={"takes_args": True}, + ) past_transform_args: workflow.Workflow[ffront_stages.PastClosure, ffront_stages.PastClosure] = ( - dataclasses.field(default=past_process_args.past_process_args) + dataclasses.field( + default=past_process_args.past_process_args, metadata={"takes_args": False} + ) ) past_to_itir: workflow.Workflow[ffront_stages.PastClosure, stages.ProgramCall] = ( dataclasses.field(default_factory=past_to_itir.PastToItirFactory) @@ -160,20 +154,22 @@ def __call__( ): offset_provider = kwargs.pop("offset_provider") from_fieldop = kwargs.pop("from_fieldop") - transforms_fop = self.transforms_fop.replace( - foast_inject_args=FopArgsInjector( - args=args, kwargs=kwargs, from_fieldop=from_fieldop - ) + # transforms_fop = self.transforms_fop.replace( + # foast_inject_args=FopArgsInjector( + # args=args, kwargs=kwargs, from_fieldop=from_fieldop + # ) + # ) + program_call = self.transforms_fop( + workflow.InputWithArgs(program, args, kwargs | {"from_fieldop": from_fieldop}) ) - program_call = transforms_fop(program) program_call = dataclasses.replace( program_call, kwargs=program_call.kwargs | {"offset_provider": offset_provider} ) else: - transforms_prog = self.transforms_prog.replace( - past_inject_args=ProgArgsInjector(args=args, kwargs=kwargs) - ) - program_call = transforms_prog(program) + # transforms_prog = self.transforms_prog.replace( + # past_inject_args=ProgArgsInjector(args=args, kwargs=kwargs) + # ) + program_call = self.transforms_prog(workflow.InputWithArgs(program, args, kwargs)) self.executor(program_call.program, *program_call.args, **program_call.kwargs) @property diff --git a/src/gt4py/next/ffront/stages.py b/src/gt4py/next/ffront/stages.py index f3fcad0219..08336f5318 100644 --- a/src/gt4py/next/ffront/stages.py +++ b/src/gt4py/next/ffront/stages.py @@ -109,6 +109,18 @@ def add_content_to_fingerprint(obj: Any, hasher: xtyping.HashlibAlgorithm) -> No hasher.update(str(obj).encode()) +@add_content_to_fingerprint.register +def add_str_to_fingerprint(obj: str, hasher: xtyping.HashlibAlgorithm) -> None: + hasher.update(str(obj).encode()) + + +@add_content_to_fingerprint.register(int) +@add_content_to_fingerprint.register(float) +@add_content_to_fingerprint.register(bool) +def add_builtin_to_fingerprint(obj: None, hasher: xtyping.HashlibAlgorithm) -> None: + hasher.update(str(obj).encode()) + + @add_content_to_fingerprint.register(FieldOperatorDefinition) @add_content_to_fingerprint.register(FoastOperatorDefinition) @add_content_to_fingerprint.register(FoastWithTypes) diff --git a/src/gt4py/next/otf/workflow.py b/src/gt4py/next/otf/workflow.py index c83748dece..2ab46e4cf9 100644 --- a/src/gt4py/next/otf/workflow.py +++ b/src/gt4py/next/otf/workflow.py @@ -265,3 +265,26 @@ def __call__(self, inp: StartT) -> EndT: def skip_condition(self, inp: StartT) -> bool: raise NotImplementedError() + + +@dataclasses.dataclass +class InputWithArgs(Generic[StartT]): + data: StartT + args: tuple[Any] + kwargs: dict[str, Any] + + +@dataclasses.dataclass(frozen=True) +class NamedStepSequenceWithArgs(NamedStepSequence[InputWithArgs[StartT], EndT]): + def __call__(self, inp: InputWithArgs[StartT]) -> EndT: + args = inp.args + kwargs = inp.kwargs + step_result: Any = inp.data + fields = {f.name: f for f in dataclasses.fields(self)} + for step_name in self.step_order: + step = getattr(self, step_name) + if fields[step_name].metadata.get("takes_args", False): + step_result = step(InputWithArgs(step_result, args, kwargs)) + else: + step_result = step(step_result) + return step_result From 8bb5f66d1b94b8124041a1695916e63e38d556fc Mon Sep 17 00:00:00 2001 From: DropD Date: Mon, 29 Apr 2024 17:20:29 +0200 Subject: [PATCH 29/30] update and expand toolchain documentation notebooks --- docs/user/next/advanced/HackTheToolchain.md | 129 +++++ .../ToolchainWalkthrough.md} | 286 ++++++++-- docs/user/next/advanced/WorkflowPatterns.md | 492 ++++++++++++++++++ tox.ini | 4 +- 4 files changed, 857 insertions(+), 54 deletions(-) create mode 100644 docs/user/next/advanced/HackTheToolchain.md rename docs/user/next/{Advanced_ToolchainWalkthrough.md => advanced/ToolchainWalkthrough.md} (64%) create mode 100644 docs/user/next/advanced/WorkflowPatterns.md diff --git a/docs/user/next/advanced/HackTheToolchain.md b/docs/user/next/advanced/HackTheToolchain.md new file mode 100644 index 0000000000..70681796ee --- /dev/null +++ b/docs/user/next/advanced/HackTheToolchain.md @@ -0,0 +1,129 @@ +```python +import dataclasses +import typing + +from gt4py import next as gtx +from gt4py.next.otf import workflow +from gt4py import eve +``` + + + + +## Replace Steps + +```python +cached_lowering_toolchain = gtx.backend.DEFAULT_PROG_TRANSFORMS.replace( + past_to_itir=workflow.CachedStep( + step=gtx.ffront.past_to_itir.PastToItirFactory(), + hash_function=eve.utils.content_hash + ) +) +``` + +## Skip Steps / Change Order + +```python +gtx.backend.DEFAULT_PROG_TRANSFORMS.step_order +``` + + ['func_to_past', + 'past_lint', + 'past_inject_args', + 'past_transform_args', + 'past_to_itir'] + +```python +@dataclasses.dataclass(frozen=True) +class SkipLinting(gtx.backend.ProgramTransformWorkflow): + @property + def step_order(self): + return [ + "func_to_past", + # not running "past_lint" + "past_inject_args", + "past_transform_args", + "past_to_itir", + ] + +same_steps = dataclasses.asdict(gtx.backend.DEFAULT_PROG_TRANSFORMS) +skip_linting_transforms = SkipLinting( + **same_steps +) +``` + +## Alternative Factory + +```python +class MyCodeGen: + ... + +class Cpp2BindingsGen: + ... + +class PureCpp2WorkflowFactory(gtx.program_processors.runners.gtfn.GTFNCompileWorkflowFactory): + translation: workflow.Workflow[ + gtx.otf.stages.ProgramCall, gtx.otf.stages.ProgramSource] = MyCodeGen() + bindings: workflow.Workflow[ + gtx.otf.stages.ProgramSource, gtx.otf.stages.CompilableSource] = Cpp2BindingsGen() + +PureCpp2WorkflowFactory(cmake_build_type=gtx.config.CMAKE_BUILD_TYPE.DEBUG) +``` + +## Invent new Workflow Types + +````mermaid +graph LR + +IN_T --> i{{split}} --> A_T --> a{{track_a}} --> B_T --> o{{combine}} --> OUT_T +i --> X_T --> x{{track_x}} --> Y_T --> o + + +```python +IN_T = typing.TypeVar("IN_T") +A_T = typing.TypeVar("A_T") +B_T = typing.TypeVar("B_T") +X_T = typing.TypeVar("X_T") +Y_T = typing.TypeVar("Y_T") +OUT_T = typing.TypeVar("OUT_T") + +@dataclasses.dataclass(frozen=True) +class FullyModularDiamond( + workflow.ChainableWorkflowMixin[IN_T, OUT_T], + workflow.ReplaceEnabledWorkflowMixin[IN_T, OUT_T], + typing.Protocol[IN_T, OUT_T, A_T, B_T, X_T, Y_T] +): + split: workflow.Workflow[IN_T, tuple[A_T, X_T]] + track_a: workflow.Workflow[A_T, B_T] + track_x: workflow.Workflow[X_T, Y_T] + combine: workflow.Workflow[tuple[B_T, Y_T], OUT_T] + + def __call__(self, inp: IN_T) -> OUT_T: + a, x = self.split(inp) + b = self.track_a(a) + y = self.track_x(x) + return self.combine((b, y)) + + +@dataclasses.dataclass(frozen=True) +class PartiallyModularDiamond( + workflow.ChainableWorkflowMixin[IN_T, OUT_T], + workflow.ReplaceEnabledWorkflowMixin[IN_T, OUT_T], + typing.Protocol[IN_T, OUT_T, A_T, B_T, X_T, Y_T] +): + track_a: workflow.Workflow[A_T, B_T] + track_x: workflow.Workflow[X_T, Y_T] + + def split(inp: IN_T) -> tuple[A_T, X_T]: + ... + + def combine(b: B_T, y: Y_T) -> OUT_T: + ... + + def __call__(inp: IN_T) -> OUT_T: + a, x = self.split(inp) + return self.combine( + b=self.track_a(a), + y=self.track_x(x) + ) +```` diff --git a/docs/user/next/Advanced_ToolchainWalkthrough.md b/docs/user/next/advanced/ToolchainWalkthrough.md similarity index 64% rename from docs/user/next/Advanced_ToolchainWalkthrough.md rename to docs/user/next/advanced/ToolchainWalkthrough.md index 94a7bfa7e2..d44663a72c 100644 --- a/docs/user/next/Advanced_ToolchainWalkthrough.md +++ b/docs/user/next/advanced/ToolchainWalkthrough.md @@ -24,14 +24,26 @@ graph LR fdef(FieldOperatorDefinition) -->|func_to_foast| foast(FoastOperatorDefinition) foast -->|foast_to_itir| itir_expr(itir.Expr) -foast -->|foast_inject_args| fclos(FoastClosure) +foasta -->|foast_to_foast_closure| fclos(FoastClosure) fclos -->|foast_to_past_closure| pclos(PastClosure) pclos -->|past_process_args| pclos pclos -->|past_to_itir| pcall(ProgramCall) pdef(ProgramDefinition) -->|func_to_past| past(PastProgramDefinition) past -->|past_lint| past -past -->|past_inject_args| pclos(ProgramClosure) +pasta -->|past_to_past_closure| pclos(ProgramClosure) + +fdefa(InputWithArgs) --> fuwr{{"internal unwrapping"}} --> fdef +fuwr --> fargs(args, kwargs) + +foast --> fiwr{{"internal wrapping"}} --> foasta(InputWithArgs) +fargs --> foasta + +pdefa(InputWithArgs) --> puwr{{"internal unwrapping"}} --> pdef +puwr --> pargs(args, kwargs) + +past --> piwr{{"internal wrapping"}} --> pasta(InputWithArgs) +pargs --> pasta ``` # Walkthrough from Field Operator @@ -71,14 +83,26 @@ graph LR fdef(FieldOperatorDefinition) -->|func_to_foast| foast(FoastOperatorDefinition) foast -->|foast_to_itir| itir_expr(itir.Expr) -foast -->|foast_inject_args| fclos(FoastClosure) +foasta -->|foast_to_foast_closure| fclos(FoastClosure) fclos -->|foast_to_past_closure| pclos(PastClosure) pclos -->|past_process_args| pclos pclos -->|past_to_itir| pcall(ProgramCall) pdef(ProgramDefinition) -->|func_to_past| past(PastProgramDefinition) past -->|past_lint| past -past -->|past_inject_args| pclos(ProgramClosure) +pasta -->|past_to_past_closure| pclos(ProgramClosure) + +fdefa(InputWithArgs) --> fuwr{{"internal unwrapping"}} --> fdef +fuwr --> fargs(args, kwargs) + +foast --> fiwr{{"internal wrapping"}} --> foasta(InputWithArgs) +fargs --> foasta + +pdefa(InputWithArgs) --> puwr{{"internal unwrapping"}} --> pdef +puwr --> pargs(args, kwargs) + +past --> piwr{{"internal wrapping"}} --> pasta(InputWithArgs) +pargs --> pasta style fdef fill:red style foast fill:red @@ -114,14 +138,26 @@ graph LR fdef(FieldOperatorDefinition) -->|func_to_foast| foast(FoastOperatorDefinition) foast -->|foast_to_itir| itir_expr(itir.Expr) -foast -->|foast_inject_args| fclos(FoastClosure) +foasta -->|foast_to_foast_closure| fclos(FoastClosure) fclos -->|foast_to_past_closure| pclos(PastClosure) pclos -->|past_process_args| pclos pclos -->|past_to_itir| pcall(ProgramCall) pdef(ProgramDefinition) -->|func_to_past| past(PastProgramDefinition) past -->|past_lint| past -past -->|past_inject_args| pclos(ProgramClosure) +pasta -->|past_to_past_closure| pclos(ProgramClosure) + +fdefa(InputWithArgs) --> fuwr{{"internal unwrapping"}} --> fdef +fuwr --> fargs(args, kwargs) + +foast --> fiwr{{"internal wrapping"}} --> foasta(InputWithArgs) +fargs --> foasta + +pdefa(InputWithArgs) --> puwr{{"internal unwrapping"}} --> pdef +puwr --> pargs(args, kwargs) + +past --> piwr{{"internal wrapping"}} --> pasta(InputWithArgs) +pargs --> pasta style foast fill:red style itir_expr fill:red @@ -147,34 +183,53 @@ graph LR fdef(FieldOperatorDefinition) -->|func_to_foast| foast(FoastOperatorDefinition) foast -->|foast_to_itir| itir_expr(itir.Expr) -foast -->|foast_inject_args| fclos(FoastClosure) +foasta -->|foast_to_foast_closure| fclos(FoastClosure) fclos -->|foast_to_past_closure| pclos(PastClosure) pclos -->|past_process_args| pclos pclos -->|past_to_itir| pcall(ProgramCall) pdef(ProgramDefinition) -->|func_to_past| past(PastProgramDefinition) past -->|past_lint| past -past -->|past_inject_args| pclos(ProgramClosure) +pasta -->|past_to_past_closure| pclos(ProgramClosure) -style foast fill:red +fdefa(InputWithArgs) --> fuwr{{"internal unwrapping"}} --> fdef +fuwr --> fargs(args, kwargs) + +foast --> fiwr{{"internal wrapping"}} --> foasta(InputWithArgs) +fargs --> foasta + +pdefa(InputWithArgs) --> puwr{{"internal unwrapping"}} --> pdef +puwr --> pargs(args, kwargs) + +past --> piwr{{"internal wrapping"}} --> pasta(InputWithArgs) +pargs --> pasta + +style foasta fill:red style fclos fill:red linkStyle 2 stroke:red,stroke-width:4px,color:pink ``` -Here we have to dynamically generate a workflow step, because the arguments were not known before. +Here we have to manually combine the previous result with the call arguments. When we call the toolchain as a whole later we will only have to do this once at the beginning. + +```python +fclos = backend.DEFAULT_FIELDOP_TRANSFORMS.foast_to_foast_closure( + gtx.otf.workflow.InputWithArgs( + data=foast, + args=(gtx.ones(domain={I: 10}, dtype=gtx.float64),), + kwargs={ + "out": gtx.zeros(domain={I: 10}, dtype=gtx.float64), + "from_fieldop": example_fo + }, + ) +) +``` ```python -fclos = backend.DEFAULT_FIELDOP_TRANSFORMS.foast_inject_args.__class__( - args=(gtx.ones(domain={I: 10}, dtype=gtx.float64),), - kwargs={ - "out": gtx.zeros(domain={I: 10}, dtype=gtx.float64) - }, - from_fieldop=example_fo -)(foast) +fclos.closure_vars["example_fo"].backend ``` ```python -gtx.ffront.stages.FoastClosure? +gtx.ffront.stages.FoastClosure?? ``` Init signature: @@ -185,6 +240,13 @@ gtx.ffront.stages.FoastClosure?  closure_vars: 'dict[str, Any]', ) -> None Docstring: FoastClosure(foast_op_def: 'FoastOperatorDefinition[OperatorNodeT]', args: 'tuple[Any, ...]', kwargs: 'dict[str, Any]', closure_vars: 'dict[str, Any]') + Source: + @dataclasses.dataclass(frozen=True) + class FoastClosure(Generic[OperatorNodeT]): +  foast_op_def: FoastOperatorDefinition[OperatorNodeT] +  args: tuple[Any, ...] +  kwargs: dict[str, Any] +  closure_vars: dict[str, Any] File: ~/Code/gt4py/src/gt4py/next/ffront/stages.py Type: type Subclasses: @@ -198,14 +260,26 @@ graph LR fdef(FieldOperatorDefinition) -->|func_to_foast| foast(FoastOperatorDefinition) foast -->|foast_to_itir| itir_expr(itir.Expr) -foast -->|foast_inject_args| fclos(FoastClosure) +foasta -->|foast_to_foast_closure| fclos(FoastClosure) fclos -->|foast_to_past_closure| pclos(PastClosure) pclos -->|past_process_args| pclos pclos -->|past_to_itir| pcall(ProgramCall) pdef(ProgramDefinition) -->|func_to_past| past(PastProgramDefinition) past -->|past_lint| past -past -->|past_inject_args| pclos(ProgramClosure) +pasta -->|past_to_past_closure| pclos(ProgramClosure) + +fdefa(InputWithArgs) --> fuwr{{"internal unwrapping"}} --> fdef +fuwr --> fargs(args, kwargs) + +foast --> fiwr{{"internal wrapping"}} --> foasta(InputWithArgs) +fargs --> foasta + +pdefa(InputWithArgs) --> puwr{{"internal unwrapping"}} --> pdef +puwr --> pargs(args, kwargs) + +past --> piwr{{"internal wrapping"}} --> pasta(InputWithArgs) +pargs --> pasta style fclos fill:red style pclos fill:red @@ -242,14 +316,26 @@ graph LR fdef(FieldOperatorDefinition) -->|func_to_foast| foast(FoastOperatorDefinition) foast -->|foast_to_itir| itir_expr(itir.Expr) -foast -->|foast_inject_args| fclos(FoastClosure) +foasta -->|foast_to_foast_closure| fclos(FoastClosure) fclos -->|foast_to_past_closure| pclos(PastClosure) pclos -->|past_process_args| pclos pclos -->|past_to_itir| pcall(ProgramCall) pdef(ProgramDefinition) -->|func_to_past| past(PastProgramDefinition) past -->|past_lint| past -past -->|past_inject_args| pclos(ProgramClosure) +pasta -->|past_to_past_closure| pclos(ProgramClosure) + +fdefa(InputWithArgs) --> fuwr{{"internal unwrapping"}} --> fdef +fuwr --> fargs(args, kwargs) + +foast --> fiwr{{"internal wrapping"}} --> foasta(InputWithArgs) +fargs --> foasta + +pdefa(InputWithArgs) --> puwr{{"internal unwrapping"}} --> pdef +puwr --> pargs(args, kwargs) + +past --> piwr{{"internal wrapping"}} --> pasta(InputWithArgs) +pargs --> pasta style pclos fill:red %%style pclos fill:red @@ -260,6 +346,12 @@ linkStyle 4 stroke:red,stroke-width:4px,color:pink pclost = backend.DEFAULT_PROG_TRANSFORMS.past_transform_args(pclos) ``` +```python +pclost.kwargs +``` + + {} + ## Lower PAST -> ITIR still forwarding the call arguments @@ -269,14 +361,26 @@ graph LR fdef(FieldOperatorDefinition) -->|func_to_foast| foast(FoastOperatorDefinition) foast -->|foast_to_itir| itir_expr(itir.Expr) -foast -->|foast_inject_args| fclos(FoastClosure) +foasta -->|foast_to_foast_closure| fclos(FoastClosure) fclos -->|foast_to_past_closure| pclos(PastClosure) pclos -->|past_process_args| pclos pclos -->|past_to_itir| pcall(ProgramCall) pdef(ProgramDefinition) -->|func_to_past| past(PastProgramDefinition) past -->|past_lint| past -past -->|past_inject_args| pclos(ProgramClosure) +pasta -->|past_to_past_closure| pclos(ProgramClosure) + +fdefa(InputWithArgs) --> fuwr{{"internal unwrapping"}} --> fdef +fuwr --> fargs(args, kwargs) + +foast --> fiwr{{"internal wrapping"}} --> foasta(InputWithArgs) +fargs --> foasta + +pdefa(InputWithArgs) --> puwr{{"internal unwrapping"}} --> pdef +puwr --> pargs(args, kwargs) + +past --> piwr{{"internal wrapping"}} --> pasta(InputWithArgs) +pargs --> pasta style pclos fill:red style pcall fill:red @@ -326,30 +430,46 @@ graph LR fdef(FieldOperatorDefinition) -->|func_to_foast| foast(FoastOperatorDefinition) foast -->|foast_to_itir| itir_expr(itir.Expr) -foast -->|foast_inject_args| fclos(FoastClosure) +foasta -->|foast_to_foast_closure| fclos(FoastClosure) fclos -->|foast_to_past_closure| pclos(PastClosure) pclos -->|past_process_args| pclos pclos -->|past_to_itir| pcall(ProgramCall) pdef(ProgramDefinition) -->|func_to_past| past(PastProgramDefinition) past -->|past_lint| past -past -->|past_inject_args| pclos(ProgramClosure) +pasta -->|past_to_past_closure| pclos(ProgramClosure) +fdefa(InputWithArgs) --> fuwr{{"internal unwrapping"}} --> fdef +fuwr --> fargs(args, kwargs) + +foast --> fiwr{{"internal wrapping"}} --> foasta(InputWithArgs) +fargs --> foasta + +pdefa(InputWithArgs) --> puwr{{"internal unwrapping"}} --> pdef +puwr --> pargs(args, kwargs) + +past --> piwr{{"internal wrapping"}} --> pasta(InputWithArgs) +pargs --> pasta + +style fdefa fill:red +style fuwr fill:red style fdef fill:red +style fargs fill:red style foast fill:red +style fiwr fill:red +style foasta fill:red style fclos fill:red style pclos fill:red style pcall fill:red -linkStyle 0,2,3,4,5 stroke:red,stroke-width:4px,color:pink +linkStyle 0,2,3,4,5,9,10,11,12,13,14 stroke:red,stroke-width:4px,color:pink ``` ### Starting from DSL ```python -foast_toolchain = backend.DEFAULT_FIELDOP_TRANSFORMS.replace( - foast_inject_args=backend.FopArgsInjector(args=fclos.args, kwargs=fclos.kwargs, from_fieldop=example_fo) +pitir2 = backend.DEFAULT_FIELDOP_TRANSFORMS( + gtx.otf.workflow.InputWithArgs(data=start, args=fclos.args, kwargs=fclos.kwargs | {"from_fieldop": example_fo}) ) -pitir2 = foast_toolchain(start) assert pitir2 == pitir ``` @@ -365,22 +485,39 @@ example_compiled = gtx.program_processors.runners.roundtrip.executor.otf_workflo example_compiled(*pitir2.args, offset_provider=OFFSET_PROVIDER) ``` +We can re-run with the output from the previous run as in- and output. + ```python example_compiled(pitir2.args[1], *pitir2.args[1:], offset_provider=OFFSET_PROVIDER) ``` ```python -pitir2.args[1].asnumpy() +pitir2.args[2] ``` - array([3., 3., 3., 3., 3., 3., 3., 3., 3., 3.]) + 10 + +```python +pitir.args +``` + + (NumPyArrayField(_domain=Domain(dims=(Dimension(value='I', kind=),), ranges=(UnitRange(0, 10),)), _ndarray=array([1., 1., 1., 1., 1., 1., 1., 1., 1., 1.])), + NumPyArrayField(_domain=Domain(dims=(Dimension(value='I', kind=),), ranges=(UnitRange(0, 10),)), _ndarray=array([3., 3., 3., 3., 3., 3., 3., 3., 3., 3.])), + 10, + 10) ### Starting from FOAST Note that it is the exact same call but with a different input stage ```python -pitir3 = foast_toolchain(foast) +pitir3 = backend.DEFAULT_FIELDOP_TRANSFORMS( + gtx.otf.workflow.InputWithArgs( + data=foast, + args=fclos.args, + kwargs=fclos.kwargs | {"from_fieldop": example_fo} + ) +) assert pitir3 == pitir ``` @@ -419,14 +556,26 @@ graph LR fdef(FieldOperatorDefinition) -->|func_to_foast| foast(FoastOperatorDefinition) foast -->|foast_to_itir| itir_expr(itir.Expr) -foast -->|foast_inject_args| fclos(FoastClosure) +foasta -->|foast_to_foast_closure| fclos(FoastClosure) fclos -->|foast_to_past_closure| pclos(PastClosure) pclos -->|past_process_args| pclos pclos -->|past_to_itir| pcall(ProgramCall) pdef(ProgramDefinition) -->|func_to_past| past(PastProgramDefinition) past -->|past_lint| past -past -->|past_inject_args| pclos(ProgramClosure) +pasta -->|past_to_past_closure| pclos(ProgramClosure) + +fdefa(InputWithArgs) --> fuwr{{"internal unwrapping"}} --> fdef +fuwr --> fargs(args, kwargs) + +foast --> fiwr{{"internal wrapping"}} --> foasta(InputWithArgs) +fargs --> foasta + +pdefa(InputWithArgs) --> puwr{{"internal unwrapping"}} --> pdef +puwr --> pargs(args, kwargs) + +past --> piwr{{"internal wrapping"}} --> pasta(InputWithArgs) +pargs --> pasta style pdef fill:red style past fill:red @@ -444,27 +593,40 @@ graph LR fdef(FieldOperatorDefinition) -->|func_to_foast| foast(FoastOperatorDefinition) foast -->|foast_to_itir| itir_expr(itir.Expr) -foast -->|foast_inject_args| fclos(FoastClosure) +foasta -->|foast_to_foast_closure| fclos(FoastClosure) fclos -->|foast_to_past_closure| pclos(PastClosure) pclos -->|past_process_args| pclos pclos -->|past_to_itir| pcall(ProgramCall) pdef(ProgramDefinition) -->|func_to_past| past(PastProgramDefinition) past -->|past_lint| past -past -->|past_inject_args| pclos(ProgramClosure) +pasta -->|past_to_past_closure| pclos(ProgramClosure) -style past fill:red +fdefa(InputWithArgs) --> fuwr{{"internal unwrapping"}} --> fdef +fuwr --> fargs(args, kwargs) + +foast --> fiwr{{"internal wrapping"}} --> foasta(InputWithArgs) +fargs --> foasta + +pdefa(InputWithArgs) --> puwr{{"internal unwrapping"}} --> pdef +puwr --> pargs(args, kwargs) + +past --> piwr{{"internal wrapping"}} --> pasta(InputWithArgs) +pargs --> pasta + +style pasta fill:red style pclos fill:red -linkStyle 7 stroke:red,stroke-width:4px,color:pink +linkStyle 8 stroke:red,stroke-width:4px,color:pink ``` ```python -pclos = backend.DEFAULT_PROG_TRANSFORMS.replace( - past_inject_args=backend.ProgArgsInjector( +pclos = backend.DEFAULT_PROG_TRANSFORMS( + gtx.otf.workflow.InputWithArgs( + data=p_past, args=fclos.args, kwargs=fclos.kwargs ) -)(p_past) +) ``` ## Full Program Toolchain @@ -474,27 +636,45 @@ graph LR fdef(FieldOperatorDefinition) -->|func_to_foast| foast(FoastOperatorDefinition) foast -->|foast_to_itir| itir_expr(itir.Expr) -foast -->|foast_inject_args| fclos(FoastClosure) +foasta -->|foast_to_foast_closure| fclos(FoastClosure) fclos -->|foast_to_past_closure| pclos(PastClosure) pclos -->|past_process_args| pclos pclos -->|past_to_itir| pcall(ProgramCall) pdef(ProgramDefinition) -->|func_to_past| past(PastProgramDefinition) past -->|past_lint| past -past -->|past_inject_args| pclos(ProgramClosure) +pasta -->|past_to_past_closure| pclos(ProgramClosure) + +fdefa(InputWithArgs) --> fuwr{{"internal unwrapping"}} --> fdef +fuwr --> fargs(args, kwargs) + +foast --> fiwr{{"internal wrapping"}} --> foasta(InputWithArgs) +fargs --> foasta +pdefa(InputWithArgs) --> puwr{{"internal unwrapping"}} --> pdef +puwr --> pargs(args, kwargs) + +past --> piwr{{"internal wrapping"}} --> pasta(InputWithArgs) +pargs --> pasta + +style pdefa fill:red +style puwr fill:red style pdef fill:red +style pargs fill:red style past fill:red +style piwr fill:red +style pasta fill:red style pclos fill:red style pcall fill:red -linkStyle 4,5,6,7 stroke:red,stroke-width:4px,color:pink +linkStyle 4,5,6,7,8,15,16,17,18,19,20 stroke:red,stroke-width:4px,color:pink ``` ### Starting from DSL ```python -toolchain = backend.DEFAULT_PROG_TRANSFORMS.replace( - past_inject_args=backend.ProgArgsInjector( +p_itir1 = backend.DEFAULT_PROG_TRANSFORMS( + gtx.otf.workflow.InputWithArgs( + data=p_start, args=fclos.args, kwargs=fclos.kwargs ) @@ -502,11 +682,13 @@ toolchain = backend.DEFAULT_PROG_TRANSFORMS.replace( ``` ```python -p_itir1 = toolchain(p_start) -``` - -```python -p_itir2 = toolchain(p_past) +p_itir2 = backend.DEFAULT_PROG_TRANSFORMS( + gtx.otf.workflow.InputWithArgs( + data=p_past, + args=fclos.args, + kwargs=fclos.kwargs + ) +) ``` ```python diff --git a/docs/user/next/advanced/WorkflowPatterns.md b/docs/user/next/advanced/WorkflowPatterns.md new file mode 100644 index 0000000000..76880d86f0 --- /dev/null +++ b/docs/user/next/advanced/WorkflowPatterns.md @@ -0,0 +1,492 @@ +--- +jupyter: + jupytext: + formats: ipynb,md + text_representation: + extension: .md + format_name: markdown + format_version: "1.3" + jupytext_version: 1.16.1 + kernelspec: + display_name: Python 3 (ipykernel) + language: python + name: python3 +--- + +```python editable=true slideshow={"slide_type": ""} +import dataclasses +import re + +import factory + +import gt4py.next as gtx + +import devtools +``` + + + +# How to read (toolchain) workflows + + + + + +## Basic workflow (single step) + +```mermaid +graph LR + +StageA -->|basic workflow| StageB +``` + +Where "Stage" describes any data structure, and where `StageA` contains all the input data and `StageB` contains all the output data. + + + + + +### Simplest possible + + + +```python editable=true slideshow={"slide_type": ""} +def simple_add_1(inp: int) -> int: + return inp + 1 + +simple_add_1(1) +``` + + + +This is already a (single step) workflow. We can build a more complex one by chaining it multiple times. + +```mermaid +graph LR + +inp(A: int) -->|simple_add_1| b(A + 1) -->|simple_add_1| c(A + 2) -->|simple_add_1| out(A + 3) +``` + + + +```python editable=true slideshow={"slide_type": ""} +manual_add_3 = gtx.otf.workflow.StepSequence.start( + simple_add_1 +).chain(simple_add_1).chain(simple_add_1) + +manual_add_3(1) +``` + + + +### Simplest Composable Step + +All we have to do for chaining to work out of the box is add the `make_step` decorator! + + + +```python editable=true slideshow={"slide_type": ""} +@gtx.otf.workflow.make_step +def chainable_add_1(inp: int) -> int: + return inp + 1 +``` + +```python editable=true slideshow={"slide_type": ""} +add_3 = chainable_add_1.chain(chainable_add_1).chain(chainable_add_1) +add_3(1) +``` + +### Example in the Wild + +```python jupyter={"outputs_hidden": true} +gtx.ffront.func_to_past.func_to_past.steps.inner[0]?? +``` + + + +### Step with Parameters + +Sometimes we want to allow for different configurations of a step. + + + +```python editable=true slideshow={"slide_type": ""} +@dataclasses.dataclass(frozen=True) +class MathOp(gtx.otf.workflow.ChainableWorkflowMixin[int, int]): + op: str + rhs: int = 0 + + def __call__(self, inp: int) -> int: + return getattr(self, self.op)(inp, self.rhs) + + def add(self, lhs: int, rhs: int) -> int: + return lhs + rhs + + def mul(self, lhs: int, rhs: int) -> int: + return lhs * rhs + +add_3_times_2 = ( + MathOp("add", 3) + .chain(MathOp("mul", 2)) +) +add_3_times_2(1) +``` + +### Example in the Wild + +```python jupyter={"outputs_hidden": true} +gtx.program_processors.runners.roundtrip.Roundtrip?? +``` + + + +### Wrapper Steps + +Sometimes we want to make a step behave slightly differently without modifying the step itself. In this case we can wrap it into a wrapper step. These behave a little bit like (limited) decorators. +Below we will go through the existing wrapper steps, which you might encounter. + +#### Caching / memoizing + +For example we might want to cach the output (memoize) for which we need to add a way of hashing the input: + +```mermaid +graph LR + + +inp --> calc +inp(A: int) --> ha{{"hash_function(A)"}} --> h("hash(A)") --> ck{{"check cache"}} -->|miss| miss("not in cache") --> calc{{add_3_times_2}} --> out(result) +ck -->|hit| hit("in cache") --> out +``` + +For this we can use the `CachedStep`, you will see something like below + + + +```python editable=true slideshow={"slide_type": ""} +@gtx.otf.workflow.make_step +def debug_print(inp: int) -> int: + print("cache miss!") + return inp + +cached_calc = gtx.otf.workflow.CachedStep( + step=debug_print.chain(add_3_times_2), + hash_function=lambda i: str(i) # using ints as their own hash +) + +cached_calc(1) +cached_calc(1) +cached_calc(1) +``` + +### Example in the Wild + +```python jupyter={"outputs_hidden": true} +gtx.backend.DEFAULT_PROG_TRANSFORMS.past_lint?? +``` + + + +Though we execute the workflow three times we only get the debug print once, it worked! Btw, hashing is rarely that easy in the wild... + +#### Conditionally skipping steps + +The `SkippableStep` pattern can be used to skip a step under a given condition. A main use case is when you might want to run a workflow either from the start or from further along (with the same interface). + +Let's say we want to make our calculation workflow compatible with string input. We can add a conversion step (which only works with strings). + + + +```python editable=true slideshow={"slide_type": ""} +@gtx.otf.workflow.make_step +def to_int(inp: str) -> int: + assert isinstance(inp, str), "Can not work with 'int'!" # yes, this is horribly contrived + return int(inp) + +str_calc = to_int.chain(add_3_times_2) + +str_calc("1") +``` + + + +Now we can start from a string that contains an int. But if we already have an int, it will fail. + + + +```python editable=true slideshow={"slide_type": ""} +try: + str_calc(1) +except AssertionError as err: + print(err) +``` + + + +What to do? What we want is a to conditionally skip the first step, so we replace it with a `SkippableStep`: + +```python +class OptionalStrToInt(SkippableStep[str | int, int]): + step: Workflow[str, int] + + def skip_condition(self, inp: str | int) -> bool: + ... # return True to skip (if we get an int) or False to run the conversion (str case) + +``` + +```mermaid +graph LR + +int(A: int = 1) --> calc{{"add_3_times_2(1)"}} --> result(8) +int --> ski{{"skip_condition(1)"}} -->|True| calc +str("B: str = '1'") --> sks{{"skip_condition('1')"}} -->|False| conv{{to_int}} --> b2("int(B) = 1") --> calc +``` + + + +```python editable=true slideshow={"slide_type": ""} +@dataclasses.dataclass(frozen=True) +class OptionalStrToInt(gtx.otf.workflow.SkippableStep[str | int, int]): + step: gtx.otf.workflow.Workflow[str, int] = to_int + + def skip_condition(self, inp: str | int) -> bool: + match inp: + case int(): + return True + case str(): + return False + case _: + # optionally raise an error with good advice + return False + +strint_calc = OptionalStrToInt().chain(add_3_times_2) +strint_calc(1) == strint_calc("1") +``` + + + +### Example in the Wild + + + +```python jupyter={"outputs_hidden": true} editable=true slideshow={"slide_type": ""} +gtx.backend.DEFAULT_PROG_TRANSFORMS.func_to_past?? +``` + + + +### Step with factory (builder) + +If a step can be useful with different combinations of parameters and wrappers, it should have a factory. In this case we will add a neutral wrapper around it, so we can put any combination of wrappers into that: + + + +```python editable=true slideshow={"slide_type": ""} +@dataclasses.dataclass(frozen=True) +class AnyStrToInt(gtx.otf.workflow.ChainableWorkflowMixin[str | int, int]): + inner_step: gtx.otf.workflow.Workflow[str, int] = to_int + + def __call__(self, inp: str | int) -> int: + return self.inner_step(inp) + + +class StrToIntFactory(factory.Factory): + class Meta: + model = AnyStrToInt + + class Params: + default_step = to_int + optional: bool = False + optional_or_not = factory.LazyAttribute(lambda o: OptionalStrToInt(step=o.default_step) if o.optional else o.default_step) + cached = factory.Trait( + inner_step = factory.LazyAttribute( + lambda o: gtx.otf.workflow.CachedStep(step=(o.optional_or_not), hash_function=str) + ) + ) + inner_step = factory.LazyAttribute(lambda o: o.optional_or_not) + +cached = StrToIntFactory(cached=True) +optional = StrToIntFactory(optional=True) +both = StrToIntFactory(cached=True, optional=True) +neither = StrToIntFactory() +neither.inner_step +``` + +### Example in the Wild + +```python +gtx.ffront.past_passes.linters.LinterFactory?? +``` + + + +## Composition 1: Chaining + +So far we have only seen compsition of workflows by chaining. Any sequence of steps can be represented as a chain. Chains can be built of smaller chains, so a Workflow could be composed and then reused in a bigger workflow. + +However, chains are of limited use in the real world, because it's a pain to access a specific step. This we might want to do in order to: + +- run that step in isolation for debugging or other purposes +- build a new chain with a step swapped out (workflows are immutable). + +Imagine swapping out `sub_third` in `complicated_workflow` below (without copy pasting code): + +```python +complicated_workflow = ( + start_step + .chain(first_sub_first.chain(first_sub_second).chain(first_sub_third)) + .chain(second_sub_first.chain(second_sub_second)) + .chain(last) +) +``` + +```mermaid +graph TD +c{{complicated_workflow}} --> 0 --> s{{start_step}} +c --> 1 -->|0| a1{{first_sub_first}} +1 -->|1| a2{{first_sub_second}} +1 -->|2| a3{{first_sub_third}} +c --> 2 -->|0| b1{{second_sub_first}} +2 -->|1| b2{{second_sub_second}} +c --> 3 -->|0| l{{last}} +``` + + + + + +## Composition 2: Sequence of Named Steps + +Let's say we want a string processing workflow where the intermediate stages are also of value on their own. We would want to access individual steps, specifically each step as it was configured for this workflow (with parameters, caching, etc identical). + +For this we can use `NamedStepSequence`, giving each step a name, by which we can access it later. For this we have to create a dataclass and derive from `NamedStepSequence`. Each step is then a field of the dataclass, type hinted as a `Workflow`. The resulting workflow will run the steps in order of their apperance in the class body. + +To use the same "complicated workflow" example from above: + +```python +@dataclasses.dataclass(frozen=True) +class FirstSub(gtx.otf.workflow.NamedStepSequence[B, E]): + first: Workflow[B, C] + second: Workflow[C, D] + third: Workflow[D, E] + + +@dataclasses.dataclass(frozen=True) +class SecondSub(gtx.otf.workflow.NamedStepSequence[E, G]): + first: Workflow[E, F] + second: Workflow[F, G] + + +@dataclasses.dataclass(frozen=True) +class ComplicatedWorkflow(gtx.otf.workflow.NamedStepSequence[A, F]): + start_step: Workflow[A, B] + first_sub: Workflow[B, E] + second_sub: Workflow[E, G] + last: Workflow[G, F] + +complicated_workflow = ComplicatedWorkflow( + start_step=start_step, + first_sub=FirstSub( + first=first_sub_first, + second=first_sub_second, + third=first_sub_third + ), + second_sub=SecondSub( + first=second_sub_first, + second=second_sub_second + ), + last=last +) + +``` + +```mermaid +graph TD + +w{{complicated_workflow: ComplicatedWorkflow}} -->|".start_step"| a{{start_step}} +w -->|".first_sub.first"| b{{first_sub_first}} +w -->|".first_sub.second"| c{{first_sub_second}} +w -->|".first_sub.third"| d{{first_sub_third}} +w -->|".second_sub.first"| e{{second_sub_first}} +w -->|".second_sub_second"| f{{second_sub_second}} +w -->|".last"| g{{last}} +``` + + + +```python editable=true slideshow={"slide_type": ""} +## Here we define how the steps are composed +@dataclasses.dataclass(frozen=True) +class StrProcess(gtx.otf.workflow.NamedStepSequence): + hexify_colors: gtx.otf.workflow.Workflow[str, str] + replace_tabs: gtx.otf.workflow.Workflow[str, str] + + +## Here we define the steps themselves +@dataclasses.dataclass(frozen=True) +class HexifyColors(gtx.otf.workflow.ChainableWorkflowMixin): + color_scheme: dict[str, str] = dataclasses.field( + default_factory=lambda: {"blue": "#0000ff", "green": "#00ff00", "red": "#ff0000"} + ) + + def __call__(self, inp: str) -> str: + result = inp + for color, hexcode in self.color_scheme.items(): + result = result.replace(color, hexcode) + return result + + +def spaces_to_tabs(inp: str) -> str: + return re.sub(r" ", r"\t", inp) +``` + + + +Note that with all this there comes an extra feature: We can easily create variants with different steps, without having to change the code that will use the composed workflow. Even if the calling code calls steps in isolation! + + + +```python editable=true slideshow={"slide_type": ""} +CUSTOM_COLORS = {"blue": "#55aaff", "green": "#00ff00", "red": "#ff0000"} + +proc = StrProcess( + hexify_colors=HexifyColors( + color_scheme=CUSTOM_COLORS + ), + replace_tabs=spaces_to_tabs +) + +proc(""" +p { + background-color: blue; + color: red; +} +""") +``` + +```python editable=true slideshow={"slide_type": ""} +proc.hexify_colors("blue") +``` + + + +`NamedStepSequence`s still work with wrapper steps, parameters and chaining. They can also be nested. So for a complex workflow there would be innumerous possible variants. Therefore expect to often see them paired with factories. + + + +### Example in the Wild + +```python editable=true slideshow={"slide_type": ""} +gtx.backend.DEFAULT_PROG_TRANSFORMS?? +``` + +```python +gtx.program_processors.runners.gtfn.run_gtfn_gpu.executor.otf_workflow?? +``` + +```python +gtx.program_processors.runners.gtfn.GTFNBackendFactory?? +``` + +```python + +``` diff --git a/tox.ini b/tox.ini index 8479e4c52c..2bc761fef8 100644 --- a/tox.ini +++ b/tox.ini @@ -109,12 +109,12 @@ commands = description = Run notebooks commands_pre = jupytext docs/user/next/QuickstartGuide.md --to .ipynb - jupytext docs/user/next/Advanced_ToolchainWalkthrough.md --to .ipynb + jupytext docs/user/next/advanced/*.md --to .ipynb commands = python -m pytest --nbmake docs/user/next/workshop/slides -v -n {env:NUM_PROCESSES:1} python -m pytest --nbmake docs/user/next/workshop/exercises -k 'solutions' -v -n {env:NUM_PROCESSES:1} python -m pytest --nbmake docs/user/next/QuickstartGuide.ipynb -v -n {env:NUM_PROCESSES:1} - python -m pytest --nbmake docs/user/next/Advanced_ToolchainWalkthrough.ipynb -v -n {env:NUM_PROCESSES:1} + python -m pytest --nbmake docs/user/next/advanced -v -n {env:NUM_PROCESSES:1} python -m pytest --nbmake examples -v -n {env:NUM_PROCESSES:1} # -- Other artefacts -- From ef8057125c9f0d025f1bf2d6dbffe981e0f65b4a Mon Sep 17 00:00:00 2001 From: DropD Date: Tue, 30 Apr 2024 13:30:21 +0200 Subject: [PATCH 30/30] update stage fingerprinting tests --- .../unit_tests/ffront_tests/test_stages.py | 76 +++++++++++++------ 1 file changed, 51 insertions(+), 25 deletions(-) diff --git a/tests/next_tests/unit_tests/ffront_tests/test_stages.py b/tests/next_tests/unit_tests/ffront_tests/test_stages.py index 67ac96d653..29dcda9e1d 100644 --- a/tests/next_tests/unit_tests/ffront_tests/test_stages.py +++ b/tests/next_tests/unit_tests/ffront_tests/test_stages.py @@ -12,9 +12,13 @@ # # SPDX-License-Identifier: GPL-3.0-or-later +import dataclasses + import pytest + from gt4py import next as gtx from gt4py.next.ffront import stages +from gt4py.next.otf import workflow @pytest.fixture @@ -87,7 +91,7 @@ def copy_program(a: gtx.Field[[jdim], gtx.int32], out: gtx.Field[[jdim], gtx.int yield copy_program -def test_cache_key_field_op_def(fieldop, samecode_fieldop, different_fieldop): +def test_fingerprint_stage_field_op_def(fieldop, samecode_fieldop, different_fieldop): assert stages.fingerprint_stage(samecode_fieldop.definition_stage) != stages.fingerprint_stage( fieldop.definition_stage ) @@ -96,7 +100,7 @@ def test_cache_key_field_op_def(fieldop, samecode_fieldop, different_fieldop): ) -def test_cache_key_foast_op_def(fieldop, samecode_fieldop, different_fieldop): +def test_fingerprint_stage_foast_op_def(fieldop, samecode_fieldop, different_fieldop): foast = gtx.backend.DEFAULT_FIELDOP_TRANSFORMS.func_to_foast(fieldop.definition_stage) samecode = gtx.backend.DEFAULT_FIELDOP_TRANSFORMS.func_to_foast( samecode_fieldop.definition_stage @@ -109,42 +113,64 @@ def test_cache_key_foast_op_def(fieldop, samecode_fieldop, different_fieldop): assert stages.fingerprint_stage(different) != stages.fingerprint_stage(foast) -def test_cache_key_foast_closure(fieldop, samecode_fieldop, different_fieldop, idim, jdim): - foast_closure = gtx.backend.DEFAULT_FIELDOP_TRANSFORMS.func_to_foast.chain( - gtx.backend.FopArgsInjector( +@dataclasses.dataclass(frozen=True) +class ToFoastClosure(workflow.NamedStepSequenceWithArgs): + func_to_foast: workflow.Workflow = gtx.backend.DEFAULT_FIELDOP_TRANSFORMS.func_to_foast + foast_to_closure: workflow.Workflow = dataclasses.field( + default=gtx.backend.DEFAULT_FIELDOP_TRANSFORMS.foast_to_foast_closure, + metadata={"takes_args": True}, + ) + + +def test_fingerprint_stage_foast_closure(fieldop, samecode_fieldop, different_fieldop, idim, jdim): + toolchain = ToFoastClosure() + foast_closure = toolchain( + workflow.InputWithArgs( + data=fieldop.definition_stage, args=(gtx.zeros({idim: 10}, gtx.int32),), - kwargs={"out": gtx.zeros({idim: 10}, gtx.int32)}, - from_fieldop=fieldop, + kwargs={ + "out": gtx.zeros({idim: 10}, gtx.int32), + "from_fieldop": fieldop, + }, ), - )(fieldop.definition_stage) - samecode = gtx.backend.DEFAULT_FIELDOP_TRANSFORMS.func_to_foast.chain( - gtx.backend.FopArgsInjector( + ) + samecode = toolchain( + workflow.InputWithArgs( + data=samecode_fieldop.definition_stage, args=(gtx.zeros({idim: 10}, gtx.int32),), - kwargs={"out": gtx.zeros({idim: 10}, gtx.int32)}, - from_fieldop=samecode_fieldop, + kwargs={ + "out": gtx.zeros({idim: 10}, gtx.int32), + "from_fieldop": samecode_fieldop, + }, ) - )(samecode_fieldop.definition_stage) - different = gtx.backend.DEFAULT_FIELDOP_TRANSFORMS.func_to_foast.chain( - gtx.backend.FopArgsInjector( + ) + different = toolchain( + workflow.InputWithArgs( + data=different_fieldop.definition_stage, args=(gtx.zeros({jdim: 10}, gtx.int32),), - kwargs={"out": gtx.zeros({jdim: 10}, gtx.int32)}, - from_fieldop=different_fieldop, + kwargs={ + "out": gtx.zeros({jdim: 10}, gtx.int32), + "from_fieldop": different_fieldop, + }, ) - )(different_fieldop.definition_stage) - different_args = gtx.backend.DEFAULT_FIELDOP_TRANSFORMS.func_to_foast.chain( - gtx.backend.FopArgsInjector( + ) + different_args = toolchain( + workflow.InputWithArgs( + data=fieldop.definition_stage, args=(gtx.zeros({idim: 11}, gtx.int32),), - kwargs={"out": gtx.zeros({idim: 11}, gtx.int32)}, - from_fieldop=fieldop, + kwargs={ + "out": gtx.zeros({idim: 11}, gtx.int32), + "from_fieldop": fieldop, + }, ) - )(fieldop.definition_stage) + ) assert stages.fingerprint_stage(samecode) != stages.fingerprint_stage(foast_closure) assert stages.fingerprint_stage(different) != stages.fingerprint_stage(foast_closure) assert stages.fingerprint_stage(different_args) != stages.fingerprint_stage(foast_closure) -def test_cache_key_program_def(program, samecode_program, different_program): +def test_fingerprint_stage_program_def(program, samecode_program, different_program): assert stages.fingerprint_stage(samecode_program.definition_stage) != stages.fingerprint_stage( program.definition_stage ) @@ -153,7 +179,7 @@ def test_cache_key_program_def(program, samecode_program, different_program): ) -def test_cache_key_past_def(program, samecode_program, different_program): +def test_fingerprint_stage_past_def(program, samecode_program, different_program): past = gtx.backend.DEFAULT_PROG_TRANSFORMS.func_to_past(program.definition_stage) samecode = gtx.backend.DEFAULT_PROG_TRANSFORMS.func_to_past(samecode_program.definition_stage) different = gtx.backend.DEFAULT_PROG_TRANSFORMS.func_to_past(different_program.definition_stage)