From 765ef5370409ab48d07c7f3e005f62ff6dd61112 Mon Sep 17 00:00:00 2001 From: Roman Cattaneo <> Date: Fri, 6 Sep 2024 16:57:26 +0200 Subject: [PATCH] refactor: use modern-style type hints In the past, the only way to create type hints for yet-to-be-defined types was to quote them. Since python 3.7 we can get rid of the quotes when using `annotations` from `__future__` instead. This is supposed to become the default in the future. --- src/gt4py/cartesian/backend/base.py | 32 ++++++------ src/gt4py/cartesian/backend/cuda_backend.py | 4 +- src/gt4py/cartesian/backend/dace_backend.py | 8 +-- .../cartesian/backend/dace_lazy_stencil.py | 4 +- .../cartesian/backend/dace_stencil_object.py | 10 ++-- src/gt4py/cartesian/backend/gtc_common.py | 6 ++- src/gt4py/cartesian/backend/gtcpp_backend.py | 4 +- .../cartesian/backend/module_generator.py | 12 +++-- src/gt4py/cartesian/backend/numpy_backend.py | 6 ++- src/gt4py/cartesian/caching.py | 10 ++-- src/gt4py/cartesian/frontend/base.py | 6 ++- src/gt4py/cartesian/gtc/common.py | 12 ++--- src/gt4py/cartesian/gtc/cuir/cuir.py | 16 +++--- src/gt4py/cartesian/gtc/cuir/oir_to_cuir.py | 8 +-- .../gtc/dace/expansion/daceir_builder.py | 50 +++++++++---------- .../cartesian/gtc/dace/expansion/expansion.py | 8 +-- .../gtc/dace/expansion/sdfg_builder.py | 30 +++++------ .../cartesian/gtc/dace/expansion/utils.py | 8 +-- .../gtc/dace/expansion_specification.py | 18 ++++--- src/gt4py/cartesian/gtc/dace/nodes.py | 6 ++- src/gt4py/cartesian/gtc/dace/oir_to_dace.py | 4 +- src/gt4py/cartesian/gtc/dace/symbol_utils.py | 4 +- src/gt4py/cartesian/gtc/dace/utils.py | 34 ++++++------- src/gt4py/cartesian/gtc/daceir.py | 14 +++--- src/gt4py/cartesian/gtc/definitions.py | 2 +- src/gt4py/cartesian/gtc/gtcpp/gtcpp.py | 6 ++- src/gt4py/cartesian/gtc/gtcpp/oir_to_gtcpp.py | 12 +++-- src/gt4py/cartesian/gtc/gtir.py | 6 ++- src/gt4py/cartesian/gtc/numpy/npir_codegen.py | 4 +- src/gt4py/cartesian/gtc/oir.py | 22 ++++---- .../horizontal_execution_merging.py | 2 +- .../gtc/passes/oir_optimizations/utils.py | 8 +-- src/gt4py/cartesian/lazy_stencil.py | 10 ++-- src/gt4py/cartesian/loader.py | 10 ++-- src/gt4py/cartesian/stencil_builder.py | 34 +++++++------ src/gt4py/cartesian/stencil_object.py | 6 +-- src/gt4py/cartesian/utils/base.py | 4 -- .../transformations/gpu_utils.py | 4 +- .../runners/dace_iterator/itir_to_tasklet.py | 22 ++++---- 39 files changed, 262 insertions(+), 204 deletions(-) diff --git a/src/gt4py/cartesian/backend/base.py b/src/gt4py/cartesian/backend/base.py index ada259680d..5bab0453a9 100644 --- a/src/gt4py/cartesian/backend/base.py +++ b/src/gt4py/cartesian/backend/base.py @@ -6,6 +6,8 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause +from __future__ import annotations + import abc import copy import hashlib @@ -43,7 +45,7 @@ REGISTRY = gt_utils.Registry() -def from_name(name: str) -> Optional[Type["Backend"]]: +def from_name(name: str) -> Optional[Type[Backend]]: backend = REGISTRY.get(name, None) if not backend: raise NotImplementedError( @@ -52,7 +54,7 @@ def from_name(name: str) -> Optional[Type["Backend"]]: return backend -def register(backend_cls: Type["Backend"]) -> Type["Backend"]: +def register(backend_cls: Type[Backend]) -> Type[Backend]: assert issubclass(backend_cls, Backend) and backend_cls.name is not None if isinstance(backend_cls.name, str): @@ -101,9 +103,9 @@ class Backend(abc.ABC): # "disable-code-generation": bool # "disable-cache-validation": bool - builder: "StencilBuilder" + builder: StencilBuilder - def __init__(self, builder: "StencilBuilder"): + def __init__(self, builder: StencilBuilder): self.builder = builder @classmethod @@ -120,7 +122,7 @@ def filter_options_for_id( return filtered_options @abc.abstractmethod - def load(self) -> Optional[Type["StencilObject"]]: + def load(self) -> Optional[Type[StencilObject]]: """ Load the stencil class from the generated python module. @@ -135,7 +137,7 @@ def load(self) -> Optional[Type["StencilObject"]]: pass @abc.abstractmethod - def generate(self) -> Type["StencilObject"]: + def generate(self) -> Type[StencilObject]: """ Generate the stencil class from GTScript's internal representation. @@ -152,7 +154,7 @@ def generate(self) -> Type["StencilObject"]: @property def extra_cache_info(self) -> Dict[str, Any]: - """Provide additional data to be stored in cache info file (sublass hook).""" + """Provide additional data to be stored in cache info file (subclass hook).""" return {} @property @@ -240,9 +242,9 @@ def generate_bindings(self, language_name: str) -> Dict[str, Union[str, Dict]]: class BaseBackend(Backend): - MODULE_GENERATOR_CLASS: ClassVar[Type["BaseModuleGenerator"]] + MODULE_GENERATOR_CLASS: ClassVar[Type[BaseModuleGenerator]] - def load(self) -> Optional[Type["StencilObject"]]: + def load(self) -> Optional[Type[StencilObject]]: build_info = self.builder.options.build_info if build_info is not None: start_time = time.perf_counter() @@ -263,11 +265,11 @@ def load(self) -> Optional[Type["StencilObject"]]: return stencil_class - def generate(self) -> Type["StencilObject"]: + def generate(self) -> Type[StencilObject]: self.check_options(self.builder.options) return self.make_module() - def _load(self) -> Type["StencilObject"]: + def _load(self) -> Type[StencilObject]: stencil_class_name = self.builder.class_name file_name = str(self.builder.module_path) stencil_module = gt_utils.make_module_from_file(stencil_class_name, file_name) @@ -289,7 +291,7 @@ def check_options(self, options: gt_definitions.BuildOptions) -> None: stacklevel=2, ) - def make_module(self, **kwargs: Any) -> Type["StencilObject"]: + def make_module(self, **kwargs: Any) -> Type[StencilObject]: build_info = self.builder.options.build_info if build_info is not None: start_time = time.perf_counter() @@ -323,7 +325,7 @@ def __call__(self, *, args_data: Optional[ModuleData] = None, **kwargs: Any) -> class PurePythonBackendCLIMixin(CLIBackendMixin): """Mixin for CLI support for backends deriving from BaseBackend.""" - builder: "StencilBuilder" + builder: StencilBuilder #: stencil python source generator method: #: In order to use this mixin, the backend class must implement @@ -378,7 +380,7 @@ def extra_cache_validation_keys(self) -> List[str]: return keys @abc.abstractmethod - def generate(self) -> Type["StencilObject"]: + def generate(self) -> Type[StencilObject]: pass def build_extension_module( @@ -436,7 +438,7 @@ def disabled(message: str, *, enabled_env_var: str) -> Callable[[Type[Backend]], else: def _decorator(cls: Type[Backend]) -> Type[Backend]: - def _no_generate(obj) -> Type["StencilObject"]: + def _no_generate(obj) -> Type[StencilObject]: raise NotImplementedError( f"Disabled '{cls.name}' backend: 'f{message}'\n", f"You can still enable the backend by hand using the environment variable '{enabled_env_var}=1'", diff --git a/src/gt4py/cartesian/backend/cuda_backend.py b/src/gt4py/cartesian/backend/cuda_backend.py index 91d4d07080..f0238e309b 100644 --- a/src/gt4py/cartesian/backend/cuda_backend.py +++ b/src/gt4py/cartesian/backend/cuda_backend.py @@ -6,6 +6,8 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause +from __future__ import annotations + from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Type from gt4py import storage as gt_storage @@ -141,7 +143,7 @@ class CudaBackend(BaseGTBackend, CLIBackendMixin): def generate_extension(self, **kwargs: Any) -> Tuple[str, str]: return self.make_extension(stencil_ir=self.builder.gtir, uses_cuda=True) - def generate(self) -> Type["StencilObject"]: + def generate(self) -> Type[StencilObject]: self.check_options(self.builder.options) pyext_module_name: Optional[str] diff --git a/src/gt4py/cartesian/backend/dace_backend.py b/src/gt4py/cartesian/backend/dace_backend.py index b9cfa8c727..163a0dee3f 100644 --- a/src/gt4py/cartesian/backend/dace_backend.py +++ b/src/gt4py/cartesian/backend/dace_backend.py @@ -6,6 +6,8 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause +from __future__ import annotations + import copy import os import pathlib @@ -302,7 +304,7 @@ def freeze_origin_domain_sdfg(inner_sdfg, arg_names, field_info, *, origin, doma for node in states[0].nodes(): state.remove_node(node) - # make sure that symbols are passed throught o inner sdfg + # make sure that symbols are passed through to inner sdfg for symbol in nsdfg.sdfg.free_symbols: if symbol not in wrapper_sdfg.symbols: wrapper_sdfg.add_symbol(symbol, nsdfg.sdfg.symbols[symbol]) @@ -531,7 +533,7 @@ def keep_line(line): return generated_code @classmethod - def apply(cls, stencil_ir: gtir.Stencil, builder: "StencilBuilder", sdfg: dace.SDFG): + def apply(cls, stencil_ir: gtir.Stencil, builder: StencilBuilder, sdfg: dace.SDFG): self = cls() with dace.config.temporary_config(): # To prevent conflict with 3rd party usage of DaCe config always make sure that any @@ -765,7 +767,7 @@ class BaseDaceBackend(BaseGTBackend, CLIBackendMixin): GT_BACKEND_T = "dace" PYEXT_GENERATOR_CLASS = DaCeExtGenerator # type: ignore - def generate(self) -> Type["StencilObject"]: + def generate(self) -> Type[StencilObject]: self.check_options(self.builder.options) pyext_module_name: Optional[str] diff --git a/src/gt4py/cartesian/backend/dace_lazy_stencil.py b/src/gt4py/cartesian/backend/dace_lazy_stencil.py index 2b3cf6fe45..6a09258889 100644 --- a/src/gt4py/cartesian/backend/dace_lazy_stencil.py +++ b/src/gt4py/cartesian/backend/dace_lazy_stencil.py @@ -6,6 +6,8 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause +from __future__ import annotations + from typing import TYPE_CHECKING, Any, Dict, Optional, Sequence, Set, Tuple import dace @@ -22,7 +24,7 @@ class DaCeLazyStencil(LazyStencil, SDFGConvertible): - def __init__(self, builder: "StencilBuilder"): + def __init__(self, builder: StencilBuilder): if "dace" not in builder.backend.name: raise ValueError("Trying to build a DaCeLazyStencil for non-dace backend.") super().__init__(builder=builder) diff --git a/src/gt4py/cartesian/backend/dace_stencil_object.py b/src/gt4py/cartesian/backend/dace_stencil_object.py index 0c9337723f..21006475a0 100644 --- a/src/gt4py/cartesian/backend/dace_stencil_object.py +++ b/src/gt4py/cartesian/backend/dace_stencil_object.py @@ -6,6 +6,8 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause +from __future__ import annotations + import copy import inspect import os @@ -58,7 +60,7 @@ def add_optional_fields( @dataclass(frozen=True) class DaCeFrozenStencil(FrozenStencil, SDFGConvertible): - stencil_object: "DaCeStencilObject" + stencil_object: DaCeStencilObject origin: Dict[str, Tuple[int, ...]] domain: Tuple[int, ...] sdfg: dace.SDFG @@ -95,7 +97,7 @@ def _get_domain_origin_key(domain, origin): return domain, origins_tuple def freeze( - self: "DaCeStencilObject", *, origin: Dict[str, Tuple[int, ...]], domain: Tuple[int, ...] + self: DaCeStencilObject, *, origin: Dict[str, Tuple[int, ...]], domain: Tuple[int, ...] ) -> DaCeFrozenStencil: key = DaCeStencilObject._get_domain_origin_key(domain, origin) if key in self._frozen_cache: @@ -135,8 +137,8 @@ def closure_resolver( self, constant_args: Dict[str, Any], given_args: Set[str], - parent_closure: Optional["dace.frontend.python.common.SDFGClosure"] = None, - ) -> "dace.frontend.python.common.SDFGClosure": + parent_closure: Optional[dace.frontend.python.common.SDFGClosure] = None, + ) -> dace.frontend.python.common.SDFGClosure: return dace.frontend.python.common.SDFGClosure() def __sdfg__(self, *args, **kwargs) -> dace.SDFG: diff --git a/src/gt4py/cartesian/backend/gtc_common.py b/src/gt4py/cartesian/backend/gtc_common.py index c55c8f3513..abc4baede1 100644 --- a/src/gt4py/cartesian/backend/gtc_common.py +++ b/src/gt4py/cartesian/backend/gtc_common.py @@ -6,6 +6,8 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause +from __future__ import annotations + import abc import os import textwrap @@ -134,7 +136,7 @@ def __init__(self): self.pyext_file_path = None def __call__( - self, args_data: ModuleData, builder: Optional["StencilBuilder"] = None, **kwargs: Any + self, args_data: ModuleData, builder: Optional[StencilBuilder] = None, **kwargs: Any ) -> str: self.pyext_module_name = kwargs["pyext_module_name"] self.pyext_file_path = kwargs["pyext_file_path"] @@ -229,7 +231,7 @@ class BaseGTBackend(gt_backend.BasePyExtBackend, gt_backend.CLIBackendMixin): PYEXT_GENERATOR_CLASS: Type[BackendCodegen] @abc.abstractmethod - def generate(self) -> Type["StencilObject"]: + def generate(self) -> Type[StencilObject]: pass def generate_computation(self) -> Dict[str, Union[str, Dict]]: diff --git a/src/gt4py/cartesian/backend/gtcpp_backend.py b/src/gt4py/cartesian/backend/gtcpp_backend.py index b071028744..5d3fd623d9 100644 --- a/src/gt4py/cartesian/backend/gtcpp_backend.py +++ b/src/gt4py/cartesian/backend/gtcpp_backend.py @@ -6,6 +6,8 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause +from __future__ import annotations + from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Type from gt4py import storage as gt_storage @@ -129,7 +131,7 @@ class GTBaseBackend(BaseGTBackend, CLIBackendMixin): def _generate_extension(self, uses_cuda: bool) -> Tuple[str, str]: return self.make_extension(stencil_ir=self.builder.gtir, uses_cuda=uses_cuda) - def generate(self) -> Type["StencilObject"]: + def generate(self) -> Type[StencilObject]: self.check_options(self.builder.options) pyext_module_name: Optional[str] diff --git a/src/gt4py/cartesian/backend/module_generator.py b/src/gt4py/cartesian/backend/module_generator.py index 6bb5621889..e2266b709c 100644 --- a/src/gt4py/cartesian/backend/module_generator.py +++ b/src/gt4py/cartesian/backend/module_generator.py @@ -6,6 +6,8 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause +from __future__ import annotations + import abc import numbers import sys @@ -121,11 +123,11 @@ class BaseModuleGenerator(abc.ABC): TEMPLATE_RESOURCE = "stencil_module.py.in" - _builder: Optional["StencilBuilder"] + _builder: Optional[StencilBuilder] args_data: ModuleData template: jinja2.Template - def __init__(self, builder: Optional["StencilBuilder"] = None): + def __init__(self, builder: Optional[StencilBuilder] = None): self._builder = builder self.args_data = ModuleData() self.template = jinja2.Template( @@ -135,7 +137,7 @@ def __init__(self, builder: Optional["StencilBuilder"] = None): ) def __call__( - self, args_data: ModuleData, builder: Optional["StencilBuilder"] = None, **kwargs: Any + self, args_data: ModuleData, builder: Optional[StencilBuilder] = None, **kwargs: Any ) -> str: """ Generate source code for a Python module containing a StencilObject. @@ -176,7 +178,7 @@ def __call__( return module_source @property - def builder(self) -> "StencilBuilder": + def builder(self) -> StencilBuilder: """ Expose the builder reference. @@ -205,7 +207,7 @@ def generate_class_name(self) -> str: """ Generate the name of the stencil class. - This should ususally be deferred to the chosen caching strategy via + This should usually be deferred to the chosen caching strategy via the builder object (see default implementation). """ return self.builder.class_name diff --git a/src/gt4py/cartesian/backend/numpy_backend.py b/src/gt4py/cartesian/backend/numpy_backend.py index 8e0813d0c1..160bd5eaa8 100644 --- a/src/gt4py/cartesian/backend/numpy_backend.py +++ b/src/gt4py/cartesian/backend/numpy_backend.py @@ -6,6 +6,8 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause +from __future__ import annotations + import pathlib from typing import TYPE_CHECKING, Any, ClassVar, Dict, Type, Union, cast @@ -51,7 +53,7 @@ def generate_implementation(self) -> str: return f"computation.run({', '.join(params)})" @property - def backend(self) -> "NumpyBackend": + def backend(self) -> NumpyBackend: return cast(NumpyBackend, self.builder.backend) @@ -98,7 +100,7 @@ def generate_bindings(self, language_name: str) -> Dict[str, Union[str, Dict]]: super().generate_bindings(language_name) return {self.builder.module_path.name: self.make_module_source()} - def generate(self) -> Type["StencilObject"]: + def generate(self) -> Type[StencilObject]: self.check_options(self.builder.options) src_dir = self.builder.module_path.parent if not self.builder.options._impl_opts.get("disable-code-generation", False): diff --git a/src/gt4py/cartesian/caching.py b/src/gt4py/cartesian/caching.py index b1bf5d0ee2..20c0b49fae 100644 --- a/src/gt4py/cartesian/caching.py +++ b/src/gt4py/cartesian/caching.py @@ -8,6 +8,8 @@ """Caching strategies for stencil generation.""" +from __future__ import annotations + import abc import inspect import pathlib @@ -29,7 +31,7 @@ class CachingStrategy(abc.ABC): name: str - def __init__(self, builder: "StencilBuilder"): + def __init__(self, builder: StencilBuilder): self.builder = builder @property @@ -176,7 +178,7 @@ class JITCachingStrategy(CachingStrategy): def __init__( self, - builder: "StencilBuilder", + builder: StencilBuilder, *, root_path: Optional[str] = None, dir_name: Optional[str] = None, @@ -368,7 +370,7 @@ class NoCachingStrategy(CachingStrategy): name = "nocaching" - def __init__(self, builder: "StencilBuilder", *, output_path: pathlib.Path = pathlib.Path(".")): + def __init__(self, builder: StencilBuilder, *, output_path: pathlib.Path = pathlib.Path(".")): super().__init__(builder) self._output_path = output_path @@ -411,7 +413,7 @@ def stencil_id(self) -> StencilID: def strategy_factory( - name: str, builder: "StencilBuilder", *args: Any, **kwargs: Any + name: str, builder: StencilBuilder, *args: Any, **kwargs: Any ) -> CachingStrategy: strategies = {"jit": JITCachingStrategy, "nocaching": NoCachingStrategy} return strategies[name](builder, *args, **kwargs) diff --git a/src/gt4py/cartesian/frontend/base.py b/src/gt4py/cartesian/frontend/base.py index da7ff22548..3ba54f3356 100644 --- a/src/gt4py/cartesian/frontend/base.py +++ b/src/gt4py/cartesian/frontend/base.py @@ -6,6 +6,8 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause +from __future__ import annotations + import abc from typing import Any, Dict, Optional, Type, Union @@ -19,12 +21,12 @@ AnyStencilFunc = Union[StencilFunc, AnnotatedStencilFunc] -def from_name(name: str) -> Optional[Type["Frontend"]]: +def from_name(name: str) -> Optional[Type[Frontend]]: """Return frontend by name.""" return REGISTRY.get(name, None) -def register(frontend_cls: Type["Frontend"]) -> None: +def register(frontend_cls: Type[Frontend]) -> None: """Register a new frontend.""" return REGISTRY.register(frontend_cls.name, frontend_cls) diff --git a/src/gt4py/cartesian/gtc/common.py b/src/gt4py/cartesian/gtc/common.py index 39246cbea7..bfe434e7f3 100644 --- a/src/gt4py/cartesian/gtc/common.py +++ b/src/gt4py/cartesian/gtc/common.py @@ -178,7 +178,7 @@ class NativeFunction(eve.StrEnum): CEIL = "ceil" TRUNC = "trunc" - IR_OP_TO_NUM_ARGS: ClassVar[Dict["NativeFunction", int]] + IR_OP_TO_NUM_ARGS: ClassVar[Dict[NativeFunction, int]] @property def arity(self) -> int: @@ -229,8 +229,8 @@ class LevelMarker(eve.StrEnum): @enum.unique class ExprKind(eve.IntEnum): - SCALAR: "ExprKind" = typing.cast("ExprKind", enum.auto()) - FIELD: "ExprKind" = typing.cast("ExprKind", enum.auto()) + SCALAR: ExprKind = typing.cast("ExprKind", enum.auto()) + FIELD: ExprKind = typing.cast("ExprKind", enum.auto()) class LocNode(eve.Node): @@ -343,7 +343,7 @@ class FieldAccess(eve.GenericNode, Generic[ExprT, VariableKOffsetT]): kind: ExprKind = ExprKind.FIELD @classmethod - def centered(cls, *, name: str, loc: Optional[eve.SourceLocation] = None) -> "FieldAccess": + def centered(cls, *, name: str, loc: Optional[eve.SourceLocation] = None) -> FieldAccess: return cls(name=name, loc=loc, offset=CartesianOffset.zero()) @datamodels.validator("data_index") @@ -721,7 +721,7 @@ class HorizontalInterval(eve.Node): end: Optional[AxisBound] @classmethod - def compute_domain(cls, start_offset: int = 0, end_offset: int = 0) -> "HorizontalInterval": + def compute_domain(cls, start_offset: int = 0, end_offset: int = 0) -> HorizontalInterval: return cls(start=AxisBound.start(start_offset), end=AxisBound.end(end_offset)) @classmethod @@ -731,7 +731,7 @@ def full(cls) -> HorizontalInterval: @classmethod def at_endpt( cls, level: LevelMarker, start_offset: int, end_offset: Optional[int] = None - ) -> "HorizontalInterval": + ) -> HorizontalInterval: if end_offset is None: end_offset = start_offset + 1 return cls( diff --git a/src/gt4py/cartesian/gtc/cuir/cuir.py b/src/gt4py/cartesian/gtc/cuir/cuir.py index 92cb2e6ac4..62c3c520ac 100644 --- a/src/gt4py/cartesian/gtc/cuir/cuir.py +++ b/src/gt4py/cartesian/gtc/cuir/cuir.py @@ -6,6 +6,8 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause +from __future__ import annotations + from typing import Any, List, Optional, Tuple, Union from gt4py import eve @@ -156,22 +158,22 @@ class IJExtent(LocNode): j: Tuple[int, int] @classmethod - def zero(cls) -> "IJExtent": + def zero(cls) -> IJExtent: return cls(i=(0, 0), j=(0, 0)) @classmethod - def from_offset(cls, offset: Union[CartesianOffset, VariableKOffset]) -> "IJExtent": + def from_offset(cls, offset: Union[CartesianOffset, VariableKOffset]) -> IJExtent: if isinstance(offset, VariableKOffset): return cls(i=(0, 0), j=(0, 0)) return cls(i=(offset.i, offset.i), j=(offset.j, offset.j)) - def union(*extents: "IJExtent") -> "IJExtent": + def union(*extents: IJExtent) -> IJExtent: return IJExtent( i=(min(e.i[0] for e in extents), max(e.i[1] for e in extents)), j=(min(e.j[0] for e in extents), max(e.j[1] for e in extents)), ) - def __add__(self, other: "IJExtent") -> "IJExtent": + def __add__(self, other: IJExtent) -> IJExtent: return IJExtent( i=(self.i[0] + other.i[0], self.i[1] + other.i[1]), j=(self.j[0] + other.j[0], self.j[1] + other.j[1]), @@ -182,17 +184,17 @@ class KExtent(LocNode): k: Tuple[int, int] @classmethod - def zero(cls) -> "KExtent": + def zero(cls) -> KExtent: return cls(k=(0, 0)) @classmethod - def from_offset(cls, offset: Union[CartesianOffset, VariableKOffset]) -> "KExtent": + def from_offset(cls, offset: Union[CartesianOffset, VariableKOffset]) -> KExtent: MAX_OFFSET = 1000 if isinstance(offset, VariableKOffset): return cls(k=(-MAX_OFFSET, MAX_OFFSET)) return cls(k=(offset.k, offset.k)) - def union(*extents: "KExtent") -> "KExtent": + def union(*extents: KExtent) -> KExtent: return KExtent(k=(min(e.k[0] for e in extents), max(e.k[1] for e in extents))) diff --git a/src/gt4py/cartesian/gtc/cuir/oir_to_cuir.py b/src/gt4py/cartesian/gtc/cuir/oir_to_cuir.py index 7b9e66d77f..fa95ec8cba 100644 --- a/src/gt4py/cartesian/gtc/cuir/oir_to_cuir.py +++ b/src/gt4py/cartesian/gtc/cuir/oir_to_cuir.py @@ -6,6 +6,8 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause +from __future__ import annotations + import functools from dataclasses import dataclass, field from typing import Any, Dict, List, Set, Union, cast @@ -93,7 +95,7 @@ def visit_VariableKOffset( ) -> cuir.VariableKOffset: return cuir.VariableKOffset(k=self.visit(node.k, **kwargs)) - def _mask_to_expr(self, mask: common.HorizontalMask, ctx: "Context") -> cuir.Expr: + def _mask_to_expr(self, mask: common.HorizontalMask, ctx: Context) -> cuir.Expr: mask_expr: List[cuir.Expr] = [] for axis_index, interval in enumerate(mask.intervals): if interval.is_single_index(): @@ -140,7 +142,7 @@ def visit_FieldAccess( *, ij_caches: Dict[str, cuir.IJCacheDecl], k_caches: Dict[str, cuir.KCacheDecl], - ctx: "Context", + ctx: Context, **kwargs: Any, ) -> Union[cuir.FieldAccess, cuir.IJCacheAccess, cuir.KCacheAccess]: data_index = self.visit( @@ -226,7 +228,7 @@ def visit_VerticalLoopSection( ) def visit_VerticalLoop( - self, node: oir.VerticalLoop, *, symtable: Dict[str, Any], ctx: "Context", **kwargs: Any + self, node: oir.VerticalLoop, *, symtable: Dict[str, Any], ctx: Context, **kwargs: Any ) -> cuir.Kernel: assert not any(c.fill or c.flush for c in node.caches if isinstance(c, oir.KCache)) ij_caches = { diff --git a/src/gt4py/cartesian/gtc/dace/expansion/daceir_builder.py b/src/gt4py/cartesian/gtc/dace/expansion/daceir_builder.py index 11b12da790..399d4d7af5 100644 --- a/src/gt4py/cartesian/gtc/dace/expansion/daceir_builder.py +++ b/src/gt4py/cartesian/gtc/dace/expansion/daceir_builder.py @@ -150,13 +150,13 @@ def all_regions_same(scope_nodes): class DaCeIRBuilder(eve.NodeTranslator): @dataclass class GlobalContext: - library_node: "StencilComputation" + library_node: StencilComputation arrays: Dict[str, dace.data.Data] def get_dcir_decls( self, access_infos: Dict[eve.SymbolRef, dcir.FieldAccessInfo], - symbol_collector: "DaCeIRBuilder.SymbolCollector", + symbol_collector: DaCeIRBuilder.SymbolCollector, ) -> List[dcir.FieldDecl]: return [ self._get_dcir_decl(field, access_info, symbol_collector=symbol_collector) @@ -167,7 +167,7 @@ def _get_dcir_decl( self, field: eve.SymbolRef, access_info: dcir.FieldAccessInfo, - symbol_collector: "DaCeIRBuilder.SymbolCollector", + symbol_collector: DaCeIRBuilder.SymbolCollector, ) -> dcir.FieldDecl: oir_decl: oir.Decl = self.library_node.declarations[field] assert isinstance(oir_decl, oir.FieldDecl) @@ -190,14 +190,14 @@ def _get_dcir_decl( @dataclass class IterationContext: grid_subset: dcir.GridSubset - parent: Optional["DaCeIRBuilder.IterationContext"] + parent: Optional[DaCeIRBuilder.IterationContext] @classmethod def init(cls, *args, **kwargs): res = cls(*args, parent=None, **kwargs) return res - def push_axes_extents(self, axes_extents) -> "DaCeIRBuilder.IterationContext": + def push_axes_extents(self, axes_extents) -> DaCeIRBuilder.IterationContext: res = self.grid_subset for axis, extent in axes_extents.items(): axis_interval = res.intervals[axis] @@ -225,12 +225,12 @@ def push_axes_extents(self, axes_extents) -> "DaCeIRBuilder.IterationContext": def push_interval( self, axis: dcir.Axis, interval: Union[dcir.DomainInterval, oir.Interval] - ) -> "DaCeIRBuilder.IterationContext": + ) -> DaCeIRBuilder.IterationContext: return DaCeIRBuilder.IterationContext( grid_subset=self.grid_subset.set_interval(axis, interval), parent=self ) - def push_expansion_item(self, item: Union[Map, Loop]) -> "DaCeIRBuilder.IterationContext": + def push_expansion_item(self, item: Union[Map, Loop]) -> DaCeIRBuilder.IterationContext: if not isinstance(item, (Map, Loop)): raise ValueError @@ -251,13 +251,13 @@ def push_expansion_item(self, item: Union[Map, Loop]) -> "DaCeIRBuilder.Iteratio def push_expansion_items( self, items: Iterable[Union[Map, Loop]] - ) -> "DaCeIRBuilder.IterationContext": + ) -> DaCeIRBuilder.IterationContext: res = self for item in items: res = res.push_expansion_item(item) return res - def pop(self) -> "DaCeIRBuilder.IterationContext": + def pop(self) -> DaCeIRBuilder.IterationContext: assert self.parent is not None return self.parent @@ -293,7 +293,7 @@ def visit_HorizontalRestriction( self, node: oir.HorizontalRestriction, *, - symbol_collector: "DaCeIRBuilder.SymbolCollector", + symbol_collector: DaCeIRBuilder.SymbolCollector, **kwargs: Any, ) -> dcir.HorizontalRestriction: for axis, interval in zip(dcir.Axis.dims_horizontal(), node.mask.intervals): @@ -354,8 +354,8 @@ def visit_ScalarAccess( self, node: oir.ScalarAccess, *, - global_ctx: "DaCeIRBuilder.GlobalContext", - symbol_collector: "DaCeIRBuilder.SymbolCollector", + global_ctx: DaCeIRBuilder.GlobalContext, + symbol_collector: DaCeIRBuilder.SymbolCollector, **kwargs: Any, ) -> dcir.ScalarAccess: if node.name in global_ctx.library_node.declarations: @@ -400,9 +400,9 @@ def visit_HorizontalExecution( self, node: oir.HorizontalExecution, *, - global_ctx: "DaCeIRBuilder.GlobalContext", - iteration_ctx: "DaCeIRBuilder.IterationContext", - symbol_collector: "DaCeIRBuilder.SymbolCollector", + global_ctx: DaCeIRBuilder.GlobalContext, + iteration_ctx: DaCeIRBuilder.IterationContext, + symbol_collector: DaCeIRBuilder.SymbolCollector, loop_order, k_interval, **kwargs, @@ -509,9 +509,9 @@ def visit_VerticalLoopSection( node: oir.VerticalLoopSection, *, loop_order, - iteration_ctx: "DaCeIRBuilder.IterationContext", - global_ctx: "DaCeIRBuilder.GlobalContext", - symbol_collector: "DaCeIRBuilder.SymbolCollector", + iteration_ctx: DaCeIRBuilder.IterationContext, + global_ctx: DaCeIRBuilder.GlobalContext, + symbol_collector: DaCeIRBuilder.SymbolCollector, **kwargs, ): sections_idx, stages_idx = [ @@ -561,8 +561,8 @@ def to_dataflow( self, nodes, *, - global_ctx: "DaCeIRBuilder.GlobalContext", - symbol_collector: "DaCeIRBuilder.SymbolCollector", + global_ctx: DaCeIRBuilder.GlobalContext, + symbol_collector: DaCeIRBuilder.SymbolCollector, ): nodes = flatten_list(nodes) if all(isinstance(n, (dcir.NestedSDFG, dcir.DomainMap, dcir.Tasklet)) for n in nodes): @@ -612,8 +612,8 @@ def _process_map_item( item: Map, *, global_ctx, - iteration_ctx: "DaCeIRBuilder.IterationContext", - symbol_collector: "DaCeIRBuilder.SymbolCollector", + iteration_ctx: DaCeIRBuilder.IterationContext, + symbol_collector: DaCeIRBuilder.SymbolCollector, **kwargs, ): grid_subset = iteration_ctx.grid_subset @@ -710,8 +710,8 @@ def _process_loop_item( item: Loop, *, global_ctx, - iteration_ctx: "DaCeIRBuilder.IterationContext", - symbol_collector: "DaCeIRBuilder.SymbolCollector", + iteration_ctx: DaCeIRBuilder.IterationContext, + symbol_collector: DaCeIRBuilder.SymbolCollector, **kwargs, ): grid_subset = union_node_grid_subsets(list(scope_nodes)) @@ -785,7 +785,7 @@ def _process_iteration_item(self, scope, item, **kwargs): raise ValueError("Invalid expansion specification set.") def visit_VerticalLoop( - self, node: oir.VerticalLoop, *, global_ctx: "DaCeIRBuilder.GlobalContext", **kwargs + self, node: oir.VerticalLoop, *, global_ctx: DaCeIRBuilder.GlobalContext, **kwargs ): start, end = (node.sections[0].interval.start, node.sections[0].interval.end) diff --git a/src/gt4py/cartesian/gtc/dace/expansion/expansion.py b/src/gt4py/cartesian/gtc/dace/expansion/expansion.py index c6cc105ea6..56bb6c1b3f 100644 --- a/src/gt4py/cartesian/gtc/dace/expansion/expansion.py +++ b/src/gt4py/cartesian/gtc/dace/expansion/expansion.py @@ -6,6 +6,8 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause +from __future__ import annotations + import copy from typing import TYPE_CHECKING, Dict, List @@ -65,7 +67,7 @@ def _solve_for_domain(field_decls: Dict[str, dcir.FieldDecl], outer_subsets): @staticmethod def _fix_context( - nsdfg, node: "StencilComputation", parent_state: dace.SDFGState, daceir: dcir.NestedSDFG + nsdfg, node: StencilComputation, parent_state: dace.SDFGState, daceir: dcir.NestedSDFG ): """Apply changes to StencilComputation and the SDFG it is embedded in to satisfy post-expansion constraints. @@ -120,7 +122,7 @@ def _fix_context( @staticmethod def _get_parent_arrays( - node: "StencilComputation", parent_state: dace.SDFGState, parent_sdfg: dace.SDFG + node: StencilComputation, parent_state: dace.SDFGState, parent_sdfg: dace.SDFG ) -> Dict[str, dace.data.Data]: parent_arrays: Dict[str, dace.data.Data] = {} for edge in (e for e in parent_state.in_edges(node) if e.dst_conn is not None): @@ -131,7 +133,7 @@ def _get_parent_arrays( @staticmethod def expansion( - node: "StencilComputation", parent_state: dace.SDFGState, parent_sdfg: dace.SDFG + node: StencilComputation, parent_state: dace.SDFGState, parent_sdfg: dace.SDFG ) -> dace.nodes.NestedSDFG: """Expand the coarse SDFG in parent_sdfg to a NestedSDFG with all the states.""" split_horizontal_executions_regions(node) diff --git a/src/gt4py/cartesian/gtc/dace/expansion/sdfg_builder.py b/src/gt4py/cartesian/gtc/dace/expansion/sdfg_builder.py index 4c415e3260..f6aa725b01 100644 --- a/src/gt4py/cartesian/gtc/dace/expansion/sdfg_builder.py +++ b/src/gt4py/cartesian/gtc/dace/expansion/sdfg_builder.py @@ -6,6 +6,8 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause +from __future__ import annotations + import dataclasses from dataclasses import dataclass from typing import Any, ChainMap, Dict, List, Optional, Set, Tuple @@ -85,8 +87,8 @@ def visit_Memlet( node: dcir.Memlet, *, scope_node: dcir.ComputationNode, - sdfg_ctx: "StencilComputationSDFGBuilder.SDFGContext", - node_ctx: "StencilComputationSDFGBuilder.NodeContext", + sdfg_ctx: StencilComputationSDFGBuilder.SDFGContext, + node_ctx: StencilComputationSDFGBuilder.NodeContext, connector_prefix="", symtable: ChainMap[eve.SymbolRef, dcir.Decl], ) -> None: @@ -118,8 +120,8 @@ def _add_empty_edges( entry_node: dace.nodes.Node, exit_node: dace.nodes.Node, *, - sdfg_ctx: "StencilComputationSDFGBuilder.SDFGContext", - node_ctx: "StencilComputationSDFGBuilder.NodeContext", + sdfg_ctx: StencilComputationSDFGBuilder.SDFGContext, + node_ctx: StencilComputationSDFGBuilder.NodeContext, ) -> None: if not sdfg_ctx.state.in_degree(entry_node) and None in node_ctx.input_node_and_conns: sdfg_ctx.state.add_edge( @@ -134,8 +136,8 @@ def visit_Tasklet( self, node: dcir.Tasklet, *, - sdfg_ctx: "StencilComputationSDFGBuilder.SDFGContext", - node_ctx: "StencilComputationSDFGBuilder.NodeContext", + sdfg_ctx: StencilComputationSDFGBuilder.SDFGContext, + node_ctx: StencilComputationSDFGBuilder.NodeContext, symtable: ChainMap[eve.SymbolRef, dcir.Decl], **kwargs, ) -> None: @@ -183,8 +185,8 @@ def visit_DomainMap( self, node: dcir.DomainMap, *, - node_ctx: "StencilComputationSDFGBuilder.NodeContext", - sdfg_ctx: "StencilComputationSDFGBuilder.SDFGContext", + node_ctx: StencilComputationSDFGBuilder.NodeContext, + sdfg_ctx: StencilComputationSDFGBuilder.SDFGContext, **kwargs, ) -> None: ndranges = { @@ -245,7 +247,7 @@ def visit_DomainLoop( self, node: dcir.DomainLoop, *, - sdfg_ctx: "StencilComputationSDFGBuilder.SDFGContext", + sdfg_ctx: StencilComputationSDFGBuilder.SDFGContext, **kwargs, ) -> None: sdfg_ctx = sdfg_ctx.add_loop(node.index_range) @@ -256,7 +258,7 @@ def visit_ComputationState( self, node: dcir.ComputationState, *, - sdfg_ctx: "StencilComputationSDFGBuilder.SDFGContext", + sdfg_ctx: StencilComputationSDFGBuilder.SDFGContext, **kwargs, ) -> None: sdfg_ctx.add_state() @@ -285,7 +287,7 @@ def visit_FieldDecl( self, node: dcir.FieldDecl, *, - sdfg_ctx: "StencilComputationSDFGBuilder.SDFGContext", + sdfg_ctx: StencilComputationSDFGBuilder.SDFGContext, non_transients: Set[eve.SymbolRef], **kwargs, ) -> None: @@ -304,7 +306,7 @@ def visit_SymbolDecl( self, node: dcir.SymbolDecl, *, - sdfg_ctx: "StencilComputationSDFGBuilder.SDFGContext", + sdfg_ctx: StencilComputationSDFGBuilder.SDFGContext, **kwargs, ) -> None: if node.name not in sdfg_ctx.sdfg.symbols: @@ -314,8 +316,8 @@ def visit_NestedSDFG( self, node: dcir.NestedSDFG, *, - sdfg_ctx: Optional["StencilComputationSDFGBuilder.SDFGContext"] = None, - node_ctx: Optional["StencilComputationSDFGBuilder.NodeContext"] = None, + sdfg_ctx: Optional[StencilComputationSDFGBuilder.SDFGContext] = None, + node_ctx: Optional[StencilComputationSDFGBuilder.NodeContext] = None, symtable: ChainMap[eve.SymbolRef, Any], **kwargs, ) -> dace.nodes.NestedSDFG: diff --git a/src/gt4py/cartesian/gtc/dace/expansion/utils.py b/src/gt4py/cartesian/gtc/dace/expansion/utils.py index 369e8c7c38..d8d2ce3176 100644 --- a/src/gt4py/cartesian/gtc/dace/expansion/utils.py +++ b/src/gt4py/cartesian/gtc/dace/expansion/utils.py @@ -6,6 +6,8 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause +from __future__ import annotations + from typing import TYPE_CHECKING, List import dace @@ -32,7 +34,7 @@ def get_dace_debuginfo(node: common.LocNode): class HorizontalIntervalRemover(eve.NodeTranslator): - def visit_HorizontalMask(self, node: common.HorizontalMask, *, axis: "dcir.Axis"): + def visit_HorizontalMask(self, node: common.HorizontalMask, *, axis: dcir.Axis): mask_attrs = dict(i=node.i, j=node.j) mask_attrs[axis.lower()] = self.visit(getattr(node, axis.lower())) return common.HorizontalMask(**mask_attrs) @@ -42,7 +44,7 @@ def visit_HorizontalInterval(self, node: common.HorizontalInterval): class HorizontalMaskRemover(eve.NodeTranslator): - def visit_Tasklet(self, node: "dcir.Tasklet"): + def visit_Tasklet(self, node: dcir.Tasklet): res_body = [] for stmt in node.stmts: newstmt = self.visit(stmt) @@ -161,7 +163,7 @@ def visit_VerticalLoopSection(self, node: oir.VerticalLoopSection, **kwargs): return oir.VerticalLoopSection(interval=node.interval, horizontal_executions=res_hes) -def split_horizontal_executions_regions(node: "StencilComputation"): +def split_horizontal_executions_regions(node: StencilComputation): extents: List[Extent] = [] node.oir_node = HorizontalExecutionSplitter().visit( diff --git a/src/gt4py/cartesian/gtc/dace/expansion_specification.py b/src/gt4py/cartesian/gtc/dace/expansion_specification.py index a5f1acc5fa..091849a1b7 100644 --- a/src/gt4py/cartesian/gtc/dace/expansion_specification.py +++ b/src/gt4py/cartesian/gtc/dace/expansion_specification.py @@ -6,6 +6,8 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause +from __future__ import annotations + import copy from dataclasses import dataclass from typing import TYPE_CHECKING, Callable, List, Optional, Set, Union @@ -368,7 +370,7 @@ def _collapse_maps(self, expansion_specification): def make_expansion_order( - node: "StencilComputation", expansion_order: Union[List[str], List[ExpansionItem]] + node: StencilComputation, expansion_order: Union[List[str], List[ExpansionItem]] ) -> List[ExpansionItem]: if expansion_order is None: return None @@ -388,7 +390,7 @@ def make_expansion_order( return expansion_specification -def _k_inside_dims(node: "StencilComputation"): +def _k_inside_dims(node: StencilComputation): # Putting K inside of i or j is valid if # * K parallel or # * All reads with k-offset to values modified in same HorizontalExecution are not @@ -438,7 +440,7 @@ def _k_inside_dims(node: "StencilComputation"): return res -def _k_inside_stages(node: "StencilComputation"): +def _k_inside_stages(node: StencilComputation): # Putting K inside of stages is valid if # * K parallel # * not "ahead" in order of iteration to fields that are modified in previous @@ -482,7 +484,7 @@ def _k_inside_stages(node: "StencilComputation"): @_register_validity_check def _sequential_as_loops( - node: "StencilComputation", expansion_specification: List[ExpansionItem] + node: StencilComputation, expansion_specification: List[ExpansionItem] ) -> bool: # K can't be Map if not parallel if node.oir_node.loop_order != common.LoopOrder.PARALLEL and any( @@ -510,7 +512,7 @@ def _stages_inside_sections(expansion_specification: List[ExpansionItem], **kwar @_register_validity_check def _k_inside_ij_valid( - node: "StencilComputation", expansion_specification: List[ExpansionItem] + node: StencilComputation, expansion_specification: List[ExpansionItem] ) -> bool: # OIR defines that horizontal maps go inside vertical K loop (i.e. all grid points are updated in a # HorizontalExecution before the computation of the next one is executed.). Under certain conditions the semantics @@ -527,7 +529,7 @@ def _k_inside_ij_valid( @_register_validity_check def _k_inside_stages_valid( - node: "StencilComputation", expansion_specification: List[ExpansionItem] + node: StencilComputation, expansion_specification: List[ExpansionItem] ) -> bool: # OIR defines that all horizontal executions of a VerticalLoopSection are run per level. Under certain conditions # the semantics remain unchanged even if the k loop is run per horizontal execution. See `_k_inside_stages` for @@ -544,7 +546,7 @@ def _k_inside_stages_valid( @_register_validity_check def _ij_outside_sections_valid( - node: "StencilComputation", expansion_specification: List[ExpansionItem] + node: StencilComputation, expansion_specification: List[ExpansionItem] ) -> bool: # If there are multiple horizontal executions in any section, IJ iteration must go inside sections. # TODO: do mergeability checks on a per-axis basis. @@ -598,7 +600,7 @@ def _iterates_domain(expansion_specification: List[ExpansionItem], **kwargs) -> return True -def is_expansion_order_valid(node: "StencilComputation", expansion_order) -> bool: +def is_expansion_order_valid(node: StencilComputation, expansion_order) -> bool: """Check if a given expansion specification valid. That is, it is semantically valid for the StencilComputation node that is to be configured and currently diff --git a/src/gt4py/cartesian/gtc/dace/nodes.py b/src/gt4py/cartesian/gtc/dace/nodes.py index 926265b2c7..5c2f11f30d 100644 --- a/src/gt4py/cartesian/gtc/dace/nodes.py +++ b/src/gt4py/cartesian/gtc/dace/nodes.py @@ -6,6 +6,8 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause +from __future__ import annotations + import base64 import pickle import typing @@ -28,13 +30,13 @@ def _set_expansion_order( - node: "StencilComputation", expansion_order: Union[List[ExpansionItem], List[str]] + node: StencilComputation, expansion_order: Union[List[ExpansionItem], List[str]] ): res = make_expansion_order(node, expansion_order) node._expansion_specification = res -def _set_tile_sizes_interpretation(node: "StencilComputation", tile_sizes_interpretation: str): +def _set_tile_sizes_interpretation(node: StencilComputation, tile_sizes_interpretation: str): valid_values = {"shape", "strides"} if tile_sizes_interpretation not in valid_values: raise ValueError(f"tile_sizes_interpretation must be one in {valid_values}.") diff --git a/src/gt4py/cartesian/gtc/dace/oir_to_dace.py b/src/gt4py/cartesian/gtc/dace/oir_to_dace.py index dba6c5a700..283402e1ac 100644 --- a/src/gt4py/cartesian/gtc/dace/oir_to_dace.py +++ b/src/gt4py/cartesian/gtc/dace/oir_to_dace.py @@ -6,6 +6,8 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause +from __future__ import annotations + from dataclasses import dataclass from typing import Dict @@ -92,7 +94,7 @@ def _make_dace_subset(self, local_access_info, field): ) def visit_VerticalLoop( - self, node: oir.VerticalLoop, *, ctx: "OirSDFGBuilder.SDFGContext", **kwargs + self, node: oir.VerticalLoop, *, ctx: OirSDFGBuilder.SDFGContext, **kwargs ): declarations = { acc.name: ctx.decls[acc.name] diff --git a/src/gt4py/cartesian/gtc/dace/symbol_utils.py b/src/gt4py/cartesian/gtc/dace/symbol_utils.py index 42a4478a6f..86823304db 100644 --- a/src/gt4py/cartesian/gtc/dace/symbol_utils.py +++ b/src/gt4py/cartesian/gtc/dace/symbol_utils.py @@ -6,6 +6,8 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause +from __future__ import annotations + from functools import lru_cache from typing import TYPE_CHECKING @@ -36,7 +38,7 @@ def get_axis_bound_str(axis_bound, var_name): return f"{axis_bound.offset}" -def get_axis_bound_dace_symbol(axis_bound: "dcir.AxisBound"): +def get_axis_bound_dace_symbol(axis_bound: dcir.AxisBound): from gt4py.cartesian.gtc.common import LevelMarker if axis_bound is None: diff --git a/src/gt4py/cartesian/gtc/dace/utils.py b/src/gt4py/cartesian/gtc/dace/utils.py index fc68723e1b..f4dade581d 100644 --- a/src/gt4py/cartesian/gtc/dace/utils.py +++ b/src/gt4py/cartesian/gtc/dace/utils.py @@ -6,6 +6,8 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause +from __future__ import annotations + import re from dataclasses import dataclass, field from typing import Any, Dict, List, Optional, Sequence, Tuple, Union @@ -83,19 +85,19 @@ def __init__(self, collect_read: bool, collect_write: bool, include_full_domain: @dataclass class Context: - axes: Dict[str, List["dcir.Axis"]] - access_infos: Dict[str, "dcir.FieldAccessInfo"] = field(default_factory=dict) + axes: Dict[str, List[dcir.Axis]] + access_infos: Dict[str, dcir.FieldAccessInfo] = field(default_factory=dict) def visit_VerticalLoop( self, node: oir.VerticalLoop, *, block_extents, ctx, **kwargs: Any - ) -> Dict[str, "dcir.FieldAccessInfo"]: + ) -> Dict[str, dcir.FieldAccessInfo]: for section in reversed(node.sections): self.visit(section, block_extents=block_extents, ctx=ctx, **kwargs) return ctx.access_infos def visit_VerticalLoopSection( self, node: oir.VerticalLoopSection, *, block_extents, ctx, grid_subset=None, **kwargs: Any - ) -> Dict[str, "dcir.FieldAccessInfo"]: + ) -> Dict[str, dcir.FieldAccessInfo]: inner_ctx = self.Context(axes=ctx.axes) if grid_subset is None: @@ -135,7 +137,7 @@ def visit_HorizontalExecution( k_interval, grid_subset=None, **kwargs, - ) -> Dict[str, "dcir.FieldAccessInfo"]: + ) -> Dict[str, dcir.FieldAccessInfo]: horizontal_extent = block_extents(node) inner_ctx = self.Context(axes=ctx.axes) @@ -186,7 +188,7 @@ def visit_While(self, node: oir.While, *, is_conditional=False, **kwargs): @staticmethod def _global_grid_subset( - region: common.HorizontalMask, he_grid: "dcir.GridSubset", offset: List[Optional[int]] + region: common.HorizontalMask, he_grid: dcir.GridSubset, offset: List[Optional[int]] ): res: Dict[ dcir.Axis, Union[dcir.DomainInterval, dcir.TileInterval, dcir.IndexWithExtent] @@ -225,7 +227,7 @@ def _make_access_info( region, he_grid, grid_subset, - ) -> "dcir.FieldAccessInfo": + ) -> dcir.FieldAccessInfo: offset = [offset_node.to_dict()[k] for k in "ijk"] if isinstance(offset_node, oir.VariableKOffset): variable_offset_axes = [dcir.Axis.K] @@ -262,7 +264,7 @@ def visit_FieldAccess( is_write: bool = False, is_conditional: bool = False, region=None, - ctx: "AccessInfoCollector.Context", + ctx: AccessInfoCollector.Context, **kwargs, ): self.visit( @@ -331,8 +333,8 @@ def compute_dcir_access_infos( def make_dace_subset( - context_info: "dcir.FieldAccessInfo", - access_info: "dcir.FieldAccessInfo", + context_info: dcir.FieldAccessInfo, + access_info: dcir.FieldAccessInfo, data_dims: Tuple[int, ...], ) -> dace.subsets.Range: clamped_access_info = access_info @@ -354,10 +356,8 @@ def make_dace_subset( return dace.subsets.Range(res_ranges) -def untile_memlets( - memlets: Sequence["dcir.Memlet"], axes: Sequence["dcir.Axis"] -) -> List["dcir.Memlet"]: - res_memlets: List["dcir.Memlet"] = [] +def untile_memlets(memlets: Sequence[dcir.Memlet], axes: Sequence[dcir.Axis]) -> List[dcir.Memlet]: + res_memlets: List[dcir.Memlet] = [] for memlet in memlets: res_memlets.append( dcir.Memlet( @@ -382,7 +382,7 @@ def union_node_grid_subsets(nodes: List[eve.Node]): return grid_subset -def _union_memlets(*memlets: "dcir.Memlet") -> List["dcir.Memlet"]: +def _union_memlets(*memlets: dcir.Memlet) -> List[dcir.Memlet]: res: Dict[str, dcir.Memlet] = {} for memlet in memlets: res[memlet.field] = memlet.union(res.get(memlet.field, memlet)) @@ -408,7 +408,7 @@ def flatten_list(list_or_node: Union[List[Any], eve.Node]): def collect_toplevel_computation_nodes( list_or_node: Union[List[Any], eve.Node], -) -> List["dcir.ComputationNode"]: +) -> List[dcir.ComputationNode]: class ComputationNodeCollector(eve.NodeVisitor): def visit_ComputationNode(self, node: dcir.ComputationNode, *, collection: List): collection.append(node) @@ -420,7 +420,7 @@ def visit_ComputationNode(self, node: dcir.ComputationNode, *, collection: List) def collect_toplevel_iteration_nodes( list_or_node: Union[List[Any], eve.Node], -) -> List["dcir.IterationNode"]: +) -> List[dcir.IterationNode]: class IterationNodeCollector(eve.NodeVisitor): def visit_IterationNode(self, node: dcir.IterationNode, *, collection: List): collection.append(node) diff --git a/src/gt4py/cartesian/gtc/daceir.py b/src/gt4py/cartesian/gtc/daceir.py index 93bd8623e5..0ecb02b50f 100644 --- a/src/gt4py/cartesian/gtc/daceir.py +++ b/src/gt4py/cartesian/gtc/daceir.py @@ -170,7 +170,7 @@ def size(self): def overapproximated_size(self): return self.size - def union(self, other: "IndexWithExtent"): + def union(self, other: IndexWithExtent): assert self.axis == other.axis if isinstance(self.value, int) or (isinstance(self.value, str) and self.value.isdigit()): value = other.value @@ -273,7 +273,7 @@ def shifted(self, offset: int): ), ) - def is_subset_of(self, other: "DomainInterval") -> bool: + def is_subset_of(self, other: DomainInterval) -> bool: return self.start >= other.start and self.end <= other.end @@ -390,7 +390,7 @@ def shape(self): def overapproximated_shape(self): return tuple(interval.overapproximated_size for _, interval in self.items()) - def restricted_to_index(self, axis: Axis, extent=(0, 0)) -> "GridSubset": + def restricted_to_index(self, axis: Axis, extent=(0, 0)) -> GridSubset: intervals = dict(self.intervals) intervals[axis] = IndexWithExtent.from_axis(axis, extent=extent) return GridSubset(intervals=intervals) @@ -399,7 +399,7 @@ def set_interval( self, axis: Axis, interval: Union[DomainInterval, IndexWithExtent, TileInterval, oir.Interval], - ) -> "GridSubset": + ) -> GridSubset: if isinstance(interval, oir.Interval): interval = DomainInterval( start=AxisBound( @@ -587,7 +587,7 @@ def apply_iteration(self, grid_subset: GridSubset): global_grid_subset=self.global_grid_subset, ) - def union(self, other: "FieldAccessInfo"): + def union(self, other: FieldAccessInfo): grid_subset = self.grid_subset.union(other.grid_subset) global_subset = self.global_grid_subset.union(other.global_grid_subset) variable_offset_axes = [ @@ -625,7 +625,7 @@ def clamp_full_axis(self, axis): global_grid_subset=self.global_grid_subset, ) - def untile(self, tile_axes: Sequence[Axis]) -> "FieldAccessInfo": + def untile(self, tile_axes: Sequence[Axis]) -> FieldAccessInfo: res_intervals = {} for axis, interval in self.grid_subset.intervals.items(): if isinstance(interval, TileInterval) and axis in tile_axes: @@ -715,7 +715,7 @@ def axes(self): def is_dynamic(self) -> bool: return self.access_info.is_dynamic - def with_set_access_info(self, access_info: FieldAccessInfo) -> "FieldDecl": + def with_set_access_info(self, access_info: FieldAccessInfo) -> FieldDecl: return FieldDecl( name=self.name, dtype=self.dtype, diff --git a/src/gt4py/cartesian/gtc/definitions.py b/src/gt4py/cartesian/gtc/definitions.py index fe5fa499e4..16c7fbc46a 100644 --- a/src/gt4py/cartesian/gtc/definitions.py +++ b/src/gt4py/cartesian/gtc/definitions.py @@ -456,7 +456,7 @@ def _broadcast(self, value): class Boundary(FrameTuple): """Frame size around one central origin (pairs of integers). - Negative numbers represent a boundary region substracting from + Negative numbers represent a boundary region subtracting from the wrapped area. """ diff --git a/src/gt4py/cartesian/gtc/gtcpp/gtcpp.py b/src/gt4py/cartesian/gtc/gtcpp/gtcpp.py index 6b59212722..0d19814b9c 100644 --- a/src/gt4py/cartesian/gtc/gtcpp/gtcpp.py +++ b/src/gt4py/cartesian/gtc/gtcpp/gtcpp.py @@ -6,6 +6,8 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause +from __future__ import annotations + import enum from typing import Any, List, Tuple, Union @@ -134,10 +136,10 @@ class GTExtent(LocNode): k: Tuple[int, int] @classmethod - def zero(cls) -> "GTExtent": + def zero(cls) -> GTExtent: return cls(i=(0, 0), j=(0, 0), k=(0, 0)) - def __add__(self, offset: Union[common.CartesianOffset, VariableKOffset]) -> "GTExtent": + def __add__(self, offset: Union[common.CartesianOffset, VariableKOffset]) -> GTExtent: if isinstance(offset, common.CartesianOffset): return GTExtent( i=(min(self.i[0], offset.i), max(self.i[1], offset.i)), diff --git a/src/gt4py/cartesian/gtc/gtcpp/oir_to_gtcpp.py b/src/gt4py/cartesian/gtc/gtcpp/oir_to_gtcpp.py index d26619eeba..0d5b1517c5 100644 --- a/src/gt4py/cartesian/gtc/gtcpp/oir_to_gtcpp.py +++ b/src/gt4py/cartesian/gtc/gtcpp/oir_to_gtcpp.py @@ -6,6 +6,8 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause +from __future__ import annotations + import functools import itertools from dataclasses import dataclass, field @@ -110,7 +112,7 @@ class GTComputationContext: def add_temporaries( self, temporaries: List[gtcpp.Temporary] - ) -> "OIRToGTCpp.GTComputationContext": + ) -> OIRToGTCpp.GTComputationContext: self.temporaries.extend(temporaries) return self @@ -118,7 +120,7 @@ def add_temporaries( def arguments(self) -> List[gtcpp.Arg]: return [gtcpp.Arg(name=name) for name in self._arguments] - def add_arguments(self, arguments: Set[str]) -> "OIRToGTCpp.GTComputationContext": + def add_arguments(self, arguments: Set[str]) -> OIRToGTCpp.GTComputationContext: self._arguments.update(arguments) return self @@ -221,7 +223,7 @@ def visit_Interval(self, node: oir.Interval, **kwargs: Any) -> gtcpp.GTInterval: ) def _mask_to_expr( - self, mask: common.HorizontalMask, comp_ctx: "GTComputationContext" + self, mask: common.HorizontalMask, comp_ctx: GTComputationContext ) -> gtcpp.Expr: mask_expr: List[gtcpp.Expr] = [] for axis_index, interval in enumerate(mask.intervals): @@ -288,8 +290,8 @@ def visit_HorizontalExecution( self, node: oir.HorizontalExecution, *, - prog_ctx: "ProgramContext", - comp_ctx: "GTComputationContext", + prog_ctx: ProgramContext, + comp_ctx: GTComputationContext, interval: gtcpp.GTInterval, **kwargs: Any, ) -> gtcpp.GTStage: diff --git a/src/gt4py/cartesian/gtc/gtir.py b/src/gt4py/cartesian/gtc/gtir.py index 9c1f61d1ab..c9f58de2da 100644 --- a/src/gt4py/cartesian/gtc/gtir.py +++ b/src/gt4py/cartesian/gtc/gtir.py @@ -19,6 +19,8 @@ - `FieldIfStmt` expansion to comply with the parallel model """ +from __future__ import annotations + from typing import Any, Dict, List, Set, Tuple, Type from gt4py import eve @@ -77,7 +79,7 @@ def no_horizontal_offset_in_assignment( @datamodels.root_validator @classmethod def no_write_and_read_with_offset_of_same_field( - cls: Type["ParAssignStmt"], instance: "ParAssignStmt" + cls: Type[ParAssignStmt], instance: ParAssignStmt ) -> None: if isinstance(instance.left, FieldAccess): offset_reads = ( @@ -203,7 +205,7 @@ class VerticalLoop(LocNode): @datamodels.root_validator @classmethod def _no_write_and_read_with_horizontal_offset( - cls: Type["VerticalLoop"], instance: "VerticalLoop" + cls: Type[VerticalLoop], instance: VerticalLoop ) -> None: """ In the same VerticalLoop a field must not be written and read with a horizontal offset. diff --git a/src/gt4py/cartesian/gtc/numpy/npir_codegen.py b/src/gt4py/cartesian/gtc/numpy/npir_codegen.py index 0870c887f6..e1a9f8e8bb 100644 --- a/src/gt4py/cartesian/gtc/numpy/npir_codegen.py +++ b/src/gt4py/cartesian/gtc/numpy/npir_codegen.py @@ -6,6 +6,8 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause +from __future__ import annotations + import numbers import textwrap from dataclasses import dataclass, field @@ -249,7 +251,7 @@ def visit_NativeFuncCall( NativeFuncCall = as_fmt("{func}({', '.join(arg for arg in args)}{mask_arg})") def visit_VectorAssign( - self, node: npir.VectorAssign, *, ctx: "BlockContext", **kwargs: Any + self, node: npir.VectorAssign, *, ctx: BlockContext, **kwargs: Any ) -> Union[str, Collection[str]]: left = self.visit(node.left, horizontal_mask=node.horizontal_mask, **kwargs) right = self.visit(node.right, horizontal_mask=node.horizontal_mask, **kwargs) diff --git a/src/gt4py/cartesian/gtc/oir.py b/src/gt4py/cartesian/gtc/oir.py index b8fad9bc1c..df71ef26cf 100644 --- a/src/gt4py/cartesian/gtc/oir.py +++ b/src/gt4py/cartesian/gtc/oir.py @@ -13,6 +13,8 @@ e.g. stage merging, staged computations to compute-on-the-fly, cache annotations, etc. """ +from __future__ import annotations + from typing import Any, List, Optional, Tuple, Type, Union from gt4py import eve @@ -125,7 +127,7 @@ class Temporary(FieldDecl): pass -def _check_interval(instance: Union["Interval", "UnboundedInterval"]) -> None: +def _check_interval(instance: Union[Interval, UnboundedInterval]) -> None: start, end = instance.start, instance.end if ( start is not None @@ -152,18 +154,18 @@ class Interval(LocNode): @datamodels.root_validator @classmethod - def check(cls: Type["Interval"], instance: "Interval") -> None: + def check(cls: Type[Interval], instance: Interval) -> None: _check_interval(instance) - def covers(self, other: "Interval") -> bool: + def covers(self, other: Interval) -> bool: outer_starts_lower = self.start < other.start or self.start == other.start outer_ends_higher = self.end > other.end or self.end == other.end return outer_starts_lower and outer_ends_higher - def intersects(self, other: "Interval") -> bool: + def intersects(self, other: Interval) -> bool: return not (other.start >= self.end or self.start >= other.end) - def shifted(self, offset: Optional[int]) -> Union["Interval", "UnboundedInterval"]: + def shifted(self, offset: Optional[int]) -> Union[Interval, UnboundedInterval]: if offset is None: return UnboundedInterval() start = AxisBound(level=self.start.level, offset=self.start.offset + offset) @@ -181,10 +183,10 @@ class UnboundedInterval: @datamodels.root_validator @classmethod - def check(cls: Type["UnboundedInterval"], instance: "UnboundedInterval") -> None: + def check(cls: Type[UnboundedInterval], instance: UnboundedInterval) -> None: _check_interval(instance) - def covers(self, other: Union[Interval, "UnboundedInterval"]) -> bool: + def covers(self, other: Union[Interval, UnboundedInterval]) -> bool: if self.start is None and self.end is None: return True if ( @@ -209,7 +211,7 @@ def covers(self, other: Union[Interval, "UnboundedInterval"]) -> bool: assert isinstance(other, Interval) return Interval(start=self.start, end=self.end).covers(other) - def intersects(self, other: Union[Interval, "UnboundedInterval"]) -> bool: + def intersects(self, other: Union[Interval, UnboundedInterval]) -> bool: no_overlap_high = ( self.end is not None and other.start is not None and other.start >= self.end ) @@ -218,7 +220,7 @@ def intersects(self, other: Union[Interval, "UnboundedInterval"]) -> bool: ) return not (no_overlap_low or no_overlap_high) - def shifted(self, offset: Optional[int]) -> "UnboundedInterval": + def shifted(self, offset: Optional[int]) -> UnboundedInterval: if offset is None: return UnboundedInterval() @@ -274,7 +276,7 @@ def nonempty_loop(self, attribute: datamodels.Attribute, v: List[VerticalLoopSec @datamodels.root_validator @classmethod - def valid_section_intervals(cls: Type["VerticalLoop"], instance: "VerticalLoop") -> None: + def valid_section_intervals(cls: Type[VerticalLoop], instance: VerticalLoop) -> None: starts, ends = zip(*((s.interval.start, s.interval.end) for s in instance.sections)) if instance.loop_order == common.LoopOrder.BACKWARD: starts, ends = starts[:-1], ends[1:] diff --git a/src/gt4py/cartesian/gtc/passes/oir_optimizations/horizontal_execution_merging.py b/src/gt4py/cartesian/gtc/passes/oir_optimizations/horizontal_execution_merging.py index b8be18af6f..f2378d7f2d 100644 --- a/src/gt4py/cartesian/gtc/passes/oir_optimizations/horizontal_execution_merging.py +++ b/src/gt4py/cartesian/gtc/passes/oir_optimizations/horizontal_execution_merging.py @@ -137,7 +137,7 @@ class OnTheFlyMerging(eve.NodeTranslator, eve.VisitorWithSymbolTableTrait): """Merges consecutive horizontal executions inside parallel vertical loops by introducing redundant computations. Limitations: - * Works on the level of whole horizontal executions, no full dependency analysis is performed (common subexpression and dead code eliminitation at a later stage can work around this limitation). + * Works on the level of whole horizontal executions, no full dependency analysis is performed (common subexpression and dead code elimination at a later stage can work around this limitation). * The chosen default merge limits are totally arbitrary. """ diff --git a/src/gt4py/cartesian/gtc/passes/oir_optimizations/utils.py b/src/gt4py/cartesian/gtc/passes/oir_optimizations/utils.py index c9b93bdb78..e512d560ca 100644 --- a/src/gt4py/cartesian/gtc/passes/oir_optimizations/utils.py +++ b/src/gt4py/cartesian/gtc/passes/oir_optimizations/utils.py @@ -6,6 +6,8 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause +from __future__ import annotations + import dataclasses import re from dataclasses import dataclass @@ -162,7 +164,7 @@ class CartesianAccessCollection(GenericAccessCollection[CartesianAccess, Tuple[i pass class GeneralAccessCollection(GenericAccessCollection[GeneralAccess, GeneralOffsetTuple]): - def cartesian_accesses(self) -> "AccessCollector.CartesianAccessCollection": + def cartesian_accesses(self) -> AccessCollector.CartesianAccessCollection: return AccessCollector.CartesianAccessCollection( [ CartesianAccess( @@ -182,7 +184,7 @@ def has_variable_access(self) -> bool: @classmethod def apply( cls, node: gt4py.eve.RootNode, **kwargs: Any - ) -> "AccessCollector.GeneralAccessCollection": + ) -> AccessCollector.GeneralAccessCollection: result = cls.GeneralAccessCollection([]) cls().visit(node, accesses=result._ordered_accesses, **kwargs) return result @@ -233,7 +235,7 @@ def __init__(self, add_k: bool = False): self.add_k = add_k self.zero_extent = Extent.zeros(ndims=2) - def visit_Stencil(self, node: oir.Stencil) -> "Context": + def visit_Stencil(self, node: oir.Stencil) -> Context: ctx = self.Context() for vloop in reversed(node.vertical_loops): self.visit(vloop, ctx=ctx) diff --git a/src/gt4py/cartesian/lazy_stencil.py b/src/gt4py/cartesian/lazy_stencil.py index 216246caff..a2209e0b7f 100644 --- a/src/gt4py/cartesian/lazy_stencil.py +++ b/src/gt4py/cartesian/lazy_stencil.py @@ -8,6 +8,8 @@ """Stencil Object that allows for deferred building.""" +from __future__ import annotations + from typing import TYPE_CHECKING, Any, Dict from cached_property import cached_property @@ -24,7 +26,7 @@ class LazyStencil: A stencil object which defers compilation until it is needed. Usually obtained using the :py:func:`gt4py.gtscript.lazy_stencil` decorator, not directly - instanciated. + instantiated. This is done by keeping a reference to a :py:class:`gt4py.stencil_builder.StencilBuilder` instance. @@ -32,12 +34,12 @@ class LazyStencil: Low-level build utilities are accessible through the public :code:`builder` attribute. """ - def __init__(self, builder: "StencilBuilder"): + def __init__(self, builder: StencilBuilder): self.builder = builder self.builder.caching.capture_externals() @cached_property - def implementation(self) -> "StencilObject": + def implementation(self) -> StencilObject: """ Expose the compiled backend-specific python callable which executes the stencil. @@ -48,7 +50,7 @@ def implementation(self) -> "StencilObject": return impl @property - def backend(self) -> "Backend": + def backend(self) -> Backend: """Do not trigger a build.""" return self.builder.backend diff --git a/src/gt4py/cartesian/loader.py b/src/gt4py/cartesian/loader.py index 839aa925da..889618ef0d 100644 --- a/src/gt4py/cartesian/loader.py +++ b/src/gt4py/cartesian/loader.py @@ -12,6 +12,8 @@ a high-level stencil function definition using a specific code generating backend. """ +from __future__ import annotations + import types from typing import TYPE_CHECKING, Any, Dict, Type @@ -31,8 +33,8 @@ def load_stencil( definition_func: StencilFunc, externals: Dict[str, Any], dtypes: Dict[Type, Type], - build_options: "BuildOptions", -) -> Type["StencilObject"]: + build_options: BuildOptions, +) -> Type[StencilObject]: """Generate a new class object implementing the provided definition.""" # Load components backend_cls = gt_backend.from_name(backend_name) @@ -57,10 +59,10 @@ def load_stencil( def gtscript_loader( definition_func: StencilFunc, backend: str, - build_options: "BuildOptions", + build_options: BuildOptions, externals: Dict[str, Any], dtypes: Dict[Type, Type], -) -> "StencilObject": +) -> StencilObject: if not isinstance(definition_func, types.FunctionType): raise ValueError("Invalid stencil definition object ({obj})".format(obj=definition_func)) diff --git a/src/gt4py/cartesian/stencil_builder.py b/src/gt4py/cartesian/stencil_builder.py index 84fa223b17..07d58f25f5 100644 --- a/src/gt4py/cartesian/stencil_builder.py +++ b/src/gt4py/cartesian/stencil_builder.py @@ -6,6 +6,8 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause +from __future__ import annotations + import pathlib from typing import TYPE_CHECKING, Any, Dict, Optional, Type, Union @@ -51,9 +53,9 @@ def __init__( self, definition_func: Union[StencilFunc, AnnotatedStencilFunc], *, - backend: Optional[Union[str, Type["BackendType"]]] = None, + backend: Optional[Union[str, Type[BackendType]]] = None, options: Optional[BuildOptions] = None, - frontend: Optional[Type["FrontendType"]] = None, + frontend: Optional[Type[FrontendType]] = None, ): self._definition = definition_func # type ignore explanation: Attribclass generated init not recognized by mypy @@ -69,13 +71,13 @@ def __init__( if frontend is None: raise RuntimeError(f"Unknown frontend: {frontend}") - self.backend: "BackendType" = backend(self) - self.frontend: Type["FrontendType"] = frontend + self.backend: BackendType = backend(self) + self.frontend: Type[FrontendType] = frontend self.with_caching("jit") self._externals: Dict[str, Any] = {} self._dtypes: Dict[Type, Type] = {} - def build(self) -> Type["StencilObject"]: + def build(self) -> Type[StencilObject]: """Generate, compile and/or load everything necessary to provide a usable stencil class.""" # load or generate stencil_class = None if self.options.rebuild else self.backend.load() @@ -96,8 +98,8 @@ def generate_bindings(self, targe_language: str) -> Dict[str, Union[str, Dict]]: return self.cli_backend.generate_bindings(targe_language) def with_caching( - self: "StencilBuilder", caching_strategy_name: str, *args: Any, **kwargs: Any - ) -> "StencilBuilder": + self: StencilBuilder, caching_strategy_name: str, *args: Any, **kwargs: Any + ) -> StencilBuilder: """ Fluidly set the caching strategy from the name. @@ -122,8 +124,8 @@ def with_caching( return self def with_options( - self: "StencilBuilder", *, name: str, module: str, **kwargs: Any - ) -> "StencilBuilder": + self: StencilBuilder, *, name: str, module: str, **kwargs: Any + ) -> StencilBuilder: """ Fluidly set the build options. @@ -144,14 +146,14 @@ def with_options( self.options = BuildOptions(name=name, module=module, **kwargs) # type: ignore return self - def with_changed_options(self: "StencilBuilder", **kwargs: Dict[str, Any]) -> "StencilBuilder": + def with_changed_options(self: StencilBuilder, **kwargs: Dict[str, Any]) -> StencilBuilder: old_options = self.options.as_dict() # BuildOptions constructor expects ``impl_opts`` keyword # but BuildOptions.as_dict outputs ``_impl_opts`` key old_options["impl_opts"] = old_options.pop("_impl_opts") return self.with_options(**{**old_options, **kwargs}) - def with_backend(self: "StencilBuilder", backend_name: str) -> "StencilBuilder": + def with_backend(self: StencilBuilder, backend_name: str) -> StencilBuilder: """ Fluidly set the backend type from backend name. @@ -219,7 +221,7 @@ def dtypes(self) -> Dict[Type, Type]: "dtypes", self._dtypes.copy() ) - def with_externals(self: "StencilBuilder", externals: Dict[str, Any]) -> "StencilBuilder": + def with_externals(self: StencilBuilder, externals: Dict[str, Any]) -> StencilBuilder: """ Fluidly set externals for this build. @@ -230,7 +232,7 @@ def with_externals(self: "StencilBuilder", externals: Dict[str, Any]) -> "Stenci self.with_caching(self.caching.name) return self - def with_dtypes(self: "StencilBuilder", dtypes: Dict[Type, Type]) -> "StencilBuilder": + def with_dtypes(self: StencilBuilder, dtypes: Dict[Type, Type]) -> StencilBuilder: self._build_data = {} self._dtypes = dtypes self.with_caching(self.caching.name) @@ -240,7 +242,7 @@ def with_dtypes(self: "StencilBuilder", dtypes: Dict[Type, Type]) -> "StencilBui def backend_data(self) -> Dict[str, Any]: return self._build_data.get("backend_data", {}).copy() - def with_backend_data(self: "StencilBuilder", data: Dict[str, Any]) -> "StencilBuilder": + def with_backend_data(self: StencilBuilder, data: Dict[str, Any]) -> StencilBuilder: self._build_data["backend_data"] = {**self.backend_data, **data} return self @@ -254,7 +256,7 @@ def root_pkg_name(self) -> str: "root_pkg_name", gt4pyc.config.code_settings["root_package_name"] ) - def with_root_pkg_name(self: "StencilBuilder", name: str) -> "StencilBuilder": + def with_root_pkg_name(self: StencilBuilder, name: str) -> StencilBuilder: self._build_data["root_pkg_name"] = name return self @@ -320,7 +322,7 @@ def is_build_data_empty(self) -> bool: return not bool(self._build_data) @property - def cli_backend(self) -> "CLIBackendMixin": + def cli_backend(self) -> CLIBackendMixin: from gt4py.cartesian.backend.base import CLIBackendMixin if not isinstance(self.backend, CLIBackendMixin): diff --git a/src/gt4py/cartesian/stencil_object.py b/src/gt4py/cartesian/stencil_object.py index cd4a55385e..b76415e17f 100644 --- a/src/gt4py/cartesian/stencil_object.py +++ b/src/gt4py/cartesian/stencil_object.py @@ -95,7 +95,7 @@ def _extract_stencil_arrays( class FrozenStencil: """Stencil with pre-computed domain and origin for each field argument.""" - stencil_object: "StencilObject" + stencil_object: StencilObject origin: Dict[str, Tuple[int, ...]] domain: Tuple[int, ...] @@ -594,7 +594,7 @@ def _call_run( exec_info["call_run_end_time"] = time.perf_counter() def freeze( - self: "StencilObject", *, origin: Dict[str, Tuple[int, ...]], domain: Tuple[int, ...] + self: StencilObject, *, origin: Dict[str, Tuple[int, ...]], domain: Tuple[int, ...] ) -> FrozenStencil: """Return a StencilObject wrapper with a fixed domain and origin for each argument. @@ -621,7 +621,7 @@ def freeze( """ return FrozenStencil(self, origin, domain) - def clean_call_args_cache(self: "StencilObject") -> None: + def clean_call_args_cache(self: StencilObject) -> None: """Clean the argument cache. Returns diff --git a/src/gt4py/cartesian/utils/base.py b/src/gt4py/cartesian/utils/base.py index 6fc6165926..d5d43a4103 100644 --- a/src/gt4py/cartesian/utils/base.py +++ b/src/gt4py/cartesian/utils/base.py @@ -34,10 +34,6 @@ def slugify(value: str, *, replace_spaces=True, valid_symbols="-_.()", invalid_m return slug -# def stringify(value): -# pass - - def jsonify(value, indent=2): return json.dumps(value, indent=indent, default=lambda obj: str(obj)) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/gpu_utils.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/gpu_utils.py index 06a3e8690d..16c9600a3a 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/gpu_utils.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/gpu_utils.py @@ -8,6 +8,8 @@ """Functions for turning an SDFG into a GPU SDFG.""" +from __future__ import annotations + import copy from typing import Any, Optional, Sequence, Union @@ -144,7 +146,7 @@ def gt_set_gpu_blocksize( def _gpu_block_parser( - self: "GPUSetBlockSize", + self: GPUSetBlockSize, val: Any, ) -> None: """Used by the setter of `GPUSetBlockSize.block_size`.""" diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py index 8c93f01b99..1437f8077d 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py @@ -6,6 +6,8 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause +from __future__ import annotations + import copy import dataclasses import itertools @@ -184,7 +186,7 @@ def __init__( def _visit_lift_in_neighbors_reduction( - transformer: "PythonTaskletCodegen", + transformer: PythonTaskletCodegen, node: itir.FunCall, node_args: Sequence[IteratorExpr | list[ValueExpr]], offset_provider: Connectivity, @@ -321,7 +323,7 @@ def _visit_lift_in_neighbors_reduction( def builtin_neighbors( - transformer: "PythonTaskletCodegen", node: itir.Expr, node_args: list[itir.Expr] + transformer: PythonTaskletCodegen, node: itir.Expr, node_args: list[itir.Expr] ) -> list[ValueExpr]: sdfg: dace.SDFG = transformer.context.body state: dace.SDFGState = transformer.context.state @@ -514,7 +516,7 @@ def builtin_neighbors( def builtin_can_deref( - transformer: "PythonTaskletCodegen", node: itir.Expr, node_args: list[itir.Expr] + transformer: PythonTaskletCodegen, node: itir.Expr, node_args: list[itir.Expr] ) -> list[ValueExpr]: di = dace_debuginfo(node, transformer.context.body.debuginfo) # first visit shift, to get set of indices for deref @@ -556,7 +558,7 @@ def builtin_can_deref( def builtin_if( - transformer: "PythonTaskletCodegen", node: itir.Expr, node_args: list[itir.Expr] + transformer: PythonTaskletCodegen, node: itir.Expr, node_args: list[itir.Expr] ) -> list[ValueExpr]: assert len(node_args) == 3 sdfg = transformer.context.body @@ -677,7 +679,7 @@ def build_if_state(arg, state): def builtin_list_get( - transformer: "PythonTaskletCodegen", node: itir.Expr, node_args: list[itir.Expr] + transformer: PythonTaskletCodegen, node: itir.Expr, node_args: list[itir.Expr] ) -> list[ValueExpr]: di = dace_debuginfo(node, transformer.context.body.debuginfo) args = list(itertools.chain(*transformer.visit(node_args))) @@ -703,7 +705,7 @@ def builtin_list_get( def builtin_cast( - transformer: "PythonTaskletCodegen", node: itir.Expr, node_args: list[itir.Expr] + transformer: PythonTaskletCodegen, node: itir.Expr, node_args: list[itir.Expr] ) -> list[ValueExpr]: di = dace_debuginfo(node, transformer.context.body.debuginfo) args = transformer.visit(node_args[0]) @@ -718,7 +720,7 @@ def builtin_cast( def builtin_make_const_list( - transformer: "PythonTaskletCodegen", node: itir.Expr, node_args: list[itir.Expr] + transformer: PythonTaskletCodegen, node: itir.Expr, node_args: list[itir.Expr] ) -> list[ValueExpr]: di = dace_debuginfo(node, transformer.context.body.debuginfo) args = [transformer.visit(arg)[0] for arg in node_args] @@ -754,14 +756,14 @@ def builtin_make_const_list( def builtin_make_tuple( - transformer: "PythonTaskletCodegen", node: itir.Expr, node_args: list[itir.Expr] + transformer: PythonTaskletCodegen, node: itir.Expr, node_args: list[itir.Expr] ) -> list[ValueExpr]: args = [transformer.visit(arg) for arg in node_args] return args def builtin_tuple_get( - transformer: "PythonTaskletCodegen", node: itir.Expr, node_args: list[itir.Expr] + transformer: PythonTaskletCodegen, node: itir.Expr, node_args: list[itir.Expr] ) -> list[ValueExpr]: elements = transformer.visit(node_args[1]) index = node_args[0] @@ -771,7 +773,7 @@ def builtin_tuple_get( _GENERAL_BUILTIN_MAPPING: dict[ - str, Callable[["PythonTaskletCodegen", itir.Expr, list[itir.Expr]], list[ValueExpr]] + str, Callable[[PythonTaskletCodegen, itir.Expr, list[itir.Expr]], list[ValueExpr]] ] = { "can_deref": builtin_can_deref, "cast_": builtin_cast,