Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: use modern-style type hints #1632

Merged
merged 1 commit into from
Sep 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 17 additions & 15 deletions src/gt4py/cartesian/backend/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
# Please, refer to the LICENSE file in the root directory.
# SPDX-License-Identifier: BSD-3-Clause

from __future__ import annotations

import abc
import copy
import hashlib
Expand Down Expand Up @@ -43,7 +45,7 @@
REGISTRY = gt_utils.Registry()


def from_name(name: str) -> Optional[Type["Backend"]]:
def from_name(name: str) -> Optional[Type[Backend]]:
backend = REGISTRY.get(name, None)
if not backend:
raise NotImplementedError(
Expand All @@ -52,7 +54,7 @@ def from_name(name: str) -> Optional[Type["Backend"]]:
return backend


def register(backend_cls: Type["Backend"]) -> Type["Backend"]:
def register(backend_cls: Type[Backend]) -> Type[Backend]:
assert issubclass(backend_cls, Backend) and backend_cls.name is not None

if isinstance(backend_cls.name, str):
Expand Down Expand Up @@ -101,9 +103,9 @@ class Backend(abc.ABC):
# "disable-code-generation": bool
# "disable-cache-validation": bool

builder: "StencilBuilder"
builder: StencilBuilder

def __init__(self, builder: "StencilBuilder"):
def __init__(self, builder: StencilBuilder):
self.builder = builder

@classmethod
Expand All @@ -120,7 +122,7 @@ def filter_options_for_id(
return filtered_options

@abc.abstractmethod
def load(self) -> Optional[Type["StencilObject"]]:
def load(self) -> Optional[Type[StencilObject]]:
"""
Load the stencil class from the generated python module.

Expand All @@ -135,7 +137,7 @@ def load(self) -> Optional[Type["StencilObject"]]:
pass

@abc.abstractmethod
def generate(self) -> Type["StencilObject"]:
def generate(self) -> Type[StencilObject]:
"""
Generate the stencil class from GTScript's internal representation.

Expand All @@ -152,7 +154,7 @@ def generate(self) -> Type["StencilObject"]:

@property
def extra_cache_info(self) -> Dict[str, Any]:
"""Provide additional data to be stored in cache info file (sublass hook)."""
"""Provide additional data to be stored in cache info file (subclass hook)."""
return {}

@property
Expand Down Expand Up @@ -240,9 +242,9 @@ def generate_bindings(self, language_name: str) -> Dict[str, Union[str, Dict]]:


class BaseBackend(Backend):
MODULE_GENERATOR_CLASS: ClassVar[Type["BaseModuleGenerator"]]
MODULE_GENERATOR_CLASS: ClassVar[Type[BaseModuleGenerator]]

def load(self) -> Optional[Type["StencilObject"]]:
def load(self) -> Optional[Type[StencilObject]]:
build_info = self.builder.options.build_info
if build_info is not None:
start_time = time.perf_counter()
Expand All @@ -263,11 +265,11 @@ def load(self) -> Optional[Type["StencilObject"]]:

return stencil_class

def generate(self) -> Type["StencilObject"]:
def generate(self) -> Type[StencilObject]:
self.check_options(self.builder.options)
return self.make_module()

def _load(self) -> Type["StencilObject"]:
def _load(self) -> Type[StencilObject]:
stencil_class_name = self.builder.class_name
file_name = str(self.builder.module_path)
stencil_module = gt_utils.make_module_from_file(stencil_class_name, file_name)
Expand All @@ -289,7 +291,7 @@ def check_options(self, options: gt_definitions.BuildOptions) -> None:
stacklevel=2,
)

def make_module(self, **kwargs: Any) -> Type["StencilObject"]:
def make_module(self, **kwargs: Any) -> Type[StencilObject]:
build_info = self.builder.options.build_info
if build_info is not None:
start_time = time.perf_counter()
Expand Down Expand Up @@ -323,7 +325,7 @@ def __call__(self, *, args_data: Optional[ModuleData] = None, **kwargs: Any) ->
class PurePythonBackendCLIMixin(CLIBackendMixin):
"""Mixin for CLI support for backends deriving from BaseBackend."""

builder: "StencilBuilder"
builder: StencilBuilder

#: stencil python source generator method:
#: In order to use this mixin, the backend class must implement
Expand Down Expand Up @@ -378,7 +380,7 @@ def extra_cache_validation_keys(self) -> List[str]:
return keys

@abc.abstractmethod
def generate(self) -> Type["StencilObject"]:
def generate(self) -> Type[StencilObject]:
pass

def build_extension_module(
Expand Down Expand Up @@ -436,7 +438,7 @@ def disabled(message: str, *, enabled_env_var: str) -> Callable[[Type[Backend]],
else:

def _decorator(cls: Type[Backend]) -> Type[Backend]:
def _no_generate(obj) -> Type["StencilObject"]:
def _no_generate(obj) -> Type[StencilObject]:
raise NotImplementedError(
f"Disabled '{cls.name}' backend: 'f{message}'\n",
f"You can still enable the backend by hand using the environment variable '{enabled_env_var}=1'",
Expand Down
4 changes: 3 additions & 1 deletion src/gt4py/cartesian/backend/cuda_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
# Please, refer to the LICENSE file in the root directory.
# SPDX-License-Identifier: BSD-3-Clause

from __future__ import annotations

from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Type

from gt4py import storage as gt_storage
Expand Down Expand Up @@ -141,7 +143,7 @@ class CudaBackend(BaseGTBackend, CLIBackendMixin):
def generate_extension(self, **kwargs: Any) -> Tuple[str, str]:
return self.make_extension(stencil_ir=self.builder.gtir, uses_cuda=True)

def generate(self) -> Type["StencilObject"]:
def generate(self) -> Type[StencilObject]:
self.check_options(self.builder.options)

pyext_module_name: Optional[str]
Expand Down
8 changes: 5 additions & 3 deletions src/gt4py/cartesian/backend/dace_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
# Please, refer to the LICENSE file in the root directory.
# SPDX-License-Identifier: BSD-3-Clause

from __future__ import annotations

import copy
import os
import pathlib
Expand Down Expand Up @@ -302,7 +304,7 @@ def freeze_origin_domain_sdfg(inner_sdfg, arg_names, field_info, *, origin, doma
for node in states[0].nodes():
state.remove_node(node)

# make sure that symbols are passed throught o inner sdfg
# make sure that symbols are passed through to inner sdfg
for symbol in nsdfg.sdfg.free_symbols:
if symbol not in wrapper_sdfg.symbols:
wrapper_sdfg.add_symbol(symbol, nsdfg.sdfg.symbols[symbol])
Expand Down Expand Up @@ -531,7 +533,7 @@ def keep_line(line):
return generated_code

@classmethod
def apply(cls, stencil_ir: gtir.Stencil, builder: "StencilBuilder", sdfg: dace.SDFG):
def apply(cls, stencil_ir: gtir.Stencil, builder: StencilBuilder, sdfg: dace.SDFG):
self = cls()
with dace.config.temporary_config():
# To prevent conflict with 3rd party usage of DaCe config always make sure that any
Expand Down Expand Up @@ -765,7 +767,7 @@ class BaseDaceBackend(BaseGTBackend, CLIBackendMixin):
GT_BACKEND_T = "dace"
PYEXT_GENERATOR_CLASS = DaCeExtGenerator # type: ignore

def generate(self) -> Type["StencilObject"]:
def generate(self) -> Type[StencilObject]:
self.check_options(self.builder.options)

pyext_module_name: Optional[str]
Expand Down
4 changes: 3 additions & 1 deletion src/gt4py/cartesian/backend/dace_lazy_stencil.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
# Please, refer to the LICENSE file in the root directory.
# SPDX-License-Identifier: BSD-3-Clause

from __future__ import annotations

from typing import TYPE_CHECKING, Any, Dict, Optional, Sequence, Set, Tuple

import dace
Expand All @@ -22,7 +24,7 @@


class DaCeLazyStencil(LazyStencil, SDFGConvertible):
def __init__(self, builder: "StencilBuilder"):
def __init__(self, builder: StencilBuilder):
if "dace" not in builder.backend.name:
raise ValueError("Trying to build a DaCeLazyStencil for non-dace backend.")
super().__init__(builder=builder)
Expand Down
10 changes: 6 additions & 4 deletions src/gt4py/cartesian/backend/dace_stencil_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
# Please, refer to the LICENSE file in the root directory.
# SPDX-License-Identifier: BSD-3-Clause

from __future__ import annotations

import copy
import inspect
import os
Expand Down Expand Up @@ -58,7 +60,7 @@ def add_optional_fields(

@dataclass(frozen=True)
class DaCeFrozenStencil(FrozenStencil, SDFGConvertible):
stencil_object: "DaCeStencilObject"
stencil_object: DaCeStencilObject
origin: Dict[str, Tuple[int, ...]]
domain: Tuple[int, ...]
sdfg: dace.SDFG
Expand Down Expand Up @@ -95,7 +97,7 @@ def _get_domain_origin_key(domain, origin):
return domain, origins_tuple

def freeze(
self: "DaCeStencilObject", *, origin: Dict[str, Tuple[int, ...]], domain: Tuple[int, ...]
self: DaCeStencilObject, *, origin: Dict[str, Tuple[int, ...]], domain: Tuple[int, ...]
) -> DaCeFrozenStencil:
key = DaCeStencilObject._get_domain_origin_key(domain, origin)
if key in self._frozen_cache:
Expand Down Expand Up @@ -135,8 +137,8 @@ def closure_resolver(
self,
constant_args: Dict[str, Any],
given_args: Set[str],
parent_closure: Optional["dace.frontend.python.common.SDFGClosure"] = None,
) -> "dace.frontend.python.common.SDFGClosure":
parent_closure: Optional[dace.frontend.python.common.SDFGClosure] = None,
) -> dace.frontend.python.common.SDFGClosure:
return dace.frontend.python.common.SDFGClosure()

def __sdfg__(self, *args, **kwargs) -> dace.SDFG:
Expand Down
6 changes: 4 additions & 2 deletions src/gt4py/cartesian/backend/gtc_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
# Please, refer to the LICENSE file in the root directory.
# SPDX-License-Identifier: BSD-3-Clause

from __future__ import annotations

import abc
import os
import textwrap
Expand Down Expand Up @@ -134,7 +136,7 @@ def __init__(self):
self.pyext_file_path = None

def __call__(
self, args_data: ModuleData, builder: Optional["StencilBuilder"] = None, **kwargs: Any
self, args_data: ModuleData, builder: Optional[StencilBuilder] = None, **kwargs: Any
) -> str:
self.pyext_module_name = kwargs["pyext_module_name"]
self.pyext_file_path = kwargs["pyext_file_path"]
Expand Down Expand Up @@ -229,7 +231,7 @@ class BaseGTBackend(gt_backend.BasePyExtBackend, gt_backend.CLIBackendMixin):
PYEXT_GENERATOR_CLASS: Type[BackendCodegen]

@abc.abstractmethod
def generate(self) -> Type["StencilObject"]:
def generate(self) -> Type[StencilObject]:
pass

def generate_computation(self) -> Dict[str, Union[str, Dict]]:
Expand Down
4 changes: 3 additions & 1 deletion src/gt4py/cartesian/backend/gtcpp_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
# Please, refer to the LICENSE file in the root directory.
# SPDX-License-Identifier: BSD-3-Clause

from __future__ import annotations

from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Type

from gt4py import storage as gt_storage
Expand Down Expand Up @@ -129,7 +131,7 @@ class GTBaseBackend(BaseGTBackend, CLIBackendMixin):
def _generate_extension(self, uses_cuda: bool) -> Tuple[str, str]:
return self.make_extension(stencil_ir=self.builder.gtir, uses_cuda=uses_cuda)

def generate(self) -> Type["StencilObject"]:
def generate(self) -> Type[StencilObject]:
self.check_options(self.builder.options)

pyext_module_name: Optional[str]
Expand Down
12 changes: 7 additions & 5 deletions src/gt4py/cartesian/backend/module_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
# Please, refer to the LICENSE file in the root directory.
# SPDX-License-Identifier: BSD-3-Clause

from __future__ import annotations

import abc
import numbers
import sys
Expand Down Expand Up @@ -121,11 +123,11 @@ class BaseModuleGenerator(abc.ABC):

TEMPLATE_RESOURCE = "stencil_module.py.in"

_builder: Optional["StencilBuilder"]
_builder: Optional[StencilBuilder]
args_data: ModuleData
template: jinja2.Template

def __init__(self, builder: Optional["StencilBuilder"] = None):
def __init__(self, builder: Optional[StencilBuilder] = None):
self._builder = builder
self.args_data = ModuleData()
self.template = jinja2.Template(
Expand All @@ -135,7 +137,7 @@ def __init__(self, builder: Optional["StencilBuilder"] = None):
)

def __call__(
self, args_data: ModuleData, builder: Optional["StencilBuilder"] = None, **kwargs: Any
self, args_data: ModuleData, builder: Optional[StencilBuilder] = None, **kwargs: Any
) -> str:
"""
Generate source code for a Python module containing a StencilObject.
Expand Down Expand Up @@ -176,7 +178,7 @@ def __call__(
return module_source

@property
def builder(self) -> "StencilBuilder":
def builder(self) -> StencilBuilder:
"""
Expose the builder reference.

Expand Down Expand Up @@ -205,7 +207,7 @@ def generate_class_name(self) -> str:
"""
Generate the name of the stencil class.

This should ususally be deferred to the chosen caching strategy via
This should usually be deferred to the chosen caching strategy via
the builder object (see default implementation).
"""
return self.builder.class_name
Expand Down
6 changes: 4 additions & 2 deletions src/gt4py/cartesian/backend/numpy_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
# Please, refer to the LICENSE file in the root directory.
# SPDX-License-Identifier: BSD-3-Clause

from __future__ import annotations

import pathlib
from typing import TYPE_CHECKING, Any, ClassVar, Dict, Type, Union, cast

Expand Down Expand Up @@ -51,7 +53,7 @@ def generate_implementation(self) -> str:
return f"computation.run({', '.join(params)})"

@property
def backend(self) -> "NumpyBackend":
def backend(self) -> NumpyBackend:
return cast(NumpyBackend, self.builder.backend)


Expand Down Expand Up @@ -98,7 +100,7 @@ def generate_bindings(self, language_name: str) -> Dict[str, Union[str, Dict]]:
super().generate_bindings(language_name)
return {self.builder.module_path.name: self.make_module_source()}

def generate(self) -> Type["StencilObject"]:
def generate(self) -> Type[StencilObject]:
self.check_options(self.builder.options)
src_dir = self.builder.module_path.parent
if not self.builder.options._impl_opts.get("disable-code-generation", False):
Expand Down
Loading
Loading