From de13b0884fe3f3bca1713720cac4835fdd57bb81 Mon Sep 17 00:00:00 2001 From: Enrique Gonzalez Paredes Date: Tue, 2 Apr 2024 17:35:14 +0200 Subject: [PATCH] Reformat with ruff to reduce vertical space. --- docs/user/next/workshop/exercises/helpers.py | 6 +- src/gt4py/__about__.py | 8 +- src/gt4py/_core/definitions.py | 4 +- src/gt4py/cartesian/backend/base.py | 10 +- src/gt4py/cartesian/backend/cuda_backend.py | 13 +- src/gt4py/cartesian/backend/dace_backend.py | 15 +- .../cartesian/backend/dace_stencil_object.py | 8 +- src/gt4py/cartesian/backend/gtc_common.py | 14 +- src/gt4py/cartesian/backend/gtcpp_backend.py | 10 +- .../cartesian/backend/module_generator.py | 5 +- src/gt4py/cartesian/backend/numpy_backend.py | 7 +- src/gt4py/cartesian/backend/pyext_builder.py | 12 +- src/gt4py/cartesian/cli.py | 15 +- src/gt4py/cartesian/frontend/__init__.py | 8 +- src/gt4py/cartesian/frontend/defir_to_gtir.py | 12 +- .../cartesian/frontend/gtscript_frontend.py | 10 +- src/gt4py/cartesian/gtc/common.py | 10 +- src/gt4py/cartesian/gtc/cuir/cuir_codegen.py | 5 +- .../cartesian/gtc/cuir/extent_analysis.py | 6 +- src/gt4py/cartesian/gtc/cuir/oir_to_cuir.py | 39 +---- .../gtc/dace/expansion/daceir_builder.py | 26 +--- .../gtc/dace/expansion/sdfg_builder.py | 6 +- .../cartesian/gtc/dace/expansion/utils.py | 6 +- .../gtc/dace/expansion_specification.py | 20 +-- src/gt4py/cartesian/gtc/dace/nodes.py | 8 +- src/gt4py/cartesian/gtc/dace/oir_to_dace.py | 16 +- src/gt4py/cartesian/gtc/dace/utils.py | 8 +- src/gt4py/cartesian/gtc/daceir.py | 47 ++---- src/gt4py/cartesian/gtc/gtcpp/oir_to_gtcpp.py | 10 +- src/gt4py/cartesian/gtc/numpy/oir_to_npir.py | 16 +- .../gtc/passes/oir_optimizations/caches.py | 19 +-- .../horizontal_execution_merging.py | 5 +- .../gtc/passes/oir_optimizations/inlining.py | 17 +-- .../oir_optimizations/mask_stmt_merging.py | 4 +- .../gtc/passes/oir_optimizations/pruning.py | 9 +- .../passes/oir_optimizations/temporaries.py | 9 +- .../gtc/passes/oir_optimizations/utils.py | 6 +- .../vertical_loop_merging.py | 6 +- src/gt4py/cartesian/gtscript_imports.py | 6 +- src/gt4py/eve/datamodels/core.py | 29 +--- src/gt4py/eve/extended_typing.py | 13 +- src/gt4py/eve/pattern_matching.py | 5 +- src/gt4py/eve/type_validation.py | 9 +- src/gt4py/eve/utils.py | 16 +- src/gt4py/eve/visitors.py | 2 +- src/gt4py/next/common.py | 21 +-- src/gt4py/next/constructors.py | 12 +- src/gt4py/next/embedded/__init__.py | 7 +- src/gt4py/next/embedded/common.py | 22 +-- src/gt4py/next/embedded/nd_array_field.py | 28 +--- src/gt4py/next/errors/excepthook.py | 6 +- src/gt4py/next/errors/formatting.py | 5 +- .../ffront/ast_passes/unchain_compares.py | 4 +- src/gt4py/next/ffront/decorator.py | 48 ++---- src/gt4py/next/ffront/dialect_parser.py | 5 +- src/gt4py/next/ffront/experimental.py | 6 +- src/gt4py/next/ffront/fbuiltins.py | 30 +--- .../foast_passes/closure_var_folding.py | 4 +- .../ffront/foast_passes/iterable_unpack.py | 18 +-- .../foast_passes/type_alias_replacement.py | 11 +- .../ffront/foast_passes/type_deduction.py | 114 ++++---------- src/gt4py/next/ffront/foast_to_itir.py | 38 +---- src/gt4py/next/ffront/func_to_foast.py | 25 +-- src/gt4py/next/ffront/func_to_past.py | 5 +- src/gt4py/next/ffront/lowering_utils.py | 4 +- .../next/ffront/past_passes/type_deduction.py | 30 +--- src/gt4py/next/ffront/past_process_args.py | 5 +- src/gt4py/next/ffront/past_to_itir.py | 35 +---- src/gt4py/next/ffront/type_info.py | 20 +-- src/gt4py/next/iterator/embedded.py | 34 +---- src/gt4py/next/iterator/ir.py | 15 +- src/gt4py/next/iterator/ir_utils/ir_makers.py | 4 +- src/gt4py/next/iterator/pretty_printer.py | 6 +- src/gt4py/next/iterator/runtime.py | 13 +- src/gt4py/next/iterator/tracing.py | 7 +- .../iterator/transforms/collapse_list_get.py | 3 +- .../next/iterator/transforms/fuse_maps.py | 13 +- .../next/iterator/transforms/global_tmps.py | 3 +- .../inline_center_deref_lift_vars.py | 3 +- .../iterator/transforms/inline_into_scan.py | 5 +- .../next/iterator/transforms/inline_lifts.py | 4 +- .../next/iterator/transforms/pass_manager.py | 8 +- .../iterator/transforms/power_unrolling.py | 7 +- .../iterator/transforms/propagate_deref.py | 3 +- .../next/iterator/transforms/remap_symbols.py | 5 +- .../next/iterator/transforms/trace_shifts.py | 7 +- .../next/iterator/transforms/unroll_reduce.py | 9 +- src/gt4py/next/iterator/type_inference.py | 124 +++------------ src/gt4py/next/otf/binding/cpp_interface.py | 5 +- src/gt4py/next/otf/binding/nanobind.py | 12 +- .../otf/compilation/build_systems/cmake.py | 5 +- .../compilation/build_systems/compiledb.py | 8 +- src/gt4py/next/otf/recipes.py | 13 +- src/gt4py/next/otf/workflow.py | 6 +- .../codegens/gtfn/gtfn_ir.py | 6 +- .../codegens/gtfn/gtfn_ir_to_gtfn_im_ir.py | 30 +--- .../codegens/gtfn/gtfn_module.py | 10 +- .../codegens/gtfn/itir_to_gtfn_ir.py | 27 +--- .../program_processors/processor_interface.py | 6 +- .../next/program_processors/runners/dace.py | 3 +- .../runners/dace_iterator/__init__.py | 7 +- .../runners/dace_iterator/itir_to_sdfg.py | 53 ++----- .../runners/dace_iterator/itir_to_tasklet.py | 75 ++------- .../runners/dace_iterator/workflow.py | 10 +- .../runners/double_roundtrip.py | 2 +- .../next/program_processors/runners/gtfn.py | 8 +- .../program_processors/runners/roundtrip.py | 14 +- src/gt4py/next/type_system/type_info.py | 43 ++---- src/gt4py/storage/allocators.py | 10 +- src/gt4py/storage/cartesian/utils.py | 3 +- .../feature_tests/test_call_interface.py | 32 +--- .../feature_tests/test_exec_info.py | 3 +- .../test_code_generation.py | 15 +- .../multi_feature_tests/test_dace_parsing.py | 11 +- .../multi_feature_tests/test_suites.py | 58 ++----- .../defir_to_gtir_definition_setup.py | 28 +--- .../frontend_tests/test_defir_to_gtir.py | 2 +- .../frontend_tests/test_gtscript_frontend.py | 20 +-- tests/cartesian_tests/unit_tests/test_cli.py | 6 +- .../unit_tests/test_gtc/test_common.py | 21 +-- .../test_gtc/test_cuir_compilation.py | 11 +- .../test_gtc/test_cuir_kernel_fusion.py | 6 +- .../test_gtc/test_gtcpp_compilation.py | 23 +-- .../unit_tests/test_gtc/test_gtir.py | 16 +- .../test_gtc/test_gtir_dtype_resolver.py | 11 +- .../unit_tests/test_gtc/test_gtir_to_oir.py | 6 +- .../unit_tests/test_gtc/test_gtir_upcaster.py | 6 +- .../unit_tests/test_gtc/test_npir_codegen.py | 26 +--- .../unit_tests/test_gtc/test_oir.py | 10 +- .../test_gtc/test_oir_access_kinds.py | 4 +- .../unit_tests/test_gtc/test_oir_to_gtcpp.py | 14 +- .../unit_tests/test_gtc/test_oir_to_npir.py | 14 +- .../test_oir_optimizations/test_caches.py | 38 +---- .../test_horizontal_execution_merging.py | 5 +- .../test_oir_optimizations/test_inlining.py | 17 +-- .../test_oir_optimizations/test_pruning.py | 14 +- .../test_temporaries.py | 5 +- .../test_oir_optimizations/test_utils.py | 9 +- tests/eve_tests/definitions.py | 11 +- tests/eve_tests/unit_tests/test_utils.py | 7 +- tests/next_tests/definitions.py | 8 +- tests/next_tests/integration_tests/cases.py | 26 +--- .../ffront_tests/ffront_test_utils.py | 10 +- .../ffront_tests/test_arg_call_interface.py | 16 +- .../ffront_tests/test_concat_where.py | 5 +- .../ffront_tests/test_execution.py | 119 +++------------ .../ffront_tests/test_external_local_field.py | 7 +- .../ffront_tests/test_gt4py_builtins.py | 13 +- .../ffront_tests/test_program.py | 9 +- .../ffront_tests/test_scalar_if.py | 8 +- .../test_temporaries_with_sizes.py | 10 +- .../ffront_tests/test_type_deduction.py | 94 +++--------- .../feature_tests/ffront_tests/test_where.py | 8 +- .../iterator_tests/test_builtins.py | 10 +- .../test_cartesian_offset_provider.py | 10 +- .../iterator_tests/test_conditional.py | 12 +- .../test_horizontal_indirection.py | 5 +- .../test_strided_offset_provider.py | 14 +- .../iterator_tests/test_trivial.py | 5 +- .../iterator_tests/test_tuple.py | 144 ++++-------------- .../feature_tests/math_builtin_test_data.py | 95 ++---------- .../otf_tests/test_nanobind_build.py | 4 +- .../feature_tests/test_util_cases.py | 4 +- .../cpp_backend_tests/fvm_nabla.py | 7 +- .../ffront_tests/test_icon_like_scan.py | 24 +-- .../ffront_tests/test_laplacian.py | 6 +- .../multi_feature_tests/fvm_nabla_setup.py | 5 +- .../iterator_tests/test_column_stencil.py | 35 +---- .../iterator_tests/test_fvm_nabla.py | 47 +----- .../iterator_tests/test_vertical_advection.py | 8 +- .../test_with_toy_connectivity.py | 16 +- .../unit_tests/embedded_tests/test_common.py | 28 +--- .../embedded_tests/test_nd_array_field.py | 91 +++-------- .../ffront_tests/test_foast_to_itir.py | 34 +---- .../ffront_tests/test_func_to_foast.py | 55 ++----- .../test_func_to_foast_error_line_number.py | 5 +- .../ffront_tests/test_func_to_past.py | 38 ++--- .../ffront_tests/test_past_to_itir.py | 7 +- .../iterator_tests/test_pretty_parser.py | 20 +-- .../iterator_tests/test_pretty_printer.py | 52 ++----- .../iterator_tests/test_runtime_domain.py | 21 +-- .../iterator_tests/test_type_inference.py | 53 ++----- .../test_collapse_list_get.py | 5 +- .../transforms_tests/test_cse.py | 18 +-- .../transforms_tests/test_fuse_maps.py | 18 +-- .../transforms_tests/test_global_tmps.py | 41 +---- .../transforms_tests/test_inline_lambdas.py | 6 +- .../test_scan_eta_reduction.py | 3 +- .../test_simple_inline_heuristic.py | 10 +- .../transforms_tests/test_trace_shifts.py | 12 +- .../transforms_tests/test_unroll_reduce.py | 8 +- .../binding_tests/test_cpp_interface.py | 7 +- .../build_systems_tests/conftest.py | 4 +- .../gtfn_tests/test_gtfn_module.py | 5 +- .../gtfn_tests/test_itir_to_gtfn_ir.py | 7 +- tests/next_tests/unit_tests/test_common.py | 13 +- .../unit_tests/test_constructors.py | 23 +-- .../test_type_translation.py | 10 +- .../unit_tests/test_interface.py | 12 +- 199 files changed, 729 insertions(+), 2777 deletions(-) diff --git a/docs/user/next/workshop/exercises/helpers.py b/docs/user/next/workshop/exercises/helpers.py index ee6926b42f..1f23c1e242 100644 --- a/docs/user/next/workshop/exercises/helpers.py +++ b/docs/user/next/workshop/exercises/helpers.py @@ -25,11 +25,7 @@ ) -def random_mask( - sizes, - *dims, - dtype=None, -) -> MutableLocatedField: +def random_mask(sizes, *dims, dtype=None) -> MutableLocatedField: arr = np.full(shape=sizes, fill_value=False).flatten() arr[: int(arr.size * 0.5)] = True np.random.shuffle(arr) diff --git a/src/gt4py/__about__.py b/src/gt4py/__about__.py index 7107c1669a..64428efd24 100644 --- a/src/gt4py/__about__.py +++ b/src/gt4py/__about__.py @@ -19,13 +19,7 @@ from packaging import version as pkg_version -__all__ = [ - "__author__", - "__copyright__", - "__license__", - "__version__", - "__version_info__", -] +__all__ = ["__author__", "__copyright__", "__license__", "__version__", "__version_info__"] __author__: Final = "ETH Zurich and individual contributors" diff --git a/src/gt4py/_core/definitions.py b/src/gt4py/_core/definitions.py index 440dba9455..769be9ba5a 100644 --- a/src/gt4py/_core/definitions.py +++ b/src/gt4py/_core/definitions.py @@ -143,9 +143,7 @@ def is_positive_integral_type(integral_type: type) -> TypeGuard[Type[PositiveInt ] # TODO(egparedes) figure out if PositiveIntegral can be made to work -def is_valid_tensor_shape( - value: Sequence[IntegralScalar], -) -> TypeGuard[TensorShape]: +def is_valid_tensor_shape(value: Sequence[IntegralScalar]) -> TypeGuard[TensorShape]: return isinstance(value, collections.abc.Sequence) and all( isinstance(v, numbers.Integral) and v > 0 for v in value ) diff --git a/src/gt4py/cartesian/backend/base.py b/src/gt4py/cartesian/backend/base.py index 5325893a04..4e9cd2d1b5 100644 --- a/src/gt4py/cartesian/backend/base.py +++ b/src/gt4py/cartesian/backend/base.py @@ -276,10 +276,7 @@ def check_options(self, options: gt_definitions.BuildOptions) -> None: stacklevel=2, ) - def make_module( - self, - **kwargs: Any, - ) -> Type["StencilObject"]: + def make_module(self, **kwargs: Any) -> Type["StencilObject"]: build_info = self.builder.options.build_info if build_info is not None: start_time = time.perf_counter() @@ -412,10 +409,7 @@ def build_extension_module( assert module_name == qualified_pyext_name self.builder.with_backend_data( - { - "pyext_module_name": module_name, - "pyext_file_path": file_path, - } + {"pyext_module_name": module_name, "pyext_file_path": file_path} ) return module_name, file_path diff --git a/src/gt4py/cartesian/backend/cuda_backend.py b/src/gt4py/cartesian/backend/cuda_backend.py index 14855d3ffa..84d1949818 100644 --- a/src/gt4py/cartesian/backend/cuda_backend.py +++ b/src/gt4py/cartesian/backend/cuda_backend.py @@ -89,8 +89,7 @@ def visit_FieldDecl(self, node: cuir.FieldDecl, **kwargs): sid_ndim = domain_ndim + data_ndim if kwargs["external_arg"]: return "py::object {name}, std::array {name}_origin".format( - name=node.name, - sid_ndim=sid_ndim, + name=node.name, sid_ndim=sid_ndim ) else: return pybuffer_to_sid( @@ -113,12 +112,7 @@ def visit_Program(self, node: cuir.Program, **kwargs): assert "module_name" in kwargs entry_params = self.visit(node.params, external_arg=True, **kwargs) sid_params = self.visit(node.params, external_arg=False, **kwargs) - return self.generic_visit( - node, - entry_params=entry_params, - sid_params=sid_params, - **kwargs, - ) + return self.generic_visit(node, entry_params=entry_params, sid_params=sid_params, **kwargs) Program = bindings_main_template() @@ -157,6 +151,5 @@ def generate(self) -> Type["StencilObject"]: # Generate and return the Python wrapper class return self.make_module( - pyext_module_name=pyext_module_name, - pyext_file_path=pyext_file_path, + pyext_module_name=pyext_module_name, pyext_file_path=pyext_file_path ) diff --git a/src/gt4py/cartesian/backend/dace_backend.py b/src/gt4py/cartesian/backend/dace_backend.py index 6d60422d5a..dd725b6b77 100644 --- a/src/gt4py/cartesian/backend/dace_backend.py +++ b/src/gt4py/cartesian/backend/dace_backend.py @@ -67,8 +67,7 @@ def _serialize_sdfg(sdfg: dace.SDFG): def _specialize_transient_strides(sdfg: dace.SDFG, layout_map): repldict = replace_strides( - [array for array in sdfg.arrays.values() if array.transient], - layout_map, + [array for array in sdfg.arrays.values() if array.transient], layout_map ) sdfg.replace_dict(repldict) for state in sdfg.nodes(): @@ -221,9 +220,7 @@ def _sdfg_add_arrays_and_edges( ranges = [ (o - max(0, e), o - max(0, e) + s - 1, 1) for o, e, s in zip( - origin, - field_info[name].boundary.lower_indices, - inner_sdfg.arrays[name].shape, + origin, field_info[name].boundary.lower_indices, inner_sdfg.arrays[name].shape ) ] ranges += [(0, d, 1) for d in field_info[name].data_dims] @@ -786,8 +783,7 @@ def generate(self) -> Type["StencilObject"]: # Generate and return the Python wrapper class return self.make_module( - pyext_module_name=pyext_module_name, - pyext_file_path=pyext_file_path, + pyext_module_name=pyext_module_name, pyext_file_path=pyext_file_path ) @@ -826,10 +822,7 @@ class DaceGPUBackend(BaseDaceBackend): ), } MODULE_GENERATOR_CLASS = DaCeCUDAPyExtModuleGenerator - options = { - **BaseGTBackend.GT_BACKEND_OPTS, - "device_sync": {"versioning": True, "type": bool}, - } + options = {**BaseGTBackend.GT_BACKEND_OPTS, "device_sync": {"versioning": True, "type": bool}} def generate_extension(self, **kwargs: Any) -> Tuple[str, str]: return self.make_extension(stencil_ir=self.builder.gtir, uses_cuda=True) diff --git a/src/gt4py/cartesian/backend/dace_stencil_object.py b/src/gt4py/cartesian/backend/dace_stencil_object.py index 71444ed86b..b58cebbf9d 100644 --- a/src/gt4py/cartesian/backend/dace_stencil_object.py +++ b/src/gt4py/cartesian/backend/dace_stencil_object.py @@ -50,10 +50,7 @@ def add_optional_fields( if info.access == AccessKind.NONE and name in kwargs and name not in sdfg.arrays: outer_array = kwargs[name] sdfg.add_array( - name, - shape=outer_array.shape, - dtype=outer_array.dtype, - strides=outer_array.strides, + name, shape=outer_array.shape, dtype=outer_array.dtype, strides=outer_array.strides ) for name, info in parameter_info.items(): @@ -194,8 +191,7 @@ def normalize_args( name: (kwargs[name] if name in kwargs else next(args_iter)) for name in arg_names } arg_infos = _extract_array_infos( - field_args=args_as_kwargs, - device=backend_cls.storage_info["device"], + field_args=args_as_kwargs, device=backend_cls.storage_info["device"] ) origin = DaCeStencilObject._normalize_origins(arg_infos, field_info, origin) diff --git a/src/gt4py/cartesian/backend/gtc_common.py b/src/gt4py/cartesian/backend/gtc_common.py index 43371cf8ac..beaf70c567 100644 --- a/src/gt4py/cartesian/backend/gtc_common.py +++ b/src/gt4py/cartesian/backend/gtc_common.py @@ -59,15 +59,10 @@ def pybuffer_to_sid( sid_def = """gt::{as_sid}<{ctype}, {sid_ndim}, gt::integral_constant>({name})""".format( - name=name, - ctype=ctype, - unique_index=stride_kind_index, - sid_ndim=sid_ndim, - as_sid=as_sid, + name=name, ctype=ctype, unique_index=stride_kind_index, sid_ndim=sid_ndim, as_sid=as_sid ) sid_def = "gt::sid::shift_sid_origin({sid_def}, {name}_origin)".format( - sid_def=sid_def, - name=name, + sid_def=sid_def, name=name ) if domain_ndim != 3: gt_dims = [ @@ -145,10 +140,7 @@ def __init__(self): self.pyext_file_path = None def __call__( - self, - args_data: ModuleData, - builder: Optional["StencilBuilder"] = None, - **kwargs: Any, + self, args_data: ModuleData, builder: Optional["StencilBuilder"] = None, **kwargs: Any ) -> str: self.pyext_module_name = kwargs["pyext_module_name"] self.pyext_file_path = kwargs["pyext_file_path"] diff --git a/src/gt4py/cartesian/backend/gtcpp_backend.py b/src/gt4py/cartesian/backend/gtcpp_backend.py index c69b5b5088..671f1296fa 100644 --- a/src/gt4py/cartesian/backend/gtcpp_backend.py +++ b/src/gt4py/cartesian/backend/gtcpp_backend.py @@ -115,12 +115,7 @@ def visit_Program(self, node: gtcpp.Program, **kwargs): assert "module_name" in kwargs entry_params = self.visit(node.parameters, external_arg=True, **kwargs) sid_params = self.visit(node.parameters, external_arg=False, **kwargs) - return self.generic_visit( - node, - entry_params=entry_params, - sid_params=sid_params, - **kwargs, - ) + return self.generic_visit(node, entry_params=entry_params, sid_params=sid_params, **kwargs) Program = bindings_main_template() @@ -151,8 +146,7 @@ def generate(self) -> Type["StencilObject"]: # Generate and return the Python wrapper class return self.make_module( - pyext_module_name=pyext_module_name, - pyext_file_path=pyext_file_path, + pyext_module_name=pyext_module_name, pyext_file_path=pyext_file_path ) diff --git a/src/gt4py/cartesian/backend/module_generator.py b/src/gt4py/cartesian/backend/module_generator.py index 5d5714f88f..7485615459 100644 --- a/src/gt4py/cartesian/backend/module_generator.py +++ b/src/gt4py/cartesian/backend/module_generator.py @@ -141,10 +141,7 @@ def __init__(self, builder: Optional["StencilBuilder"] = None): ) def __call__( - self, - args_data: ModuleData, - builder: Optional["StencilBuilder"] = None, - **kwargs: Any, + self, args_data: ModuleData, builder: Optional["StencilBuilder"] = None, **kwargs: Any ) -> str: """ Generate source code for a Python module containing a StencilObject. diff --git a/src/gt4py/cartesian/backend/numpy_backend.py b/src/gt4py/cartesian/backend/numpy_backend.py index 6f1aab52cf..cb20198d3a 100644 --- a/src/gt4py/cartesian/backend/numpy_backend.py +++ b/src/gt4py/cartesian/backend/numpy_backend.py @@ -117,12 +117,7 @@ def _make_npir(self) -> npir.Computation: oir_pipeline = self.builder.options.backend_opts.get( "oir_pipeline", DefaultPipeline( - skip=[ - IJCacheDetection, - KCacheDetection, - PruneKCacheFills, - PruneKCacheFlushes, - ] + skip=[IJCacheDetection, KCacheDetection, PruneKCacheFills, PruneKCacheFlushes] ), ) oir_node = oir_pipeline.run(base_oir) diff --git a/src/gt4py/cartesian/backend/pyext_builder.py b/src/gt4py/cartesian/backend/pyext_builder.py index 9d04cf3413..018c050e3c 100644 --- a/src/gt4py/cartesian/backend/pyext_builder.py +++ b/src/gt4py/cartesian/backend/pyext_builder.py @@ -174,11 +174,7 @@ def get_gt_pyext_build_opts( # The following tells mypy to accept unpacking kwargs @overload def build_pybind_ext( - name: str, - sources: list, - build_path: str, - target_path: str, - **kwargs: str, + name: str, sources: list, build_path: str, target_path: str, **kwargs: str ) -> Tuple[str, str]: ... @@ -283,11 +279,7 @@ def build_pybind_ext( # The following tells mypy to accept unpacking kwargs @overload def build_pybind_cuda_ext( - name: str, - sources: list, - build_path: str, - target_path: str, - **kwargs: str, + name: str, sources: list, build_path: str, target_path: str, **kwargs: str ) -> Tuple[str, str]: pass diff --git a/src/gt4py/cartesian/cli.py b/src/gt4py/cartesian/cli.py index 4dcb8f1ee4..be941cca93 100644 --- a/src/gt4py/cartesian/cli.py +++ b/src/gt4py/cartesian/cli.py @@ -60,10 +60,7 @@ class BackendChoice(click.Choice): name = "backend" def convert( - self, - value: str, - param: Optional[click.Parameter], - ctx: Optional[click.Context], + self, value: str, param: Optional[click.Parameter], ctx: Optional[click.Context] ) -> Type[CLIBackendMixin]: """Convert a CLI option argument to a backend.""" name = super().convert(value, param, ctx) @@ -239,10 +236,7 @@ def write_computation_src( self.reporter.echo(f"Writing source file: {file_path}") file_path.write_text(content) - def generate_stencils( - self, - build_options: Optional[Dict[str, Any]] = None, - ) -> None: + def generate_stencils(self, build_options: Optional[Dict[str, Any]] = None) -> None: for proto_stencil in self.iterate_stencils(): self.reporter.echo(f"Building stencil {proto_stencil.builder.options.name}") builder = proto_stencil.builder.with_backend(self.backend_cls.name) @@ -314,8 +308,5 @@ def gen( ) -> None: """Generate stencils from gtscript modules or packages.""" GTScriptBuilder( - input_path=input_path, - output_path=output_path, - backend=backend, - silent=silent, + input_path=input_path, output_path=output_path, backend=backend, silent=silent ).generate_stencils(build_options=dict(options)) diff --git a/src/gt4py/cartesian/frontend/__init__.py b/src/gt4py/cartesian/frontend/__init__.py index 20b5213138..f67c917116 100644 --- a/src/gt4py/cartesian/frontend/__init__.py +++ b/src/gt4py/cartesian/frontend/__init__.py @@ -16,10 +16,4 @@ from .base import REGISTRY, Frontend, from_name, register -__all__ = [ - "gtscript_frontend", - "REGISTRY", - "Frontend", - "from_name", - "register", -] +__all__ = ["gtscript_frontend", "REGISTRY", "Frontend", "from_name", "register"] diff --git a/src/gt4py/cartesian/frontend/defir_to_gtir.py b/src/gt4py/cartesian/frontend/defir_to_gtir.py index 2471063789..635b7e6fdc 100644 --- a/src/gt4py/cartesian/frontend/defir_to_gtir.py +++ b/src/gt4py/cartesian/frontend/defir_to_gtir.py @@ -196,10 +196,7 @@ def visit_FieldRef(self, node: FieldRef, *, fields_decls: Dict[str, FieldDecl], data_type = DataType.INT32 data_index = [ScalarLiteral(value=index, data_type=data_type)] element_ref = FieldRef( - name=node.name, - offset=node.offset, - data_index=data_index, - loc=node.loc, + name=node.name, offset=node.offset, data_index=data_index, loc=node.loc ) field_list.append(element_ref) # matrix @@ -398,9 +395,7 @@ def visit_ComputationBlock(self, node: ComputationBlock) -> gtir.VerticalLoop: stmts.append(decl_or_stmt) start, end = self.visit(node.interval) interval = gtir.Interval( - start=start, - end=end, - loc=location_to_source_location(node.interval.loc), + start=start, end=end, loc=location_to_source_location(node.interval.loc) ) return gtir.VerticalLoop( interval=interval, @@ -525,8 +520,7 @@ def make_bound_or_level(bound: AxisBound, level) -> Optional[common.AxisBound]: } return gtir.HorizontalRestriction( - mask=common.HorizontalMask(**axes), - body=self.visit(node.body), + mask=common.HorizontalMask(**axes), body=self.visit(node.body) ) def visit_While(self, node: While) -> gtir.While: diff --git a/src/gt4py/cartesian/frontend/gtscript_frontend.py b/src/gt4py/cartesian/frontend/gtscript_frontend.py index fc19b8c253..3a50402c3d 100644 --- a/src/gt4py/cartesian/frontend/gtscript_frontend.py +++ b/src/gt4py/cartesian/frontend/gtscript_frontend.py @@ -658,8 +658,7 @@ def _make_init_computations( else: stmts.append( nodes.Assign( - target=nodes.FieldRef.at_center(name, axes=decl.axes), - value=init_values[name], + target=nodes.FieldRef.at_center(name, axes=decl.axes), value=init_values[name] ) ) @@ -982,8 +981,7 @@ def _visit_computation_node(self, node: ast.With) -> nodes.ComputationBlock: if intervals_dicts: stmts = [ nodes.HorizontalIf( - intervals=intervals_dict, - body=nodes.BlockStmt(stmts=stmts, loc=loc), + intervals=intervals_dict, body=nodes.BlockStmt(stmts=stmts, loc=loc) ) for intervals_dict in intervals_dicts ] @@ -1010,9 +1008,7 @@ def visit_Constant( elif isinstance(value, bool): return nodes.Cast( data_type=nodes.DataType.BOOL, - expr=nodes.BuiltinLiteral( - value=nodes.Builtin.from_value(value), - ), + expr=nodes.BuiltinLiteral(value=nodes.Builtin.from_value(value)), loc=nodes.Location.from_ast_node(node), ) elif isinstance(value, numbers.Number): diff --git a/src/gt4py/cartesian/gtc/common.py b/src/gt4py/cartesian/gtc/common.py index 1e0364d721..19ec1437e9 100644 --- a/src/gt4py/cartesian/gtc/common.py +++ b/src/gt4py/cartesian/gtc/common.py @@ -406,10 +406,7 @@ def _make_root_validator(impl: datamodels.RootValidator) -> datamodels.RootValid def assign_stmt_dtype_validation(*, strict: bool) -> datamodels.RootValidator: - def _impl( - cls: Type[datamodels.DataModel], - instance: datamodels.DataModel, - ) -> None: + def _impl(cls: Type[datamodels.DataModel], instance: datamodels.DataModel) -> None: assert isinstance(instance, AssignStmt) verify_and_get_common_dtype(cls, [instance.left, instance.right], strict=strict) @@ -865,10 +862,7 @@ def data_type_to_typestr(dtype: DataType) -> str: ComparisonOperator.EQ: "equal", ComparisonOperator.NE: "not_equal", }, - LogicalOperator: { - LogicalOperator.AND: "logical_and", - LogicalOperator.OR: "logical_or", - }, + LogicalOperator: {LogicalOperator.AND: "logical_and", LogicalOperator.OR: "logical_or"}, NativeFunction: { NativeFunction.ABS: "abs", NativeFunction.MIN: "minimum", diff --git a/src/gt4py/cartesian/gtc/cuir/cuir_codegen.py b/src/gt4py/cartesian/gtc/cuir/cuir_codegen.py index 142c721a0e..6ac304b437 100644 --- a/src/gt4py/cartesian/gtc/cuir/cuir_codegen.py +++ b/src/gt4py/cartesian/gtc/cuir/cuir_codegen.py @@ -127,10 +127,7 @@ def visit_IJCacheAccess( Cast = as_fmt("static_cast<{dtype}>({expr})") - BUILTIN_LITERAL_TO_CODE = { - BuiltInLiteral.TRUE: "true", - BuiltInLiteral.FALSE: "false", - } + BUILTIN_LITERAL_TO_CODE = {BuiltInLiteral.TRUE: "true", BuiltInLiteral.FALSE: "false"} def visit_BuiltInLiteral(self, builtin: BuiltInLiteral, **kwargs: Any) -> str: try: diff --git a/src/gt4py/cartesian/gtc/cuir/extent_analysis.py b/src/gt4py/cartesian/gtc/cuir/extent_analysis.py index 5b5c8e347b..7e296097a8 100644 --- a/src/gt4py/cartesian/gtc/cuir/extent_analysis.py +++ b/src/gt4py/cartesian/gtc/cuir/extent_analysis.py @@ -22,11 +22,7 @@ class CacheExtents(NodeTranslator): def visit_IJCacheDecl( - self, - node: cuir.IJCacheDecl, - *, - ij_extents: Dict[str, cuir.KExtent], - **kwargs: Any, + self, node: cuir.IJCacheDecl, *, ij_extents: Dict[str, cuir.KExtent], **kwargs: Any ) -> cuir.IJCacheDecl: return cuir.IJCacheDecl(name=node.name, dtype=node.dtype, extent=ij_extents[node.name]) diff --git a/src/gt4py/cartesian/gtc/cuir/oir_to_cuir.py b/src/gt4py/cartesian/gtc/cuir/oir_to_cuir.py index de1ca93557..db4bdd089b 100644 --- a/src/gt4py/cartesian/gtc/cuir/oir_to_cuir.py +++ b/src/gt4py/cartesian/gtc/cuir/oir_to_cuir.py @@ -86,9 +86,7 @@ def visit_UnaryOp(self, node: oir.UnaryOp, **kwargs: Any) -> cuir.UnaryOp: def visit_BinaryOp(self, node: oir.BinaryOp, **kwargs: Any) -> cuir.BinaryOp: return cuir.BinaryOp( - op=node.op, - left=self.visit(node.left, **kwargs), - right=self.visit(node.right, **kwargs), + op=node.op, left=self.visit(node.left, **kwargs), right=self.visit(node.right, **kwargs) ) def visit_Temporary(self, node: oir.Temporary, **kwargs: Any) -> cuir.Temporary: @@ -152,19 +150,9 @@ def visit_FieldAccess( **kwargs: Any, ) -> Union[cuir.FieldAccess, cuir.IJCacheAccess, cuir.KCacheAccess]: data_index = self.visit( - node.data_index, - ij_caches=ij_caches, - k_caches=k_caches, - ctx=ctx, - **kwargs, - ) - offset = self.visit( - node.offset, - ij_caches=ij_caches, - k_caches=k_caches, - ctx=ctx, - **kwargs, + node.data_index, ij_caches=ij_caches, k_caches=k_caches, ctx=ctx, **kwargs ) + offset = self.visit(node.offset, ij_caches=ij_caches, k_caches=k_caches, ctx=ctx, **kwargs) if node.name in ij_caches: return cuir.IJCacheAccess( name=ij_caches[node.name].name, @@ -181,10 +169,7 @@ def visit_FieldAccess( ) ctx.accessed_fields.add(node.name) return cuir.FieldAccess( - name=node.name, - offset=offset, - data_index=data_index, - dtype=node.dtype, + name=node.name, offset=offset, data_index=data_index, dtype=node.dtype ) def visit_ScalarAccess( @@ -247,12 +232,7 @@ def visit_VerticalLoopSection( ) def visit_VerticalLoop( - self, - node: oir.VerticalLoop, - *, - symtable: Dict[str, Any], - ctx: "Context", - **kwargs: Any, + self, node: oir.VerticalLoop, *, symtable: Dict[str, Any], ctx: "Context", **kwargs: Any ) -> cuir.Kernel: assert not any(c.fill or c.flush for c in node.caches if isinstance(c, oir.KCache)) ij_caches = { @@ -280,7 +260,7 @@ def visit_VerticalLoop( ij_caches=list(ij_caches.values()), k_caches=list(k_caches.values()), ) - ], + ] ) def visit_Stencil(self, node: oir.Stencil, **kwargs: Any) -> cuir.Program: @@ -288,12 +268,7 @@ def visit_Stencil(self, node: oir.Stencil, **kwargs: Any) -> cuir.Program: ctx = self.Context( new_symbol_name=cast(SymbolNameCreator, symbol_name_creator(collect_symbol_names(node))) ) - kernels = self.visit( - node.vertical_loops, - ctx=ctx, - block_extents=block_extents, - **kwargs, - ) + kernels = self.visit(node.vertical_loops, ctx=ctx, block_extents=block_extents, **kwargs) temporaries = [self.visit(d) for d in node.declarations if d.name in ctx.accessed_fields] return cuir.Program( name=node.name, diff --git a/src/gt4py/cartesian/gtc/dace/expansion/daceir_builder.py b/src/gt4py/cartesian/gtc/dace/expansion/daceir_builder.py index 9a214441ad..3dd05f97fa 100644 --- a/src/gt4py/cartesian/gtc/dace/expansion/daceir_builder.py +++ b/src/gt4py/cartesian/gtc/dace/expansion/daceir_builder.py @@ -233,8 +233,7 @@ def push_interval( self, axis: dcir.Axis, interval: Union[dcir.DomainInterval, oir.Interval] ) -> "DaCeIRBuilder.IterationContext": return DaCeIRBuilder.IterationContext( - grid_subset=self.grid_subset.set_interval(axis, interval), - parent=self, + grid_subset=self.grid_subset.set_interval(axis, interval), parent=self ) def push_expansion_item(self, item: Union[Map, Loop]) -> "DaCeIRBuilder.IterationContext": @@ -310,8 +309,7 @@ def visit_HorizontalRestriction( if bound.level == common.LevelMarker.END: symbol_collector.add_symbol(axis.domain_symbol()) return dcir.HorizontalRestriction( - mask=node.mask, - body=self.visit(node.body, symbol_collector=symbol_collector, **kwargs), + mask=node.mask, body=self.visit(node.body, symbol_collector=symbol_collector, **kwargs) ) def visit_VariableKOffset(self, node: oir.VariableKOffset, **kwargs): @@ -461,10 +459,7 @@ def visit_HorizontalExecution( ) dcir_node = dcir.Tasklet( - decls=decls, - stmts=stmts, - read_memlets=read_memlets, - write_memlets=write_memlets, + decls=decls, stmts=stmts, read_memlets=read_memlets, write_memlets=write_memlets ) for memlet in [*read_memlets, *write_memlets]: @@ -648,11 +643,7 @@ def _process_map_item( ) symbol_collector.remove_symbol(axis.tile_symbol()) ranges.append( - dcir.Range( - var=axis.tile_symbol(), - interval=interval, - stride=iteration.stride, - ) + dcir.Range(var=axis.tile_symbol(), interval=interval, stride=iteration.stride) ) else: if _all_stmts_same_region(scope_nodes, axis, interval): @@ -800,11 +791,7 @@ def _process_iteration_item(self, scope, item, **kwargs): raise ValueError("Invalid expansion specification set.") def visit_VerticalLoop( - self, - node: oir.VerticalLoop, - *, - global_ctx: "DaCeIRBuilder.GlobalContext", - **kwargs, + self, node: oir.VerticalLoop, *, global_ctx: "DaCeIRBuilder.GlobalContext", **kwargs ): start, end = (node.sections[0].interval.start, node.sections[0].interval.end) @@ -863,8 +850,7 @@ def visit_VerticalLoop( read_memlets, write_memlets, field_memlets = union_inout_memlets(computations) field_decls = global_ctx.get_dcir_decls( - global_ctx.library_node.access_infos, - symbol_collector=symbol_collector, + global_ctx.library_node.access_infos, symbol_collector=symbol_collector ) read_fields = set(memlet.field for memlet in read_memlets) diff --git a/src/gt4py/cartesian/gtc/dace/expansion/sdfg_builder.py b/src/gt4py/cartesian/gtc/dace/expansion/sdfg_builder.py index dcffd9e410..7fb4226d93 100644 --- a/src/gt4py/cartesian/gtc/dace/expansion/sdfg_builder.py +++ b/src/gt4py/cartesian/gtc/dace/expansion/sdfg_builder.py @@ -55,11 +55,7 @@ def add_loop(self, index_range: dcir.Range): after_state = self.sdfg.add_state() for edge in self.sdfg.out_edges(self.state): self.sdfg.remove_edge(edge) - self.sdfg.add_edge( - after_state, - edge.dst, - edge.data, - ) + self.sdfg.add_edge(after_state, edge.dst, edge.data) assert isinstance(index_range.interval, dcir.DomainInterval) if index_range.stride < 0: diff --git a/src/gt4py/cartesian/gtc/dace/expansion/utils.py b/src/gt4py/cartesian/gtc/dace/expansion/utils.py index dc10c53f21..b1879d0816 100644 --- a/src/gt4py/cartesian/gtc/dace/expansion/utils.py +++ b/src/gt4py/cartesian/gtc/dace/expansion/utils.py @@ -31,11 +31,7 @@ def get_dace_debuginfo(node: common.LocNode): if node.loc is not None: return dace.dtypes.DebugInfo( - node.loc.line, - node.loc.column, - node.loc.line, - node.loc.column, - node.loc.filename, + node.loc.line, node.loc.column, node.loc.line, node.loc.column, node.loc.filename ) else: return dace.dtypes.DebugInfo(0) diff --git a/src/gt4py/cartesian/gtc/dace/expansion_specification.py b/src/gt4py/cartesian/gtc/dace/expansion_specification.py index 7c99146426..6f810bd932 100644 --- a/src/gt4py/cartesian/gtc/dace/expansion_specification.py +++ b/src/gt4py/cartesian/gtc/dace/expansion_specification.py @@ -147,27 +147,11 @@ def _order_as_spec(computation_node, expansion_order): expansion_specification.append(item) elif axis := _is_tiling(item): expansion_specification.append( - Map( - iterations=[ - Iteration( - axis=axis, - kind="tiling", - stride=None, - ) - ] - ) + Map(iterations=[Iteration(axis=axis, kind="tiling", stride=None)]) ) elif axis := _is_domain_map(item): expansion_specification.append( - Map( - iterations=[ - Iteration( - axis=axis, - kind="contiguous", - stride=1, - ) - ] - ) + Map(iterations=[Iteration(axis=axis, kind="contiguous", stride=1)]) ) elif axis := _is_domain_loop(item): expansion_specification.append( diff --git a/src/gt4py/cartesian/gtc/dace/nodes.py b/src/gt4py/cartesian/gtc/dace/nodes.py index bd8c08034c..7fea45f4cb 100644 --- a/src/gt4py/cartesian/gtc/dace/nodes.py +++ b/src/gt4py/cartesian/gtc/dace/nodes.py @@ -92,14 +92,10 @@ class StencilComputation(library.LibraryNode): dtype=dace.DeviceType, default=dace.DeviceType.CPU, allow_none=True ) expansion_specification = PickledListProperty( - element_type=ExpansionItem, - allow_none=True, - setter=_set_expansion_order, + element_type=ExpansionItem, allow_none=True, setter=_set_expansion_order ) tile_sizes = PickledDictProperty( - key_type=dcir.Axis, - value_type=int, - default={dcir.Axis.I: 8, dcir.Axis.J: 8, dcir.Axis.K: 8}, + key_type=dcir.Axis, value_type=int, default={dcir.Axis.I: 8, dcir.Axis.J: 8, dcir.Axis.K: 8} ) tile_sizes_interpretation = dace.properties.Property( diff --git a/src/gt4py/cartesian/gtc/dace/oir_to_dace.py b/src/gt4py/cartesian/gtc/dace/oir_to_dace.py index 361ae41324..3c5cf292d8 100644 --- a/src/gt4py/cartesian/gtc/dace/oir_to_dace.py +++ b/src/gt4py/cartesian/gtc/dace/oir_to_dace.py @@ -125,30 +125,20 @@ def visit_VerticalLoop( library_node.add_in_connector("__in_" + field) subset = ctx.make_input_dace_subset(node, field) state.add_edge( - access_node, - None, - library_node, - "__in_" + field, - dace.Memlet(field, subset=subset), + access_node, None, library_node, "__in_" + field, dace.Memlet(field, subset=subset) ) for field in access_collection.write_fields(): access_node = state.add_access(field, debuginfo=dace.DebugInfo(0)) library_node.add_out_connector("__out_" + field) subset = ctx.make_output_dace_subset(node, field) state.add_edge( - library_node, - "__out_" + field, - access_node, - None, - dace.Memlet(field, subset=subset), + library_node, "__out_" + field, access_node, None, dace.Memlet(field, subset=subset) ) return def visit_Stencil(self, node: oir.Stencil, **kwargs): - ctx = OirSDFGBuilder.SDFGContext( - stencil=node, - ) + ctx = OirSDFGBuilder.SDFGContext(stencil=node) for param in node.params: if isinstance(param, oir.FieldDecl): dim_strs = [d for i, d in enumerate("IJK") if param.dimensions[i]] + [ diff --git a/src/gt4py/cartesian/gtc/dace/utils.py b/src/gt4py/cartesian/gtc/dace/utils.py index cfde545f40..6baf3c3c46 100644 --- a/src/gt4py/cartesian/gtc/dace/utils.py +++ b/src/gt4py/cartesian/gtc/dace/utils.py @@ -100,13 +100,7 @@ def visit_VerticalLoop( return ctx.access_infos def visit_VerticalLoopSection( - self, - node: oir.VerticalLoopSection, - *, - block_extents, - ctx, - grid_subset=None, - **kwargs: Any, + self, node: oir.VerticalLoopSection, *, block_extents, ctx, grid_subset=None, **kwargs: Any ) -> Dict[str, "dcir.FieldAccessInfo"]: inner_ctx = self.Context(axes=ctx.axes) diff --git a/src/gt4py/cartesian/gtc/daceir.py b/src/gt4py/cartesian/gtc/daceir.py index 0f43537758..055a51e064 100644 --- a/src/gt4py/cartesian/gtc/daceir.py +++ b/src/gt4py/cartesian/gtc/daceir.py @@ -195,18 +195,12 @@ def union(self, other: "IndexWithExtent"): return IndexWithExtent( axis=self.axis, value=value, - extent=( - min(self.extent[0], other.extent[0]), - max(self.extent[1], other.extent[1]), - ), + extent=(min(self.extent[0], other.extent[0]), max(self.extent[1], other.extent[1])), ) @property def idx_range(self): - return ( - f"{self.value}{self.extent[0]:+d}", - f"{self.value}{self.extent[1] + 1:+d}", - ) + return (f"{self.value}{self.extent[0]:+d}", f"{self.value}{self.extent[1] + 1:+d}") def to_dace_symbolic(self): if isinstance(self.value, AxisBound): @@ -249,10 +243,7 @@ def overapproximated_size(self): @classmethod def union(cls, first, second): - return cls( - start=min(first.start, second.start), - end=max(first.end, second.end), - ) + return cls(start=min(first.start, second.start), end=max(first.end, second.end)) @classmethod def intersection(cls, axis, first, second): @@ -281,14 +272,10 @@ def to_dace_symbolic(self): def shifted(self, offset: int): return DomainInterval( start=AxisBound( - axis=self.start.axis, - level=self.start.level, - offset=self.start.offset + offset, + axis=self.start.axis, level=self.start.level, offset=self.start.offset + offset ), end=AxisBound( - axis=self.end.axis, - level=self.end.level, - offset=self.end.offset + offset, + axis=self.end.axis, level=self.end.level, offset=self.end.offset + offset ), ) @@ -305,9 +292,7 @@ class TileInterval(eve.Node): @property def free_symbols(self) -> Set[eve.SymbolRef]: - res = { - self.axis.tile_symbol(), - } + res = {self.axis.tile_symbol()} if self.domain_limit.level == common.LevelMarker.END: res.add(self.axis.domain_symbol()) return res @@ -324,8 +309,7 @@ def size(self): @property def overapproximated_size(self): return "{tile_size}{halo_size:+d}".format( - tile_size=self.tile_size, - halo_size=self.end_offset - self.start_offset, + tile_size=self.tile_size, halo_size=self.end_offset - self.start_offset ) @classmethod @@ -371,11 +355,7 @@ class Range(eve.Node): def from_axis_and_interval( cls, axis: Axis, interval: Union[DomainInterval, TileInterval], stride=1 ): - return cls( - var=axis.iteration_symbol(), - interval=interval, - stride=stride, - ) + return cls(var=axis.iteration_symbol(), interval=interval, stride=stride) @property def free_symbols(self) -> Set[eve.SymbolRef]: @@ -429,9 +409,7 @@ def set_interval( if isinstance(interval, oir.Interval): interval = DomainInterval( start=AxisBound( - level=interval.start.level, - offset=interval.start.offset, - axis=Axis.K, + level=interval.start.level, offset=interval.start.offset, axis=Axis.K ), end=AxisBound(level=interval.end.level, offset=interval.end.offset, axis=Axis.K), ) @@ -464,9 +442,7 @@ def from_interval( if isinstance(interval, (DomainInterval, oir.Interval)): res_interval = DomainInterval( start=AxisBound( - level=interval.start.level, - offset=interval.start.offset, - axis=Axis.K, + level=interval.start.level, offset=interval.start.offset, axis=Axis.K ), end=AxisBound(level=interval.end.level, offset=interval.end.offset, axis=Axis.K), ) @@ -640,8 +616,7 @@ def clamp_full_axis(self, axis): end=AxisBound(level=common.LevelMarker.END, offset=0, axis=axis), ) res_interval = DomainInterval.union( - full_interval, - self.global_grid_subset.intervals.get(axis, full_interval), + full_interval, self.global_grid_subset.intervals.get(axis, full_interval) ) if isinstance(interval, DomainInterval): interval_union = DomainInterval.union(interval, res_interval) diff --git a/src/gt4py/cartesian/gtc/gtcpp/oir_to_gtcpp.py b/src/gt4py/cartesian/gtc/gtcpp/oir_to_gtcpp.py index 06db14940b..859ed281d1 100644 --- a/src/gt4py/cartesian/gtc/gtcpp/oir_to_gtcpp.py +++ b/src/gt4py/cartesian/gtc/gtcpp/oir_to_gtcpp.py @@ -162,9 +162,7 @@ def visit_UnaryOp(self, node: oir.UnaryOp, **kwargs: Any) -> gtcpp.UnaryOp: def visit_BinaryOp(self, node: oir.BinaryOp, **kwargs: Any) -> gtcpp.BinaryOp: return gtcpp.BinaryOp( - op=node.op, - left=self.visit(node.left, **kwargs), - right=self.visit(node.right, **kwargs), + op=node.op, left=self.visit(node.left, **kwargs), right=self.visit(node.right, **kwargs) ) def visit_TernaryOp(self, node: oir.TernaryOp, **kwargs: Any) -> gtcpp.TernaryOp: @@ -329,11 +327,7 @@ def visit_HorizontalExecution( return gtcpp.GTStage(functor=functor_name, args=stage_args) def visit_VerticalLoop( - self, - node: oir.VerticalLoop, - *, - comp_ctx: GTComputationContext, - **kwargs: Any, + self, node: oir.VerticalLoop, *, comp_ctx: GTComputationContext, **kwargs: Any ) -> gtcpp.GTMultiStage: # the following visit assumes that temporaries are already available in comp_ctx stages = list( diff --git a/src/gt4py/cartesian/gtc/numpy/oir_to_npir.py b/src/gt4py/cartesian/gtc/numpy/oir_to_npir.py index 02412c9ff2..28664ee515 100644 --- a/src/gt4py/cartesian/gtc/numpy/oir_to_npir.py +++ b/src/gt4py/cartesian/gtc/numpy/oir_to_npir.py @@ -103,9 +103,7 @@ def visit_BinaryOp( self, node: oir.BinaryOp, **kwargs: Any ) -> Union[npir.VectorArithmetic, npir.VectorLogic]: args = dict( - op=node.op, - left=self.visit(node.left, **kwargs), - right=self.visit(node.right, **kwargs), + op=node.op, left=self.visit(node.left, **kwargs), right=self.visit(node.right, **kwargs) ) if isinstance(node.op, common.LogicalOperator): return npir.VectorLogic(**args) @@ -135,11 +133,7 @@ def visit_NativeFuncCall(self, node: oir.NativeFuncCall, **kwargs: Any) -> npir. # --- Statements --- def visit_MaskStmt( - self, - node: oir.MaskStmt, - *, - mask: Optional[npir.Expr] = None, - **kwargs: Any, + self, node: oir.MaskStmt, *, mask: Optional[npir.Expr] = None, **kwargs: Any ) -> List[npir.Stmt]: mask_expr = self.visit(node.mask, **kwargs) if mask: @@ -237,11 +231,7 @@ def visit_Stencil(self, node: oir.Stencil, **kwargs: Any) -> npir.Computation: ] vertical_passes = utils.flatten_list( - self.visit( - node.vertical_loops, - block_extents=block_extents, - **kwargs, - ) + self.visit(node.vertical_loops, block_extents=block_extents, **kwargs) ) return npir.Computation( diff --git a/src/gt4py/cartesian/gtc/passes/oir_optimizations/caches.py b/src/gt4py/cartesian/gtc/passes/oir_optimizations/caches.py index 913dc46b6f..ac4ce5929c 100644 --- a/src/gt4py/cartesian/gtc/passes/oir_optimizations/caches.py +++ b/src/gt4py/cartesian/gtc/passes/oir_optimizations/caches.py @@ -74,10 +74,7 @@ def has_vertical_offset(offsets: Set[Tuple[int, int, int]]) -> bool: oir.IJCache(name=field) for field in cacheable ] return oir.VerticalLoop( - sections=node.sections, - loop_order=node.loop_order, - caches=caches, - loc=node.loc, + sections=node.sections, loop_order=node.loop_order, caches=caches, loc=node.loc ) def visit_Stencil(self, node: oir.Stencil, **kwargs: Any) -> oir.Stencil: @@ -149,10 +146,7 @@ def has_variable_offset_reads(field: str) -> bool: oir.KCache(name=field, fill=True, flush=True) for field in cacheable ] return oir.VerticalLoop( - loop_order=node.loop_order, - sections=node.sections, - caches=caches, - loc=node.loc, + loop_order=node.loop_order, sections=node.sections, caches=caches, loc=node.loc ) @@ -459,9 +453,7 @@ def _fill_stmts( lmin = max(lmin, first_unfilled.get(field, lmin)) for offset in range(lmin, lmax + 1): k_offset = common.CartesianOffset( - i=0, - j=0, - k=offset if loop_order == common.LoopOrder.FORWARD else -offset, + i=0, j=0, k=offset if loop_order == common.LoopOrder.FORWARD else -offset ) fill_stmts.append( oir.AssignStmt( @@ -586,10 +578,7 @@ def visit_VerticalLoop( ] return oir.VerticalLoop( - loop_order=node.loop_order, - sections=sections, - caches=caches, - loc=node.loc, + loop_order=node.loop_order, sections=sections, caches=caches, loc=node.loc ) def visit_Stencil(self, node: oir.Stencil, **kwargs: Any) -> oir.Stencil: diff --git a/src/gt4py/cartesian/gtc/passes/oir_optimizations/horizontal_execution_merging.py b/src/gt4py/cartesian/gtc/passes/oir_optimizations/horizontal_execution_merging.py index df923c6470..61296f8c0b 100644 --- a/src/gt4py/cartesian/gtc/passes/oir_optimizations/horizontal_execution_merging.py +++ b/src/gt4py/cartesian/gtc/passes/oir_optimizations/horizontal_execution_merging.py @@ -322,10 +322,7 @@ def first_has_horizontal_restriction() -> bool: for offset in read_offsets: merged.body = ( self.visit( - first.body, - shift=offset, - offset_symbol_map=offset_symbol_map, - scalar_map={}, + first.body, shift=offset, offset_symbol_map=offset_symbol_map, scalar_map={} ) + merged.body ) diff --git a/src/gt4py/cartesian/gtc/passes/oir_optimizations/inlining.py b/src/gt4py/cartesian/gtc/passes/oir_optimizations/inlining.py index 0f403cdffd..ee085e4eac 100644 --- a/src/gt4py/cartesian/gtc/passes/oir_optimizations/inlining.py +++ b/src/gt4py/cartesian/gtc/passes/oir_optimizations/inlining.py @@ -23,10 +23,7 @@ class MaskCollector(eve.NodeVisitor): """Collects the boolean expressions defining mask statements that are boolean fields.""" def visit_AssignStmt( - self, - node: oir.AssignStmt, - *, - masks_to_inline: Dict[str, oir.Expr], + self, node: oir.AssignStmt, *, masks_to_inline: Dict[str, oir.Expr] ) -> None: if node.left.name in masks_to_inline: assert masks_to_inline[node.left.name] is None @@ -72,22 +69,14 @@ class MaskInlining(eve.NodeTranslator): """ def visit_FieldAccess( - self, - node: oir.FieldAccess, - *, - masks_to_inline: Dict[str, oir.Expr], - **kwargs: Any, + self, node: oir.FieldAccess, *, masks_to_inline: Dict[str, oir.Expr], **kwargs: Any ) -> oir.Expr: if node.name in masks_to_inline: return cp.copy(masks_to_inline[node.name]) return self.generic_visit(node, masks_to_inline=masks_to_inline, **kwargs) def visit_AssignStmt( - self, - node: oir.AssignStmt, - *, - masks_to_inline: Dict[str, oir.Expr], - **kwargs: Any, + self, node: oir.AssignStmt, *, masks_to_inline: Dict[str, oir.Expr], **kwargs: Any ) -> Union[oir.AssignStmt, eve.NothingType]: if node.left.name in masks_to_inline: return eve.NOTHING diff --git a/src/gt4py/cartesian/gtc/passes/oir_optimizations/mask_stmt_merging.py b/src/gt4py/cartesian/gtc/passes/oir_optimizations/mask_stmt_merging.py index b83f745106..71ea1baa33 100644 --- a/src/gt4py/cartesian/gtc/passes/oir_optimizations/mask_stmt_merging.py +++ b/src/gt4py/cartesian/gtc/passes/oir_optimizations/mask_stmt_merging.py @@ -41,9 +41,7 @@ def _merge(self, stmts: List[oir.Stmt]) -> List[oir.Stmt]: def visit_HorizontalExecution(self, node: oir.HorizontalExecution) -> oir.HorizontalExecution: return oir.HorizontalExecution( - body=self._merge(node.body), - declarations=node.declarations, - loc=node.loc, + body=self._merge(node.body), declarations=node.declarations, loc=node.loc ) # Stmt node types with lists of Stmts within them: diff --git a/src/gt4py/cartesian/gtc/passes/oir_optimizations/pruning.py b/src/gt4py/cartesian/gtc/passes/oir_optimizations/pruning.py index 64929f70f6..efe9bd8b9a 100644 --- a/src/gt4py/cartesian/gtc/passes/oir_optimizations/pruning.py +++ b/src/gt4py/cartesian/gtc/passes/oir_optimizations/pruning.py @@ -41,9 +41,7 @@ def visit_VerticalLoopSection(self, node: oir.VerticalLoopSection) -> Any: if not horizontal_executions: return eve.NOTHING return oir.VerticalLoopSection( - interval=node.interval, - horizontal_executions=horizontal_executions, - loc=node.loc, + interval=node.interval, horizontal_executions=horizontal_executions, loc=node.loc ) def visit_VerticalLoop(self, node: oir.VerticalLoop) -> Any: @@ -51,10 +49,7 @@ def visit_VerticalLoop(self, node: oir.VerticalLoop) -> Any: if not sections: return eve.NOTHING return oir.VerticalLoop( - loop_order=node.loop_order, - sections=sections, - caches=node.caches, - loc=node.loc, + loop_order=node.loop_order, sections=sections, caches=node.caches, loc=node.loc ) def visit_Stencil(self, node: oir.Stencil, **kwargs): diff --git a/src/gt4py/cartesian/gtc/passes/oir_optimizations/temporaries.py b/src/gt4py/cartesian/gtc/passes/oir_optimizations/temporaries.py index c97b478f77..d4855a904c 100644 --- a/src/gt4py/cartesian/gtc/passes/oir_optimizations/temporaries.py +++ b/src/gt4py/cartesian/gtc/passes/oir_optimizations/temporaries.py @@ -63,10 +63,7 @@ def visit_HorizontalExecution( ) def visit_VerticalLoop( - self, - node: oir.VerticalLoop, - tmps_to_replace: Set[str], - **kwargs: Any, + self, node: oir.VerticalLoop, tmps_to_replace: Set[str], **kwargs: Any ) -> oir.VerticalLoop: return oir.VerticalLoop( loop_order=node.loop_order, @@ -82,9 +79,7 @@ def visit_Stencil(self, node: oir.Stencil, **kwargs: Any) -> oir.Stencil: name=node.name, params=node.params, vertical_loops=self.visit( - node.vertical_loops, - new_symbol_name=symbol_name_creator(all_names), - **kwargs, + node.vertical_loops, new_symbol_name=symbol_name_creator(all_names), **kwargs ), declarations=[d for d in node.declarations if d.name not in tmps_to_replace], loc=node.loc, diff --git a/src/gt4py/cartesian/gtc/passes/oir_optimizations/utils.py b/src/gt4py/cartesian/gtc/passes/oir_optimizations/utils.py index ddf4713757..3d66471166 100644 --- a/src/gt4py/cartesian/gtc/passes/oir_optimizations/utils.py +++ b/src/gt4py/cartesian/gtc/passes/oir_optimizations/utils.py @@ -95,11 +95,7 @@ def visit_FieldAccess( ) ) - def visit_AssignStmt( - self, - node: oir.AssignStmt, - **kwargs: Any, - ) -> None: + def visit_AssignStmt(self, node: oir.AssignStmt, **kwargs: Any) -> None: self.visit(node.right, is_write=False, **kwargs) self.visit(node.left, is_write=True, **kwargs) diff --git a/src/gt4py/cartesian/gtc/passes/oir_optimizations/vertical_loop_merging.py b/src/gt4py/cartesian/gtc/passes/oir_optimizations/vertical_loop_merging.py index daea8a2c02..26d5a5464e 100644 --- a/src/gt4py/cartesian/gtc/passes/oir_optimizations/vertical_loop_merging.py +++ b/src/gt4py/cartesian/gtc/passes/oir_optimizations/vertical_loop_merging.py @@ -41,11 +41,7 @@ def _merge(a: oir.VerticalLoop, b: oir.VerticalLoop) -> oir.VerticalLoop: warnings.warn( "AdjacentLoopMerging pass removed previously declared caches", stacklevel=2 ) - return oir.VerticalLoop( - loop_order=a.loop_order, - sections=sections, - caches=[], - ) + return oir.VerticalLoop(loop_order=a.loop_order, sections=sections, caches=[]) def visit_Stencil(self, node: oir.Stencil, **kwargs: Any) -> oir.Stencil: if not node.vertical_loops: diff --git a/src/gt4py/cartesian/gtscript_imports.py b/src/gt4py/cartesian/gtscript_imports.py index 206d190f36..44fa42d8e4 100644 --- a/src/gt4py/cartesian/gtscript_imports.py +++ b/src/gt4py/cartesian/gtscript_imports.py @@ -235,11 +235,7 @@ def enabled(**kwargs: Any) -> Iterator: import some_other_stencil # in the same directory as some_stencil.gt.py raises error """ - backup_import_system = ( - sys.path.copy(), - sys.meta_path.copy(), - sys.modules.copy(), - ) + backup_import_system = (sys.path.copy(), sys.meta_path.copy(), sys.modules.copy()) try: yield enable(**kwargs) finally: diff --git a/src/gt4py/eve/datamodels/core.py b/src/gt4py/eve/datamodels/core.py index 5b6915fc38..5bd9c98f8c 100644 --- a/src/gt4py/eve/datamodels/core.py +++ b/src/gt4py/eve/datamodels/core.py @@ -208,15 +208,9 @@ def field_type_validator_factory( ) -> FieldTypeValidatorFactory: """Create a factory of field type validators from a factory of regular type validators.""" if use_cache: - factory = cast( - type_val.TypeValidatorFactory, - utils.optional_lru_cache(func=factory), - ) + factory = cast(type_val.TypeValidatorFactory, utils.optional_lru_cache(func=factory)) - def _field_type_validator_factory( - type_annotation: TypeAnnotation, - name: str, - ) -> FieldValidator: + def _field_type_validator_factory(type_annotation: TypeAnnotation, name: str) -> FieldValidator: """Field type validator for datamodels, supporting forward references.""" if isinstance(type_annotation, ForwardRef): return ForwardRefValidator(factory) @@ -492,11 +486,7 @@ def __init_subclass__( raise TypeError("Subclasses of a frozen DataModel cannot be unfrozen.") _make_datamodel( - cls, - slots=False, - generic=generic, - **datamodel_kwargs, - _stacklevel_offset=1, + cls, slots=False, generic=generic, **datamodel_kwargs, _stacklevel_offset=1 ) @@ -729,8 +719,7 @@ def astuple(instance: DataModel) -> Tuple[Any, ...]: def update_forward_refs( - model_cls: Type[_DataModelT], - localns: Optional[Dict[str, Any]] = None, + model_cls: Type[_DataModelT], localns: Optional[Dict[str, Any]] = None ) -> Type[_DataModelT]: """Update Data Model class meta-information replacing forwarded type annotations with actual types. @@ -1109,11 +1098,7 @@ def _make_datamodel( ) if attr_value_in_cls is NOTHING: # The field has no definition in the class dict, it's only an annotation - setattr( - cls, - key, - attrs.field(converter=converter, validator=type_validator), - ) + setattr(cls, key, attrs.field(converter=converter, validator=type_validator)) else: # The field contains the default value in the class dict @@ -1157,9 +1142,7 @@ def _make_datamodel( if base_field_attr: # Create a new field in the current class cloning the existing # definition and add the new validator (attrs recommendation) - field_c_attr = _make_counting_attr_from_attribute( - base_field_attr, - ) + field_c_attr = _make_counting_attr_from_attribute(base_field_attr) setattr(cls, qualified_field_name, field_c_attr) else: raise TypeError( diff --git a/src/gt4py/eve/extended_typing.py b/src/gt4py/eve/extended_typing.py index acb2103602..b5c35447e3 100644 --- a/src/gt4py/eve/extended_typing.py +++ b/src/gt4py/eve/extended_typing.py @@ -143,9 +143,7 @@ def __dir__() -> List[str]: ] else: SolvedTypeAnnotation = Union[ # type: ignore[misc] # mypy consider this assignment a redefinition - Type, - _typing._SpecialForm, - _typing._GenericAlias, # type: ignore[attr-defined] # _GenericAlias is not exported in stub + Type, _typing._SpecialForm, _typing._GenericAlias # type: ignore[attr-defined] # _GenericAlias is not exported in stub ] TypeAnnotation = Union[ForwardRef, SolvedTypeAnnotation] @@ -334,9 +332,7 @@ class NonProtocolABC(metaclass=NonProtocolABCMeta): @overload def extended_runtime_checkable( - *, - instance_check_shortcut: bool = True, - subclass_check_with_data_members: bool = False, + *, instance_check_shortcut: bool = True, subclass_check_with_data_members: bool = False ) -> Callable[[_ProtoT], _ProtoT]: ... @@ -696,10 +692,7 @@ class CallableKwargsInfo: def infer_type( - value: Any, - *, - annotate_callable_kwargs: bool = False, - none_as_type: bool = True, + value: Any, *, annotate_callable_kwargs: bool = False, none_as_type: bool = True ) -> TypeAnnotation: """Generate a typing definition from a value. diff --git a/src/gt4py/eve/pattern_matching.py b/src/gt4py/eve/pattern_matching.py index fe11b0b0c7..2baa80151e 100644 --- a/src/gt4py/eve/pattern_matching.py +++ b/src/gt4py/eve/pattern_matching.py @@ -85,10 +85,7 @@ def get_differences(a: Any, b: Any, path: str = "") -> Iterator[Tuple[str, str]] @get_differences.register def _(a: ObjectPattern, b: Any, path: str = "") -> Iterator[Tuple[str, str]]: if not isinstance(b, a.cls): - yield ( - path, - f"Expected an instance of class {a.cls.__name__}, but got {type(b).__name__}", - ) + yield (path, f"Expected an instance of class {a.cls.__name__}, but got {type(b).__name__}") else: for k in a.fields.keys(): if not hasattr(b, k): diff --git a/src/gt4py/eve/type_validation.py b/src/gt4py/eve/type_validation.py index 43e059dc40..0c175e519f 100644 --- a/src/gt4py/eve/type_validation.py +++ b/src/gt4py/eve/type_validation.py @@ -78,11 +78,7 @@ def __call__( class FixedTypeValidator(Protocol): @abc.abstractmethod - def __call__( - self, - value: Any, - **kwargs: Any, - ) -> None: + def __call__(self, value: Any, **kwargs: Any) -> None: """Protocol for callables checking that ``value`` matches a fixed type_annotation. Arguments: @@ -485,8 +481,7 @@ def _combined_validator(value: Any, **kwargs: Any) -> None: simple_type_validator_factory: Final = cast( - TypeValidatorFactory, - utils.optional_lru_cache(SimpleTypeValidatorFactory(), typed=True), + TypeValidatorFactory, utils.optional_lru_cache(SimpleTypeValidatorFactory(), typed=True) ) """Public (with optional cache) entry point for :class:`SimpleTypeValidatorFactory`.""" diff --git a/src/gt4py/eve/utils.py b/src/gt4py/eve/utils.py index b779a70d3e..01c066ca91 100644 --- a/src/gt4py/eve/utils.py +++ b/src/gt4py/eve/utils.py @@ -428,11 +428,7 @@ def dhash(obj: Any, **kwargs: Any) -> str: def pprint_ddiff( - old: Any, - new: Any, - *, - pprint_opts: Optional[Dict[str, Any]] = None, - **kwargs: Any, + old: Any, new: Any, *, pprint_opts: Optional[Dict[str, Any]] = None, **kwargs: Any ) -> None: """Pretty printing of deepdiff.DeepDiff objects. @@ -1030,10 +1026,7 @@ def chain(self, *others: Iterable) -> XIterable[Union[T, S]]: return XIterable(itertools.chain(self.iterator, *iterators)) def diff( - self, - *others: Iterable, - default: Any = NOTHING, - key: Union[NOTHING, Callable] = NOTHING, + self, *others: Iterable, default: Any = NOTHING, key: Union[NOTHING, Callable] = NOTHING ) -> XIterable[Tuple[T, S]]: """Diff iterators. @@ -1316,10 +1309,7 @@ def groupby( ) -> XIterable[Tuple[Any, List[T]]]: ... def groupby( - self, - key: Union[str, List[Any], Callable[[T], Any]], - *attr_keys: str, - as_dict: bool = False, + self, key: Union[str, List[Any], Callable[[T], Any]], *attr_keys: str, as_dict: bool = False ) -> Union[XIterable[Tuple[Any, List[T]]], Dict]: """Group a sequence by a given key. diff --git a/src/gt4py/eve/visitors.py b/src/gt4py/eve/visitors.py index ea76aea108..bf40db5772 100644 --- a/src/gt4py/eve/visitors.py +++ b/src/gt4py/eve/visitors.py @@ -162,7 +162,7 @@ def generic_visit(self, node: concepts.RootNode, **kwargs: Any) -> Any: name: new_child for name, child in node.iter_children_items() if (new_child := self.visit(child, **kwargs)) is not NOTHING - }, + } ) if self.PRESERVED_ANNEX_ATTRS and (old_annex := getattr(node, "__node_annex__", None)): # access to `new_node.annex` implicitly creates the `__node_annex__` attribute in the property getter diff --git a/src/gt4py/next/common.py b/src/gt4py/next/common.py index 2936e4163a..fa19946f8f 100644 --- a/src/gt4py/next/common.py +++ b/src/gt4py/next/common.py @@ -144,9 +144,7 @@ def __init__( object.__setattr__(self, "stop", 0) @classmethod - def infinite( - cls, - ) -> UnitRange: + def infinite(cls) -> UnitRange: return cls(Infinity.NEGATIVE, Infinity.POSITIVE) def __len__(self) -> int: @@ -354,10 +352,7 @@ def is_relative_index_sequence(v: AnyIndexSequence) -> TypeGuard[RelativeIndexSe def as_any_index_sequence(index: AnyIndexSpec) -> AnyIndexSequence: # `cast` because mypy/typing doesn't special case 1-element tuples, i.e. `tuple[A|B] != tuple[A]|tuple[B]` - return cast( - AnyIndexSequence, - (index,) if is_any_index_element(index) else index, - ) + return cast(AnyIndexSequence, (index,) if is_any_index_element(index) else index) def named_range(v: tuple[Dimension, RangeLike]) -> NamedRange: @@ -715,9 +710,7 @@ def __xor__(self, other: Field | core_defs.ScalarT) -> Field: """Only defined for `Field` of value type `bool`.""" -def is_field( - v: Any, -) -> TypeGuard[Field]: +def is_field(v: Any) -> TypeGuard[Field]: # This function is introduced to localize the `type: ignore` because # extended_runtime_checkable does not make the protocol runtime_checkable # for mypy. @@ -731,9 +724,7 @@ class MutableField(Field[DimsT, core_defs.ScalarT], Protocol[DimsT, core_defs.Sc def __setitem__(self, index: AnyIndexSpec, value: Field | core_defs.ScalarT) -> None: ... -def is_mutable_field( - v: Field, -) -> TypeGuard[MutableField]: +def is_mutable_field(v: Field) -> TypeGuard[MutableField]: # This function is introduced to localize the `type: ignore` because # extended_runtime_checkable does not make the protocol runtime_checkable # for mypy. @@ -828,9 +819,7 @@ def __xor__(self, other: Field | core_defs.IntegralScalar) -> Never: raise TypeError("'ConnectivityField' does not support this operation.") -def is_connectivity_field( - v: Any, -) -> TypeGuard[ConnectivityField]: +def is_connectivity_field(v: Any) -> TypeGuard[ConnectivityField]: # This function is introduced to localize the `type: ignore` because # extended_runtime_checkable does not make the protocol runtime_checkable # for mypy. diff --git a/src/gt4py/next/constructors.py b/src/gt4py/next/constructors.py index 7f92d57c1b..0140a4ab42 100644 --- a/src/gt4py/next/constructors.py +++ b/src/gt4py/next/constructors.py @@ -115,11 +115,7 @@ def zeros( array([0., 0., 0., 0., 0., 0., 0.]) """ field = empty( - domain=domain, - dtype=dtype, - aligned_index=aligned_index, - allocator=allocator, - device=device, + domain=domain, dtype=dtype, aligned_index=aligned_index, allocator=allocator, device=device ) field[...] = field.dtype.scalar_type(0) return field @@ -147,11 +143,7 @@ def ones( array([1., 1., 1., 1., 1., 1., 1.]) """ field = empty( - domain=domain, - dtype=dtype, - aligned_index=aligned_index, - allocator=allocator, - device=device, + domain=domain, dtype=dtype, aligned_index=aligned_index, allocator=allocator, device=device ) field[...] = field.dtype.scalar_type(1) return field diff --git a/src/gt4py/next/embedded/__init__.py b/src/gt4py/next/embedded/__init__.py index e0cb114148..276b4df818 100644 --- a/src/gt4py/next/embedded/__init__.py +++ b/src/gt4py/next/embedded/__init__.py @@ -15,9 +15,4 @@ from . import common, context, exceptions, nd_array_field -__all__ = [ - "common", - "context", - "exceptions", - "nd_array_field", -] +__all__ = ["common", "context", "exceptions", "nd_array_field"] diff --git a/src/gt4py/next/embedded/common.py b/src/gt4py/next/embedded/common.py index cdfa439193..d36f9409e5 100644 --- a/src/gt4py/next/embedded/common.py +++ b/src/gt4py/next/embedded/common.py @@ -98,9 +98,7 @@ def _absolute_sub_domain( return common.Domain(*named_ranges) -def domain_intersection( - *domains: common.Domain, -) -> common.Domain: +def domain_intersection(*domains: common.Domain) -> common.Domain: """ Return the intersection of the given domains. @@ -111,11 +109,7 @@ def domain_intersection( ... ) # doctest: +ELLIPSIS Domain(dims=(Dimension(value='I', ...), ranges=(UnitRange(1, 3),)) """ - return functools.reduce( - operator.and_, - domains, - common.Domain(dims=tuple(), ranges=tuple()), - ) + return functools.reduce(operator.and_, domains, common.Domain(dims=tuple(), ranges=tuple())) def restrict_to_intersection( @@ -153,9 +147,7 @@ def restrict_to_intersection( ) -def iterate_domain( - domain: common.Domain, -) -> Iterator[tuple[common.NamedIndex]]: +def iterate_domain(domain: common.Domain) -> Iterator[tuple[common.NamedIndex]]: for idx in itertools.product(*(list(r) for r in domain.ranges)): yield tuple(common.NamedIndex(d, i) for d, i in zip(domain.dims, idx)) # type: ignore[misc] # trust me, `idx` is `tuple[int, ...]` @@ -198,9 +190,7 @@ def _find_index_of_dim( return None -def canonicalize_any_index_sequence( - index: common.AnyIndexSpec, -) -> common.AnyIndexSpec: +def canonicalize_any_index_sequence(index: common.AnyIndexSpec) -> common.AnyIndexSpec: # TODO: instead of canonicalizing to `NamedRange`, we should canonicalize to `NamedSlice` new_index: common.AnyIndexSpec = (index,) if isinstance(index, slice) else index if isinstance(new_index, tuple) and all(isinstance(i, slice) for i in new_index): @@ -208,9 +198,7 @@ def canonicalize_any_index_sequence( return new_index -def _named_slice_to_named_range( - idx: common.NamedSlice, -) -> common.NamedRange | common.NamedSlice: +def _named_slice_to_named_range(idx: common.NamedSlice) -> common.NamedRange | common.NamedSlice: assert hasattr(idx, "start") and hasattr(idx, "stop") if common.is_named_slice(idx): start_dim, start_value = idx.start diff --git a/src/gt4py/next/embedded/nd_array_field.py b/src/gt4py/next/embedded/nd_array_field.py index af3ac0e646..b00aed9f73 100644 --- a/src/gt4py/next/embedded/nd_array_field.py +++ b/src/gt4py/next/embedded/nd_array_field.py @@ -220,11 +220,7 @@ def remap( # finally, take the new array new_buffer = xp.take(self._ndarray, new_idx_array, axis=dim_idx) - return self.__class__.from_array( - new_buffer, - domain=new_domain, - dtype=self.dtype, - ) + return self.__class__.from_array(new_buffer, domain=new_domain, dtype=self.dtype) __call__ = remap # type: ignore[assignment] @@ -395,12 +391,7 @@ def from_array( # type: ignore[override] assert isinstance(codomain, common.Dimension) - return cls( - domain, - array, - codomain, - _skip_value=skip_value, - ) + return cls(domain, array, codomain, _skip_value=skip_value) def inverse_image( self, image_range: common.UnitRange | common.NamedRange @@ -530,9 +521,7 @@ def _hypercube( NdArrayField.register_builtin_func(fbuiltins.where, _make_builtin("where", "where")) -def _compute_mask_ranges( - mask: core_defs.NDArrayObject, -) -> list[tuple[bool, common.UnitRange]]: +def _compute_mask_ranges(mask: core_defs.NDArrayObject) -> list[tuple[bool, common.UnitRange]]: """Take a 1-dimensional mask and return a sequence of mappings from boolean values to ranges.""" # TODO: does it make sense to upgrade this naive algorithm to numpy? assert mask.ndim == 1 @@ -698,10 +687,7 @@ def _concat_where( def _make_reduction( builtin_name: str, array_builtin_name: str, initial_value_op: Callable -) -> Callable[ - ..., - NdArrayField[common.DimsT, core_defs.ScalarT], -]: +) -> Callable[..., NdArrayField[common.DimsT, core_defs.ScalarT]]: def _builtin_op( field: NdArrayField[common.DimsT, core_defs.ScalarT], axis: common.Dimension ) -> NdArrayField[common.DimsT, core_defs.ScalarT]: @@ -735,11 +721,7 @@ def _builtin_op( ) return field.__class__.from_array( - getattr(xp, array_builtin_name)( - masked_array, - axis=reduce_dim_index, - ), - domain=new_domain, + getattr(xp, array_builtin_name)(masked_array, axis=reduce_dim_index), domain=new_domain ) _builtin_op.__name__ = builtin_name diff --git a/src/gt4py/next/errors/excepthook.py b/src/gt4py/next/errors/excepthook.py index e0be899c00..9dd63136df 100644 --- a/src/gt4py/next/errors/excepthook.py +++ b/src/gt4py/next/errors/excepthook.py @@ -33,11 +33,7 @@ def _format_uncaught_error(err: exceptions.DSLError, verbose_exceptions: bool) -> list[str]: if verbose_exceptions: return formatting.format_compilation_error( - type(err), - err.message, - err.location, - err.__traceback__, - err.__cause__, + type(err), err.message, err.location, err.__traceback__, err.__cause__ ) else: return formatting.format_compilation_error(type(err), err.message, err.location) diff --git a/src/gt4py/next/errors/formatting.py b/src/gt4py/next/errors/formatting.py index 0176607971..98b68f96f6 100644 --- a/src/gt4py/next/errors/formatting.py +++ b/src/gt4py/next/errors/formatting.py @@ -78,10 +78,7 @@ def _format_cause(cause: BaseException) -> list[str]: def _format_traceback(tb: types.TracebackType) -> list[str]: """Format the traceback of an exception.""" intro_message = "Traceback (most recent call last):" - traceback_strs = [ - f"{intro_message}\n", - *traceback.format_tb(tb), - ] + traceback_strs = [f"{intro_message}\n", *traceback.format_tb(tb)] return traceback_strs diff --git a/src/gt4py/next/ffront/ast_passes/unchain_compares.py b/src/gt4py/next/ffront/ast_passes/unchain_compares.py index 8fc6d5f349..a72c2401d1 100644 --- a/src/gt4py/next/ffront/ast_passes/unchain_compares.py +++ b/src/gt4py/next/ffront/ast_passes/unchain_compares.py @@ -46,9 +46,7 @@ def visit_Compare(self, node: ast.Compare) -> ast.Compare | ast.BinOp: # left leaf of the new tree: ``a < b`` # example: ``a < b > c > d`` left_leaf = ast.Compare( - ops=node.ops[0:1], - left=node.left, - comparators=node.comparators[0:1], + ops=node.ops[0:1], left=node.left, comparators=node.comparators[0:1] ) ast.copy_location(left_leaf, node) diff --git a/src/gt4py/next/ffront/decorator.py b/src/gt4py/next/ffront/decorator.py index 4ef4d55e08..f55e829715 100644 --- a/src/gt4py/next/ffront/decorator.py +++ b/src/gt4py/next/ffront/decorator.py @@ -102,10 +102,7 @@ def from_function( grid_type: Optional[GridType] = None, ) -> Program: program_def = ffront_stages.ProgramDefinition(definition=definition, grid_type=grid_type) - return cls( - definition_stage=program_def, - backend=backend, - ) + return cls(definition_stage=program_def, backend=backend) # needed in testing @property @@ -233,9 +230,7 @@ def __call__(self, *args, offset_provider: dict[str, Dimension], **kwargs: Any) ppi.ensure_processor_kind(self.backend.executor, ppi.ProgramExecutor) self.backend( - self.definition_stage, - *args, - **(kwargs | {"offset_provider": offset_provider}), + self.definition_stage, *args, **(kwargs | {"offset_provider": offset_provider}) ) @@ -287,10 +282,7 @@ def __call__(self, *args, offset_provider: dict[str, Dimension], **kwargs): raise ValueError(f"Parameter '{name}' already set as a bound argument.") type_info.accepts_args( - new_type, - with_args=arg_types, - with_kwargs=kwarg_types, - raise_exception=True, + new_type, with_args=arg_types, with_kwargs=kwarg_types, raise_exception=True ) except ValueError as err: bound_arg_names = ", ".join([f"'{bound_arg}'" for bound_arg in self.bound_args.keys()]) @@ -326,10 +318,7 @@ def itir(self): ) new_clos.inputs.pop(index) params = [sym(inp.id) for inp in new_clos.inputs] - expr = itir.FunCall( - fun=new_clos.stencil, - args=new_args, - ) + expr = itir.FunCall(fun=new_clos.stencil, args=new_args) new_clos.stencil = itir.Lambda(params=params, expr=expr) return new_itir @@ -371,9 +360,7 @@ def program( def program_inner(definition: types.FunctionType) -> Program: return Program.from_function( - definition, - DEFAULT_BACKEND if backend is eve.NOTHING else backend, - grid_type, + definition, DEFAULT_BACKEND if backend is eve.NOTHING else backend, grid_type ) return program_inner if definition is None else program_inner(definition) @@ -487,10 +474,7 @@ def as_program( # TODO(tehrengruber): check foast operator has no out argument that clashes # with the out argument of the program we generate here. hash_ = eve_utils.content_hash( - ( - tuple(arg_types), - tuple((name, arg) for name, arg in kwarg_types.items()), - ) + (tuple(arg_types), tuple((name, arg) for name, arg in kwarg_types.items())) ) try: return self._program_cache[hash_] @@ -529,7 +513,7 @@ def as_program( type=ts.DeferredType(constraint=None), namespace=dialect_ast_enums.Namespace.CLOSURE, location=loc, - ), + ) ] untyped_past_node = past.Program( @@ -553,20 +537,14 @@ def as_program( self._program_cache[hash_] = ProgramFromPast( definition_stage=None, past_stage=ffront_stages.PastProgramDefinition( - past_node=past_node, - closure_vars=closure_vars, - grid_type=self.grid_type, + past_node=past_node, closure_vars=closure_vars, grid_type=self.grid_type ), backend=self.backend, ) return self._program_cache[hash_] - def __call__( - self, - *args, - **kwargs, - ) -> None: + def __call__(self, *args, **kwargs) -> None: if not next_embedded.context.within_context() and self.backend is not None: # non embedded execution if "offset_provider" not in kwargs: @@ -635,13 +613,9 @@ def field_operator(definition=None, *, backend=eve.NOTHING, grid_type=None): ... ... """ - def field_operator_inner( - definition: types.FunctionType, - ) -> FieldOperator[foast.FieldOperator]: + def field_operator_inner(definition: types.FunctionType) -> FieldOperator[foast.FieldOperator]: return FieldOperator.from_function( - definition, - DEFAULT_BACKEND if backend is eve.NOTHING else backend, - grid_type, + definition, DEFAULT_BACKEND if backend is eve.NOTHING else backend, grid_type ) return field_operator_inner if definition is None else field_operator_inner(definition) diff --git a/src/gt4py/next/ffront/dialect_parser.py b/src/gt4py/next/ffront/dialect_parser.py index ec27d963e0..95240fca0c 100644 --- a/src/gt4py/next/ffront/dialect_parser.py +++ b/src/gt4py/next/ffront/dialect_parser.py @@ -96,10 +96,7 @@ def _preprocess_definition_ast(cls, definition_ast: ast.AST) -> ast.AST: @classmethod def _postprocess_dialect_ast( - cls, - output_ast: DialectRootT, - closure_vars: dict[str, Any], - annotations: dict[str, Any], + cls, output_ast: DialectRootT, closure_vars: dict[str, Any], annotations: dict[str, Any] ) -> DialectRootT: return output_ast diff --git a/src/gt4py/next/ffront/experimental.py b/src/gt4py/next/ffront/experimental.py index b69a118713..6771d62e4d 100644 --- a/src/gt4py/next/ffront/experimental.py +++ b/src/gt4py/next/ffront/experimental.py @@ -20,11 +20,7 @@ @BuiltInFunction -def as_offset( - offset_: FieldOffset, - field: common.Field, - /, -) -> common.ConnectivityField: +def as_offset(offset_: FieldOffset, field: common.Field, /) -> common.ConnectivityField: raise NotImplementedError() diff --git a/src/gt4py/next/ffront/fbuiltins.py b/src/gt4py/next/ffront/fbuiltins.py index 34562ffdcb..3b2130ee22 100644 --- a/src/gt4py/next/ffront/fbuiltins.py +++ b/src/gt4py/next/ffront/fbuiltins.py @@ -156,37 +156,23 @@ def __call__(self, mask: MaskT, true_field: FieldT, false_field: FieldT) -> _R: @BuiltInFunction -def neighbor_sum( - field: common.Field, - /, - axis: common.Dimension, -) -> common.Field: +def neighbor_sum(field: common.Field, /, axis: common.Dimension) -> common.Field: raise NotImplementedError() @BuiltInFunction -def max_over( - field: common.Field, - /, - axis: common.Dimension, -) -> common.Field: +def max_over(field: common.Field, /, axis: common.Dimension) -> common.Field: raise NotImplementedError() @BuiltInFunction -def min_over( - field: common.Field, - /, - axis: common.Dimension, -) -> common.Field: +def min_over(field: common.Field, /, axis: common.Dimension) -> common.Field: raise NotImplementedError() @BuiltInFunction def broadcast( - field: common.Field | core_defs.ScalarT, - dims: tuple[common.Dimension, ...], - /, + field: common.Field | core_defs.ScalarT, dims: tuple[common.Dimension, ...], / ) -> common.Field: assert core_defs.is_scalar_type( field @@ -207,9 +193,7 @@ def where( @BuiltInFunction def astype( - value: common.Field | core_defs.ScalarT | Tuple, - type_: type, - /, + value: common.Field | core_defs.ScalarT | Tuple, type_: type, / ) -> common.Field | core_defs.ScalarT | Tuple: if isinstance(value, tuple): return tuple(astype(v, type_) for v in value) @@ -269,9 +253,7 @@ def impl(value: common.Field | core_defs.ScalarT, /) -> common.Field | core_defs def _make_binary_math_builtin(name: str) -> None: def impl( - lhs: common.Field | core_defs.ScalarT, - rhs: common.Field | core_defs.ScalarT, - /, + lhs: common.Field | core_defs.ScalarT, rhs: common.Field | core_defs.ScalarT, / ) -> common.Field | core_defs.ScalarT: # default implementation for scalars, Fields are handled via dispatch assert core_defs.is_scalar_type(lhs) diff --git a/src/gt4py/next/ffront/foast_passes/closure_var_folding.py b/src/gt4py/next/ffront/foast_passes/closure_var_folding.py index 70f2ef132f..97108f6c7f 100644 --- a/src/gt4py/next/ffront/foast_passes/closure_var_folding.py +++ b/src/gt4py/next/ffront/foast_passes/closure_var_folding.py @@ -35,9 +35,7 @@ class ClosureVarFolding(NodeTranslator, traits.VisitorWithSymbolTableTrait): @classmethod def apply( - cls, - node: foast.FunctionDefinition | foast.FieldOperator, - closure_vars: dict[str, Any], + cls, node: foast.FunctionDefinition | foast.FieldOperator, closure_vars: dict[str, Any] ) -> foast.FunctionDefinition: return cls(closure_vars=closure_vars).visit(node) diff --git a/src/gt4py/next/ffront/foast_passes/iterable_unpack.py b/src/gt4py/next/ffront/foast_passes/iterable_unpack.py index 88888d69ee..4c24833073 100644 --- a/src/gt4py/next/ffront/foast_passes/iterable_unpack.py +++ b/src/gt4py/next/ffront/foast_passes/iterable_unpack.py @@ -79,32 +79,22 @@ def visit_BlockStmt(self, node: foast.BlockStmt, **kwargs: Any) -> foast.BlockSt slice_indices = list(range(lower, upper)) tuple_slice = [ foast.Subscript( - value=tuple_name, - index=i, - type=el_type, - location=stmt.location, + value=tuple_name, index=i, type=el_type, location=stmt.location ) for i in slice_indices ] new_tuple = foast.TupleExpr( - elts=tuple_slice, - type=el_type, - location=stmt.location, + elts=tuple_slice, type=el_type, location=stmt.location ) new_assign = foast.Assign( - target=subtarget.id, - value=new_tuple, - location=stmt.location, + target=subtarget.id, value=new_tuple, location=stmt.location ) else: new_assign = foast.Assign( target=subtarget, value=foast.Subscript( - value=tuple_name, - index=index, - type=el_type, - location=stmt.location, + value=tuple_name, index=index, type=el_type, location=stmt.location ), location=stmt.location, ) diff --git a/src/gt4py/next/ffront/foast_passes/type_alias_replacement.py b/src/gt4py/next/ffront/foast_passes/type_alias_replacement.py index f77c833578..e7b01a7764 100644 --- a/src/gt4py/next/ffront/foast_passes/type_alias_replacement.py +++ b/src/gt4py/next/ffront/foast_passes/type_alias_replacement.py @@ -37,9 +37,7 @@ class TypeAliasReplacement(NodeTranslator, traits.VisitorWithSymbolTableTrait): @classmethod def apply( - cls, - node: foast.FunctionDefinition | foast.FieldOperator, - closure_vars: dict[str, Any], + cls, node: foast.FunctionDefinition | foast.FieldOperator, closure_vars: dict[str, Any] ) -> tuple[foast.FunctionDefinition, dict[str, Any]]: foast_node = cls(closure_vars=closure_vars).visit(node) new_closure_vars = closure_vars.copy() @@ -58,9 +56,7 @@ def is_type_alias(self, node_id: SymbolName | SymbolRef) -> bool: def visit_Name(self, node: foast.Name, **kwargs: Any) -> foast.Name: if self.is_type_alias(node.id): return foast.Name( - id=self.closure_vars[node.id].__name__, - location=node.location, - type=node.type, + id=self.closure_vars[node.id].__name__, location=node.location, type=node.type ) return node @@ -83,8 +79,7 @@ def _update_closure_var_symbols( kw_only_args={}, pos_only_args=[ts.DeferredType(constraint=ts.ScalarType)], returns=cast( - ts.DataType, - from_type_hint(self.closure_vars[var.id]), + ts.DataType, from_type_hint(self.closure_vars[var.id]) ), ), namespace=dialect_ast_enums.Namespace.CLOSURE, diff --git a/src/gt4py/next/ffront/foast_passes/type_deduction.py b/src/gt4py/next/ffront/foast_passes/type_deduction.py index 6044b41421..471840ff1b 100644 --- a/src/gt4py/next/ffront/foast_passes/type_deduction.py +++ b/src/gt4py/next/ffront/foast_passes/type_deduction.py @@ -60,9 +60,7 @@ def with_altered_scalar_kind( def construct_tuple_type( - true_branch_types: list, - false_branch_types: list, - mask_type: ts.FieldType, + true_branch_types: list, false_branch_types: list, mask_type: ts.FieldType ) -> list: """ Recursively construct the return types for the tuple return branch. @@ -89,8 +87,7 @@ def construct_tuple_type( ) else: element_types_new[i] = promote_to_mask_type( - mask_type, - type_info.promote(element_types_new[i], false_branch_types[i]), + mask_type, type_info.promote(element_types_new[i], false_branch_types[i]) ) return element_types_new @@ -298,8 +295,7 @@ def visit_ScanOperator(self, node: foast.ScanOperator, **kwargs: Any) -> foast.S new_axis = self.visit(node.axis, **kwargs) if not isinstance(new_axis.type, ts.DimensionType): raise errors.DSLError( - node.location, - f"Argument 'axis' to scan operator '{node.id}' must be a dimension.", + node.location, f"Argument 'axis' to scan operator '{node.id}' must be a dimension." ) if not new_axis.type.dim.kind == DimensionKind.VERTICAL: raise errors.DSLError( @@ -309,8 +305,7 @@ def visit_ScanOperator(self, node: foast.ScanOperator, **kwargs: Any) -> foast.S new_forward = self.visit(node.forward, **kwargs) if not new_forward.type.kind == ts.ScalarKind.BOOL: raise errors.DSLError( - node.location, - f"Argument 'forward' to scan operator '{node.id}' must be a boolean.", + node.location, f"Argument 'forward' to scan operator '{node.id}' must be a boolean." ) new_init = self.visit(node.init, **kwargs) if not all( @@ -339,10 +334,7 @@ def visit_ScanOperator(self, node: foast.ScanOperator, **kwargs: Any) -> foast.S f"expected '{carry_type}', got '{new_init.type}'.", ) - new_type = ts_ffront.ScanOperatorType( - axis=new_axis.type.dim, - definition=new_def_type, - ) + new_type = ts_ffront.ScanOperatorType(axis=new_axis.type.dim, definition=new_def_type) return foast.ScanOperator( id=node.id, axis=new_axis, @@ -381,8 +373,7 @@ def visit_TupleTargetAssign( if not any(isinstance(i, tuple) for i in indices) and len(targets) != num_elts: raise errors.DSLError( - node.location, - f"Too many values to unpack (expected {len(targets)}).", + node.location, f"Too many values to unpack (expected {len(targets)})." ) new_targets: TargetType = [] @@ -405,23 +396,16 @@ def visit_TupleTargetAssign( else: new_type = values.type.types[index] new_target = self.visit( - old_target, - refine_type=new_type, - location=old_target.location, - **kwargs, + old_target, refine_type=new_type, location=old_target.location, **kwargs ) new_target = self.visit( - new_target, - refine_type=new_type, - location=old_target.location, - **kwargs, + new_target, refine_type=new_type, location=old_target.location, **kwargs ) new_targets.append(new_target) else: raise errors.DSLError( - node.location, - f"Assignment value must be of type tuple, got '{values.type}'.", + node.location, f"Assignment value must be of type tuple, got '{values.type}'." ) return foast.TupleTargetAssign(targets=new_targets, value=values, location=node.location) @@ -468,10 +452,7 @@ def visit_IfStmt(self, node: foast.IfStmt, **kwargs: Any) -> foast.IfStmt: return new_node def visit_Symbol( - self, - node: foast.Symbol, - refine_type: Optional[ts.FieldType] = None, - **kwargs: Any, + self, node: foast.Symbol, refine_type: Optional[ts.FieldType] = None, **kwargs: Any ) -> foast.Symbol: symtable = kwargs["symtable"] if refine_type: @@ -499,8 +480,7 @@ def visit_Subscript(self, node: foast.Subscript, **kwargs: Any) -> foast.Subscri case ts.OffsetType(source=source, target=(target1, target2)): if not target2.kind == DimensionKind.LOCAL: raise errors.DSLError( - new_value.location, - "Second dimension in offset must be a local dimension.", + new_value.location, "Second dimension in offset must be a local dimension." ) new_type = ts.OffsetType(source=source, target=(target1,)) case ts.OffsetType(source=source, target=(target,)): @@ -527,11 +507,7 @@ def visit_BinOp(self, node: foast.BinOp, **kwargs: Any) -> foast.BinOp: new_right = self.visit(node.right, **kwargs) new_type = self._deduce_binop_type(node, left=new_left, right=new_right) return foast.BinOp( - op=node.op, - left=new_left, - right=new_right, - location=node.location, - type=new_type, + op=node.op, left=new_left, right=new_right, location=node.location, type=new_type ) def visit_TernaryExpr(self, node: foast.TernaryExpr, **kwargs: Any) -> foast.TernaryExpr: @@ -539,10 +515,7 @@ def visit_TernaryExpr(self, node: foast.TernaryExpr, **kwargs: Any) -> foast.Ter new_true_expr = self.visit(node.true_expr, **kwargs) new_false_expr = self.visit(node.false_expr, **kwargs) new_type = self._deduce_ternaryexpr_type( - node, - condition=new_condition, - true_expr=new_true_expr, - false_expr=new_false_expr, + node, condition=new_condition, true_expr=new_true_expr, false_expr=new_false_expr ) return foast.TernaryExpr( condition=new_condition, @@ -578,11 +551,7 @@ def visit_Compare(self, node: foast.Compare, **kwargs: Any) -> foast.Compare: new_right = self.visit(node.right, **kwargs) new_type = self._deduce_compare_type(node, left=new_left, right=new_right) return foast.Compare( - op=node.op, - left=new_left, - right=new_right, - location=node.location, - type=new_type, + op=node.op, left=new_left, right=new_right, location=node.location, type=new_type ) def _deduce_compare_type( @@ -592,8 +561,7 @@ def _deduce_compare_type( for arg in (left, right): if not type_info.is_arithmetic(arg.type): raise errors.DSLError( - arg.location, - f"Type '{arg.type}' can not be used in operator '{node.op}'.", + arg.location, f"Type '{arg.type}' can not be used in operator '{node.op}'." ) self._check_operand_dtypes_match(node, left=left, right=right) @@ -613,12 +581,7 @@ def _deduce_compare_type( ) from ex def _deduce_binop_type( - self, - node: foast.BinOp, - *, - left: foast.Expr, - right: foast.Expr, - **kwargs: Any, + self, node: foast.BinOp, *, left: foast.Expr, right: foast.Expr, **kwargs: Any ) -> Optional[ts.TypeSpec]: logical_ops = { dialect_ast_enums.BinaryOperator.BIT_AND, @@ -631,8 +594,7 @@ def _deduce_binop_type( for arg in (left, right): if not is_compatible(arg.type): raise errors.DSLError( - arg.location, - f"Type '{arg.type}' can not be used in operator '{node.op}'.", + arg.location, f"Type '{arg.type}' can not be used in operator '{node.op}'." ) left_type = cast(ts.FieldType | ts.ScalarType, left.type) @@ -673,10 +635,7 @@ def visit_UnaryOp(self, node: foast.UnaryOp, **kwargs: Any) -> foast.UnaryOp: is_compatible = ( type_info.is_logical if node.op - in [ - dialect_ast_enums.UnaryOperator.NOT, - dialect_ast_enums.UnaryOperator.INVERT, - ] + in [dialect_ast_enums.UnaryOperator.NOT, dialect_ast_enums.UnaryOperator.INVERT] else type_info.is_arithmetic ) if not is_compatible(new_operand.type): @@ -685,10 +644,7 @@ def visit_UnaryOp(self, node: foast.UnaryOp, **kwargs: Any) -> foast.UnaryOp: f"Incompatible type for unary operator '{node.op}': '{new_operand.type}'.", ) return foast.UnaryOp( - op=node.op, - operand=new_operand, - location=node.location, - type=new_operand.type, + op=node.op, operand=new_operand, location=node.location, type=new_operand.type ) def visit_TupleExpr(self, node: foast.TupleExpr, **kwargs: Any) -> foast.TupleExpr: @@ -713,12 +669,7 @@ def visit_Call(self, node: foast.Call, **kwargs: Any) -> foast.Call: # have the proper format here. if not isinstance( new_func, - ( - foast.FunctionDefinition, - foast.FieldOperator, - foast.ScanOperator, - foast.Name, - ), + (foast.FunctionDefinition, foast.FieldOperator, foast.ScanOperator, foast.Name), ): raise errors.DSLError(node.location, "Functions can only be called directly.") elif isinstance(new_func.type, ts.FieldType): @@ -732,10 +683,7 @@ def visit_Call(self, node: foast.Call, **kwargs: Any) -> foast.Call: # ensure signature is valid try: type_info.accepts_args( - func_type, - with_args=arg_types, - with_kwargs=kwarg_types, - raise_exception=True, + func_type, with_args=arg_types, with_kwargs=kwarg_types, raise_exception=True ) except ValueError as err: raise errors.DSLError( @@ -814,8 +762,7 @@ def _visit_math_built_in(self, node: foast.Call, **kwargs: Any) -> foast.Call: return_type = cast(ts.FieldType | ts.ScalarType, node.args[0].type) elif func_name in fbuiltins.UNARY_MATH_FP_PREDICATE_BUILTIN_NAMES: return_type = with_altered_scalar_kind( - cast(ts.FieldType | ts.ScalarType, node.args[0].type), - ts.ScalarKind.BOOL, + cast(ts.FieldType | ts.ScalarType, node.args[0].type), ts.ScalarKind.BOOL ) elif func_name in fbuiltins.BINARY_MATH_NUMBER_BUILTIN_NAMES: try: @@ -849,8 +796,7 @@ def _visit_reduction(self, node: foast.Call, **kwargs: Any) -> foast.Call: f"'{field_dims_str}'.", ) return_type = ts.FieldType( - dims=[dim for dim in field_type.dims if dim != reduction_dim], - dtype=field_type.dtype, + dims=[dim for dim in field_type.dims if dim != reduction_dim], dtype=field_type.dtype ) return foast.Call( @@ -922,11 +868,7 @@ def _visit_as_offset(self, node: foast.Call, **kwargs: Any) -> foast.Call: ) return foast.Call( - func=node.func, - args=node.args, - kwargs=node.kwargs, - type=arg_0, - location=node.location, + func=node.func, args=node.args, kwargs=node.kwargs, type=arg_0, location=node.location ) def _visit_where(self, node: foast.Call, **kwargs: Any) -> foast.Call: @@ -966,8 +908,7 @@ def _visit_where(self, node: foast.Call, **kwargs: Any) -> foast.Call: except ValueError as ex: raise errors.DSLError( - node.location, - f"Incompatible argument in call to '{node.func!s}'.", + node.location, f"Incompatible argument in call to '{node.func!s}'." ) from ex return foast.Call( @@ -1000,10 +941,7 @@ def _visit_broadcast(self, node: foast.Call, **kwargs: Any) -> foast.Call: f"broadcast dimension(s) '{set(arg_dims).difference(set(broadcast_dims))}' missing", ) - return_type = ts.FieldType( - dims=broadcast_dims, - dtype=type_info.extract_dtype(arg_type), - ) + return_type = ts.FieldType(dims=broadcast_dims, dtype=type_info.extract_dtype(arg_type)) return foast.Call( func=node.func, diff --git a/src/gt4py/next/ffront/foast_to_itir.py b/src/gt4py/next/ffront/foast_to_itir.py index 0e39853a3c..80c0f1fea3 100644 --- a/src/gt4py/next/ffront/foast_to_itir.py +++ b/src/gt4py/next/ffront/foast_to_itir.py @@ -33,9 +33,7 @@ from gt4py.next.type_system import type_info, type_specifications as ts -def promote_to_list( - node: foast.Symbol | foast.Expr, -) -> Callable[[itir.Expr], itir.Expr]: +def promote_to_list(node: foast.Symbol | foast.Expr) -> Callable[[itir.Expr], itir.Expr]: if not type_info.contains_local_field(node.type): return lambda x: im.promote_to_lifted_stencil("make_const_list")(x) return lambda x: x @@ -79,9 +77,7 @@ def visit_FunctionDefinition( ) -> itir.FunctionDefinition: params = self.visit(node.params) return itir.FunctionDefinition( - id=node.id, - params=params, - expr=self.visit_BlockStmt(node.body, inner_expr=None), + id=node.id, params=params, expr=self.visit_BlockStmt(node.body, inner_expr=None) ) # `expr` is a lifted stencil def visit_FieldOperator( @@ -92,9 +88,7 @@ def visit_FieldOperator( new_body = func_definition.expr return itir.FunctionDefinition( - id=func_definition.id, - params=func_definition.params, - expr=new_body, + id=func_definition.id, params=func_definition.params, expr=new_body ) def visit_ScanOperator( @@ -141,8 +135,7 @@ def visit_ScanOperator( stencil_args.append(lowering_utils.to_iterator_of_tuples(param.id, arg_type)) new_body = im.let( - param.id, - lowering_utils.to_tuples_of_iterator(param.id, arg_type), + param.id, lowering_utils.to_tuples_of_iterator(param.id, arg_type) )(new_body) else: stencil_args.append(im.ref(param.id)) @@ -151,11 +144,7 @@ def visit_ScanOperator( body = im.lift(im.call("scan")(definition, forward, init))(*stencil_args) - return itir.FunctionDefinition( - id=node.id, - params=definition.params[1:], - expr=body, - ) + return itir.FunctionDefinition(id=node.id, params=definition.params[1:], expr=body) def visit_Stmt(self, node: foast.Stmt, **kwargs: Any) -> Never: raise AssertionError("Statements must always be visited in the context of a function.") @@ -259,10 +248,7 @@ def visit_TupleExpr(self, node: foast.TupleExpr, **kwargs: Any) -> itir.Expr: def visit_UnaryOp(self, node: foast.UnaryOp, **kwargs: Any) -> itir.Expr: # TODO(tehrengruber): extend iterator ir to support unary operators dtype = type_info.extract_dtype(node.type) - if node.op in [ - dialect_ast_enums.UnaryOperator.NOT, - dialect_ast_enums.UnaryOperator.INVERT, - ]: + if node.op in [dialect_ast_enums.UnaryOperator.NOT, dialect_ast_enums.UnaryOperator.INVERT]: if dtype.kind != ts.ScalarKind.BOOL: raise NotImplementedError(f"'{node.op}' is only supported on 'bool' arguments.") return self._map("not_", node.operand) @@ -327,11 +313,7 @@ def visit_Call(self, node: foast.Call, **kwargs: Any) -> itir.Expr: return self._visit_type_constr(node, **kwargs) elif isinstance( node.func.type, - ( - ts.FunctionType, - ts_ffront.FieldOperatorType, - ts_ffront.ScanOperatorType, - ), + (ts.FunctionType, ts_ffront.FieldOperatorType, ts_ffront.ScanOperatorType), ): # ITIR has no support for keyword arguments. Instead, we concatenate both positional # and keyword arguments and use the unique order as given in the function signature. @@ -387,11 +369,7 @@ def _visit_math_built_in(self, node: foast.Call, **kwargs: Any) -> itir.FunCall: return self._map(self.visit(node.func, **kwargs), *node.args) def _make_reduction_expr( - self, - node: foast.Call, - op: str | itir.SymRef, - init_expr: itir.Expr, - **kwargs: Any, + self, node: foast.Call, op: str | itir.SymRef, init_expr: itir.Expr, **kwargs: Any ) -> itir.Expr: # TODO(havogt): deal with nested reductions of the form neighbor_sum(neighbor_sum(field(off1)(off2))) it = self.visit(node.args[0], **kwargs) diff --git a/src/gt4py/next/ffront/func_to_foast.py b/src/gt4py/next/ffront/func_to_foast.py index ceac9902cf..9f24dbf6db 100644 --- a/src/gt4py/next/ffront/func_to_foast.py +++ b/src/gt4py/next/ffront/func_to_foast.py @@ -225,9 +225,7 @@ def visit_Assign( ) return foast.TupleTargetAssign( - targets=new_targets, - value=self.visit(node.value), - location=self.get_location(node), + targets=new_targets, value=self.visit(node.value), location=self.get_location(node) ) if not isinstance(target, ast.Name): @@ -267,9 +265,7 @@ def visit_AnnAssign(self, node: ast.AnnAssign, **kwargs: Any) -> foast.Assign: return foast.Assign( target=foast.Symbol[ts.FieldType]( - id=node.target.id, - location=self.get_location(node.target), - type=target_type, + id=node.target.id, location=self.get_location(node.target), type=target_type ), value=self.visit(node.value) if node.value else None, location=self.get_location(node), @@ -299,22 +295,17 @@ def visit_Subscript(self, node: ast.Subscript, **kwargs: Any) -> foast.Subscript ) from None return foast.Subscript( - value=self.visit(node.value), - index=index, - location=self.get_location(node), + value=self.visit(node.value), index=index, location=self.get_location(node) ) def visit_Attribute(self, node: ast.Attribute) -> Any: return foast.Attribute( - value=self.visit(node.value), - attr=node.attr, - location=self.get_location(node), + value=self.visit(node.value), attr=node.attr, location=self.get_location(node) ) def visit_Tuple(self, node: ast.Tuple, **kwargs: Any) -> foast.TupleExpr: return foast.TupleExpr( - elts=[self.visit(item) for item in node.elts], - location=self.get_location(node), + elts=[self.visit(item) for item in node.elts], location=self.get_location(node) ) def visit_Return(self, node: ast.Return, **kwargs: Any) -> foast.Return: @@ -481,8 +472,4 @@ def visit_Constant(self, node: ast.Constant, **kwargs: Any) -> foast.Constant: loc, f"Constants of type {type(node.value)} are not permitted." ) from None - return foast.Constant( - value=node.value, - location=loc, - type=type_, - ) + return foast.Constant(value=node.value, location=loc, type=type_) diff --git a/src/gt4py/next/ffront/func_to_past.py b/src/gt4py/next/ffront/func_to_past.py index dac3f8d8c9..6864993f4c 100644 --- a/src/gt4py/next/ffront/func_to_past.py +++ b/src/gt4py/next/ffront/func_to_past.py @@ -86,10 +86,7 @@ class ProgramParser(DialectParser[past.Program]): @classmethod def _postprocess_dialect_ast( - cls, - output_node: past.Program, - closure_vars: dict[str, Any], - annotations: dict[str, Any], + cls, output_node: past.Program, closure_vars: dict[str, Any], annotations: dict[str, Any] ) -> past.Program: output_node = ClosureVarTypeDeduction.apply(output_node, closure_vars) return ProgramTypeDeduction.apply(output_node) diff --git a/src/gt4py/next/ffront/lowering_utils.py b/src/gt4py/next/ffront/lowering_utils.py index 1adc566497..72182b7d31 100644 --- a/src/gt4py/next/ffront/lowering_utils.py +++ b/src/gt4py/next/ffront/lowering_utils.py @@ -134,9 +134,7 @@ def process_elements( def _process_elements_impl( - process_func: Callable[..., itir.Expr], - _current_el_exprs: list[T], - current_el_type: ts.TypeSpec, + process_func: Callable[..., itir.Expr], _current_el_exprs: list[T], current_el_type: ts.TypeSpec ) -> itir.Expr: if isinstance(current_el_type, ts.TupleType): result = im.make_tuple( diff --git a/src/gt4py/next/ffront/past_passes/type_deduction.py b/src/gt4py/next/ffront/past_passes/type_deduction.py index 5cd9ba5055..5f0011f47d 100644 --- a/src/gt4py/next/ffront/past_passes/type_deduction.py +++ b/src/gt4py/next/ffront/past_passes/type_deduction.py @@ -52,10 +52,7 @@ def _validate_operator_call(new_func: past.Name, new_kwargs: dict) -> None: Domain has to be of type dictionary, including dimensions with values expressed as tuples of 2 numbers. """ - if not isinstance( - new_func.type, - (ts_ffront.FieldOperatorType, ts_ffront.ScanOperatorType), - ): + if not isinstance(new_func.type, (ts_ffront.FieldOperatorType, ts_ffront.ScanOperatorType)): raise ValueError( f"Only calls to 'FieldOperators' and 'ScanOperators' " f"allowed in 'Program', got '{new_func.type}'." @@ -126,18 +123,11 @@ def visit_Subscript(self, node: past.Subscript, **kwargs: Any) -> past.Subscript def visit_TupleExpr(self, node: past.TupleExpr, **kwargs: Any) -> past.TupleExpr: elts = self.visit(node.elts, **kwargs) return past.TupleExpr( - elts=elts, - type=ts.TupleType(types=[el.type for el in elts]), - location=node.location, + elts=elts, type=ts.TupleType(types=[el.type for el in elts]), location=node.location ) def _deduce_binop_type( - self, - node: past.BinOp, - *, - left: past.Expr, - right: past.Expr, - **kwargs: Any, + self, node: past.BinOp, *, left: past.Expr, right: past.Expr, **kwargs: Any ) -> Optional[ts.TypeSpec]: logical_ops = { dialect_ast_enums.BinaryOperator.BIT_AND, @@ -149,8 +139,7 @@ def _deduce_binop_type( for arg in (left, right): if not isinstance(arg.type, ts.ScalarType) or not is_compatible(arg.type): raise errors.DSLError( - arg.location, - f"Type '{arg.type}' can not be used in operator '{node.op}'.", + arg.location, f"Type '{arg.type}' can not be used in operator '{node.op}'." ) left_type = cast(ts.ScalarType, left.type) @@ -181,11 +170,7 @@ def visit_BinOp(self, node: past.BinOp, **kwargs: Any) -> past.BinOp: new_right = self.visit(node.right, **kwargs) new_type = self._deduce_binop_type(node, left=new_left, right=new_right) return past.BinOp( - op=node.op, - left=new_left, - right=new_right, - location=node.location, - type=new_type, + op=node.op, left=new_left, right=new_right, location=node.location, type=new_type ) def visit_Call(self, node: past.Call, **kwargs: Any) -> past.Call: @@ -207,10 +192,7 @@ def visit_Call(self, node: past.Call, **kwargs: Any) -> past.Call: } type_info.accepts_args( - new_func.type, - with_args=arg_types, - with_kwargs=kwarg_types, - raise_exception=True, + new_func.type, with_args=arg_types, with_kwargs=kwarg_types, raise_exception=True ) return_type = ts.VoidType() if is_operator: diff --git a/src/gt4py/next/ffront/past_process_args.py b/src/gt4py/next/ffront/past_process_args.py index 2fde8dd37e..f50bd21fe5 100644 --- a/src/gt4py/next/ffront/past_process_args.py +++ b/src/gt4py/next/ffront/past_process_args.py @@ -50,10 +50,7 @@ def _validate_args(past_node: past.Program, args: list, kwargs: dict[str, Any]) try: type_info.accepts_args( - past_node.type, - with_args=arg_types, - with_kwargs=kwarg_types, - raise_exception=True, + past_node.type, with_args=arg_types, with_kwargs=kwarg_types, raise_exception=True ) except ValueError as err: raise errors.DSLError( diff --git a/src/gt4py/next/ffront/past_to_itir.py b/src/gt4py/next/ffront/past_to_itir.py index 7d59e6fd72..a7e9751c4e 100644 --- a/src/gt4py/next/ffront/past_to_itir.py +++ b/src/gt4py/next/ffront/past_to_itir.py @@ -61,9 +61,7 @@ def __call__(self, inp: ffront_stages.PastClosure) -> stages.ProgramCall: devtools.debug(itir_program) return stages.ProgramCall( - itir_program, - inp.args, - inp.kwargs | {"column_axis": _column_axis(all_closure_vars)}, + itir_program, inp.args, inp.kwargs | {"column_axis": _column_axis(all_closure_vars)} ) @@ -80,10 +78,7 @@ def _column_axis(all_closure_vars: dict[str, Any]) -> Optional[common.Dimension] for name, gt_callable in transform_utils._filter_closure_vars_by_type( all_closure_vars, gtcallable.GTCallable ).items(): - if isinstance( - (type_ := gt_callable.__gt_type__()), - ts_ffront.ScanOperatorType, - ): + if isinstance((type_ := gt_callable.__gt_type__()), ts_ffront.ScanOperatorType): scanops_per_axis.setdefault(type_.axis, []).append(name) if len(scanops_per_axis.values()) == 0: @@ -108,9 +103,7 @@ def _size_arg_from_field(field_name: str, dim: int) -> str: return f"__{field_name}_size_{dim}" -def _flatten_tuple_expr( - node: past.Expr, -) -> list[past.Name | past.Subscript]: +def _flatten_tuple_expr(node: past.Expr) -> list[past.Name | past.Subscript]: if isinstance(node, (past.Name, past.Subscript)): return [node] elif isinstance(node, past.TupleExpr): @@ -208,10 +201,7 @@ def visit_Program( closures.append(self._visit_stencil_call(stmt, **kwargs)) return itir.FencilDefinition( - id=node.id, - function_definitions=function_definitions, - params=params, - closures=closures, + id=node.id, function_definitions=function_definitions, params=params, closures=closures ) def _visit_stencil_call(self, node: past.Call, **kwargs: Any) -> itir.StencilClosure: @@ -227,10 +217,7 @@ def _visit_stencil_call(self, node: past.Call, **kwargs: Any) -> itir.StencilClo assert isinstance(node.func.type, (ts_ffront.FieldOperatorType, ts_ffront.ScanOperatorType)) args, node_kwargs = type_info.canonicalize_arguments( - node.func.type, - node.args, - node_kwargs, - use_signature_ordering=True, + node.func.type, node.args, node_kwargs, use_signature_ordering=True ) lowered_args, lowered_kwargs = self.visit(args, **kwargs), self.visit(node_kwargs, **kwargs) @@ -253,9 +240,7 @@ def _visit_stencil_call(self, node: past.Call, **kwargs: Any) -> itir.StencilClo else: # field operators return a tuple of iterators, deref element-wise stencil_body = lowering_utils.process_elements( - im.deref, - im.call(node.func.id)(*stencil_args), - node.func.type.definition.returns, + im.deref, im.call(node.func.id)(*stencil_args), node.func.type.definition.returns ) return itir.StencilClosure( @@ -281,8 +266,7 @@ def _visit_slice_bound( ) if slice_bound.value < 0: lowered_bound = itir.FunCall( - fun=itir.SymRef(id="plus"), - args=[dim_size, self.visit(slice_bound, **kwargs)], + fun=itir.SymRef(id="plus"), args=[dim_size, self.visit(slice_bound, **kwargs)] ) else: lowered_bound = self.visit(slice_bound, **kwargs) @@ -380,10 +364,7 @@ def _construct_itir_domain_arg( ) def _construct_itir_initialized_domain_arg( - self, - dim_i: int, - dim: common.Dimension, - node_domain: past.Dict, + self, dim_i: int, dim: common.Dimension, node_domain: past.Dict ) -> list[itir.FunCall]: assert len(node_domain.values_[dim_i].elts) == 2 keys_dims_types = cast(ts.DimensionType, node_domain.keys_[dim_i].type).dim diff --git a/src/gt4py/next/ffront/type_info.py b/src/gt4py/next/ffront/type_info.py index 83b86ac656..2bd4f21993 100644 --- a/src/gt4py/next/ffront/type_info.py +++ b/src/gt4py/next/ffront/type_info.py @@ -41,9 +41,7 @@ def promote_el(type_el: ts.TypeSpec) -> ts.TypeSpec: def promote_zero_dims( - function_type: ts.FunctionType, - args: list[ts.TypeSpec], - kwargs: dict[str, ts.TypeSpec], + function_type: ts.FunctionType, args: list[ts.TypeSpec], kwargs: dict[str, ts.TypeSpec] ) -> tuple[list, dict]: """ Promote arg types to zero dimensional fields if compatible and required by function signature. @@ -75,10 +73,7 @@ def _as_field(arg_el: ts.TypeSpec, path: tuple[int, ...]) -> ts.TypeSpec: new_args = [*args] for i, (param, arg) in enumerate( - zip( - function_type.pos_only_args + list(function_type.pos_or_kw_args.values()), - args, - ) + zip(function_type.pos_only_args + list(function_type.pos_or_kw_args.values()), args) ): new_args[i] = promote_arg(param, arg) new_kwargs = {**kwargs} @@ -204,9 +199,7 @@ def _as_field(dtype: ts.TypeSpec, path: tuple[int, ...]) -> ts.FieldType: @type_info.function_signature_incompatibilities.register def function_signature_incompatibilities_scanop( - scanop_type: ts_ffront.ScanOperatorType, - args: list[ts.TypeSpec], - kwargs: dict[str, ts.TypeSpec], + scanop_type: ts_ffront.ScanOperatorType, args: list[ts.TypeSpec], kwargs: dict[str, ts.TypeSpec] ) -> Iterator[str]: if not all( type_info.is_type_or_tuple_of_type(arg, (ts.ScalarType, ts.FieldType)) for arg in args @@ -276,9 +269,7 @@ def function_signature_incompatibilities_scanop( @type_info.function_signature_incompatibilities.register def function_signature_incompatibilities_program( - program_type: ts_ffront.ProgramType, - args: list[ts.TypeSpec], - kwargs: dict[str, ts.TypeSpec], + program_type: ts_ffront.ProgramType, args: list[ts.TypeSpec], kwargs: dict[str, ts.TypeSpec] ) -> Iterator[str]: args, kwargs = type_info.canonicalize_arguments( program_type.definition, args, kwargs, ignore_errors=True @@ -318,6 +309,5 @@ def return_type_scanop( [callable_type.axis], ) return type_info.apply_to_primitive_constituents( - carry_dtype, - lambda arg: ts.FieldType(dims=promoted_dims, dtype=cast(ts.ScalarType, arg)), + carry_dtype, lambda arg: ts.FieldType(dims=promoted_dims, dtype=cast(ts.ScalarType, arg)) ) diff --git a/src/gt4py/next/iterator/embedded.py b/src/gt4py/next/iterator/embedded.py index c9552e7138..97f38177b5 100644 --- a/src/gt4py/next/iterator/embedded.py +++ b/src/gt4py/next/iterator/embedded.py @@ -832,11 +832,7 @@ def deref(self) -> Any: assert _is_concrete_position(shifted_pos) position = {**shifted_pos, **slice_column} - return _make_tuple( - self.field, - position, - column_axis=self.column_axis, - ) + return _make_tuple(self.field, position, column_axis=self.column_axis) def _get_sparse_dimensions(axes: Sequence[common.Dimension]) -> list[Tag]: @@ -856,10 +852,7 @@ def _wrap_field(field: common.Field | tuple) -> NDArrayLocatedFieldWrapper | tup def make_in_iterator( - inp_: common.Field, - pos: Position, - *, - column_axis: Optional[Tag], + inp_: common.Field, pos: Position, *, column_axis: Optional[Tag] ) -> ItIterator: inp = _wrap_field(inp_) axes = _get_axes(inp) @@ -873,11 +866,7 @@ def make_in_iterator( # if we deal with column stencil the column position is just an offset by which the whole column needs to be shifted assert column_range is not None new_pos[column_axis] = column_range.start - it = MDIterator( - inp, - new_pos, - column_axis=column_axis, - ) + it = MDIterator(inp, new_pos, column_axis=column_axis) if len(sparse_dimensions) >= 1: if len(sparse_dimensions) == 1: return SparseListIterator(it, sparse_dimensions[0]) @@ -1006,8 +995,7 @@ def _range2slice(r: range | common.IntIndex) -> slice | common.IntIndex: def _shift_field_indices( - ranges_or_indices: tuple[range | common.IntIndex, ...], - offsets: tuple[int, ...], + ranges_or_indices: tuple[range | common.IntIndex, ...], offsets: tuple[int, ...] ) -> tuple[ArrayIndex, ...]: return tuple( _range2slice(r) if o == 0 else _shift_range(r, o) @@ -1358,10 +1346,7 @@ def sten(*lists): n = len(lst) res = init for i in range(n): - res = fun( - res, - *(lst[i] for lst in lists), - ) + res = fun(res, *(lst[i] for lst in lists)) return res return sten @@ -1549,8 +1534,7 @@ def closure( del domain[column_axis.value] column_range = common.NamedRange( - column_axis, - common.UnitRange(column.col_range.start, column.col_range.stop), + column_axis, common.UnitRange(column.col_range.start, column.col_range.stop) ) out = as_tuple_field(out) if is_tuple_of_field(out) else _wrap_field(out) @@ -1563,11 +1547,7 @@ def _closure_runner(): for pos in _domain_iterator(domain): promoted_ins = [promote_scalars(inp) for inp in ins] ins_iters = list( - make_in_iterator( - inp, - pos, - column_axis=column.axis if column else None, - ) + make_in_iterator(inp, pos, column_axis=column.axis if column else None) for inp in promoted_ins ) res = sten(*ins_iters) diff --git a/src/gt4py/next/iterator/ir.py b/src/gt4py/next/iterator/ir.py index ce45af0870..56f931f451 100644 --- a/src/gt4py/next/iterator/ir.py +++ b/src/gt4py/next/iterator/ir.py @@ -156,19 +156,8 @@ def _output_validator(self: datamodels.DataModelTP, attribute: datamodels.Attrib "mod", "floordiv", # TODO see https://github.com/GridTools/gt4py/issues/1136 } -BINARY_MATH_COMPARISON_BUILTINS = { - "eq", - "less", - "greater", - "greater_equal", - "less_equal", - "not_eq", -} -BINARY_LOGICAL_BUILTINS = { - "and_", - "or_", - "xor_", -} +BINARY_MATH_COMPARISON_BUILTINS = {"eq", "less", "greater", "greater_equal", "less_equal", "not_eq"} +BINARY_LOGICAL_BUILTINS = {"and_", "or_", "xor_"} ARITHMETIC_BUILTINS = { *UNARY_MATH_NUMBER_BUILTINS, diff --git a/src/gt4py/next/iterator/ir_utils/ir_makers.py b/src/gt4py/next/iterator/ir_utils/ir_makers.py index e64220844d..8e505be0ec 100644 --- a/src/gt4py/next/iterator/ir_utils/ir_makers.py +++ b/src/gt4py/next/iterator/ir_utils/ir_makers.py @@ -353,9 +353,7 @@ def promote_to_const_iterator(expr: str | itir.Expr) -> itir.Expr: return lift(lambda_()(expr))() -def promote_to_lifted_stencil( - op: str | itir.SymRef | Callable, -) -> Callable[..., itir.FunCall]: +def promote_to_lifted_stencil(op: str | itir.SymRef | Callable) -> Callable[..., itir.FunCall]: """ Promotes a function `op` from values to iterators. diff --git a/src/gt4py/next/iterator/pretty_printer.py b/src/gt4py/next/iterator/pretty_printer.py index c4ac84dd57..041c12e7b0 100644 --- a/src/gt4py/next/iterator/pretty_printer.py +++ b/src/gt4py/next/iterator/pretty_printer.py @@ -43,11 +43,7 @@ } # replacements for builtin unary operations -UNARY_OPS: Final = { - "deref": "·", - "lift": "↑", - "not_": "¬", -} +UNARY_OPS: Final = {"deref": "·", "lift": "↑", "not_": "¬"} # operator precedence PRECEDENCE: Final = { diff --git a/src/gt4py/next/iterator/runtime.py b/src/gt4py/next/iterator/runtime.py index 8209c6dd41..f1710159a4 100644 --- a/src/gt4py/next/iterator/runtime.py +++ b/src/gt4py/next/iterator/runtime.py @@ -91,11 +91,7 @@ def __call__(self, *args, backend: Optional[ProgramExecutor] = None, **kwargs): if backend is not None: ensure_processor_kind(backend, ProgramExecutor) - backend( - self.itir(*args, **kwargs), - *args, - **kwargs, - ) + backend(self.itir(*args, **kwargs), *args, **kwargs) else: if fendef_embedded is None: raise RuntimeError("Embedded execution is not registered.") @@ -143,12 +139,7 @@ def _deduce_domain(domain: dict[common.Dimension, range], offset_provider: dict[ ) return domain_builtin( - *tuple( - map( - lambda x: builtins.named_range(x[0], x[1].start, x[1].stop), - domain.items(), - ) - ) + *tuple(map(lambda x: builtins.named_range(x[0], x[1].start, x[1].stop), domain.items())) ) diff --git a/src/gt4py/next/iterator/tracing.py b/src/gt4py/next/iterator/tracing.py index 82e6f20388..d01faea01c 100644 --- a/src/gt4py/next/iterator/tracing.py +++ b/src/gt4py/next/iterator/tracing.py @@ -237,12 +237,7 @@ def closure(domain, stencil, output, inputs): stencil(*(_s(param) for param in inspect.signature(stencil).parameters)) stencil = make_node(stencil) TracerContext.add_closure( - StencilClosure( - domain=domain, - stencil=stencil, - output=output, - inputs=inputs, - ) + StencilClosure(domain=domain, stencil=stencil, output=output, inputs=inputs) ) diff --git a/src/gt4py/next/iterator/transforms/collapse_list_get.py b/src/gt4py/next/iterator/transforms/collapse_list_get.py index 6acb8a79c4..3790545b69 100644 --- a/src/gt4py/next/iterator/transforms/collapse_list_get.py +++ b/src/gt4py/next/iterator/transforms/collapse_list_get.py @@ -43,8 +43,7 @@ def visit_FunCall(self, node: ir.FunCall, **kwargs) -> ir.Node: args=[ ir.FunCall( fun=ir.FunCall( - fun=ir.SymRef(id="shift"), - args=[offset_tag, offset_index], + fun=ir.SymRef(id="shift"), args=[offset_tag, offset_index] ), args=[it], ) diff --git a/src/gt4py/next/iterator/transforms/fuse_maps.py b/src/gt4py/next/iterator/transforms/fuse_maps.py index 694dcd6a61..c10cb6f3e7 100644 --- a/src/gt4py/next/iterator/transforms/fuse_maps.py +++ b/src/gt4py/next/iterator/transforms/fuse_maps.py @@ -110,21 +110,14 @@ def visit_FunCall(self, node: ir.FunCall, **kwargs): inlined_args.append(ir.SymRef(id=outer_op.params[i + first_param].id)) new_params.append(outer_op.params[i + first_param]) new_args.append(node.args[i]) - new_body = ir.FunCall( - fun=outer_op, - args=inlined_args, - ) + new_body = ir.FunCall(fun=outer_op, args=inlined_args) new_body = inline_lambdas.inline_lambda( new_body ) # removes one level of nesting (the recursive inliner could simplify more, however this can also be done on the full tree later) - new_op = ir.Lambda( - params=new_params, - expr=new_body, - ) + new_op = ir.Lambda(params=new_params, expr=new_body) if _is_map(node): return ir.FunCall( - fun=ir.FunCall(fun=ir.SymRef(id="map_"), args=[new_op]), - args=new_args, + fun=ir.FunCall(fun=ir.SymRef(id="map_"), args=[new_op]), args=new_args ) else: # _is_reduce(node) return ir.FunCall( diff --git a/src/gt4py/next/iterator/transforms/global_tmps.py b/src/gt4py/next/iterator/transforms/global_tmps.py index ddbc4e0c74..d099272b2b 100644 --- a/src/gt4py/next/iterator/transforms/global_tmps.py +++ b/src/gt4py/next/iterator/transforms/global_tmps.py @@ -410,8 +410,7 @@ def _max_domain_sizes_by_location_type(offset_provider: Mapping[str, Any]) -> di assert provider.origin_axis.kind == gtx.DimensionKind.HORIZONTAL assert provider.neighbor_axis.kind == gtx.DimensionKind.HORIZONTAL sizes[provider.origin_axis.value] = max( - sizes.get(provider.origin_axis.value, 0), - provider.table.shape[0], + sizes.get(provider.origin_axis.value, 0), provider.table.shape[0] ) sizes[provider.neighbor_axis.value] = max( sizes.get(provider.neighbor_axis.value, 0), diff --git a/src/gt4py/next/iterator/transforms/inline_center_deref_lift_vars.py b/src/gt4py/next/iterator/transforms/inline_center_deref_lift_vars.py index 5285bb94d7..2b6bcf3c9d 100644 --- a/src/gt4py/next/iterator/transforms/inline_center_deref_lift_vars.py +++ b/src/gt4py/next/iterator/transforms/inline_center_deref_lift_vars.py @@ -92,8 +92,7 @@ def visit_FunCall(self, node: itir.FunCall, **kwargs): if any(eligible_params): new_node = inline_lambda( - im.call(node.fun)(*new_args), - eligible_params=eligible_params, + im.call(node.fun)(*new_args), eligible_params=eligible_params ) # TODO(tehrengruber): propagate let outwards return im.let(*bound_scalars.items())(new_node) # type: ignore[arg-type] # mypy not smart enough diff --git a/src/gt4py/next/iterator/transforms/inline_into_scan.py b/src/gt4py/next/iterator/transforms/inline_into_scan.py index a1c9a2eb5b..3ebac575e3 100644 --- a/src/gt4py/next/iterator/transforms/inline_into_scan.py +++ b/src/gt4py/next/iterator/transforms/inline_into_scan.py @@ -93,10 +93,7 @@ def visit_FunCall(self, node: ir.FunCall, **kwargs): ) new_scanpass_body = _lambda_and_lift_inliner(new_scanpass_body) new_scanpass = ir.Lambda( - params=[ - original_scanpass.params[0], - *(ir.Sym(id=ref) for ref in refs_in_args), - ], + params=[original_scanpass.params[0], *(ir.Sym(id=ref) for ref in refs_in_args)], expr=new_scanpass_body, ) new_scan = ir.FunCall( diff --git a/src/gt4py/next/iterator/transforms/inline_lifts.py b/src/gt4py/next/iterator/transforms/inline_lifts.py index 73991a869c..bf56186253 100644 --- a/src/gt4py/next/iterator/transforms/inline_lifts.py +++ b/src/gt4py/next/iterator/transforms/inline_lifts.py @@ -70,9 +70,7 @@ def _is_scan(node: ir.FunCall): def _transform_and_extract_lift_args( - node: ir.FunCall, - symtable: dict[eve.SymbolName, ir.Sym], - extracted_args: dict[ir.Sym, ir.Expr], + node: ir.FunCall, symtable: dict[eve.SymbolName, ir.Sym], extracted_args: dict[ir.Sym, ir.Expr] ): """ Transform and extract non-symbol arguments of a lifted stencil call. diff --git a/src/gt4py/next/iterator/transforms/pass_manager.py b/src/gt4py/next/iterator/transforms/pass_manager.py index 06005d7135..5852ba9ae5 100644 --- a/src/gt4py/next/iterator/transforms/pass_manager.py +++ b/src/gt4py/next/iterator/transforms/pass_manager.py @@ -147,9 +147,7 @@ def apply_common_transforms( for _ in range(10): inlined = InlineLifts().visit(ir) inlined = InlineLambdas.apply( - inlined, - opcount_preserving=True, - force_inline_lift_args=True, + inlined, opcount_preserving=True, force_inline_lift_args=True ) if inlined == ir: break @@ -202,9 +200,7 @@ def apply_common_transforms( ir = MergeLet().visit(ir) ir = InlineLambdas.apply( - ir, - opcount_preserving=True, - force_inline_lambda_args=force_inline_lambda_args, + ir, opcount_preserving=True, force_inline_lambda_args=force_inline_lambda_args ) return ir diff --git a/src/gt4py/next/iterator/transforms/power_unrolling.py b/src/gt4py/next/iterator/transforms/power_unrolling.py index ac71f2747d..0492c87875 100644 --- a/src/gt4py/next/iterator/transforms/power_unrolling.py +++ b/src/gt4py/next/iterator/transforms/power_unrolling.py @@ -20,9 +20,7 @@ from gt4py.next.iterator.transforms.inline_lambdas import InlineLambdas -def _is_power_call( - node: ir.FunCall, -) -> bool: +def _is_power_call(node: ir.FunCall) -> bool: """Match expressions of the form `power(base, integral_literal)`.""" return ( isinstance(node.fun, ir.SymRef) @@ -71,8 +69,7 @@ def visit_FunCall(self, node: ir.FunCall): # Nest target expression to avoid multiple redundant evaluations for i in range(pow_max, 0, -1): ret = im.let( - f"power_{2 ** i}", - im.multiplies_(f"power_{2**(i-1)}", f"power_{2**(i-1)}"), + f"power_{2 ** i}", im.multiplies_(f"power_{2**(i-1)}", f"power_{2**(i-1)}") )(ret) ret = im.let("power_1", base)(ret) diff --git a/src/gt4py/next/iterator/transforms/propagate_deref.py b/src/gt4py/next/iterator/transforms/propagate_deref.py index 9f8bff7a84..21551fab6a 100644 --- a/src/gt4py/next/iterator/transforms/propagate_deref.py +++ b/src/gt4py/next/iterator/transforms/propagate_deref.py @@ -52,8 +52,7 @@ def visit_FunCall(self, node: ir.FunCall): lambda_args: list[ir.Expr] = node.args[0].args # type: ignore[attr-defined] # invariant ensured by pattern match above node = ir.FunCall( fun=ir.Lambda( - params=lambda_fun.params, - expr=ir.FunCall(fun=builtin, args=[lambda_fun.expr]), + params=lambda_fun.params, expr=ir.FunCall(fun=builtin, args=[lambda_fun.expr]) ), args=lambda_args, ) diff --git a/src/gt4py/next/iterator/transforms/remap_symbols.py b/src/gt4py/next/iterator/transforms/remap_symbols.py index cba0f364e7..08e4fa827f 100644 --- a/src/gt4py/next/iterator/transforms/remap_symbols.py +++ b/src/gt4py/next/iterator/transforms/remap_symbols.py @@ -28,10 +28,7 @@ def visit_SymRef(self, node: ir.SymRef, *, symbol_map: Dict[str, ir.Node]): def visit_Lambda(self, node: ir.Lambda, *, symbol_map: Dict[str, ir.Node]): params = {str(p.id) for p in node.params} new_symbol_map = {k: v for k, v in symbol_map.items() if k not in params} - return ir.Lambda( - params=node.params, - expr=self.visit(node.expr, symbol_map=new_symbol_map), - ) + return ir.Lambda(params=node.params, expr=self.visit(node.expr, symbol_map=new_symbol_map)) def generic_visit(self, node: ir.Node, **kwargs: Any): # type: ignore[override] assert isinstance(node, SymbolTableTrait) == isinstance( diff --git a/src/gt4py/next/iterator/transforms/trace_shifts.py b/src/gt4py/next/iterator/transforms/trace_shifts.py index 1b62a8a02e..925dbb8f43 100644 --- a/src/gt4py/next/iterator/transforms/trace_shifts.py +++ b/src/gt4py/next/iterator/transforms/trace_shifts.py @@ -98,9 +98,7 @@ class IteratorArgTracer(IteratorTracer): def shift(self, offsets: tuple[ir.OffsetLiteral, ...]): return IteratorArgTracer( - arg=self.arg, - shift_recorder=self.shift_recorder, - offsets=self.offsets + tuple(offsets), + arg=self.arg, shift_recorder=self.shift_recorder, offsets=self.offsets + tuple(offsets) ) def deref(self): @@ -367,8 +365,7 @@ def apply( def _save_to_annex( - node: ir.Node, - recorded_shifts: dict[int, set[tuple[ir.OffsetLiteral, ...]]], + node: ir.Node, recorded_shifts: dict[int, set[tuple[ir.OffsetLiteral, ...]]] ) -> None: for child_node in node.pre_walk_values(): if id(child_node) in recorded_shifts: diff --git a/src/gt4py/next/iterator/transforms/unroll_reduce.py b/src/gt4py/next/iterator/transforms/unroll_reduce.py index 1e0977e68e..b058fc0a7b 100644 --- a/src/gt4py/next/iterator/transforms/unroll_reduce.py +++ b/src/gt4py/next/iterator/transforms/unroll_reduce.py @@ -43,10 +43,7 @@ def _is_neighbors_or_lifted_and_neighbors(arg: itir.Expr) -> TypeGuard[itir.FunC def _get_neighbors_args(reduce_args: Iterable[itir.Expr]) -> Iterator[itir.FunCall]: - return filter( - _is_neighbors_or_lifted_and_neighbors, - reduce_args, - ) + return filter(_is_neighbors_or_lifted_and_neighbors, reduce_args) def _is_list_of_funcalls(lst: list) -> TypeGuard[list[itir.FunCall]]: @@ -118,9 +115,7 @@ def _make_can_deref(iterator: itir.Expr) -> itir.FunCall: def _make_if(cond: itir.Expr, true_expr: itir.Expr, false_expr: itir.Expr) -> itir.FunCall: return itir.FunCall( - fun=itir.SymRef(id="if_"), - args=[cond, true_expr, false_expr], - location=cond.location, + fun=itir.SymRef(id="if_"), args=[cond, true_expr, false_expr], location=cond.location ) diff --git a/src/gt4py/next/iterator/type_inference.py b/src/gt4py/next/iterator/type_inference.py index 046b8418c5..1aae474c4c 100644 --- a/src/gt4py/next/iterator/type_inference.py +++ b/src/gt4py/next/iterator/type_inference.py @@ -406,13 +406,7 @@ def _default_constraints(): ret=Val(kind=Value(), dtype=FLOAT_DTYPE, size=T0), ), ), - ( - ir.UNARY_MATH_NUMBER_BUILTINS, - FunctionType( - args=Tuple.from_elems(Val_T0_T1), - ret=Val_T0_T1, - ), - ), + (ir.UNARY_MATH_NUMBER_BUILTINS, FunctionType(args=Tuple.from_elems(Val_T0_T1), ret=Val_T0_T1)), ( {"power"}, FunctionType( @@ -438,15 +432,7 @@ def _default_constraints(): ir.BINARY_LOGICAL_BUILTINS, FunctionType(args=Tuple.from_elems(Val_BOOL_T1, Val_BOOL_T1), ret=Val_BOOL_T1), ), - ( - ir.UNARY_LOGICAL_BUILTINS, - FunctionType( - args=Tuple.from_elems( - Val_BOOL_T1, - ), - ret=Val_BOOL_T1, - ), - ), + (ir.UNARY_LOGICAL_BUILTINS, FunctionType(args=Tuple.from_elems(Val_BOOL_T1), ret=Val_BOOL_T1)), ) BUILTIN_TYPES: dict[str, Type] = { @@ -463,10 +449,7 @@ def _default_constraints(): ), ret=Val_BOOL_T1, ), - "if_": FunctionType( - args=Tuple.from_elems(Val_BOOL_T1, T2, T2), - ret=T2, - ), + "if_": FunctionType(args=Tuple.from_elems(Val_BOOL_T1, T2, T2), ret=T2), "lift": FunctionType( args=Tuple.from_elems( FunctionType( @@ -481,10 +464,7 @@ def _default_constraints(): ), "map_": FunctionType( args=Tuple.from_elems( - FunctionType( - args=ValTuple(kind=Value(), dtypes=T2, size=T1), - ret=Val_T0_T1, - ), + FunctionType(args=ValTuple(kind=Value(), dtypes=T2, size=T1), ret=Val_T0_T1) ), ret=FunctionType( args=ValListTuple(kind=Value(), list_dtypes=T2, size=T1), @@ -619,10 +599,7 @@ def visit_Sym(self, node: ir.Sym, **kwargs) -> Type: if node.kind: kind = {"Iterator": Iterator(), "Value": Value()}[node.kind] self.constraints.add( - ( - Val(kind=kind, current_loc=TypeVar.fresh(), defined_loc=TypeVar.fresh()), - result, - ) + (Val(kind=kind, current_loc=TypeVar.fresh(), defined_loc=TypeVar.fresh()), result) ) if node.dtype: assert node.dtype is not None @@ -630,14 +607,7 @@ def visit_Sym(self, node: ir.Sym, **kwargs) -> Type: if node.dtype[1]: dtype = List(dtype=dtype) self.constraints.add( - ( - Val( - dtype=dtype, - current_loc=TypeVar.fresh(), - defined_loc=TypeVar.fresh(), - ), - result, - ) + (Val(dtype=dtype, current_loc=TypeVar.fresh(), defined_loc=TypeVar.fresh()), result) ) return result @@ -681,11 +651,7 @@ def visit_AxisLiteral(self, node: ir.AxisLiteral, **kwargs) -> Val: def visit_OffsetLiteral(self, node: ir.OffsetLiteral, **kwargs) -> TypeVar: return TypeVar.fresh() - def visit_Lambda( - self, - node: ir.Lambda, - **kwargs, - ) -> FunctionType: + def visit_Lambda(self, node: ir.Lambda, **kwargs) -> FunctionType: ptypes = {p.id: self.visit(p, **kwargs) for p in node.params} ret = self.visit(node.expr, **kwargs) return FunctionType(args=Tuple.from_elems(*(ptypes[p.id] for p in node.params)), ret=ret) @@ -724,11 +690,7 @@ def _visit_tuple_get(self, node: ir.FunCall, **kwargs) -> Type: for _ in range(idx): dtype = Tuple(front=TypeVar.fresh(), others=dtype) - val = Val( - kind=kind, - dtype=dtype, - size=size, - ) + val = Val(kind=kind, dtype=dtype, size=size) self.constraints.add((tup, val)) return Val(kind=kind, dtype=elem, size=size) @@ -766,11 +728,7 @@ def _visit_neighbors(self, node: ir.FunCall, **kwargs) -> Type: ), ) ) - lst = List( - dtype=dtype_, - max_length=max_length, - has_skip_values=has_skip_values, - ) + lst = List(dtype=dtype_, max_length=max_length, has_skip_values=has_skip_values) return Val(kind=Value(), dtype=lst, size=size) def _visit_cast_(self, node: ir.FunCall, **kwargs) -> Type: @@ -783,22 +741,9 @@ def _visit_cast_(self, node: ir.FunCall, **kwargs) -> Type: size = TypeVar.fresh() - self.constraints.add( - ( - val_arg_type, - Val( - kind=Value(), - dtype=TypeVar.fresh(), - size=size, - ), - ) - ) + self.constraints.add((val_arg_type, Val(kind=Value(), dtype=TypeVar.fresh(), size=size))) - return Val( - kind=Value(), - dtype=Primitive(name=type_arg.id), - size=size, - ) + return Val(kind=Value(), dtype=Primitive(name=type_arg.id), size=size) def _visit_shift(self, node: ir.FunCall, **kwargs) -> Type: # Calls to shift are handled as being part of the grammar, not @@ -822,7 +767,7 @@ def _visit_shift(self, node: ir.FunCall, **kwargs) -> Type: size=size, current_loc=current_loc_in, defined_loc=defined_loc, - ), + ) ), ret=Val( kind=Iterator(), @@ -849,11 +794,7 @@ def _visit_cartesian_domain(self, node: ir.FunCall, **kwargs) -> Type: def _visit_unstructured_domain(self, node: ir.FunCall, **kwargs) -> Type: return self._visit_domain(node, **kwargs) - def visit_FunCall( - self, - node: ir.FunCall, - **kwargs, - ) -> Type: + def visit_FunCall(self, node: ir.FunCall, **kwargs) -> Type: if isinstance(node.fun, ir.SymRef) and node.fun.id in ir.GRAMMAR_BUILTINS: # builtins that are treated as part of the grammar are handled in `_visit_` return getattr(self, f"_visit_{node.fun.id}")(node, **kwargs) @@ -866,11 +807,7 @@ def visit_FunCall( self.constraints.add((fun, FunctionType(args=args, ret=ret))) return ret - def visit_FunctionDefinition( - self, - node: ir.FunctionDefinition, - **kwargs, - ) -> LetPolymorphic: + def visit_FunctionDefinition(self, node: ir.FunctionDefinition, **kwargs) -> LetPolymorphic: fun = ir.Lambda(params=node.params, expr=node.expr) # Since functions defined in a function definition are let-polymorphic we don't want @@ -884,31 +821,19 @@ def visit_FunctionDefinition( return fun_type - def visit_StencilClosure( - self, - node: ir.StencilClosure, - **kwargs, - ) -> Closure: + def visit_StencilClosure(self, node: ir.StencilClosure, **kwargs) -> Closure: domain = self.visit(node.domain, **kwargs) stencil = self.visit(node.stencil, **kwargs) output = self.visit(node.output, **kwargs) output_dtype = TypeVar.fresh() output_loc = TypeVar.fresh() self.constraints.add( - ( - domain, - Val(kind=Value(), dtype=Primitive(name="domain"), size=Scalar()), - ) + (domain, Val(kind=Value(), dtype=Primitive(name="domain"), size=Scalar())) ) self.constraints.add( ( output, - Val( - kind=Iterator(), - dtype=output_dtype, - size=Column(), - defined_loc=output_loc, - ), + Val(kind=Iterator(), dtype=output_dtype, size=Column(), defined_loc=output_loc), ) ) @@ -946,11 +871,7 @@ def visit_StencilClosure( def visit_FencilWithTemporaries(self, node: FencilWithTemporaries, **kwargs): return self.visit(node.fencil, **kwargs) - def visit_FencilDefinition( - self, - node: ir.FencilDefinition, - **kwargs, - ) -> FencilDefinitionType: + def visit_FencilDefinition(self, node: ir.FencilDefinition, **kwargs) -> FencilDefinitionType: ftypes = [] # Note: functions have to be ordered according to Lisp/Scheme `let*` # statements; that is, functions can only reference other functions @@ -963,9 +884,7 @@ def visit_FencilDefinition( params = [self.visit(p, **kwargs) for p in node.params] self.visit(node.closures, **kwargs) return FencilDefinitionType( - name=str(node.id), - fundefs=Tuple.from_elems(*ftypes), - params=Tuple.from_elems(*params), + name=str(node.id), fundefs=Tuple.from_elems(*ftypes), params=Tuple.from_elems(*params) ) @@ -1012,10 +931,7 @@ def infer_all( if reindex: unified_types, unsatisfiable_constraints = reindex_vars( - ( - unified_types, - unsatisfiable_constraints, - ) + (unified_types, unsatisfiable_constraints) ) result = { diff --git a/src/gt4py/next/otf/binding/cpp_interface.py b/src/gt4py/next/otf/binding/cpp_interface.py index 5f7c4225fe..2d6f806a36 100644 --- a/src/gt4py/next/otf/binding/cpp_interface.py +++ b/src/gt4py/next/otf/binding/cpp_interface.py @@ -20,10 +20,7 @@ CPP_DEFAULT: Final = languages.LanguageWithHeaderFilesSettings( - formatter_key="cpp", - formatter_style="LLVM", - file_extension="cpp", - header_extension="cpp.inc", + formatter_key="cpp", formatter_style="LLVM", file_extension="cpp", header_extension="cpp.inc" ) diff --git a/src/gt4py/next/otf/binding/nanobind.py b/src/gt4py/next/otf/binding/nanobind.py index e308883af6..3e84265c4b 100644 --- a/src/gt4py/next/otf/binding/nanobind.py +++ b/src/gt4py/next/otf/binding/nanobind.py @@ -246,23 +246,17 @@ def create_bindings( doc="", functions=[ BindingFunction( - exported_name=program_source.entry_point.name, - wrapper_name=wrapper_name, - doc="", + exported_name=program_source.entry_point.name, wrapper_name=wrapper_name, doc="" ) ], ), ) src = interface.format_source( - program_source.language_settings, - BindingCodeGenerator.apply(file_binding), + program_source.language_settings, BindingCodeGenerator.apply(file_binding) ) - return stages.BindingSource( - src, - (interface.LibraryDependency("nanobind", "1.4.0"),), - ) + return stages.BindingSource(src, (interface.LibraryDependency("nanobind", "1.4.0"),)) @workflow.make_step diff --git a/src/gt4py/next/otf/compilation/build_systems/cmake.py b/src/gt4py/next/otf/compilation/build_systems/cmake.py index 694a99e54e..82c97941e9 100644 --- a/src/gt4py/next/otf/compilation/build_systems/cmake.py +++ b/src/gt4py/next/otf/compilation/build_systems/cmake.py @@ -57,10 +57,7 @@ def __call__( if source.program_source.language is languages.Cuda: cmake_languages = [*cmake_languages, cmake_lists.Language(name="CUDA")] cmake_lists_src = cmake_lists.generate_cmakelists_source( - name, - source.library_deps, - [header_name, bindings_name], - languages=cmake_languages, + name, source.library_deps, [header_name, bindings_name], languages=cmake_languages ) return CMakeProject( root_path=cache.get_cache_folder(source, cache_lifetime), diff --git a/src/gt4py/next/otf/compilation/build_systems/compiledb.py b/src/gt4py/next/otf/compilation/build_systems/compiledb.py index 140ea6a5fc..d4abb432ba 100644 --- a/src/gt4py/next/otf/compilation/build_systems/compiledb.py +++ b/src/gt4py/next/otf/compilation/build_systems/compiledb.py @@ -52,9 +52,7 @@ class CompiledbFactory( def __call__( self, source: stages.CompilableSource[ - SrcL, - languages.LanguageWithHeaderFilesSettings, - languages.Python, + SrcL, languages.LanguageWithHeaderFilesSettings, languages.Python ], cache_lifetime: config.BuildCacheLifetime, ) -> CompiledbProject: @@ -286,9 +284,7 @@ def _cc_create_compiledb( with log_file.open("w") as log_file_pointer: commands_json_str = subprocess.check_output( - ["ninja", "-t", "compdb"], - cwd=cache_path / "build", - stderr=log_file_pointer, + ["ninja", "-t", "compdb"], cwd=cache_path / "build", stderr=log_file_pointer ).decode("utf-8") commands = json.loads(commands_json_str) diff --git a/src/gt4py/next/otf/recipes.py b/src/gt4py/next/otf/recipes.py index 702d0ebb9d..41d0c8947f 100644 --- a/src/gt4py/next/otf/recipes.py +++ b/src/gt4py/next/otf/recipes.py @@ -35,8 +35,7 @@ class ProgramTransformWorkflow(workflow.NamedStepSequence): kwargs: dict[str, Any] = dataclasses.field(default_factory=dict) def __call__( - self, - inp: ffront_stages.ProgramDefinition | ffront_stages.PastProgramDefinition, + self, inp: ffront_stages.ProgramDefinition | ffront_stages.PastProgramDefinition ) -> stages.ProgramCall: past_stage = self.func_to_past(inp) return self.past_to_itir( @@ -57,12 +56,6 @@ class OTFCompileWorkflow(workflow.NamedStepSequence): """The typical compiled backend steps composed into a workflow.""" translation: step_types.TranslationStep - bindings: workflow.Workflow[ - stages.ProgramSource, - stages.CompilableSource, - ] - compilation: workflow.Workflow[ - stages.CompilableSource, - stages.CompiledProgram, - ] + bindings: workflow.Workflow[stages.ProgramSource, stages.CompilableSource] + compilation: workflow.Workflow[stages.CompilableSource, stages.CompiledProgram] decoration: workflow.Workflow[stages.CompiledProgram, stages.CompiledProgram] diff --git a/src/gt4py/next/otf/workflow.py b/src/gt4py/next/otf/workflow.py index c7d84fc736..8ae741195f 100644 --- a/src/gt4py/next/otf/workflow.py +++ b/src/gt4py/next/otf/workflow.py @@ -91,8 +91,7 @@ def chain(self, next_step: Workflow[EndT, NewEndT]) -> ChainableWorkflowMixin[St @dataclasses.dataclass(frozen=True) class NamedStepSequence( - ChainableWorkflowMixin[StartT, EndT], - ReplaceEnabledWorkflowMixin[StartT, EndT], + ChainableWorkflowMixin[StartT, EndT], ReplaceEnabledWorkflowMixin[StartT, EndT] ): """ Workflow with linear succession of named steps. @@ -253,8 +252,7 @@ def __call__(self, inp: StartT) -> EndT: @dataclasses.dataclass(frozen=True) class SkippableStep( - ChainableWorkflowMixin[StartT, EndT], - ReplaceEnabledWorkflowMixin[StartT, EndT], + ChainableWorkflowMixin[StartT, EndT], ReplaceEnabledWorkflowMixin[StartT, EndT] ): step: Workflow[StartT, EndT] diff --git a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_ir.py b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_ir.py index 0b465047cb..b266f577c3 100644 --- a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_ir.py +++ b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_ir.py @@ -147,11 +147,7 @@ class TemporaryAllocation(Node): ARITHMETIC_BUILTINS = itir.ARITHMETIC_BUILTINS TYPEBUILTINS = itir.TYPEBUILTINS -BUILTINS = { - *GTFN_BUILTINS, - *ARITHMETIC_BUILTINS, - *TYPEBUILTINS, -} +BUILTINS = {*GTFN_BUILTINS, *ARITHMETIC_BUILTINS, *TYPEBUILTINS} class TagDefinition(Node): diff --git a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_ir_to_gtfn_im_ir.py b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_ir_to_gtfn_im_ir.py index c92b269a3a..30ef08a04a 100644 --- a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_ir_to_gtfn_im_ir.py +++ b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_ir_to_gtfn_im_ir.py @@ -58,10 +58,7 @@ def _is_shifted_or_lifted_and_shifted(arg: gtfn_ir_common.Expr) -> TypeGuard[gtf def _get_shifted_args(reduce_args: Iterable[gtfn_ir_common.Expr]) -> Iterator[gtfn_ir.FunCall]: - return filter( - _is_shifted_or_lifted_and_shifted, - reduce_args, - ) + return filter(_is_shifted_or_lifted_and_shifted, reduce_args) def _is_list_of_funcalls(lst: list) -> TypeGuard[list[gtfn_ir.FunCall]]: @@ -137,10 +134,7 @@ def _make_sparse_acess( ) -> gtfn_ir.FunCall: return gtfn_ir.FunCall( fun=gtfn_ir_common.SymRef(id="tuple_get"), - args=[ - nbh_iter, - gtfn_ir.FunCall(fun=gtfn_ir_common.SymRef(id="deref"), args=[field_ref]), - ], + args=[nbh_iter, gtfn_ir.FunCall(fun=gtfn_ir_common.SymRef(id="deref"), args=[field_ref])], ) @@ -233,8 +227,7 @@ def _expand_symref( for arg in new_args ] rhs = gtfn_ir.FunCall( - fun=fun, - args=[gtfn_ir_common.SymRef(id=red_idx), *plugged_in_args], + fun=fun, args=[gtfn_ir_common.SymRef(id=red_idx), *plugged_in_args] ) self.imp_list_ir.append(AssignStmt(lhs=gtfn_ir_common.SymRef(id=red_idx), rhs=rhs)) @@ -265,13 +258,7 @@ def handle_Reduction(self, node: gtfn_ir.FunCall, **kwargs: Any) -> gtfn_ir_comm return gtfn_ir_common.SymRef(id=red_idx) def visit_FunCall(self, node: gtfn_ir.FunCall, **kwargs: Any) -> gtfn_ir_common.Expr: - if any( - isinstance( - arg, - gtfn_ir.Lambda, - ) - for arg in node.args - ): + if any(isinstance(arg, gtfn_ir.Lambda) for arg in node.args): # do not try to lower constructs that take lambdas as argument to something more readable lam_idx = self.uids.sequential_id(prefix="lam") self.imp_list_ir.append(InitStmt(lhs=gtfn_ir_common.Sym(id=f"{lam_idx}"), rhs=node)) @@ -293,10 +280,7 @@ def visit_FunCall(self, node: gtfn_ir.FunCall, **kwargs: Any) -> gtfn_ir_common. ) else: self.imp_list_ir.append( - InitStmt( - lhs=gtfn_ir_common.Sym(id=f"{param.id}"), - rhs=arg, - ) + InitStmt(lhs=gtfn_ir_common.Sym(id=f"{param.id}"), rhs=arg) ) expr = self.visit(node.fun.expr, **kwargs) self.imp_list_ir.append(InitStmt(lhs=gtfn_ir_common.Sym(id=f"{lam_idx}"), rhs=expr)) @@ -343,9 +327,7 @@ def visit_FunctionDefinition( ret = self.visit(node.expr, localized_symbols={}, **kwargs) return ImperativeFunctionDefinition( - id=node.id, - params=node.params, - fun=[*self.imp_list_ir, ReturnStmt(ret=ret)], + id=node.id, params=node.params, fun=[*self.imp_list_ir, ReturnStmt(ret=ret)] ) def visit_ScanPassDefinition( diff --git a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py index 50d71cb94a..ca293aa235 100644 --- a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py +++ b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py @@ -133,8 +133,7 @@ def _process_regular_arguments( return parameters, arg_exprs def _process_connectivity_args( - self, - offset_provider: dict[str, Connectivity | Dimension], + self, offset_provider: dict[str, Connectivity | Dimension] ) -> tuple[list[interface.Parameter], list[str]]: parameters: list[interface.Parameter] = [] arg_exprs: list[str] = [] @@ -233,9 +232,7 @@ def generate_stencil_source( ) -> str: new_program = self._preprocess_program(program, offset_provider, runtime_lift_mode) gtfn_ir = GTFN_lowering.apply( - new_program, - offset_provider=offset_provider, - column_axis=column_axis, + new_program, offset_provider=offset_provider, column_axis=column_axis ) if self.use_imperative_backend: @@ -246,8 +243,7 @@ def generate_stencil_source( return codegen.format_source("cpp", generated_code, style="LLVM") def __call__( - self, - inp: stages.ProgramCall, + self, inp: stages.ProgramCall ) -> stages.ProgramSource[languages.NanobindSrcL, languages.LanguageWithHeaderFilesSettings]: """Generate GTFN C++ code from the ITIR definition.""" program: itir.FencilDefinition = inp.program diff --git a/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py b/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py index 2207b1c1d5..4617e54eae 100644 --- a/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py +++ b/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py @@ -170,10 +170,7 @@ def _collect_offset_definitions( offset_definitions[offset_name] = TagDefinition(name=Sym(id=offset_name)) connectivity: common.Connectivity = dim_or_connectivity - for dim in [ - connectivity.origin_axis, - connectivity.neighbor_axis, - ]: + for dim in [connectivity.origin_axis, connectivity.neighbor_axis]: if dim.kind != common.DimensionKind.HORIZONTAL: raise NotImplementedError() offset_definitions[dim.value] = TagDefinition( @@ -359,10 +356,7 @@ def _visit_tuple_get(self, node: itir.FunCall, **kwargs: Any) -> Node: assert isinstance(node.args[0], itir.Literal) return FunCall( fun=SymRef(id="tuple_get"), - args=[ - _literal_as_integral_constant(node.args[0]), - self.visit(node.args[1]), - ], + args=[_literal_as_integral_constant(node.args[0]), self.visit(node.args[1])], ) def _visit_list_get(self, node: itir.FunCall, **kwargs: Any) -> Node: @@ -374,13 +368,7 @@ def _visit_list_get(self, node: itir.FunCall, **kwargs: Any) -> Node: node.args[0] ) # from unroll_reduce we get a `SymRef` which is refering to an `OffsetLiteral` which is lowered to integral_constant ) - return FunCall( - fun=SymRef(id="tuple_get"), - args=[ - tuple_idx, - self.visit(node.args[1]), - ], - ) + return FunCall(fun=SymRef(id="tuple_get"), args=[tuple_idx, self.visit(node.args[1])]) def _visit_cartesian_domain(self, node: itir.FunCall, **kwargs: Any) -> Node: sizes, domain_offsets = self._make_domain(node) @@ -397,9 +385,7 @@ def _visit_unstructured_domain(self, node: itir.FunCall, **kwargs: Any) -> Node: ): connectivities.append(SymRef(id=o)) return UnstructuredDomain( - tagged_sizes=sizes, - tagged_offsets=domain_offsets, - connectivities=connectivities, + tagged_sizes=sizes, tagged_offsets=domain_offsets, connectivities=connectivities ) def visit_FunCall(self, node: itir.FunCall, **kwargs: Any) -> Node: @@ -535,10 +521,7 @@ def visit_FencilDefinition( self, node: itir.FencilDefinition, **kwargs: Any ) -> FencilDefinition: extracted_functions: list[Union[FunctionDefinition, ScanPassDefinition]] = [] - executions = self.visit( - node.closures, - extracted_functions=extracted_functions, - ) + executions = self.visit(node.closures, extracted_functions=extracted_functions) executions = self._merge_scans(executions) function_definitions = self.visit(node.function_definitions) + extracted_functions offset_definitions = { diff --git a/src/gt4py/next/program_processors/processor_interface.py b/src/gt4py/next/program_processors/processor_interface.py index fcde1cc2b6..cab490c853 100644 --- a/src/gt4py/next/program_processors/processor_interface.py +++ b/src/gt4py/next/program_processors/processor_interface.py @@ -207,11 +207,7 @@ def program_executor( return cast( ProgramExecutor, make_program_processor( - func, - ProgramExecutor, - name=name, - accept_args=accept_args, - accept_kwargs=accept_kwargs, + func, ProgramExecutor, name=name, accept_args=accept_args, accept_kwargs=accept_kwargs ), ) diff --git a/src/gt4py/next/program_processors/runners/dace.py b/src/gt4py/next/program_processors/runners/dace.py index 2291541dd6..2920b3d812 100644 --- a/src/gt4py/next/program_processors/runners/dace.py +++ b/src/gt4py/next/program_processors/runners/dace.py @@ -77,8 +77,7 @@ class Params: lambda o: f"run_dace_{o.name_device}{o.name_temps}{o.name_cached}{o.name_postfix}" ) auto_optimize = factory.Trait( - otf_workflow__translation__auto_optimize=True, - name_temps="_opt", + otf_workflow__translation__auto_optimize=True, name_temps="_opt" ) use_field_canonical_representation: bool = False diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py index 590b111437..918a7db15a 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py @@ -118,8 +118,7 @@ def _ensure_is_on_device( def get_connectivity_args( - neighbor_tables: Mapping[str, common.NeighborTable], - device: dace.dtypes.DeviceType, + neighbor_tables: Mapping[str, common.NeighborTable], device: dace.dtypes.DeviceType ) -> dict[str, Any]: return { connectivity_identifier(offset): _ensure_is_on_device(offset_provider.table, device) @@ -267,9 +266,7 @@ def build_sdfg_from_itir( getframeinfo(currentframe()), # type: ignore[arg-type] ) nested_sdfg.debuginfo = dace.dtypes.DebugInfo( - start_line=frameinfo.lineno, - end_line=frameinfo.lineno, - filename=frameinfo.filename, + start_line=frameinfo.lineno, end_line=frameinfo.lineno, filename=frameinfo.filename ) # TODO(edopao): remove `inline_loop_blocks` when DaCe transformations support LoopRegion construct diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py index 2df7357f17..f0cfad5f1f 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py @@ -100,18 +100,11 @@ def _get_scan_dim( else enumerate(output_type.dims) ) ] - return ( - column_axis.value, - sorted_dims.index(column_axis), - output_type.dtype, - ) + return (column_axis.value, sorted_dims.index(column_axis), output_type.dtype) def _make_array_shape_and_strides( - name: str, - dims: Sequence[Dimension], - offset_provider: Mapping[str, Any], - sort_dims: bool, + name: str, dims: Sequence[Dimension], offset_provider: Mapping[str, Any], sort_dims: bool ) -> tuple[list[dace.symbol], list[dace.symbol]]: """ Parse field dimensions and allocate symbols for array shape and strides. @@ -182,24 +175,13 @@ def __init__( self.tmps = tmps self.use_field_canonical_representation = use_field_canonical_representation - def add_storage( - self, - sdfg: dace.SDFG, - name: str, - type_: ts.TypeSpec, - sort_dimensions: bool, - ): + def add_storage(self, sdfg: dace.SDFG, name: str, type_: ts.TypeSpec, sort_dimensions: bool): if isinstance(type_, ts.FieldType): shape, strides = _make_array_shape_and_strides( name, type_.dims, self.offset_provider, sort_dimensions ) dtype = as_dace_type(type_.dtype) - sdfg.add_array( - name, - shape=shape, - strides=strides, - dtype=dtype, - ) + sdfg.add_array(name, shape=shape, strides=strides, dtype=dtype) elif isinstance(type_, ts.ScalarType): dtype = as_dace_type(type_) @@ -263,10 +245,7 @@ def add_storage_for_temporaries( # Loop through all dimensions to visit the symbolic expressions for array shape and offset. # These expressions are later mapped to interstate symbols. - for (_, (begin, end)), shape_sym in zip( - tmp_domain, - tmp_array.shape, - ): + for (_, (begin, end)), shape_sym in zip(tmp_domain, tmp_array.shape): """ The temporary field has a dimension range defined by `begin` and `end` values. Therefore, the actual size is given by the difference `end.value - begin.value`. @@ -283,11 +262,7 @@ def add_storage_for_temporaries( return tmp_symbols - def create_memlet_at( - self, - field_name: str, - index: dict[str, str], - ): + def create_memlet_at(self, field_name: str, index: dict[str, str]): field_type = cast(ts.FieldType, self.storage_types[field_name]) if self.use_field_canonical_representation: field_index = [index[dim.value] for _, dim in get_sorted_dims(field_type.dims)] @@ -322,10 +297,7 @@ def visit_FencilDefinition(self, node: itir.FencilDefinition): # Add program parameters as SDFG storages. for param, type_ in zip(node.params, self.param_types): self.add_storage( - program_sdfg, - str(param.id), - type_, - self.use_field_canonical_representation, + program_sdfg, str(param.id), type_, self.use_field_canonical_representation ) if self.tmps: @@ -333,11 +305,7 @@ def visit_FencilDefinition(self, node: itir.FencilDefinition): # on the first interstate edge define symbols for shape and offsets of temporary arrays last_state = program_sdfg.add_state("init_symbols_for_temporaries") program_sdfg.add_edge( - entry_state, - last_state, - dace.InterstateEdge( - assignments=tmp_symbols, - ), + entry_state, last_state, dace.InterstateEdge(assignments=tmp_symbols) ) else: last_state = entry_state @@ -350,10 +318,7 @@ def visit_FencilDefinition(self, node: itir.FencilDefinition): [offset_provider.origin_axis, local_dim], ts.ScalarType(scalar_kind) ) self.add_storage( - program_sdfg, - connectivity_identifier(offset), - type_, - sort_dimensions=False, + program_sdfg, connectivity_identifier(offset), type_, sort_dimensions=False ) # Create a nested SDFG for all stencil closures. diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py index 9267270653..f400593695 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py @@ -202,12 +202,7 @@ def _visit_lift_in_neighbors_reduction( lifted_indices.pop(origin_dim) lifted_indices[neighbor_dim] = neighbor_index_node lifted_args.append( - IteratorExpr( - arg.field, - lifted_indices, - arg.dtype, - arg.dimensions, - ) + IteratorExpr(arg.field, lifted_indices, arg.dtype, arg.dimensions) ) else: lifted_args.append(arg) @@ -283,11 +278,7 @@ def _visit_lift_in_neighbors_reduction( parent_state.add_edge(access_node, None, nested_sdfg_node, inner_connector, memlet) else: parent_state.add_memlet_path( - access_node, - map_entry, - nested_sdfg_node, - dst_conn=inner_connector, - memlet=memlet, + access_node, map_entry, nested_sdfg_node, dst_conn=inner_connector, memlet=memlet ) parent_state.add_memlet_path( @@ -556,11 +547,7 @@ def builtin_can_deref( expr_code = " and ".join(f"{v} != {neighbor_skip_value}" for v in internals) return transformer.add_expr_tasklet( - list(zip(args, internals)), - expr_code, - dace.dtypes.bool, - "can_deref", - dace_debuginfo=di, + list(zip(args, internals)), expr_code, dace.dtypes.bool, "can_deref", dace_debuginfo=di ) @@ -699,9 +686,7 @@ def builtin_list_get( transformer.context.body.add_scalar(result_name, args[1].dtype, transient=True) result_node = transformer.context.state.add_access(result_name) transformer.context.state.add_nedge( - args[1].value, - result_node, - dace.Memlet(data=args[1].value.data, subset=index_value), + args[1].value, result_node, dace.Memlet(data=args[1].value.data, subset=index_value) ) return [ValueExpr(result_node, args[1].dtype)] @@ -727,11 +712,7 @@ def builtin_cast( assert isinstance(node_type, itir_typing.Val) type_ = itir_type_as_dace_type(node_type.dtype) return transformer.add_expr_tasklet( - list(zip(args, internals)), - expr, - type_, - "cast", - dace_debuginfo=di, + list(zip(args, internals)), expr, type_, "cast", dace_debuginfo=di ) @@ -808,12 +789,7 @@ class GatherLambdaSymbolsPass(eve.NodeVisitor): _symbol_map: dict[str, TaskletExpr | tuple[ValueExpr]] _parent_symbol_map: dict[str, TaskletExpr] - def __init__( - self, - sdfg, - state, - parent_symbol_map, - ): + def __init__(self, sdfg, state, parent_symbol_map): self._sdfg = sdfg self._state = state self._symbol_map = {} @@ -891,12 +867,7 @@ def symbol_refs(self): """Dictionary of symbols referenced from the output expression.""" return self._symbol_map - def __init__( - self, - sdfg, - state, - node_types, - ): + def __init__(self, sdfg, state, node_types): self._sdfg = sdfg self._state = state self._node_types = node_types @@ -1022,9 +993,7 @@ def visit_Lambda( result_name, debuginfo=lambda_sdfg.debuginfo ) lambda_state.add_nedge( - expr.value, - result_access, - dace.Memlet(data=result_access.data, subset="0"), + expr.value, result_access, dace.Memlet(data=result_access.data, subset="0") ) result = ValueExpr(value=result_access, dtype=expr.dtype) else: @@ -1164,11 +1133,7 @@ def _visit_deref(self, node: itir.FunCall) -> list[ValueExpr]: internals = [f"{arg.value.data}_v" for arg in args] expr = f"{internals[0]}[{', '.join(internals[1:])}]" return self.add_expr_tasklet( - list(zip(args, internals)), - expr, - iterator.dtype, - "deref", - dace_debuginfo=di, + list(zip(args, internals)), expr, iterator.dtype, "deref", dace_debuginfo=di ) else: @@ -1197,9 +1162,7 @@ def _visit_deref(self, node: itir.FunCall) -> list[ValueExpr]: # we create a mapped tasklet for array slicing index_name = unique_name(f"_i_{neighbor_dim}") - map_ranges = { - index_name: f"0:{offset_provider.max_neighbors}", - } + map_ranges = {index_name: f"0:{offset_provider.max_neighbors}"} src_subset = ",".join( [f"_i_{dim}" if dim in iterator.indices else index_name for dim in sorted_dims] ) @@ -1207,15 +1170,11 @@ def _visit_deref(self, node: itir.FunCall) -> list[ValueExpr]: "deref", map_ranges, inputs={k: v for k, v in zip(deref_connectors, deref_memlets)}, - outputs={ - "_out": dace.Memlet.from_array(result_name, result_array), - }, + outputs={"_out": dace.Memlet.from_array(result_name, result_array)}, code=f"_out[{index_name}] = _inp[{src_subset}]", external_edges=True, input_nodes={node.data: node for node in deref_nodes}, - output_nodes={ - result_name: result_node, - }, + output_nodes={result_name: result_node}, debuginfo=di, ) return [ValueExpr(result_node, iterator.dtype)] @@ -1277,10 +1236,7 @@ def _visit_shift(self, node: itir.FunCall) -> IteratorExpr | list[ValueExpr]: shifted_dim = offset_provider.origin_axis.value target_dim = offset_provider.neighbor_axis.value - args = [ - ValueExpr(iterator.indices[shifted_dim], offset_node.dtype), - offset_node, - ] + args = [ValueExpr(iterator.indices[shifted_dim], offset_node.dtype), offset_node] internals = [f"{arg.value.data}_v" for arg in args] expr = f"{internals[0]} * {offset_provider.max_neighbors} + {internals[1]}" else: @@ -1288,10 +1244,7 @@ def _visit_shift(self, node: itir.FunCall) -> IteratorExpr | list[ValueExpr]: shifted_dim = self.offset_provider[offset_dim].value target_dim = shifted_dim - args = [ - ValueExpr(iterator.indices[shifted_dim], offset_node.dtype), - offset_node, - ] + args = [ValueExpr(iterator.indices[shifted_dim], offset_node.dtype), offset_node] internals = [f"{arg.value.data}_v" for arg in args] expr = f"{internals[0]} + {internals[1]}" diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/workflow.py b/src/gt4py/next/program_processors/runners/dace_iterator/workflow.py index 9999b367b6..537bda7922 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/workflow.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/workflow.py @@ -39,8 +39,7 @@ @dataclasses.dataclass(frozen=True) class DaCeTranslator( workflow.ChainableWorkflowMixin[ - stages.ProgramCall, - stages.ProgramSource[languages.SDFG, languages.LanguageSettings], + stages.ProgramCall, stages.ProgramSource[languages.SDFG, languages.LanguageSettings] ], step_types.TranslationStep[languages.SDFG, languages.LanguageSettings], ): @@ -54,14 +53,11 @@ class DaCeTranslator( def _language_settings(self) -> languages.LanguageSettings: return languages.LanguageSettings( - formatter_key="", - formatter_style="", - file_extension="sdfg", + formatter_key="", formatter_style="", file_extension="sdfg" ) def __call__( - self, - inp: stages.ProgramCall, + self, inp: stages.ProgramCall ) -> stages.ProgramSource[languages.SDFG, LanguageSettings]: """Generate DaCe SDFG file from the ITIR definition.""" program: itir.FencilDefinition = inp.program diff --git a/src/gt4py/next/program_processors/runners/double_roundtrip.py b/src/gt4py/next/program_processors/runners/double_roundtrip.py index e37fb65891..e6220ea879 100644 --- a/src/gt4py/next/program_processors/runners/double_roundtrip.py +++ b/src/gt4py/next/program_processors/runners/double_roundtrip.py @@ -20,7 +20,7 @@ backend = next_backend.Backend( executor=roundtrip.RoundtripExecutorFactory( - dispatch_backend=roundtrip.RoundtripExecutorFactory(), + dispatch_backend=roundtrip.RoundtripExecutorFactory() ), allocator=roundtrip.backend.allocator, ) diff --git a/src/gt4py/next/program_processors/runners/gtfn.py b/src/gt4py/next/program_processors/runners/gtfn.py index ad24b19dbf..49ae0582df 100644 --- a/src/gt4py/next/program_processors/runners/gtfn.py +++ b/src/gt4py/next/program_processors/runners/gtfn.py @@ -54,10 +54,7 @@ def decorated_program( ) -> None: converted_args = [convert_arg(arg) for arg in args] conn_args = extract_connectivity_args(offset_provider, device) - return inp( - *converted_args, - *conn_args, - ) + return inp(*converted_args, *conn_args) return decorated_program @@ -193,8 +190,7 @@ class Params: run_gtfn = GTFNBackendFactory() run_gtfn_imperative = GTFNBackendFactory( - name_postfix="_imperative", - otf_workflow__translation__use_imperative_backend=True, + name_postfix="_imperative", otf_workflow__translation__use_imperative_backend=True ) run_gtfn_cached = GTFNBackendFactory(cached=True) diff --git a/src/gt4py/next/program_processors/runners/roundtrip.py b/src/gt4py/next/program_processors/runners/roundtrip.py index 7ca88eab06..38714221fc 100644 --- a/src/gt4py/next/program_processors/runners/roundtrip.py +++ b/src/gt4py/next/program_processors/runners/roundtrip.py @@ -90,11 +90,7 @@ def visit_Temporary(self, node: gtmps_transform.Temporary, **kwargs: Any) -> str assert ( isinstance(node.domain, itir.FunCall) and isinstance(node.domain.fun, itir.SymRef) - and node.domain.fun.id - in ( - "cartesian_domain", - "unstructured_domain", - ) + and node.domain.fun.id in ("cartesian_domain", "unstructured_domain") ) assert all( isinstance(r, itir.FunCall) and r.fun == itir.SymRef(id="named_range") @@ -225,10 +221,7 @@ def execute_roundtrip( use_embedded=dispatch_backend is None, ) - new_kwargs: dict[str, Any] = { - "offset_provider": offset_provider, - "column_axis": column_axis, - } + new_kwargs: dict[str, Any] = {"offset_provider": offset_provider, "column_axis": column_axis} if dispatch_backend: new_kwargs["backend"] = dispatch_backend @@ -284,6 +277,5 @@ class Params: executor = RoundtripExecutorFactory(name="roundtrip") backend = next_backend.Backend( - executor=executor, - allocator=next_allocators.StandardCPUFieldBufferAllocator(), + executor=executor, allocator=next_allocators.StandardCPUFieldBufferAllocator() ) diff --git a/src/gt4py/next/type_system/type_info.py b/src/gt4py/next/type_system/type_info.py index 03702517de..b235e6f26d 100644 --- a/src/gt4py/next/type_system/type_info.py +++ b/src/gt4py/next/type_system/type_info.py @@ -90,21 +90,18 @@ def type_class(symbol_type: ts.TypeSpec) -> Type[ts.TypeSpec]: @typing.overload def primitive_constituents( - symbol_type: ts.TypeSpec, - with_path_arg: typing.Literal[False] = False, + symbol_type: ts.TypeSpec, with_path_arg: typing.Literal[False] = False ) -> XIterable[ts.TypeSpec]: ... @typing.overload def primitive_constituents( - symbol_type: ts.TypeSpec, - with_path_arg: typing.Literal[True], + symbol_type: ts.TypeSpec, with_path_arg: typing.Literal[True] ) -> XIterable[tuple[ts.TypeSpec, tuple[int, ...]]]: ... def primitive_constituents( - symbol_type: ts.TypeSpec, - with_path_arg: bool = False, + symbol_type: ts.TypeSpec, with_path_arg: bool = False ) -> XIterable[ts.TypeSpec] | XIterable[tuple[ts.TypeSpec, tuple[int, ...]]]: """ Return the primitive types contained in a composite type. @@ -223,10 +220,7 @@ def is_floating_point(symbol_type: ts.TypeSpec) -> bool: >>> is_floating_point(ts.FieldType(dims=[], dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT32))) True """ - return extract_dtype(symbol_type).kind in [ - ts.ScalarKind.FLOAT32, - ts.ScalarKind.FLOAT64, - ] + return extract_dtype(symbol_type).kind in [ts.ScalarKind.FLOAT32, ts.ScalarKind.FLOAT64] def is_integral(symbol_type: ts.TypeSpec) -> bool: @@ -242,10 +236,7 @@ def is_integral(symbol_type: ts.TypeSpec) -> bool: >>> is_integral(ts.FieldType(dims=[], dtype=ts.ScalarType(kind=ts.ScalarKind.INT32))) True """ - return extract_dtype(symbol_type).kind in [ - ts.ScalarKind.INT32, - ts.ScalarKind.INT64, - ] + return extract_dtype(symbol_type).kind in [ts.ScalarKind.INT32, ts.ScalarKind.INT64] def is_number(symbol_type: ts.TypeSpec) -> bool: @@ -494,28 +485,17 @@ def return_type( @return_type.register def return_type_func( - func_type: ts.FunctionType, - *, - with_args: list[ts.TypeSpec], - with_kwargs: dict[str, ts.TypeSpec], + func_type: ts.FunctionType, *, with_args: list[ts.TypeSpec], with_kwargs: dict[str, ts.TypeSpec] ) -> ts.TypeSpec: return func_type.returns @return_type.register def return_type_field( - field_type: ts.FieldType, - *, - with_args: list[ts.TypeSpec], - with_kwargs: dict[str, ts.TypeSpec], + field_type: ts.FieldType, *, with_args: list[ts.TypeSpec], with_kwargs: dict[str, ts.TypeSpec] ) -> ts.FieldType: try: - accepts_args( - field_type, - with_args=with_args, - with_kwargs=with_kwargs, - raise_exception=True, - ) + accepts_args(field_type, with_args=with_args, with_kwargs=with_kwargs, raise_exception=True) except ValueError as ex: raise ValueError("Could not deduce return type of invalid remap operation.") from ex @@ -626,8 +606,7 @@ def structural_function_signature_incompatibilities( missing_positional_args = [] for i, arg_type in zip( - range(len(func_type.pos_only_args), num_pos_params), - func_type.pos_or_kw_args.keys(), + range(len(func_type.pos_only_args), num_pos_params), func_type.pos_or_kw_args.keys() ): if args[i] is UNDEFINED_ARG: missing_positional_args.append(f"'{arg_type}'") @@ -699,9 +678,7 @@ def function_signature_incompatibilities_func( @function_signature_incompatibilities.register def function_signature_incompatibilities_field( - field_type: ts.FieldType, - args: list[ts.TypeSpec], - kwargs: dict[str, ts.TypeSpec], + field_type: ts.FieldType, args: list[ts.TypeSpec], kwargs: dict[str, ts.TypeSpec] ) -> Iterator[str]: if len(args) != 1: yield f"Function takes 1 argument, but {len(args)} were given." diff --git a/src/gt4py/storage/allocators.py b/src/gt4py/storage/allocators.py index 1c9525c7b1..5e8b42f459 100644 --- a/src/gt4py/storage/allocators.py +++ b/src/gt4py/storage/allocators.py @@ -50,9 +50,7 @@ _NDBuffer: TypeAlias = Union[ # TODO: add `xtyping.Buffer` once we update typing_extensions - xtyping.ArrayInterface, - xtyping.CUDAArrayInterface, - xtyping.DLPackBuffer, + xtyping.ArrayInterface, xtyping.CUDAArrayInterface, xtyping.DLPackBuffer ] #: Tuple of positive integers encoding a permutation of the dimensions, such that @@ -350,11 +348,7 @@ class NDArrayBufferAllocator(_BaseNDArrayBufferAllocator[core_defs.DeviceTypeT]) _device_type: core_defs.DeviceTypeT _array_ns: ValidNumPyLikeAllocationNS - def __init__( - self, - device_type: core_defs.DeviceTypeT, - array_ns: ValidNumPyLikeAllocationNS, - ): + def __init__(self, device_type: core_defs.DeviceTypeT, array_ns: ValidNumPyLikeAllocationNS): object.__setattr__(self, "_device_type", device_type) object.__setattr__(self, "_array_ns", array_ns) diff --git a/src/gt4py/storage/cartesian/utils.py b/src/gt4py/storage/cartesian/utils.py index 4e7ebb0c21..ce4b8c70bc 100644 --- a/src/gt4py/storage/cartesian/utils.py +++ b/src/gt4py/storage/cartesian/utils.py @@ -51,8 +51,7 @@ assert allocators.is_valid_nplike_allocation_ns(np) _CPUBufferAllocator = allocators.NDArrayBufferAllocator( - device_type=core_defs.DeviceType.CPU, - array_ns=np, + device_type=core_defs.DeviceType.CPU, array_ns=np ) _GPUBufferAllocator: Optional[allocators.NDArrayBufferAllocator] = None diff --git a/tests/cartesian_tests/integration_tests/feature_tests/test_call_interface.py b/tests/cartesian_tests/integration_tests/feature_tests/test_call_interface.py index b3af613fec..4230264320 100644 --- a/tests/cartesian_tests/integration_tests/feature_tests/test_call_interface.py +++ b/tests/cartesian_tests/integration_tests/feature_tests/test_call_interface.py @@ -352,9 +352,7 @@ class TestAxesMismatch: @pytest.fixture def sample_stencil(self): @gtscript.stencil(backend="numpy") - def _stencil( - field_out: gtscript.Field[gtscript.IJ, np.float64], - ): + def _stencil(field_out: gtscript.Field[gtscript.IJ, np.float64]): with computation(FORWARD), interval(...): field_out = 1.0 @@ -391,9 +389,7 @@ class TestDataDimensions: @pytest.fixture def sample_stencil(self): @gtscript.stencil(backend=self.backend) - def _stencil( - field_out: gtscript.Field[gtscript.IJK, (np.float64, (2,))], - ): + def _stencil(field_out: gtscript.Field[gtscript.IJK, (np.float64, (2,))]): with computation(FORWARD), interval(...): field_out[0, 0, 0][0] = 0.0 field_out[0, 0, 0][1] = 1.0 @@ -423,18 +419,10 @@ def calc_damp(outp: Field[float], inp: Field[K, float]): outp = inp outp = gt_storage.ones( - backend=backend, - aligned_index=(1, 1, 1), - shape=(4, 4, 4), - dtype=float, - dimensions="IJK", + backend=backend, aligned_index=(1, 1, 1), shape=(4, 4, 4), dtype=float, dimensions="IJK" ) inp = gt_storage.ones( - backend=backend, - aligned_index=(1,), - shape=(4,), - dtype=float, - dimensions="K", + backend=backend, aligned_index=(1,), shape=(4,), dtype=float, dimensions="K" ) origin = {"_all_": (1, 1, 1), "inp": (1,)} @@ -454,20 +442,12 @@ def calc_damp(outp: Field[float], inp: Field[K, float]): outp = inp outp = gt_storage.ones( - backend="numpy", - aligned_index=(1, 1, 1), - shape=(4, 4, 4), - dtype=float, - dimensions="KJI", + backend="numpy", aligned_index=(1, 1, 1), shape=(4, 4, 4), dtype=float, dimensions="KJI" ) outp_wrap = DimensionsWrapper(array=outp, dimensions="KJI") inp = gt_storage.from_array( - data=np.arange(4), - backend="numpy", - aligned_index=(1,), - dtype=float, - dimensions="K", + data=np.arange(4), backend="numpy", aligned_index=(1,), dtype=float, dimensions="K" ) calc_damp(outp_wrap, inp) diff --git a/tests/cartesian_tests/integration_tests/feature_tests/test_exec_info.py b/tests/cartesian_tests/integration_tests/feature_tests/test_exec_info.py index 5814daa495..b0e1aa810f 100644 --- a/tests/cartesian_tests/integration_tests/feature_tests/test_exec_info.py +++ b/tests/cartesian_tests/integration_tests/feature_tests/test_exec_info.py @@ -173,8 +173,7 @@ def subtest_stencil_info(self, exec_info, stencil_info, last_called_stencil=Fals assert "run_time" in stencil_info if last_called_stencil: assert np.isclose( - stencil_info["run_time"], - exec_info["run_end_time"] - exec_info["run_start_time"], + stencil_info["run_time"], exec_info["run_end_time"] - exec_info["run_start_time"] ) assert stencil_info["call_time"] > stencil_info["run_time"] assert "total_run_time" in stencil_info diff --git a/tests/cartesian_tests/integration_tests/multi_feature_tests/test_code_generation.py b/tests/cartesian_tests/integration_tests/multi_feature_tests/test_code_generation.py index 8e270362df..03974f32ac 100644 --- a/tests/cartesian_tests/integration_tests/multi_feature_tests/test_code_generation.py +++ b/tests/cartesian_tests/integration_tests/multi_feature_tests/test_code_generation.py @@ -413,10 +413,7 @@ def stencil( @pytest.mark.parametrize("backend", ALL_BACKENDS) def test_nested_while_loop(backend): @gtscript.stencil(backend=backend) - def stencil( - field_a: gtscript.Field[np.float_], - field_b: gtscript.Field[np.int_], - ): + def stencil(field_a: gtscript.Field[np.float_], field_b: gtscript.Field[np.int_]): with computation(PARALLEL), interval(...): while field_a < 1: add = 0 @@ -428,9 +425,7 @@ def stencil( @pytest.mark.parametrize("backend", ALL_BACKENDS) def test_mask_with_offset_written_in_conditional(backend): @gtscript.stencil(backend, externals={"mord": 5}) - def stencil( - outp: gtscript.Field[np.float_], - ): + def stencil(outp: gtscript.Field[np.float_]): with computation(PARALLEL), interval(...): cond = True if cond[0, -1, 0] or cond[0, 0, 0]: @@ -551,11 +546,7 @@ def k_to_ijk(outp: Field[np.float64], inp: Field[gtscript.K, np.float64]): data = np.arange(10, dtype=np.float64) inp = gt_storage.from_array( - data=data, - aligned_index=(0,), - dtype=np.float64, - dimensions="K", - backend=backend, + data=data, aligned_index=(0,), dtype=np.float64, dimensions="K", backend=backend ) outp = gt_storage.zeros( shape=(2, 2, 10), aligned_index=(0, 0, 0), dtype=np.float64, backend=backend diff --git a/tests/cartesian_tests/integration_tests/multi_feature_tests/test_dace_parsing.py b/tests/cartesian_tests/integration_tests/multi_feature_tests/test_dace_parsing.py index b4317fd9e4..5547903e5a 100644 --- a/tests/cartesian_tests/integration_tests/multi_feature_tests/test_dace_parsing.py +++ b/tests/cartesian_tests/integration_tests/multi_feature_tests/test_dace_parsing.py @@ -69,8 +69,7 @@ def tuple_st(min_value, max_value): @pytest.mark.parametrize( - "backend", - ["dace:cpu", pytest.param("dace:gpu", marks=[pytest.mark.requires_gpu])], + "backend", ["dace:cpu", pytest.param("dace:gpu", marks=[pytest.mark.requires_gpu])] ) def test_basic(decorator, backend): @decorator(backend=backend) @@ -255,8 +254,7 @@ def stencil( outp = inp # noqa: F841 [unused-variable] frozen_stencil = stencil.freeze( - domain=(3, 3, 10), - origin={"inp": (2, 2, 0), "outp": (2, 2, 0), "unused_field": (0, 0, 0)}, + domain=(3, 3, 10), origin={"inp": (2, 2, 0), "outp": (2, 2, 0), "unused_field": (0, 0, 0)} ) inp = OriginWrapper( @@ -419,10 +417,7 @@ def numpy_stencil(inp: gtscript.Field[np.float64], outp: gtscript.Field[np.float ) outp = OriginWrapper( array=gt_storage.zeros( - dtype=np.float64, - shape=(10, 10, 10), - aligned_index=(0, 0, 0), - backend="numpy", + dtype=np.float64, shape=(10, 10, 10), aligned_index=(0, 0, 0), backend="numpy" ), origin=(0, 0, 0), ) diff --git a/tests/cartesian_tests/integration_tests/multi_feature_tests/test_suites.py b/tests/cartesian_tests/integration_tests/multi_feature_tests/test_suites.py index eb94342495..0dff218e18 100644 --- a/tests/cartesian_tests/integration_tests/multi_feature_tests/test_suites.py +++ b/tests/cartesian_tests/integration_tests/multi_feature_tests/test_suites.py @@ -594,11 +594,7 @@ class TestNotSpecifiedTwoOptionalFields(TestTwoOptionalFields): class TestNon3DFields(gt_testing.StencilTestSuite): - dtypes = { - "field_in": np.float64, - "another_field": np.float64, - "field_out": np.float64, - } + dtypes = {"field_in": np.float64, "another_field": np.float64, "field_out": np.float64} domain_range = [(4, 10), (4, 10), (4, 10)] backends = ["gt:cpu_ifirst", "gt:cpu_kfirst", "gt:gpu", "dace:cpu", "dace:gpu"] symbols = { @@ -684,10 +680,7 @@ def validation(field_in, another_field, field_out, *, domain, origin): class TestReadOutsideKInterval1(gt_testing.StencilTestSuite): - dtypes = { - "field_in": np.float64, - "field_out": np.float64, - } + dtypes = {"field_in": np.float64, "field_out": np.float64} domain_range = [(4, 4), (4, 4), (4, 4)] backends = ALL_BACKENDS symbols = { @@ -710,10 +703,7 @@ def validation(field_in, field_out, *, domain, origin): class TestReadOutsideKInterval2(gt_testing.StencilTestSuite): - dtypes = { - "field_in": np.float64, - "field_out": np.float64, - } + dtypes = {"field_in": np.float64, "field_out": np.float64} domain_range = [(4, 4), (4, 4), (4, 4)] backends = ALL_BACKENDS symbols = { @@ -734,10 +724,7 @@ def validation(field_in, field_out, *, domain, origin): class TestReadOutsideKInterval3(gt_testing.StencilTestSuite): - dtypes = { - "field_in": np.float64, - "field_out": np.float64, - } + dtypes = {"field_in": np.float64, "field_out": np.float64} domain_range = [(4, 4), (4, 4), (4, 4)] backends = ALL_BACKENDS symbols = { @@ -777,11 +764,7 @@ def _skip_dace_cpu_gcc_error(backends): class TestVariableKRead(gt_testing.StencilTestSuite): - dtypes = { - "field_in": np.float32, - "field_out": np.float32, - "index": np.int32, - } + dtypes = {"field_in": np.float32, "field_out": np.float32, "index": np.int32} domain_range = [(2, 2), (2, 2), (2, 8)] backends = _skip_dace_cpu_gcc_error(ALL_BACKENDS) symbols = { @@ -803,11 +786,7 @@ def validation(field_in, field_out, index, *, domain, origin): class TestVariableKAndReadOutside(gt_testing.StencilTestSuite): - dtypes = { - "field_in": np.float64, - "field_out": np.float64, - "index": np.int32, - } + dtypes = {"field_in": np.float64, "field_out": np.float64, "index": np.int32} domain_range = [(2, 2), (2, 2), (2, 8)] backends = _skip_dace_cpu_gcc_error(ALL_BACKENDS) symbols = { @@ -834,10 +813,7 @@ def validation(field_in, field_out, index, *, domain, origin): class TestDiagonalKOffset(gt_testing.StencilTestSuite): - dtypes = { - "field_in": np.float64, - "field_out": np.float64, - } + dtypes = {"field_in": np.float64, "field_out": np.float64} domain_range = [(2, 2), (2, 2), (2, 8)] backends = ALL_BACKENDS symbols = { @@ -861,10 +837,7 @@ def validation(field_in, field_out, *, domain, origin): class TestHorizontalRegions(gt_testing.StencilTestSuite): - dtypes = { - "field_in": np.float32, - "field_out": np.float32, - } + dtypes = {"field_in": np.float32, "field_out": np.float32} domain_range = [(4, 4), (4, 4), (2, 2)] backends = ALL_BACKENDS symbols = { @@ -899,10 +872,7 @@ def validation(field_in, field_out, *, domain, origin): class TestHorizontalRegionsCorners(gt_testing.StencilTestSuite): - dtypes = { - "field_in": np.float32, - "field_out": np.float32, - } + dtypes = {"field_in": np.float32, "field_out": np.float32} domain_range = [(4, 4), (4, 4), (2, 2)] backends = ALL_BACKENDS symbols = { @@ -988,10 +958,7 @@ def validation(field_in, field_out, *, domain, origin): class TestMatrixAssignment(gt_testing.StencilTestSuite): - dtypes = { - "field_in": np.float32, - "field_out": np.float32, - } + dtypes = {"field_in": np.float32, "field_out": np.float32} domain_range = [(2, 2), (2, 2), (2, 2)] backends = ALL_BACKENDS symbols = { @@ -1132,10 +1099,7 @@ class TestMaskedMatmul(gt_testing.StencilTestSuite): backends = ALL_BACKENDS symbols = { "matrix": gt_testing.field( - in_range=(-10, 10), - axes="K", - boundary=[(0, 0), (0, 0), (0, 0)], - data_dims=(4, 6), + in_range=(-10, 10), axes="K", boundary=[(0, 0), (0, 0), (0, 0)], data_dims=(4, 6) ), "field_1": gt_testing.field( in_range=(-10, 10), axes="IJK", boundary=[(0, 0), (0, 0), (0, 0)], data_dims=(6,) diff --git a/tests/cartesian_tests/unit_tests/frontend_tests/defir_to_gtir_definition_setup.py b/tests/cartesian_tests/unit_tests/frontend_tests/defir_to_gtir_definition_setup.py index 8ced4374ac..57b20e376c 100644 --- a/tests/cartesian_tests/unit_tests/frontend_tests/defir_to_gtir_definition_setup.py +++ b/tests/cartesian_tests/unit_tests/frontend_tests/defir_to_gtir_definition_setup.py @@ -182,9 +182,7 @@ def build(self) -> ComputationBlock: end=AxisBound(level=LevelMarker.END, offset=self.end), ), iteration_order=self.order, - body=BlockStmt( - stmts=temp_decls + [stmt.build() for stmt in self.children], - ), + body=BlockStmt(stmts=temp_decls + [stmt.build() for stmt in self.children]), loc=self.loc, ) @@ -218,9 +216,7 @@ def value(self): if isinstance(self._value, str): value = TFieldRef(name=self._value, offset=self.offset) value.loc = Location( - line=self.loc.line, - column=self.loc.column + self.target.width + 3, - scope=self.loc.scope, + line=self.loc.line, column=self.loc.column + self.target.width + 3, scope=self.loc.scope ) value.parent = self return value @@ -267,11 +263,7 @@ def __init__( def build(self): if self.parent: self.loc.scope = self.parent.child_scope - return FieldRef( - name=self.name, - offset=self.offset, - loc=self.loc, - ) + return FieldRef(name=self.name, offset=self.offset, loc=self.loc) @property def height(self) -> int: @@ -287,24 +279,14 @@ def field_names(self) -> Set[str]: class TScalarLiteral(TObject): - def __init__( - self, - *, - value: Any, - loc: Location = None, - parent: TObject = None, - ): + def __init__(self, *, value: Any, loc: Location = None, parent: TObject = None): super().__init__(loc or Location(line=0, column=0), parent=parent) self.value = value def build(self): if self.parent: self.loc.scope = self.parent.child_scope - return ScalarLiteral( - value=self.value, - data_type=DataType.AUTO, - loc=self.loc, - ) + return ScalarLiteral(value=self.value, data_type=DataType.AUTO, loc=self.loc) @property def height(self) -> int: diff --git a/tests/cartesian_tests/unit_tests/frontend_tests/test_defir_to_gtir.py b/tests/cartesian_tests/unit_tests/frontend_tests/test_defir_to_gtir.py index 0e4e90b9b3..73a77b05df 100644 --- a/tests/cartesian_tests/unit_tests/frontend_tests/test_defir_to_gtir.py +++ b/tests/cartesian_tests/unit_tests/frontend_tests/test_defir_to_gtir.py @@ -52,7 +52,7 @@ def test_stencil_definition( TDefinition(name="definition", domain=ijk_domain, fields=["a", "b"]) .add_blocks( TComputationBlock(order=IterationOrder.PARALLEL).add_statements( - TAssign("a", "b", (0, 0, 0)), + TAssign("a", "b", (0, 0, 0)) ) ) .build() diff --git a/tests/cartesian_tests/unit_tests/frontend_tests/test_gtscript_frontend.py b/tests/cartesian_tests/unit_tests/frontend_tests/test_gtscript_frontend.py index 889fa0d145..fa07a53dcb 100644 --- a/tests/cartesian_tests/unit_tests/frontend_tests/test_gtscript_frontend.py +++ b/tests/cartesian_tests/unit_tests/frontend_tests/test_gtscript_frontend.py @@ -60,11 +60,7 @@ def parse_definition( original_annotations = gtscript._set_arg_dtypes(definition_func, dtypes=dtypes or {}) build_options = gt_definitions.BuildOptions( - name=name, - module=module, - rebuild=rebuild, - backend_opts=kwargs, - build_info=None, + name=name, module=module, rebuild=rebuild, backend_opts=kwargs, build_info=None ) gt_frontend.GTScriptFrontend.prepare_stencil_definition( @@ -1009,8 +1005,7 @@ def definition(inout_field: gtscript.Field[float]): compile_assert(inout_field[0, 0, 0] < 0) with pytest.raises( - gt_frontend.GTScriptSyntaxError, - match="Evaluation of compile_assert condition failed", + gt_frontend.GTScriptSyntaxError, match="Evaluation of compile_assert condition failed" ): parse_definition(definition, name=inspect.stack()[0][3], module=self.__class__.__name__) @@ -1085,11 +1080,7 @@ def definition( tmp[0, 0, 0][0] = field_in field_out = tmp[0, 0, 0][0] - parse_definition( - definition, - name=inspect.stack()[0][3], - module=self.__class__.__name__, - ) + parse_definition(definition, name=inspect.stack()[0][3], module=self.__class__.__name__) def test_typed_temp_missing(self): def definition( @@ -1320,10 +1311,7 @@ def definition_func( class TestBuiltinDTypes: @staticmethod - def literal_add_func( - in_field: gtscript.Field[float], - out_field: gtscript.Field["my_float"], - ): + def literal_add_func(in_field: gtscript.Field[float], out_field: gtscript.Field["my_float"]): with computation(PARALLEL), interval(...): out_field = in_field + 42.0 diff --git a/tests/cartesian_tests/unit_tests/test_cli.py b/tests/cartesian_tests/unit_tests/test_cli.py index 12959a88cd..142e79f52b 100644 --- a/tests/cartesian_tests/unit_tests/test_cli.py +++ b/tests/cartesian_tests/unit_tests/test_cli.py @@ -35,10 +35,8 @@ def clirunner(): @pytest.fixture( params=[ *ALL_BACKENDS, # gtc backends require definition ir as input, for now we skip the tests - pytest.param( - "nocli", - ), - ], + pytest.param("nocli"), + ] ) def backend_name(request, nocli_backend): """Parametrize by backend name.""" diff --git a/tests/cartesian_tests/unit_tests/test_gtc/test_common.py b/tests/cartesian_tests/unit_tests/test_gtc/test_common.py index 1a51cad736..437b6ad999 100644 --- a/tests/cartesian_tests/unit_tests/test_gtc/test_common.py +++ b/tests/cartesian_tests/unit_tests/test_gtc/test_common.py @@ -140,10 +140,7 @@ class AssignStmt(Stmt, common.AssignStmt[DummyExpr, Expr]): DataType.BOOL, ), ( - UnaryOp( - expr=DummyExpr(dtype=ARITHMETIC_TYPE), - op=A_ARITHMETIC_UNARY_OPERATOR, - ), + UnaryOp(expr=DummyExpr(dtype=ARITHMETIC_TYPE), op=A_ARITHMETIC_UNARY_OPERATOR), ARITHMETIC_TYPE, ), ], @@ -157,9 +154,7 @@ def test_dtype_propagation(node, expected): [ ( lambda: TernaryOp( - cond=DummyExpr(dtype=ARITHMETIC_TYPE), - true_expr=DummyExpr(), - false_expr=DummyExpr(), + cond=DummyExpr(dtype=ARITHMETIC_TYPE), true_expr=DummyExpr(), false_expr=DummyExpr() ), r"Condition.*must be bool.*", ValueError, @@ -293,7 +288,7 @@ class SymbolTableRootNode(eve.Node, eve.ValidatedSymbolTableTrait): SymbolChildNode(name="foo"), SymbolRefChildNode(name="foo"), SymbolRefChildNode(name="foo2"), - ], + ] ), lambda: SymbolTableRootNode( nodes=[ @@ -309,16 +304,14 @@ def test_symbolref_validation_for_invalid_tree(tree_with_missing_symbol): def test_symbolref_validation_for_valid_tree(): - SymbolTableRootNode( - nodes=[SymbolChildNode(name="foo"), SymbolRefChildNode(name="foo")], - ) + SymbolTableRootNode(nodes=[SymbolChildNode(name="foo"), SymbolRefChildNode(name="foo")]) ( SymbolTableRootNode( # noqa: B018 nodes=[ SymbolChildNode(name="foo"), SymbolRefChildNode(name="foo"), SymbolRefChildNode(name="foo"), - ], + ] ), ) SymbolTableRootNode( @@ -368,9 +361,9 @@ def construct_dims_assignment(dimensions: Tuple[bool, bool, bool], direction: co MultiDimLoop( loop_order=direction, assigns=[ - AssignStmt(left=MultiDimRef(name=out_name), right=MultiDimRef(name=in_name)), + AssignStmt(left=MultiDimRef(name=out_name), right=MultiDimRef(name=in_name)) ], - ), + ) ], ) diff --git a/tests/cartesian_tests/unit_tests/test_gtc/test_cuir_compilation.py b/tests/cartesian_tests/unit_tests/test_gtc/test_cuir_compilation.py index 634748ae99..e787b36ea6 100644 --- a/tests/cartesian_tests/unit_tests/test_gtc/test_cuir_compilation.py +++ b/tests/cartesian_tests/unit_tests/test_gtc/test_cuir_compilation.py @@ -33,12 +33,7 @@ def build_gridtools_test(tmp_path: Path, code: str): opts = pyext_builder.get_gt_pyext_build_opts(uses_cuda=True) assert isinstance(opts["include_dirs"], list) opts["include_dirs"].append(gridtools_cpp.get_include_dir()) - ext_module = setuptools.Extension( - "test", - [str(tmp_src.absolute())], - language="c++", - **opts, - ) + ext_module = setuptools.Extension("test", [str(tmp_src.absolute())], language="c++", **opts) args = [ "build_ext", "--build-temp=" + str(tmp_src.parent), @@ -47,9 +42,7 @@ def build_gridtools_test(tmp_path: Path, code: str): ] setuptools.setup( name="test", - ext_modules=[ - ext_module, - ], + ext_modules=[ext_module], script_args=args, cmdclass={"build_ext": pyext_builder.CUDABuildExtension}, ) diff --git a/tests/cartesian_tests/unit_tests/test_gtc/test_cuir_kernel_fusion.py b/tests/cartesian_tests/unit_tests/test_gtc/test_cuir_kernel_fusion.py index e3a21cdd41..2d0a5eaffe 100644 --- a/tests/cartesian_tests/unit_tests/test_gtc/test_cuir_kernel_fusion.py +++ b/tests/cartesian_tests/unit_tests/test_gtc/test_cuir_kernel_fusion.py @@ -55,7 +55,7 @@ def test_forward_backward_fusion(): ), ) ), - ], + ] ) transformed = kernel_fusion.FuseKernels().visit(testee) assert len(transformed.kernels) == 1 @@ -95,7 +95,7 @@ def test_no_fusion_with_parallel_offsets(): ), ) ), - ], + ] ) transformed = kernel_fusion.FuseKernels().visit(testee) assert len(transformed.kernels) == 2 @@ -112,7 +112,7 @@ def test_no_fusion_with_parallel_offsets(): left__name="out", right__name="tmp", right__offset__k=1 ) ), - ], + ] ) transformed = kernel_fusion.FuseKernels().visit(testee) assert len(transformed.kernels) == 2 diff --git a/tests/cartesian_tests/unit_tests/test_gtc/test_gtcpp_compilation.py b/tests/cartesian_tests/unit_tests/test_gtc/test_gtcpp_compilation.py index 7c9b1626e1..61fa81bea5 100644 --- a/tests/cartesian_tests/unit_tests/test_gtc/test_gtcpp_compilation.py +++ b/tests/cartesian_tests/unit_tests/test_gtc/test_gtcpp_compilation.py @@ -55,22 +55,13 @@ def build_gridtools_test(tmp_path: Path, code: str): "--build-lib=" + str(tmp_src.parent), "--force", ] - setuptools.setup( - name="test", - ext_modules=[ - ext_module, - ], - script_args=args, - ) + setuptools.setup(name="test", ext_modules=[ext_module], script_args=args) def make_compilation_input_and_expected(): return [ (ProgramFactory(name="test"), r"auto test"), - ( - ProgramFactory(functors__0__name="fun"), - r"struct fun", - ), + (ProgramFactory(functors__0__name="fun"), r"struct fun"), ( ProgramFactory( functors__0__applies=[], @@ -88,10 +79,7 @@ def make_compilation_input_and_expected(): ), r"inout_accessor<0, extent<1,\s*2,\s*-3,\s*-4,\s*10,\s*-10>", ), - ( - ProgramFactory(), - r"void\s*apply\(", - ), + (ProgramFactory(), r"void\s*apply\("), (ProgramFactory(parameters=[FieldDeclFactory(name="my_param")]), r"my_param"), ( ProgramFactory( @@ -141,10 +129,7 @@ def _embed_apply_method_in_program(apply_method: GTApplyMethod): GTApplyMethodFactory(body__0__left__name="foo", body__0__right__name="bar"), r"foo.*=.*bar", ), - ( - GTApplyMethodFactory(body__0=IfStmtFactory()), - r"if", - ), + (GTApplyMethodFactory(body__0=IfStmtFactory()), r"if"), ], ) def test_apply_method_compilation_succeeds(tmp_path, apply_method, expected_regex): diff --git a/tests/cartesian_tests/unit_tests/test_gtc/test_gtir.py b/tests/cartesian_tests/unit_tests/test_gtc/test_gtir.py index 34bce743f3..eaa5e06bff 100644 --- a/tests/cartesian_tests/unit_tests/test_gtc/test_gtir.py +++ b/tests/cartesian_tests/unit_tests/test_gtc/test_gtir.py @@ -94,11 +94,7 @@ def test_symbolref_without_decl(): lambda: VerticalLoopFactory( body=[ ParAssignStmtFactory( - right=BinaryOpFactory( - left__name="foo", - right__name="foo", - right__offset__i=1, - ) + right=BinaryOpFactory(left__name="foo", right__name="foo", right__offset__i=1) ), ParAssignStmtFactory(left__name="foo"), ] @@ -106,10 +102,7 @@ def test_symbolref_without_decl(): # offset access in condition lambda: VerticalLoopFactory( body=[ - FieldIfStmtFactory( - cond__name="foo", - cond__offset__i=1, - ), + FieldIfStmtFactory(cond__name="foo", cond__offset__i=1), ParAssignStmtFactory(left__name="foo"), ] ), @@ -157,10 +150,7 @@ def test_while_with_accumulated_extents(): ): WhileFactory( cond=BinaryOpFactory( - left__name="a", - right__name="b", - op=ComparisonOperator.LT, - dtype=DataType.BOOL, + left__name="a", right__name="b", op=ComparisonOperator.LT, dtype=DataType.BOOL ), body=[ ParAssignStmtFactory(left__name="a", right__name="b", right__offset__i=1), diff --git a/tests/cartesian_tests/unit_tests/test_gtc/test_gtir_dtype_resolver.py b/tests/cartesian_tests/unit_tests/test_gtc/test_gtir_dtype_resolver.py index b038581eaf..a193308f6a 100644 --- a/tests/cartesian_tests/unit_tests/test_gtc/test_gtir_dtype_resolver.py +++ b/tests/cartesian_tests/unit_tests/test_gtc/test_gtir_dtype_resolver.py @@ -38,11 +38,7 @@ def test_propagate_dtype_to_FieldAccess(): name = "foo" - decl = FieldDecl( - name=name, - dtype=A_ARITHMETIC_TYPE, - dimensions=(True, True, True), - ) + decl = FieldDecl(name=name, dtype=A_ARITHMETIC_TYPE, dimensions=(True, True, True)) testee = FieldAccessFactory(name=name) @@ -85,10 +81,7 @@ def test_resolve_dtype_to_FieldAccess(): right__dtype=DataType.AUTO, ), ) - resolve_dtype_and_validate( - testee, - {"field": A_ARITHMETIC_TYPE}, - ) + resolve_dtype_and_validate(testee, {"field": A_ARITHMETIC_TYPE}) def test_resolve_dtype_to_FieldAccess_variable(): diff --git a/tests/cartesian_tests/unit_tests/test_gtc/test_gtir_to_oir.py b/tests/cartesian_tests/unit_tests/test_gtc/test_gtir_to_oir.py index ff2e4dae31..3cd7697182 100644 --- a/tests/cartesian_tests/unit_tests/test_gtc/test_gtir_to_oir.py +++ b/tests/cartesian_tests/unit_tests/test_gtc/test_gtir_to_oir.py @@ -110,11 +110,7 @@ def test_indirect_read_with_offset_and_write(): ParAssignStmtFactory(right__name="tmp"), ], ), - VerticalLoopFactory( - body=[ - ParAssignStmtFactory(left__name="foo"), - ], - ), + VerticalLoopFactory(body=[ParAssignStmtFactory(left__name="foo")]), ] ) diff --git a/tests/cartesian_tests/unit_tests/test_gtc/test_gtir_upcaster.py b/tests/cartesian_tests/unit_tests/test_gtc/test_gtir_upcaster.py index 53d2ca65ec..ef4d413e62 100644 --- a/tests/cartesian_tests/unit_tests/test_gtc/test_gtir_upcaster.py +++ b/tests/cartesian_tests/unit_tests/test_gtc/test_gtir_upcaster.py @@ -99,11 +99,7 @@ def test_upcast_ParAssignStmt(): def test_upcast_TernaryOp(): - testee = TernaryOp( - cond=A_BOOL_LITERAL, - true_expr=A_INT64_LITERAL, - false_expr=A_FLOAT64_LITERAL, - ) + testee = TernaryOp(cond=A_BOOL_LITERAL, true_expr=A_INT64_LITERAL, false_expr=A_FLOAT64_LITERAL) upcast_and_validate(testee, [Cast(dtype=DataType.FLOAT64, expr=A_INT64_LITERAL)]) diff --git a/tests/cartesian_tests/unit_tests/test_gtc/test_npir_codegen.py b/tests/cartesian_tests/unit_tests/test_gtc/test_npir_codegen.py index de3269ff91..b808fc107e 100644 --- a/tests/cartesian_tests/unit_tests/test_gtc/test_npir_codegen.py +++ b/tests/cartesian_tests/unit_tests/test_gtc/test_npir_codegen.py @@ -165,10 +165,7 @@ def test_native_function() -> None: result = NpirCodegen().visit( NativeFuncCallFactory( func=common.NativeFunction.MIN, - args=[ - FieldSliceFactory(name="a"), - ParamAccessFactory(name="p"), - ], + args=[FieldSliceFactory(name="a"), ParamAccessFactory(name="p")], ), is_serial=False, ) @@ -226,10 +223,7 @@ def test_vector_arithmetic() -> None: def test_vector_unary_op() -> None: result = NpirCodegen().visit( - npir.VectorUnaryOp( - expr=FieldSliceFactory(name="a"), - op=common.UnaryOperator.NEG, - ), + npir.VectorUnaryOp(expr=FieldSliceFactory(name="a"), op=common.UnaryOperator.NEG), is_serial=False, ) assert result == "(-(a[i:I, j:J, k:K]))" @@ -277,11 +271,7 @@ def test_vertical_pass_seq() -> None: def test_vertical_pass_par() -> None: result = NpirCodegen().visit(VerticalPassFactory(direction=common.LoopOrder.PARALLEL)) print(result) - match = re.match( - (r"(#.*?\n)?" r"k, K = _dk_, _dK_\n"), - result, - re.MULTILINE, - ) + match = re.match((r"(#.*?\n)?" r"k, K = _dk_, _dK_\n"), result, re.MULTILINE) assert match @@ -333,13 +323,7 @@ def test_full_computation_valid(tmp_path) -> None: a = np.zeros((10, 10, 10)) b = np.ones_like(a) * 3 p = 2 - mod.run( - a=a, - b=b, - p=p, - _domain_=(8, 5, 9), - _origin_={"a": (1, 1, 0), "b": (0, 0, 0)}, - ) + mod.run(a=a, b=b, p=p, _domain_=(8, 5, 9), _origin_={"a": (1, 1, 0), "b": (0, 0, 0)}) assert (a[1:9, 1:6, 0:9] == 5).all() @@ -359,7 +343,7 @@ def test_variable_read_outside_bounds(tmp_path) -> None: k=FieldSliceFactory(name="index", dtype=common.DataType.INT32) ), ), - ), + ) ) result = NpirCodegen().visit(computation) diff --git a/tests/cartesian_tests/unit_tests/test_gtc/test_oir.py b/tests/cartesian_tests/unit_tests/test_gtc/test_oir.py index b82130e6cc..9b1ca00131 100644 --- a/tests/cartesian_tests/unit_tests/test_gtc/test_oir.py +++ b/tests/cartesian_tests/unit_tests/test_gtc/test_oir.py @@ -69,10 +69,7 @@ def test_eq_true(self, lhs, rhs): assert isinstance(res2, bool) assert res2 - @pytest.mark.parametrize( - ["lhs", "rhs"], - LESS_AXISBOUNDS + GREATER_AXISBOUNDS, - ) + @pytest.mark.parametrize(["lhs", "rhs"], LESS_AXISBOUNDS + GREATER_AXISBOUNDS) def test_eq_false(self, lhs, rhs): res1 = lhs == rhs assert isinstance(res1, bool) @@ -88,10 +85,7 @@ def test_lt_true(self, lhs, rhs): assert isinstance(res, bool) assert res - @pytest.mark.parametrize( - ["lhs", "rhs"], - GREATER_AXISBOUNDS + EQUAL_AXISBOUNDS, - ) + @pytest.mark.parametrize(["lhs", "rhs"], GREATER_AXISBOUNDS + EQUAL_AXISBOUNDS) def test_lt_false(self, lhs, rhs): res = lhs < rhs assert isinstance(res, bool) diff --git a/tests/cartesian_tests/unit_tests/test_gtc/test_oir_access_kinds.py b/tests/cartesian_tests/unit_tests/test_gtc/test_oir_access_kinds.py index dfce0e1ac0..0ba9aa5bd3 100644 --- a/tests/cartesian_tests/unit_tests/test_gtc/test_oir_access_kinds.py +++ b/tests/cartesian_tests/unit_tests/test_gtc/test_oir_access_kinds.py @@ -35,7 +35,7 @@ def test_access_readwrite(): vertical_loops__0__sections__0__horizontal_executions__0__body=[ AssignStmtFactory(left__name="output_field", right__name="inout_field"), AssignStmtFactory(left__name="inout_field", right__name="other_field"), - ], + ] ) access = compute_access_kinds(testee) @@ -48,7 +48,7 @@ def test_access_write_only(): vertical_loops__0__sections__0__horizontal_executions__0__body=[ AssignStmtFactory(left__name="inout_field", right__name="other_field"), AssignStmtFactory(left__name="output_field", right__name="inout_field"), - ], + ] ) access = compute_access_kinds(testee) diff --git a/tests/cartesian_tests/unit_tests/test_gtc/test_oir_to_gtcpp.py b/tests/cartesian_tests/unit_tests/test_gtc/test_oir_to_gtcpp.py index 585c2932cd..76bdf12ad0 100644 --- a/tests/cartesian_tests/unit_tests/test_gtc/test_oir_to_gtcpp.py +++ b/tests/cartesian_tests/unit_tests/test_gtc/test_oir_to_gtcpp.py @@ -37,29 +37,25 @@ def test_horizontal_mask(): AssignStmtFactory(left__name=out_name, right__name=in_name), HorizontalRestrictionFactory( mask=HorizontalMask( - i=HorizontalInterval.at_endpt(LevelMarker.START, 0), - j=HorizontalInterval.full(), + i=HorizontalInterval.at_endpt(LevelMarker.START, 0), j=HorizontalInterval.full() ), body=[AssignStmtFactory(left__name=out_name, right=LiteralFactory())], ), HorizontalRestrictionFactory( mask=HorizontalMask( - i=HorizontalInterval.at_endpt(LevelMarker.END, 0), - j=HorizontalInterval.full(), + i=HorizontalInterval.at_endpt(LevelMarker.END, 0), j=HorizontalInterval.full() ), body=[AssignStmtFactory(left__name=out_name, right=LiteralFactory())], ), HorizontalRestrictionFactory( mask=HorizontalMask( - i=HorizontalInterval.full(), - j=HorizontalInterval.at_endpt(LevelMarker.START, 0), + i=HorizontalInterval.full(), j=HorizontalInterval.at_endpt(LevelMarker.START, 0) ), body=[AssignStmtFactory(left__name=out_name, right=LiteralFactory())], ), HorizontalRestrictionFactory( mask=HorizontalMask( - i=HorizontalInterval.full(), - j=HorizontalInterval.at_endpt(LevelMarker.END, 0), + i=HorizontalInterval.full(), j=HorizontalInterval.at_endpt(LevelMarker.END, 0) ), body=[AssignStmtFactory(left__name=out_name, right=LiteralFactory())], ), @@ -91,7 +87,7 @@ def test_variable_offset_accessor(): right__offset=VariableKOffsetFactory(k__name=index_name), ) ], - ), + ) ) gtcpp_program = OIRToGTCpp().visit(oir_stencil) diff --git a/tests/cartesian_tests/unit_tests/test_gtc/test_oir_to_npir.py b/tests/cartesian_tests/unit_tests/test_gtc/test_oir_to_npir.py index 53332dd601..4d611c8268 100644 --- a/tests/cartesian_tests/unit_tests/test_gtc/test_oir_to_npir.py +++ b/tests/cartesian_tests/unit_tests/test_gtc/test_oir_to_npir.py @@ -41,14 +41,8 @@ def test_stencil_to_computation() -> None: stencil = StencilFactory( name="stencil", params=[ - FieldDeclFactory( - name="a", - dtype=common.DataType.FLOAT64, - ), - oir.ScalarDecl( - name="b", - dtype=common.DataType.INT32, - ), + FieldDeclFactory(name="a", dtype=common.DataType.FLOAT64), + oir.ScalarDecl(name="b", dtype=common.DataType.INT32), ], vertical_loops__0__sections__0__horizontal_executions__0__body=[ AssignStmtFactory( @@ -58,9 +52,7 @@ def test_stencil_to_computation() -> None: ) computation = OirToNpir().visit(stencil) - assert set(d.name for d in computation.api_field_decls) == { - "a", - } + assert set(d.name for d in computation.api_field_decls) == {"a"} assert set(computation.arguments) == {"a", "b"} assert len(computation.vertical_passes) == 1 diff --git a/tests/cartesian_tests/unit_tests/test_gtc/test_passes/test_oir_optimizations/test_caches.py b/tests/cartesian_tests/unit_tests/test_gtc/test_passes/test_oir_optimizations/test_caches.py index 55f1d157fd..9ceb995b02 100644 --- a/tests/cartesian_tests/unit_tests/test_gtc/test_passes/test_oir_optimizations/test_caches.py +++ b/tests/cartesian_tests/unit_tests/test_gtc/test_passes/test_oir_optimizations/test_caches.py @@ -89,27 +89,13 @@ def test_k_cache_detection_basic(): testee = VerticalLoopFactory( loop_order=LoopOrder.FORWARD, sections__0__horizontal_executions__0__body=[ + AssignStmtFactory(left__name="foo", right__name="foo", right__offset__k=1), + AssignStmtFactory(left__name="bar", right__name="foo", right__offset__k=-1), AssignStmtFactory( - left__name="foo", - right__name="foo", - right__offset__k=1, + left__name="baz", right__name="baz", right__offset__i=1, right__offset__k=1 ), AssignStmtFactory( - left__name="bar", - right__name="foo", - right__offset__k=-1, - ), - AssignStmtFactory( - left__name="baz", - right__name="baz", - right__offset__i=1, - right__offset__k=1, - ), - AssignStmtFactory( - left__name="foo", - right__name="baz", - right__offset__j=1, - right__offset__k=-1, + left__name="foo", right__name="baz", right__offset__j=1, right__offset__k=-1 ), ], ) @@ -191,13 +177,11 @@ def test_prune_k_cache_fills_forward_with_reads_outside_interval(): loop_order=LoopOrder.FORWARD, sections__0=VerticalLoopSectionFactory( horizontal_executions__0__body=[ - AssignStmtFactory(left__name="foo", right__name="foo", right__offset__k=-1), + AssignStmtFactory(left__name="foo", right__name="foo", right__offset__k=-1) ], interval__start=AxisBound.from_start(1), ), - caches=[ - KCacheFactory(name="foo", fill=True), - ], + caches=[KCacheFactory(name="foo", fill=True)], ) transformed = PruneKCacheFills().visit(testee) cache_dict = {c.name: c for c in transformed.caches} @@ -284,10 +268,7 @@ def test_prune_k_cache_flushes(): ], ), ], - declarations=[ - TemporaryFactory(name="tmp1"), - TemporaryFactory(name="tmp2"), - ], + declarations=[TemporaryFactory(name="tmp1"), TemporaryFactory(name="tmp2")], ) transformed = PruneKCacheFlushes().visit(testee) cache_dict = {c.name: c for c in transformed.vertical_loops[0].caches} @@ -667,10 +648,7 @@ def test_fill_flush_to_local_k_caches_basic_forward(): VerticalLoopFactory( loop_order=LoopOrder.FORWARD, sections__0__horizontal_executions__0__body=[ - AssignStmtFactory( - left__name="foo", - right__name="foo", - ), + AssignStmtFactory(left__name="foo", right__name="foo") ], caches=[KCacheFactory(name="foo", fill=True, flush=True)], ) diff --git a/tests/cartesian_tests/unit_tests/test_gtc/test_passes/test_oir_optimizations/test_horizontal_execution_merging.py b/tests/cartesian_tests/unit_tests/test_gtc/test_passes/test_oir_optimizations/test_horizontal_execution_merging.py index 070f724cc6..e764dedf09 100644 --- a/tests/cartesian_tests/unit_tests/test_gtc/test_passes/test_oir_optimizations/test_horizontal_execution_merging.py +++ b/tests/cartesian_tests/unit_tests/test_gtc/test_passes/test_oir_optimizations/test_horizontal_execution_merging.py @@ -127,7 +127,7 @@ def test_horiz_exec_merging_complexity(): + [ HorizontalExecutionFactory( body=[AssignStmtFactory(left__name="output", right__name=f"tmp{n}")] - ), + ) ], declarations=[TemporaryFactory(name=f"tmp{i}") for i in range(n)], ) @@ -317,8 +317,7 @@ def test_on_the_fly_merging_localscalars(): HorizontalExecutionFactory( body=[ AssignStmtFactory( - left=ScalarAccessFactory(name="scalar_tmp"), - right__name="in", + left=ScalarAccessFactory(name="scalar_tmp"), right__name="in" ), AssignStmtFactory( left__name="tmp", right=ScalarAccessFactory(name="scalar_tmp") diff --git a/tests/cartesian_tests/unit_tests/test_gtc/test_passes/test_oir_optimizations/test_inlining.py b/tests/cartesian_tests/unit_tests/test_gtc/test_passes/test_oir_optimizations/test_inlining.py index ba400f0d98..2df0512ed8 100644 --- a/tests/cartesian_tests/unit_tests/test_gtc/test_passes/test_oir_optimizations/test_inlining.py +++ b/tests/cartesian_tests/unit_tests/test_gtc/test_passes/test_oir_optimizations/test_inlining.py @@ -43,11 +43,7 @@ def mask_cond() -> BinaryOp: @pytest.fixture def mask_assign(mask_cond) -> AssignStmtFactory: - return AssignStmtFactory( - left__name="mask_f", - left__dtype=DataType.BOOL, - right=mask_cond, - ) + return AssignStmtFactory(left__name="mask_f", left__dtype=DataType.BOOL, right=mask_cond) def test_mask_inlining(mask_assign): @@ -74,9 +70,7 @@ def test_mask_inlining(mask_assign): ], caches=[IJCacheFactory(name=mask_name)], ), - declarations=[ - TemporaryFactory(name=mask_name, dtype=DataType.BOOL), - ], + declarations=[TemporaryFactory(name=mask_name, dtype=DataType.BOOL)], ) pre_section = pre_oir.vertical_loops[0].sections[0] @@ -115,8 +109,7 @@ def test_mask_no_inlining(mask_assign, mask_cond): mask=FieldAccessFactory(name=mask_name, dtype=DataType.BOOL), body=[ AssignStmtFactory( - left__name=cond_name, - right=LiteralFactory(), + left__name=cond_name, right=LiteralFactory() ) ], ) @@ -127,9 +120,7 @@ def test_mask_no_inlining(mask_assign, mask_cond): ], caches=[IJCacheFactory(name=mask_name)], ), - declarations=[ - TemporaryFactory(name=mask_name, dtype=DataType.BOOL), - ], + declarations=[TemporaryFactory(name=mask_name, dtype=DataType.BOOL)], ) pre_section = pre_oir.vertical_loops[0].sections[0] diff --git a/tests/cartesian_tests/unit_tests/test_gtc/test_passes/test_oir_optimizations/test_pruning.py b/tests/cartesian_tests/unit_tests/test_gtc/test_passes/test_oir_optimizations/test_pruning.py index 9644ece9e3..7a12c5fb7f 100644 --- a/tests/cartesian_tests/unit_tests/test_gtc/test_passes/test_oir_optimizations/test_pruning.py +++ b/tests/cartesian_tests/unit_tests/test_gtc/test_passes/test_oir_optimizations/test_pruning.py @@ -56,7 +56,7 @@ def test_no_field_access_pruning(): ) ], declarations=[LocalScalarFactory(name="bar")], - ), + ) ] ), ] @@ -76,8 +76,8 @@ def test_no_field_write_access_pruning(): AssignStmtFactory( left=FieldAccessFactory(name="foo"), right=LiteralFactory() ) - ], - ), + ] + ) ] ), VerticalLoopFactory( @@ -90,10 +90,10 @@ def test_no_field_write_access_pruning(): ) ], declarations=[LocalScalarFactory(name="bar")], - ), + ) ] ), - ], + ] ) transformed = NoFieldAccessPruning().visit(testee) assert len(transformed.vertical_loops) == 1 @@ -106,9 +106,7 @@ def test_unreachable_stmt_pruning(): testee = StencilFactory( vertical_loops__0__sections__0__horizontal_executions=[ HorizontalExecutionFactory( - body=[ - AssignStmtFactory(left__name=out_name, right__name=in_name), - ] + body=[AssignStmtFactory(left__name=out_name, right__name=in_name)] ), HorizontalExecutionFactory( body=[ diff --git a/tests/cartesian_tests/unit_tests/test_gtc/test_passes/test_oir_optimizations/test_temporaries.py b/tests/cartesian_tests/unit_tests/test_gtc/test_passes/test_oir_optimizations/test_temporaries.py index 5cf3cd3b84..7f00226298 100644 --- a/tests/cartesian_tests/unit_tests/test_gtc/test_passes/test_oir_optimizations/test_temporaries.py +++ b/tests/cartesian_tests/unit_tests/test_gtc/test_passes/test_oir_optimizations/test_temporaries.py @@ -46,10 +46,7 @@ def test_local_temporaries_to_scalars_multiexec(): testee = StencilFactory( vertical_loops__0__sections__0__horizontal_executions=[ HorizontalExecutionFactory( - body=[ - AssignStmtFactory(left__name="tmp"), - AssignStmtFactory(right__name="tmp"), - ] + body=[AssignStmtFactory(left__name="tmp"), AssignStmtFactory(right__name="tmp")] ), HorizontalExecutionFactory(body=[AssignStmtFactory(right__name="tmp")]), ], diff --git a/tests/cartesian_tests/unit_tests/test_gtc/test_passes/test_oir_optimizations/test_utils.py b/tests/cartesian_tests/unit_tests/test_gtc/test_passes/test_oir_optimizations/test_utils.py index 7d6610b3ea..558144bb10 100644 --- a/tests/cartesian_tests/unit_tests/test_gtc/test_passes/test_oir_optimizations/test_utils.py +++ b/tests/cartesian_tests/unit_tests/test_gtc/test_passes/test_oir_optimizations/test_utils.py @@ -53,7 +53,7 @@ def test_access_collector(): body=[ AssignStmtFactory( left__name="baz", right__name="tmp", right__offset__j=1 - ), + ) ], mask=FieldAccessFactory( name="mask", @@ -63,7 +63,7 @@ def test_access_collector(): offset__k=1, ), ) - ], + ] ), ], declarations=[TemporaryFactory(name="tmp")], @@ -188,8 +188,7 @@ def test_access_overlap_along_axis(): ), ( common.HorizontalMask( - i=common.HorizontalInterval.full(), - j=common.HorizontalInterval.full(), + i=common.HorizontalInterval.full(), j=common.HorizontalInterval.full() ), -1, ((-1, 0), (0, 0)), @@ -211,7 +210,7 @@ def test_stencil_extents_region(mask, offset, access_extent): left__name="tmp", right__name="input", right__offset__i=offset ) ], - ), + ) ] ), HorizontalExecutionFactory( diff --git a/tests/eve_tests/definitions.py b/tests/eve_tests/definitions.py index f4f0232ae4..f32885fa08 100644 --- a/tests/eve_tests/definitions.py +++ b/tests/eve_tests/definitions.py @@ -370,8 +370,7 @@ def make_simple_node_with_collections(*, fixed: bool = False) -> SimpleNodeWithC def make_simple_node_with_abstract_collections( - *, - fixed: bool = False, + *, fixed: bool = False ) -> SimpleNodeWithAbstractCollections: int_value = make_int_value(fixed=fixed) int_sequence = make_collection_value(int, collection_type=tuple, length=3) @@ -386,10 +385,7 @@ def make_simple_node_with_abstract_collections( ) -def make_simple_node_with_symbol_name( - *, - fixed: bool = False, -) -> SimpleNodeWithSymbolName: +def make_simple_node_with_symbol_name(*, fixed: bool = False) -> SimpleNodeWithSymbolName: int_value = make_int_value(fixed=fixed) name = make_str_value(fixed=fixed) @@ -397,8 +393,7 @@ def make_simple_node_with_symbol_name( def make_simple_node_with_default_symbol_name( - *, - fixed: bool = False, + *, fixed: bool = False ) -> SimpleNodeWithDefaultSymbolName: int_value = make_int_value(fixed=fixed) diff --git a/tests/eve_tests/unit_tests/test_utils.py b/tests/eve_tests/unit_tests/test_utils.py index a6fb5028d2..e64de42db8 100644 --- a/tests/eve_tests/unit_tests/test_utils.py +++ b/tests/eve_tests/unit_tests/test_utils.py @@ -27,12 +27,7 @@ def test_getitem_(): from gt4py.eve.utils import getitem_ - mapping = { - "true": True, - 1: True, - "false": False, - 0: False, - } + mapping = {"true": True, 1: True, "false": False, 0: False} sequence = [False, True, True] diff --git a/tests/next_tests/definitions.py b/tests/next_tests/definitions.py index e1341f99ae..44989e7804 100644 --- a/tests/next_tests/definitions.py +++ b/tests/next_tests/definitions.py @@ -177,7 +177,7 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum): OptionalProgramBackendId.DACE_GPU: DACE_SKIP_TEST_LIST + [ # awaiting dace fix, see https://github.com/spcl/dace/pull/1442 - (USES_FLOORDIV, XFAIL, BINDINGS_UNSUPPORTED_MESSAGE), + (USES_FLOORDIV, XFAIL, BINDINGS_UNSUPPORTED_MESSAGE) ], ProgramBackendId.GTFN_CPU: GTFN_SKIP_TEST_LIST + [(USES_SCAN_NESTED, XFAIL, UNSUPPORTED_MESSAGE)], @@ -186,11 +186,9 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum): ProgramBackendId.GTFN_GPU: GTFN_SKIP_TEST_LIST + [(USES_SCAN_NESTED, XFAIL, UNSUPPORTED_MESSAGE)], ProgramBackendId.GTFN_CPU_WITH_TEMPORARIES: GTFN_SKIP_TEST_LIST - + [ - (USES_DYNAMIC_OFFSETS, XFAIL, UNSUPPORTED_MESSAGE), - ], + + [(USES_DYNAMIC_OFFSETS, XFAIL, UNSUPPORTED_MESSAGE)], ProgramFormatterId.GTFN_CPP_FORMATTER: [ - (USES_REDUCTION_WITH_ONLY_SPARSE_FIELDS, XFAIL, REDUCTION_WITH_ONLY_SPARSE_FIELDS_MESSAGE), + (USES_REDUCTION_WITH_ONLY_SPARSE_FIELDS, XFAIL, REDUCTION_WITH_ONLY_SPARSE_FIELDS_MESSAGE) ], ProgramBackendId.ROUNDTRIP: [(USES_SPARSE_FIELDS_AS_OUTPUT, XFAIL, UNSUPPORTED_MESSAGE)], } diff --git a/tests/next_tests/integration_tests/cases.py b/tests/next_tests/integration_tests/cases.py index 4b50e21260..c6f0f9a113 100644 --- a/tests/next_tests/integration_tests/cases.py +++ b/tests/next_tests/integration_tests/cases.py @@ -144,10 +144,7 @@ def field( dtype: np.typing.DTypeLike, ) -> FieldValue: return constructors.full( - domain=common.domain(sizes), - fill_value=self.value, - dtype=dtype, - allocator=allocator, + domain=common.domain(sizes), fill_value=self.value, dtype=dtype, allocator=allocator ) @@ -422,13 +419,7 @@ def verify( Else, ``inout`` will not be passed and compared to ``ref``. """ if out: - run( - case, - fieldview_prog, - *args, - out=out, - offset_provider=offset_provider, - ) + run(case, fieldview_prog, *args, out=out, offset_provider=offset_provider) else: run(case, fieldview_prog, *args, offset_provider=offset_provider) @@ -478,9 +469,7 @@ def verify_with_default_data( @pytest.fixture -def cartesian_case( - exec_alloc_descriptor: test_definitions.ExecutionAndAllocatorDescriptor, -): +def cartesian_case(exec_alloc_descriptor: test_definitions.ExecutionAndAllocatorDescriptor): yield Case( exec_alloc_descriptor if exec_alloc_descriptor.executor else None, offset_provider={"Ioff": IDim, "Joff": JDim, "Koff": KDim}, @@ -492,8 +481,7 @@ def cartesian_case( @pytest.fixture def unstructured_case( - mesh_descriptor, - exec_alloc_descriptor: test_definitions.ExecutionAndAllocatorDescriptor, + mesh_descriptor, exec_alloc_descriptor: test_definitions.ExecutionAndAllocatorDescriptor ): yield Case( exec_alloc_descriptor if exec_alloc_descriptor.executor else None, @@ -568,8 +556,7 @@ def get_param_size(param_type: ts.TypeSpec, sizes: dict[gtx.Dimension, int]) -> def extend_sizes( - sizes: dict[gtx.Dimension, int], - extend: Optional[dict[gtx.Dimension, tuple[int, int]]] = None, + sizes: dict[gtx.Dimension, int], extend: Optional[dict[gtx.Dimension, tuple[int, int]]] = None ) -> dict[gtx.Dimension, int]: """Calculate the sizes per dimension given a set of extensions.""" sizes = sizes.copy() @@ -580,8 +567,7 @@ def extend_sizes( def get_default_data( - case: Case, - fieldview_prog: decorator.FieldOperator | decorator.Program, + case: Case, fieldview_prog: decorator.FieldOperator | decorator.Program ) -> tuple[tuple[gtx.Field | ScalarValue | tuple, ...], dict[str, gtx.Field]]: """ Allocate default data for a fieldview code object given a test case. diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py index 04fbb58305..388849bf09 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py @@ -279,15 +279,7 @@ def skip_value_mesh() -> MeshDescriptor: ) c2v_arr = np.array( - [ - [0, 6, 5], - [0, 2, 6], - [0, 1, 2], - [2, 3, 6], - [3, 4, 6], - [4, 5, 6], - ], - dtype=gtx.IndexType, + [[0, 6, 5], [0, 2, 6], [0, 1, 2], [2, 3, 6], [3, 4, 6], [4, 5, 6]], dtype=gtx.IndexType ) c2e_arr = np.array( diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_arg_call_interface.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_arg_call_interface.py index b8d9841616..4b3ca54c36 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_arg_call_interface.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_arg_call_interface.py @@ -124,13 +124,7 @@ def foo(x: IField, y: IField, z: IField) -> IField: @program def testee( - a: IField, - b: IField, - c: IField, - out1: IField, - out2: IField, - out3: IField, - out4: IField, + a: IField, b: IField, c: IField, out1: IField, out2: IField, out3: IField, out4: IField ): foo(a, b, c, out=out1) foo(a, y=b, z=c, out=out2) @@ -230,9 +224,7 @@ def test_scan_wrong_return_type(cartesian_case): ): @scan_operator(axis=KDim, forward=True, init=0) - def testee_scan( - state: int32, - ) -> float: + def testee_scan(state: int32) -> float: return 1.0 @program @@ -250,9 +242,7 @@ def test_scan_wrong_state_type(cartesian_case): ): @scan_operator(axis=KDim, forward=True, init=0) - def testee_scan( - state: float, - ) -> int32: + def testee_scan(state: float) -> int32: return 1 @program diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_concat_where.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_concat_where.py index 9da6d260e5..49e428541a 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_concat_where.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_concat_where.py @@ -15,10 +15,7 @@ import numpy as np from typing import Tuple import pytest -from next_tests.integration_tests.cases import ( - KDim, - cartesian_case, -) +from next_tests.integration_tests.cases import KDim, cartesian_case from gt4py import next as gtx from gt4py.next.ffront.experimental import concat_where from next_tests.integration_tests import cases diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py index 8b6d414f0f..6e716298b7 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py @@ -187,11 +187,7 @@ def testee(a: int32) -> cases.VField: cases.verify_with_default_data( unstructured_case, testee, - ref=lambda a: np.full( - [unstructured_case.default_sizes[Vertex]], - a + 1, - dtype=int32, - ), + ref=lambda a: np.full([unstructured_case.default_sizes[Vertex]], a + 1, dtype=int32), comparison=lambda a, b: np.all(a == b), ) @@ -246,11 +242,7 @@ def testee(size: gtx.IndexType, out: gtx.Field[[IDim], gtx.IndexType]): out = cases.allocate(cartesian_case, testee, "out").zeros()() cases.verify( - cartesian_case, - testee, - size, - out=out, - ref=np.full_like(out, size, dtype=gtx.IndexType), + cartesian_case, testee, size, out=out, ref=np.full_like(out, size, dtype=gtx.IndexType) ) @@ -363,30 +355,17 @@ def field_op_returning_a_tuple( @gtx.field_operator def cast_tuple( - a: cases.IFloatField, - b: cases.IFloatField, - a_asint: cases.IField, - b_asint: cases.IField, + a: cases.IFloatField, b: cases.IFloatField, a_asint: cases.IField, b_asint: cases.IField ) -> tuple[gtx.Field[[IDim], bool], gtx.Field[[IDim], bool]]: result = astype(field_op_returning_a_tuple(a, b), int32) - return ( - result[0] == a_asint, - result[1] == b_asint, - ) + return (result[0] == a_asint, result[1] == b_asint) @gtx.field_operator def cast_nested_tuple( - a: cases.IFloatField, - b: cases.IFloatField, - a_asint: cases.IField, - b_asint: cases.IField, + a: cases.IFloatField, b: cases.IFloatField, a_asint: cases.IField, b_asint: cases.IField ) -> tuple[gtx.Field[[IDim], bool], gtx.Field[[IDim], bool], gtx.Field[[IDim], bool]]: result = astype((a, field_op_returning_a_tuple(a, b)), int32) - return ( - result[0] == a_asint, - result[1][0] == a_asint, - result[1][1] == b_asint, - ) + return (result[0] == a_asint, result[1][0] == a_asint, result[1][1] == b_asint) a = cases.allocate(cartesian_case, cast_tuple, "a")() b = cases.allocate(cartesian_case, cast_tuple, "b")() @@ -432,10 +411,7 @@ def testee(a: cases.IFloatField) -> gtx.Field[[IDim], bool]: return b cases.verify_with_default_data( - cartesian_case, - testee, - ref=lambda a: a.astype(bool), - comparison=lambda a, b: np.all(a == b), + cartesian_case, testee, ref=lambda a: a.astype(bool), comparison=lambda a, b: np.all(a == b) ) @@ -526,11 +502,9 @@ def testee(a: cases.VField) -> cases.VField: unstructured_case, testee, ref=lambda a: np.sum( - np.sum( - a[unstructured_case.offset_provider["E2V"].table], - axis=1, - initial=0, - )[unstructured_case.offset_provider["V2E"].table], + np.sum(a[unstructured_case.offset_provider["E2V"].table], axis=1, initial=0)[ + unstructured_case.offset_provider["V2E"].table + ], axis=1, where=unstructured_case.offset_provider["V2E"].table != common._DEFAULT_SKIP_VALUE, ), @@ -733,12 +707,7 @@ def testee(a: cases.EField, b: cases.EField) -> cases.VField: unstructured_case, testee, ref=lambda a, b: ( - np.sum( - b[v2e_table], - axis=1, - initial=0, - where=v2e_table != common._DEFAULT_SKIP_VALUE, - ) + np.sum(b[v2e_table], axis=1, initial=0, where=v2e_table != common._DEFAULT_SKIP_VALUE) ), ) @@ -834,12 +803,7 @@ def test_scan_different_domain_in_tuple(cartesian_case): i_size = cartesian_case.default_sizes[IDim] k_size = cartesian_case.default_sizes[KDim] - inp1_np = np.ones( - ( - i_size + 1, - k_size, - ) - ) # i_size bigger than in the other argument + inp1_np = np.ones((i_size + 1, k_size)) # i_size bigger than in the other argument inp2_np = np.fromfunction(lambda i, k: k, shape=(i_size, k_size), dtype=float) inp1 = cartesian_case.as_field([IDim, KDim], inp1_np) inp2 = cartesian_case.as_field([IDim, KDim], inp2_np) @@ -887,11 +851,7 @@ def prev_levels_iterator(i): expected = np.asarray( [ - reduce( - lambda prev, k: prev + 1.0 + inp2_np[:, k], - prev_levels_iterator(k), - init, - ) + reduce(lambda prev, k: prev + 1.0 + inp2_np[:, k], prev_levels_iterator(k), init) for k in range(k_size) ] ).transpose() @@ -954,11 +914,7 @@ def fieldop_domain(a: cases.IField) -> cases.IField: def program_domain( inp: cases.IField, out: cases.IField, lower_i: gtx.IndexType, upper_i: gtx.IndexType ): - fieldop_domain( - inp, - out=out, - domain={IDim: (lower_i, upper_i // 2)}, - ) + fieldop_domain(inp, out=out, domain={IDim: (lower_i, upper_i // 2)}) inp = cases.allocate(cartesian_case, program_domain, "inp")() out = cases.allocate(cartesian_case, fieldop_domain, cases.RETURN)() @@ -966,16 +922,7 @@ def program_domain( ref = out.asnumpy().copy() ref[lower_i : int(upper_i / 2)] = inp.asnumpy()[lower_i : int(upper_i / 2)] * 2 - cases.verify( - cartesian_case, - program_domain, - inp, - out, - lower_i, - upper_i, - inout=out, - ref=ref, - ) + cases.verify(cartesian_case, program_domain, inp, out, lower_i, upper_i, inout=out, ref=ref) def test_domain_input_bounds_1(cartesian_case): @@ -998,9 +945,7 @@ def program_domain( upper_j: gtx.IndexType, ): fieldop_domain( - a, - out=out, - domain={IDim: (1 * lower_i, upper_i + 0), JDim: (lower_j - 0, upper_j)}, + a, out=out, domain={IDim: (1 * lower_i, upper_i + 0), JDim: (lower_j - 0, upper_j)} ) a = cases.allocate(cartesian_case, program_domain, "a")() @@ -1034,10 +979,7 @@ def fieldop_domain_tuple( @gtx.program def program_domain_tuple( - inp0: cases.IJField, - inp1: cases.IJField, - out0: cases.IJField, - out1: cases.IJField, + inp0: cases.IJField, inp1: cases.IJField, out0: cases.IJField, out1: cases.IJField ): fieldop_domain_tuple(inp0, inp1, out=(out0, out1), domain={IDim: (1, 9), JDim: (4, 6)}) @@ -1103,14 +1045,7 @@ def fieldop_implicit_broadcast_2(inp: cases.IField) -> cases.IField: @pytest.mark.uses_tuple_returns def test_tuple_unpacking(cartesian_case): @gtx.field_operator - def unpack( - inp: cases.IField, - ) -> tuple[ - cases.IField, - cases.IField, - cases.IField, - cases.IField, - ]: + def unpack(inp: cases.IField) -> tuple[cases.IField, cases.IField, cases.IField, cases.IField]: a, b, c, d = (inp + 2, inp + 3, inp + 5, inp + 7) return a, b, c, d @@ -1137,9 +1072,7 @@ def test_tuple_unpacking_star_multi(cartesian_case): ] @gtx.field_operator - def unpack( - inp: cases.IField, - ) -> OutType: + def unpack(inp: cases.IField) -> OutType: *a, a2, a3 = (inp, inp + 1, inp + 2, inp + 3) b1, *b, b3 = (inp + 4, inp + 5, inp + 6, inp + 7) c1, c2, *c = (inp + 8, inp + 9, inp + 10, inp + 11) @@ -1166,10 +1099,7 @@ def unpack( def test_tuple_unpacking_too_many_values(cartesian_case): - with pytest.raises( - errors.DSLError, - match=(r"Too many values to unpack \(expected 3\)."), - ): + with pytest.raises(errors.DSLError, match=(r"Too many values to unpack \(expected 3\).")): @gtx.field_operator(backend=cartesian_case.executor) def _star_unpack() -> tuple[int32, float64, int32]: @@ -1191,17 +1121,12 @@ def _invalid_unpack() -> tuple[int32, float64, int32]: def test_constant_closure_vars(cartesian_case): from gt4py.eve.utils import FrozenNamespace - constants = FrozenNamespace( - PI=np.float64(3.142), - E=np.float64(2.718), - ) + constants = FrozenNamespace(PI=np.float64(3.142), E=np.float64(2.718)) @gtx.field_operator def consume_constants(input: cases.IFloatField) -> cases.IFloatField: return constants.PI * constants.E * input cases.verify_with_default_data( - cartesian_case, - consume_constants, - ref=lambda input: constants.PI * constants.E * input, + cartesian_case, consume_constants, ref=lambda input: constants.PI * constants.E * input ) diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_external_local_field.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_external_local_field.py index 497dadf864..ab0f519391 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_external_local_field.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_external_local_field.py @@ -49,12 +49,7 @@ def testee( inp, ones, out=cases.allocate(unstructured_case, testee, cases.RETURN)(), - ref=np.sum( - v2e_table, - axis=1, - initial=0, - where=v2e_table != common._DEFAULT_SKIP_VALUE, - ), + ref=np.sum(v2e_table, axis=1, initial=0, where=v2e_table != common._DEFAULT_SKIP_VALUE), ) diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py index 686f951549..0bb1e78582 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py @@ -220,10 +220,7 @@ def testee(flux: cases.EField) -> cases.VField: unstructured_case, testee, ref=lambda flux: np.sum( - flux[v2e_table] * 2, - axis=1, - initial=0, - where=v2e_table != common._DEFAULT_SKIP_VALUE, + flux[v2e_table] * 2, axis=1, initial=0, where=v2e_table != common._DEFAULT_SKIP_VALUE ), ) @@ -234,8 +231,7 @@ def test_conditional_nested_tuple(cartesian_case): def conditional_nested_tuple( mask: cases.IBoolField, a: cases.IFloatField, b: cases.IFloatField ) -> tuple[ - tuple[cases.IFloatField, cases.IFloatField], - tuple[cases.IFloatField, cases.IFloatField], + tuple[cases.IFloatField, cases.IFloatField], tuple[cases.IFloatField, cases.IFloatField] ]: return where(mask, ((a, b), (b, a)), ((5.0, 7.0), (7.0, 5.0))) @@ -362,10 +358,7 @@ def conditional_shifted( @gtx.program def conditional_program( - mask: cases.IBoolField, - a: cases.IFloatField, - b: cases.IFloatField, - out: cases.IFloatField, + mask: cases.IBoolField, a: cases.IFloatField, b: cases.IFloatField, out: cases.IFloatField ): conditional_shifted(mask, a, b, out=out) diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_program.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_program.py index df0009d0d4..9e858f5afd 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_program.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_program.py @@ -130,9 +130,7 @@ def fo_from_fo_program(in_field: cases.IFloatField, out: cases.IFloatField): pow_three(in_field, out=out) cases.verify_with_default_data( - cartesian_case, - fo_from_fo_program, - ref=lambda in_field: in_field**3, + cartesian_case, fo_from_fo_program, ref=lambda in_field: in_field**3 ) @@ -236,7 +234,7 @@ def test_wrong_argument_type(cartesian_case, copy_program_def): msgs = [ r"- Expected argument 'in_field' to be of type 'Field\[\[IDim], float64\]'," - r" got 'Field\[\[JDim\], float64\]'.", + r" got 'Field\[\[JDim\], float64\]'." ] for msg in msgs: assert re.search(msg, exc_info.value.__cause__.args[0]) is not None @@ -256,7 +254,6 @@ def empty_domain_program(a: cases.IJField, out_field: cases.IJField): out_field = cases.allocate(cartesian_case, empty_domain_program, "out_field")() with pytest.raises( - ValueError, - match=(r"Dimensions in out field and field domain are not equivalent"), + ValueError, match=(r"Dimensions in out field and field domain are not equivalent") ): cases.run(cartesian_case, empty_domain_program, a, out_field, offset_provider={}) diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_scalar_if.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_scalar_if.py index da61233d95..2d33fdf230 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_scalar_if.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_scalar_if.py @@ -67,10 +67,7 @@ def simple_if(a: cases.IField, b: cases.IField, condition: bool) -> cases.IField def test_simple_if_conditional(condition1, condition2, cartesian_case): @field_operator def simple_if( - a: cases.IField, - b: cases.IField, - condition1: bool, - condition2: bool, + a: cases.IField, b: cases.IField, condition1: bool, condition2: bool ) -> cases.IField: if condition1: result1 = a @@ -361,8 +358,7 @@ def if_non_boolean_condition( def test_if_inconsistent_types(): with pytest.raises( - errors.DSLError, - match="Inconsistent types between two branches for variable", + errors.DSLError, match="Inconsistent types between two branches for variable" ): @field_operator diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_temporaries_with_sizes.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_temporaries_with_sizes.py index 191f8ee739..9807982797 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_temporaries_with_sizes.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_temporaries_with_sizes.py @@ -47,8 +47,8 @@ def run_gtfn_with_temporaries_and_symbolic_sizes(): "Cell": "num_cells", "Edge": "num_edges", "Vertex": "num_vertices", - }, - ), + } + ) ), ), allocator=run_gtfn_with_temporaries.allocator, @@ -64,11 +64,7 @@ def testee_op(a: cases.VField) -> cases.EField: @gtx.program def prog( - a: cases.VField, - out: cases.EField, - num_vertices: int32, - num_edges: int32, - num_cells: int32, + a: cases.VField, out: cases.EField, num_vertices: int32, num_edges: int32, num_cells: int32 ): testee_op(a, out=out) diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_type_deduction.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_type_deduction.py index 2174871f89..06226548ed 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_type_deduction.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_type_deduction.py @@ -44,25 +44,14 @@ def type_info_cases() -> list[tuple[Optional[ts.TypeSpec], dict]]: return [ - ( - ts.DeferredType(constraint=None), - { - "is_concrete": False, - }, - ), + (ts.DeferredType(constraint=None), {"is_concrete": False}), ( ts.DeferredType(constraint=ts.ScalarType), - { - "is_concrete": False, - "type_class": ts.ScalarType, - }, + {"is_concrete": False, "type_class": ts.ScalarType}, ), ( ts.DeferredType(constraint=ts.FieldType), - { - "is_concrete": False, - "type_class": ts.FieldType, - }, + {"is_concrete": False, "type_class": ts.FieldType}, ), ( ts.ScalarType(kind=ts.ScalarKind.INT64), @@ -287,13 +276,7 @@ def callable_type_info_cases(): [], ts.VoidType(), ), - ( - unary_tuple_arg_func_type, - [tuple_type], - {}, - [], - ts.VoidType(), - ), + (unary_tuple_arg_func_type, [tuple_type], {}, [], ts.VoidType()), ( unary_tuple_arg_func_type, [ts.TupleType(types=[float_type, field_type])], @@ -350,10 +333,7 @@ def callable_type_info_cases(): ), ( scanop_type, - [ - ts.FieldType(dims=[KDim], dtype=int_type), - ts.FieldType(dims=[KDim], dtype=int_type), - ], + [ts.FieldType(dims=[KDim], dtype=int_type), ts.FieldType(dims=[KDim], dtype=int_type)], {}, [], ts.FieldType(dims=[KDim], dtype=float_type), @@ -384,13 +364,7 @@ def callable_type_info_cases(): ), ( tuple_scanop_type, - [ - ts.TupleType( - types=[ - ts.FieldType(dims=[IDim, JDim, KDim], dtype=int_type), - ] - ) - ], + [ts.TupleType(types=[ts.FieldType(dims=[IDim, JDim, KDim], dtype=int_type)])], {}, [ r"Expected argument 'a' to be of type 'tuple\[Field\[\[I, J, K\], int64\], " @@ -419,9 +393,7 @@ def test_accept_args( assert accepts_args == type_info.accepts_args(func_type, with_args=args, with_kwargs=kwargs) if len(expected) > 0: - with pytest.raises( - ValueError, - ) as exc_info: + with pytest.raises(ValueError) as exc_info: type_info.accepts_args( func_type, with_args=args, with_kwargs=kwargs, raise_exception=True ) @@ -453,12 +425,10 @@ def unpack_explicit_tuple( parsed = FieldOperatorParser.apply_to_function(unpack_explicit_tuple) assert parsed.body.annex.symtable[ssa.unique_name("tmp_a", 0)].type == ts.FieldType( - dims=[TDim], - dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT64, shape=None), + dims=[TDim], dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT64, shape=None) ) assert parsed.body.annex.symtable[ssa.unique_name("tmp_b", 0)].type == ts.FieldType( - dims=[TDim], - dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT64, shape=None), + dims=[TDim], dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT64, shape=None) ) @@ -471,14 +441,8 @@ def temp_tuple(a: Field[[TDim], float64], b: Field[[TDim], int64]): assert parsed.body.annex.symtable[ssa.unique_name("tmp", 0)].type == ts.TupleType( types=[ - ts.FieldType( - dims=[TDim], - dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT64, shape=None), - ), - ts.FieldType( - dims=[TDim], - dtype=ts.ScalarType(kind=ts.ScalarKind.INT64, shape=None), - ), + ts.FieldType(dims=[TDim], dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT64, shape=None)), + ts.FieldType(dims=[TDim], dtype=ts.ScalarType(kind=ts.ScalarKind.INT64, shape=None)), ] ) @@ -490,8 +454,7 @@ def add_bools(a: Field[[TDim], bool], b: Field[[TDim], bool]): return a + b with pytest.raises( - errors.DSLError, - match=(r"Type 'Field\[\[TDim\], bool\]' can not be used in operator '\+'."), + errors.DSLError, match=(r"Type 'Field\[\[TDim\], bool\]' can not be used in operator '\+'.") ): _ = FieldOperatorParser.apply_to_function(add_bools) @@ -641,10 +604,7 @@ def test_broadcast_disjoint(): def disjoint_broadcast(a: Field[[ADim], float64]): return broadcast(a, (BDim, CDim)) - with pytest.raises( - errors.DSLError, - match=r"expected broadcast dimension\(s\) \'.*\' missing", - ): + with pytest.raises(errors.DSLError, match=r"expected broadcast dimension\(s\) \'.*\' missing"): _ = FieldOperatorParser.apply_to_function(disjoint_broadcast) @@ -657,8 +617,7 @@ def badtype_broadcast(a: Field[[ADim], float64]): return broadcast(a, (BDim, CDim)) with pytest.raises( - errors.DSLError, - match=r"expected all broadcast dimensions to be of type 'Dimension'.", + errors.DSLError, match=r"expected all broadcast dimensions to be of type 'Dimension'." ): _ = FieldOperatorParser.apply_to_function(badtype_broadcast) @@ -722,10 +681,7 @@ def test_where_bad_dim(): def bad_dim_where(a: Field[[ADim], bool], b: Field[[ADim], float64]): return where(a, ((5.0, 9.0), (b, 6.0)), b) - with pytest.raises( - errors.DSLError, - match=r"Return arguments need to be of same type", - ): + with pytest.raises(errors.DSLError, match=r"Return arguments need to be of same type"): _ = FieldOperatorParser.apply_to_function(bad_dim_where) @@ -792,10 +748,7 @@ def simple_astype(a: Field[[TDim], float64]): _ = FieldOperatorParser.apply_to_function(simple_astype) assert ( - re.search( - "Expected 1st argument to be of type", - exc_info.value.__cause__.args[0], - ) + re.search("Expected 1st argument to be of type", exc_info.value.__cause__.args[0]) is not None ) @@ -804,10 +757,7 @@ def test_mod_floats(): def modulo_floats(inp: Field[[TDim], float]): return inp % 3.0 - with pytest.raises( - errors.DSLError, - match=r"Type 'float64' can not be used in operator '%'", - ): + with pytest.raises(errors.DSLError, match=r"Type 'float64' can not be used in operator '%'"): _ = FieldOperatorParser.apply_to_function(modulo_floats) @@ -827,10 +777,7 @@ def test_as_offset_dim(): def as_offset_dim(a: Field[[ADim, BDim], float], b: Field[[ADim], int]): return a(as_offset(Boff, b)) - with pytest.raises( - errors.DSLError, - match=f"not in list of offset field dimensions", - ): + with pytest.raises(errors.DSLError, match=f"not in list of offset field dimensions"): _ = FieldOperatorParser.apply_to_function(as_offset_dim) @@ -842,8 +789,5 @@ def test_as_offset_dtype(): def as_offset_dtype(a: Field[[ADim, BDim], float], b: Field[[BDim], float]): return a(as_offset(Boff, b)) - with pytest.raises( - errors.DSLError, - match=f"expected integer for offset field dtype", - ): + with pytest.raises(errors.DSLError, match=f"expected integer for offset field dtype"): _ = FieldOperatorParser.apply_to_function(as_offset_dtype) diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_where.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_where.py index 2fc31e6574..308367bf6b 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_where.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_where.py @@ -15,13 +15,7 @@ import numpy as np from typing import Tuple import pytest -from next_tests.integration_tests.cases import ( - IDim, - JDim, - KDim, - Koff, - cartesian_case, -) +from next_tests.integration_tests.cases import IDim, JDim, KDim, Koff, cartesian_case from gt4py import next as gtx from gt4py.next.ffront.fbuiltins import where, broadcast from next_tests.integration_tests import cases diff --git a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_builtins.py b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_builtins.py index d3f3f35699..3e06789367 100644 --- a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_builtins.py +++ b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_builtins.py @@ -138,11 +138,7 @@ def arithmetic_and_logical_test_data(): (minus, [2.0, 3.0], -1.0), (multiplies, [2.0, 3.0], 6.0), (divides, [6.0, 2.0], 3.0), - ( - if_, - [[True, False], [1.0, 1.0], [2.0, 2.0]], - [1.0, 2.0], - ), + (if_, [[True, False], [1.0, 1.0], [2.0, 2.0]], [1.0, 2.0]), (mod, [5, 2], 1), (greater, [[2.0, 1.0, 1.0], [1.0, 2.0, 1.0]], [True, False, False]), (greater_equal, [[2.0, 1.0, 1.0], [1.0, 2.0, 1.0]], [True, False, True]), @@ -192,9 +188,7 @@ def test_arithmetic_and_logical_functors_gtfn(builtin, inputs, expected): gtfn_without_transforms = dataclasses.replace( gtfn_executor, otf_workflow=gtfn_executor.otf_workflow.replace( - translation=gtfn_executor.otf_workflow.translation.replace( - enable_itir_transforms=False - ), + translation=gtfn_executor.otf_workflow.translation.replace(enable_itir_transforms=False) ), ) # avoid inlining the function fencil(builtin, out, *inps, processor=gtfn_without_transforms) diff --git a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_cartesian_offset_provider.py b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_cartesian_offset_provider.py index b82cea4b22..214c5b70f3 100644 --- a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_cartesian_offset_provider.py +++ b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_cartesian_offset_provider.py @@ -34,20 +34,14 @@ def foo(inp): @fendef(offset_provider={"I": I_loc, "J": J_loc}) def fencil(output, input): closure( - cartesian_domain(named_range(I_loc, 0, 1), named_range(J_loc, 0, 1)), - foo, - output, - [input], + cartesian_domain(named_range(I_loc, 0, 1), named_range(J_loc, 0, 1)), foo, output, [input] ) @fendef(offset_provider={"I": J_loc, "J": I_loc}) def fencil_swapped(output, input): closure( - cartesian_domain(named_range(I_loc, 0, 1), named_range(J_loc, 0, 1)), - foo, - output, - [input], + cartesian_domain(named_range(I_loc, 0, 1), named_range(J_loc, 0, 1)), foo, output, [input] ) diff --git a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_conditional.py b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_conditional.py index db7776b2f4..134af33b1b 100644 --- a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_conditional.py +++ b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_conditional.py @@ -39,16 +39,8 @@ def test_conditional_w_tuple(program_processor): inp = gtx.as_field([IDim], np.random.randint(0, 2, shape, dtype=np.int32)) out = gtx.as_field([IDim], np.zeros(shape)) - dom = { - IDim: range(0, shape[0]), - } - run_processor( - stencil_conditional[dom], - program_processor, - inp, - out=out, - offset_provider={}, - ) + dom = {IDim: range(0, shape[0])} + run_processor(stencil_conditional[dom], program_processor, inp, out=out, offset_provider={}) if validate: assert np.all(out.asnumpy()[inp.asnumpy() == 0] == 3.0) assert np.all(out.asnumpy()[inp.asnumpy() == 1] == 7.0) diff --git a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_horizontal_indirection.py b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_horizontal_indirection.py index 4cdfff46da..b24cec7d02 100644 --- a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_horizontal_indirection.py +++ b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_horizontal_indirection.py @@ -56,10 +56,7 @@ def test_simple_indirection(program_processor): pytest.xfail("Applied shifts in if_ statements are not supported in TraceShift pass.") - if program_processor in [ - type_check.check_type_inference, - gtfn_format_sourcecode, - ]: + if program_processor in [type_check.check_type_inference, gtfn_format_sourcecode]: pytest.xfail( "We only support applied shifts in type_inference." ) # TODO fix test or generalize itir? diff --git a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_strided_offset_provider.py b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_strided_offset_provider.py index dd603fa3be..5d4c6cae4d 100644 --- a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_strided_offset_provider.py +++ b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_strided_offset_provider.py @@ -39,12 +39,7 @@ def foo(inp): @fendef(offset_provider={"O": LocA2LocAB_offset_provider}) def fencil(size, out, inp): - closure( - unstructured_domain(named_range(LocA, 0, size)), - foo, - out, - [inp], - ) + closure(unstructured_domain(named_range(LocA, 0, size)), foo, out, [inp]) @pytest.mark.uses_strided_neighbor_offset @@ -56,12 +51,7 @@ def test_strided_offset_provider(program_processor): LocAB_size = LocA_size * max_neighbors rng = np.random.default_rng() - inp = gtx.as_field( - [LocAB], - rng.normal( - size=(LocAB_size,), - ), - ) + inp = gtx.as_field([LocAB], rng.normal(size=(LocAB_size,))) out = gtx.as_field([LocA], np.zeros((LocA_size,))) ref = np.sum(inp.asnumpy().reshape(LocA_size, max_neighbors), axis=-1) diff --git a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_trivial.py b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_trivial.py index 8e12647c1b..f85b9b4035 100644 --- a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_trivial.py +++ b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_trivial.py @@ -106,10 +106,7 @@ def test_shifted_arg_to_lift(program_processor, lift_mode): @fendef def fen_direct_deref(i_size, j_size, out, inp): closure( - cartesian_domain( - named_range(IDim, 0, i_size), - named_range(JDim, 0, j_size), - ), + cartesian_domain(named_range(IDim, 0, i_size), named_range(JDim, 0, j_size)), deref, out, [inp], diff --git a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_tuple.py b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_tuple.py index 925ad33e86..de12a6ba5f 100644 --- a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_tuple.py +++ b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_tuple.py @@ -45,35 +45,22 @@ def tuple_output2(inp1, inp2): return make_tuple(deref(inp1), deref(inp2)) -@pytest.mark.parametrize( - "stencil", - [tuple_output1, tuple_output2], -) +@pytest.mark.parametrize("stencil", [tuple_output1, tuple_output2]) @pytest.mark.uses_tuple_returns def test_tuple_output(program_processor, stencil): program_processor, validate = program_processor shape = [5, 7, 9] rng = np.random.default_rng() - inp1 = gtx.as_field( - [IDim, JDim, KDim], - rng.normal(size=(shape[0], shape[1], shape[2])), - ) - inp2 = gtx.as_field( - [IDim, JDim, KDim], - rng.normal(size=(shape[0], shape[1], shape[2])), - ) + inp1 = gtx.as_field([IDim, JDim, KDim], rng.normal(size=(shape[0], shape[1], shape[2]))) + inp2 = gtx.as_field([IDim, JDim, KDim], rng.normal(size=(shape[0], shape[1], shape[2]))) out = ( gtx.as_field([IDim, JDim, KDim], np.zeros(shape)), gtx.as_field([IDim, JDim, KDim], np.zeros(shape)), ) - dom = { - IDim: range(0, shape[0]), - JDim: range(0, shape[1]), - KDim: range(0, shape[2]), - } + dom = {IDim: range(0, shape[0]), JDim: range(0, shape[1]), KDim: range(0, shape[2])} run_processor(stencil[dom], program_processor, inp1, inp2, out=out, offset_provider={}) if validate: assert np.allclose(inp1.asnumpy(), out[0].asnumpy()) @@ -100,22 +87,10 @@ def stencil(inp1, inp2, inp3, inp4): shape = [5, 7, 9] rng = np.random.default_rng() - inp1 = gtx.as_field( - [IDim, JDim, KDim], - rng.normal(size=(shape[0], shape[1], shape[2])), - ) - inp2 = gtx.as_field( - [IDim, JDim, KDim], - rng.normal(size=(shape[0], shape[1], shape[2])), - ) - inp3 = gtx.as_field( - [IDim, JDim, KDim], - rng.normal(size=(shape[0], shape[1], shape[2])), - ) - inp4 = gtx.as_field( - [IDim, JDim, KDim], - rng.normal(size=(shape[0], shape[1], shape[2])), - ) + inp1 = gtx.as_field([IDim, JDim, KDim], rng.normal(size=(shape[0], shape[1], shape[2]))) + inp2 = gtx.as_field([IDim, JDim, KDim], rng.normal(size=(shape[0], shape[1], shape[2]))) + inp3 = gtx.as_field([IDim, JDim, KDim], rng.normal(size=(shape[0], shape[1], shape[2]))) + inp4 = gtx.as_field([IDim, JDim, KDim], rng.normal(size=(shape[0], shape[1], shape[2]))) out = ( ( @@ -128,20 +103,9 @@ def stencil(inp1, inp2, inp3, inp4): ), ) - dom = { - IDim: range(0, shape[0]), - JDim: range(0, shape[1]), - KDim: range(0, shape[2]), - } + dom = {IDim: range(0, shape[0]), JDim: range(0, shape[1]), KDim: range(0, shape[2])} run_processor( - stencil[dom], - program_processor, - inp1, - inp2, - inp3, - inp4, - out=out, - offset_provider={}, + stencil[dom], program_processor, inp1, inp2, inp3, inp4, out=out, offset_provider={} ) if validate: assert np.allclose(inp1.asnumpy(), out[0][0].asnumpy()) @@ -150,10 +114,7 @@ def stencil(inp1, inp2, inp3, inp4): assert np.allclose(inp4.asnumpy(), out[1][1].asnumpy()) -@pytest.mark.parametrize( - "stencil", - [tuple_output1, tuple_output2], -) +@pytest.mark.parametrize("stencil", [tuple_output1, tuple_output2]) def test_tuple_of_field_output_constructed_inside(program_processor, stencil): program_processor, validate = program_processor @@ -172,14 +133,8 @@ def fencil(size0, size1, size2, inp1, inp2, out1, out2): shape = [5, 7, 9] rng = np.random.default_rng() - inp1 = gtx.as_field( - [IDim, JDim, KDim], - rng.normal(size=(shape[0], shape[1], shape[2])), - ) - inp2 = gtx.as_field( - [IDim, JDim, KDim], - rng.normal(size=(shape[0], shape[1], shape[2])), - ) + inp1 = gtx.as_field([IDim, JDim, KDim], rng.normal(size=(shape[0], shape[1], shape[2]))) + inp2 = gtx.as_field([IDim, JDim, KDim], rng.normal(size=(shape[0], shape[1], shape[2]))) out1 = gtx.as_field([IDim, JDim, KDim], np.zeros(shape)) out2 = gtx.as_field([IDim, JDim, KDim], np.zeros(shape)) @@ -223,18 +178,9 @@ def fencil(size0, size1, size2, inp1, inp2, inp3, out1, out2, out3): shape = [5, 7, 9] rng = np.random.default_rng() - inp1 = gtx.as_field( - [IDim, JDim, KDim], - rng.normal(size=(shape[0], shape[1], shape[2])), - ) - inp2 = gtx.as_field( - [IDim, JDim, KDim], - rng.normal(size=(shape[0], shape[1], shape[2])), - ) - inp3 = gtx.as_field( - [IDim, JDim, KDim], - rng.normal(size=(shape[0], shape[1], shape[2])), - ) + inp1 = gtx.as_field([IDim, JDim, KDim], rng.normal(size=(shape[0], shape[1], shape[2]))) + inp2 = gtx.as_field([IDim, JDim, KDim], rng.normal(size=(shape[0], shape[1], shape[2]))) + inp3 = gtx.as_field([IDim, JDim, KDim], rng.normal(size=(shape[0], shape[1], shape[2]))) out1 = gtx.as_field([IDim, JDim, KDim], np.zeros(shape)) out2 = gtx.as_field([IDim, JDim, KDim], np.zeros(shape)) @@ -261,32 +207,19 @@ def fencil(size0, size1, size2, inp1, inp2, inp3, out1, out2, out3): @pytest.mark.xfail(reason="Implement wrapper for extradim as tuple") -@pytest.mark.parametrize( - "stencil", - [tuple_output1, tuple_output2], -) +@pytest.mark.parametrize("stencil", [tuple_output1, tuple_output2]) def test_field_of_extra_dim_output(program_processor, stencil): program_processor, validate = program_processor shape = [5, 7, 9] rng = np.random.default_rng() - inp1 = gtx.as_field( - [IDim, JDim, KDim], - rng.normal(size=(shape[0], shape[1], shape[2])), - ) - inp2 = gtx.as_field( - [IDim, JDim, KDim], - rng.normal(size=(shape[0], shape[1], shape[2])), - ) + inp1 = gtx.as_field([IDim, JDim, KDim], rng.normal(size=(shape[0], shape[1], shape[2]))) + inp2 = gtx.as_field([IDim, JDim, KDim], rng.normal(size=(shape[0], shape[1], shape[2]))) out_np = np.zeros(shape + [2]) out = gtx.as_field([IDim, JDim, KDim, None], out_np) - dom = { - IDim: range(0, shape[0]), - JDim: range(0, shape[1]), - KDim: range(0, shape[2]), - } + dom = {IDim: range(0, shape[0]), JDim: range(0, shape[1]), KDim: range(0, shape[2])} run_processor(stencil[dom], program_processor, inp1, inp2, out=out, offset_provider={}) if validate: assert np.allclose(inp1, out_np[:, :, :, 0]) @@ -305,10 +238,7 @@ def test_tuple_field_input(program_processor): shape = [5, 7, 9] rng = np.random.default_rng() - inp1 = gtx.as_field( - [IDim, JDim, KDim], - rng.normal(size=(shape[0], shape[1], shape[2])), - ) + inp1 = gtx.as_field([IDim, JDim, KDim], rng.normal(size=(shape[0], shape[1], shape[2]))) inp2 = gtx.as_field( [IDim, JDim, KDim], rng.normal( @@ -318,11 +248,7 @@ def test_tuple_field_input(program_processor): out = gtx.as_field([IDim, JDim, KDim], np.zeros(shape)) - dom = { - IDim: range(0, shape[0]), - JDim: range(0, shape[1]), - KDim: range(0, shape[2]), - } + dom = {IDim: range(0, shape[0]), JDim: range(0, shape[1]), KDim: range(0, shape[2])} run_processor(tuple_input[dom], program_processor, (inp1, inp2), out=out, offset_provider={}) if validate: assert np.allclose(inp1.asnumpy() + inp2.asnumpy()[:, :, :-1], out.asnumpy()) @@ -342,11 +268,7 @@ def test_field_of_extra_dim_input(program_processor): inp = gtx.as_field([IDim, JDim, KDim, None], inp) out = gtx.as_field([IDim, JDim, KDim], np.zeros(shape)) - dom = { - IDim: range(0, shape[0]), - JDim: range(0, shape[1]), - KDim: range(0, shape[2]), - } + dom = {IDim: range(0, shape[0]), JDim: range(0, shape[1]), KDim: range(0, shape[2])} run_processor(tuple_input[dom], program_processor, inp, out=out, offset_provider={}) if validate: assert np.allclose(np.asarray(inp1) + np.asarray(inp2), out) @@ -377,11 +299,7 @@ def test_tuple_of_tuple_of_field_input(program_processor): out = gtx.as_field([IDim, JDim, KDim], np.zeros(shape)) - dom = { - IDim: range(0, shape[0]), - JDim: range(0, shape[1]), - KDim: range(0, shape[2]), - } + dom = {IDim: range(0, shape[0]), JDim: range(0, shape[1]), KDim: range(0, shape[2])} run_processor( tuple_tuple_input[dom], program_processor, @@ -408,17 +326,7 @@ def test_field_of_2_extra_dim_input(program_processor): out = gtx.as_field([IDim, JDim, KDim], np.zeros(shape)) - dom = { - IDim: range(0, shape[0]), - JDim: range(0, shape[1]), - KDim: range(0, shape[2]), - } - run_processor( - tuple_tuple_input[dom], - program_processor, - inp, - out=out, - offset_provider={}, - ) + dom = {IDim: range(0, shape[0]), JDim: range(0, shape[1]), KDim: range(0, shape[2])} + run_processor(tuple_tuple_input[dom], program_processor, inp, out=out, offset_provider={}) if validate: assert np.allclose(np.sum(inp, axis=(3, 4)), out) diff --git a/tests/next_tests/integration_tests/feature_tests/math_builtin_test_data.py b/tests/next_tests/integration_tests/feature_tests/math_builtin_test_data.py index 367784cd4e..f05e3ccb37 100644 --- a/tests/next_tests/integration_tests/feature_tests/math_builtin_test_data.py +++ b/tests/next_tests/integration_tests/feature_tests/math_builtin_test_data.py @@ -21,48 +21,19 @@ def math_builtin_test_data() -> list[tuple[str, tuple[list[int | float], ...]]]: # FIXME(ben): dataset is missing invalid ranges (mostly nan outputs) # FIXME(ben): we're not properly testing different datatypes # builtin name, tuple of arguments - ( - "abs", - ([-1, 1, -1.0, 1.0, 0, -0, 0.0, -0.0],), - ), + ("abs", ([-1, 1, -1.0, 1.0, 0, -0, 0.0, -0.0],)), ( "minimum", ( [2, 2.0, 2.0, 3.0, 2, 3, -2, -2.0, -2.0, -3.0, -2, -3], - [ - 2, - 2.0, - 3.0, - 2.0, - 3, - 2, - -2, - -2.0, - -3.0, - -2.0, - -3, - -2, - ], + [2, 2.0, 3.0, 2.0, 3, 2, -2, -2.0, -3.0, -2.0, -3, -2], ), ), ( "maximum", ( [2, 2.0, 2.0, 3.0, 2, 3, -2, -2.0, -2.0, -3.0, -2, -3], - [ - 2, - 2.0, - 3.0, - 2.0, - 3, - 2, - -2, - -2.0, - -3.0, - -2.0, - -3, - -2, - ], + [2, 2.0, 3.0, 2.0, 3, 2, -2, -2.0, -3.0, -2.0, -3, -2], ), ), ( @@ -75,26 +46,11 @@ def math_builtin_test_data() -> list[tuple[str, tuple[list[int | float], ...]]]: # ([6, 6.0, -6, 6.0, 7, -7.0, 4.8, 4], [2, 2.0, 2.0, -2, 3.0, -3, 1.2, -1.2]), ([2, 2.0], [2, 2.0]), ), - ( - "sin", - ([0, 0.1, -0.01, np.pi, -2.0 / 3.0 * np.pi, 2.0 * np.pi, 3, 1000, -1000],), - ), - ( - "cos", - ([0, 0.1, -0.01, np.pi, -2.0 / 3.0 * np.pi, 2.0 * np.pi, 3, 1000, -1000],), - ), - ( - "tan", - ([0, 0.1, -0.01, np.pi, -2.0 / 3.0 * np.pi, 2.0 * np.pi, 3, 1000, -1000],), - ), - ( - "arcsin", - ([-1.0, -1, -0.7, -0.2, -0.0, 0, 0.0, 0.2, 0.7, 1, 1.0],), - ), - ( - "arccos", - ([-1.0, -1, -0.7, -0.2, -0.0, 0, 0.0, 0.2, 0.7, 1, 1.0],), - ), + ("sin", ([0, 0.1, -0.01, np.pi, -2.0 / 3.0 * np.pi, 2.0 * np.pi, 3, 1000, -1000],)), + ("cos", ([0, 0.1, -0.01, np.pi, -2.0 / 3.0 * np.pi, 2.0 * np.pi, 3, 1000, -1000],)), + ("tan", ([0, 0.1, -0.01, np.pi, -2.0 / 3.0 * np.pi, 2.0 * np.pi, 3, 1000, -1000],)), + ("arcsin", ([-1.0, -1, -0.7, -0.2, -0.0, 0, 0.0, 0.2, 0.7, 1, 1.0],)), + ("arccos", ([-1.0, -1, -0.7, -0.2, -0.0, 0, 0.0, 0.2, 0.7, 1, 1.0],)), ( "arctan", ( @@ -225,14 +181,8 @@ def math_builtin_test_data() -> list[tuple[str, tuple[list[int | float], ...]]]: ], ), ), - ( - "arccosh", - ([1, 1.0, 1.2, 1.7, 2, 2.0, 100, 103.7, 1000, 1379.89],), - ), - ( - "arctanh", - ([-1.0, -1, -0.7, -0.2, -0.0, 0, 0.0, 0.2, 0.7, 1, 1.0],), - ), + ("arccosh", ([1, 1.0, 1.2, 1.7, 2, 2.0, 100, 103.7, 1000, 1379.89],)), + ("arctanh", ([-1.0, -1, -0.7, -0.2, -0.0, 0, 0.0, 0.2, 0.7, 1, 1.0],)), ( "sqrt", ( @@ -338,29 +288,14 @@ def math_builtin_test_data() -> list[tuple[str, tuple[list[int | float], ...]]]: ], ), ), - ( - "isfinite", - ([1000, 0, 1, np.pi, -np.inf, np.inf, np.nan, np.nan + 1],), - ), - ( - "isinf", - ([1000, 0, 1, np.pi, -np.inf, np.inf, np.nan, np.nan + 1],), - ), + ("isfinite", ([1000, 0, 1, np.pi, -np.inf, np.inf, np.nan, np.nan + 1],)), + ("isinf", ([1000, 0, 1, np.pi, -np.inf, np.inf, np.nan, np.nan + 1],)), ( "isnan", # TODO(BenWeber42): would be good to ensure we have nans with different bit patterns ([1000, 0, 1, np.pi, -np.inf, np.inf, np.nan, np.nan + 1],), ), - ( - "floor", - ([-3.4, -1.5, -0.6, -0.1, -0.0, 0.0, 0.1, 0.6, 1.5, 3.4],), - ), - ( - "ceil", - ([-3.4, -1.5, -0.6, -0.1, -0.0, 0.0, 0.1, 0.6, 1.5, 3.4],), - ), - ( - "trunc", - ([-3.4, -1.5, -0.6, -0.1, -0.0, 0.0, 0.1, 0.6, 1.5, 3.4],), - ), + ("floor", ([-3.4, -1.5, -0.6, -0.1, -0.0, 0.0, 0.1, 0.6, 1.5, 3.4],)), + ("ceil", ([-3.4, -1.5, -0.6, -0.1, -0.0, 0.0, 0.1, 0.6, 1.5, 3.4],)), + ("trunc", ([-3.4, -1.5, -0.6, -0.1, -0.0, 0.0, 0.1, 0.6, 1.5, 3.4],)), ] diff --git a/tests/next_tests/integration_tests/feature_tests/otf_tests/test_nanobind_build.py b/tests/next_tests/integration_tests/feature_tests/otf_tests/test_nanobind_build.py index 936bb3ee58..19befd6304 100644 --- a/tests/next_tests/integration_tests/feature_tests/otf_tests/test_nanobind_build.py +++ b/tests/next_tests/integration_tests/feature_tests/otf_tests/test_nanobind_build.py @@ -32,7 +32,7 @@ def test_gtfn_cpp_with_cmake(program_source_with_name): build_the_program = workflow.make_step(nanobind.bind_source).chain( compiler.Compiler( cache_lifetime=config.BuildCacheLifetime.SESSION, builder_factory=cmake.CMakeFactory() - ), + ) ) compiled_program = build_the_program(example_program_source) buf = (np.zeros(shape=(6, 5), dtype=np.float32), (0, 0)) @@ -51,7 +51,7 @@ def test_gtfn_cpp_with_compiledb(program_source_with_name): compiler.Compiler( cache_lifetime=config.BuildCacheLifetime.SESSION, builder_factory=compiledb.CompiledbFactory(), - ), + ) ) compiled_program = build_the_program(example_program_source) buf = (np.zeros(shape=(6, 5), dtype=np.float32), (0, 0)) diff --git a/tests/next_tests/integration_tests/feature_tests/test_util_cases.py b/tests/next_tests/integration_tests/feature_tests/test_util_cases.py index d739cec659..5c72a856a5 100644 --- a/tests/next_tests/integration_tests/feature_tests/test_util_cases.py +++ b/tests/next_tests/integration_tests/feature_tests/test_util_cases.py @@ -92,9 +92,7 @@ def test_verify_fails_with_wrong_type(cartesian_case): @pytest.mark.parametrize("exec_alloc_descriptor", [definitions.ProgramBackendId.ROUNDTRIP.load()]) -def test_verify_with_default_data_fails_with_wrong_reference( - cartesian_case, -): +def test_verify_with_default_data_fails_with_wrong_reference(cartesian_case): def wrong_ref(a, b): return a - b diff --git a/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/fvm_nabla.py b/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/fvm_nabla.py index abc3755dca..fe8b54f95c 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/fvm_nabla.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/fvm_nabla.py @@ -53,12 +53,7 @@ def compute_pnabla(pp, S_M, sign, vol): def zavgS_fencil(edge_domain, out, pp, S_M): - closure( - edge_domain, - compute_zavgS, - out, - [pp, S_M], - ) + closure(edge_domain, compute_zavgS, out, [pp, S_M]) Vertex = gtx.Dimension("Vertex") diff --git a/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_icon_like_scan.py b/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_icon_like_scan.py index 8da95712f4..ae6331f76b 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_icon_like_scan.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_icon_like_scan.py @@ -38,12 +38,7 @@ @gtx.scan_operator(axis=KDim, forward=True, init=(0.0, 0.0, True)) def _scan( - state: tuple[float, float, bool], - w: float, - z_q: float, - z_a: float, - z_b: float, - z_c: float, + state: tuple[float, float, bool], w: float, z_q: float, z_a: float, z_b: float, z_c: float ) -> tuple[float, float, bool]: z_q_m1, w_m1, first = state z_g = z_b + z_a * z_q_m1 @@ -78,11 +73,7 @@ def solve_nonhydro_stencil_52_like( dummy: gtx.Field[[Cell, KDim], bool], ): _solve_nonhydro_stencil_52_like( - z_alpha, - z_beta, - z_q, - w, - out=(z_q[:, 1:], w[:, 1:], dummy[:, 1:]), + z_alpha, z_beta, z_q, w, out=(z_q[:, 1:], w[:, 1:], dummy[:, 1:]) ) @@ -109,11 +100,7 @@ def solve_nonhydro_stencil_52_like_with_gtfn_tuple_merge( w: gtx.Field[[Cell, KDim], float], ): _solve_nonhydro_stencil_52_like_with_gtfn_tuple_merge( - z_alpha, - z_beta, - z_q, - w, - out=(z_q[:, 1:], w[:, 1:]), + z_alpha, z_beta, z_q, w, out=(z_q[:, 1:], w[:, 1:]) ) @@ -168,10 +155,7 @@ def solve_nonhydro_stencil_52_like_z_q_tup( def reference( - z_alpha: np.array, - z_beta: np.array, - z_q_ref: np.array, - w_ref: np.array, + z_alpha: np.array, z_beta: np.array, z_q_ref: np.array, w_ref: np.array ) -> tuple[np.ndarray, np.ndarray]: z_q = np.copy(z_q_ref) w = np.copy(w_ref) diff --git a/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_laplacian.py b/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_laplacian.py index 6784857211..e4e155bc25 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_laplacian.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_laplacian.py @@ -45,16 +45,14 @@ def laplap(in_field: gtx.Field[[IDim, JDim], "float"]) -> gtx.Field[[IDim, JDim] @gtx.program def lap_program( - in_field: gtx.Field[[IDim, JDim], "float"], - out_field: gtx.Field[[IDim, JDim], "float"], + in_field: gtx.Field[[IDim, JDim], "float"], out_field: gtx.Field[[IDim, JDim], "float"] ): lap(in_field, out=out_field[1:-1, 1:-1]) @gtx.program def laplap_program( - in_field: gtx.Field[[IDim, JDim], "float"], - out_field: gtx.Field[[IDim, JDim], "float"], + in_field: gtx.Field[[IDim, JDim], "float"], out_field: gtx.Field[[IDim, JDim], "float"] ): laplap(in_field, out=out_field[2:-2, 2:-2]) diff --git a/tests/next_tests/integration_tests/multi_feature_tests/fvm_nabla_setup.py b/tests/next_tests/integration_tests/multi_feature_tests/fvm_nabla_setup.py index 03e1af27dd..f6d22c31a2 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/fvm_nabla_setup.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/fvm_nabla_setup.py @@ -166,10 +166,7 @@ def input_field(self): zlonc = 3.0 * rpi / 2.0 m_rlonlatcr = self.fs_nodes.create_field( - name="m_rlonlatcr", - levels=1, - dtype=np.float64, - variables=self.edges_per_node, + name="m_rlonlatcr", levels=1, dtype=np.float64, variables=self.edges_per_node ) rlonlatcr = np.array(m_rlonlatcr, copy=False) diff --git a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_column_stencil.py b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_column_stencil.py index 9ba8eef3a3..c7c8cf6c57 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_column_stencil.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_column_stencil.py @@ -199,10 +199,7 @@ def ksum_fencil(i_size, k_start, k_end, inp, out): @pytest.mark.parametrize( "kstart, reference", - [ - (0, np.asarray([[0, 1, 3, 6, 10, 15, 21]])), - (2, np.asarray([[0, 0, 2, 5, 9, 14, 20]])), - ], + [(0, np.asarray([[0, 1, 3, 6, 10, 15, 21]])), (2, np.asarray([[0, 0, 2, 5, 9, 14, 20]]))], ) def test_ksum_scan(program_processor, lift_mode, kstart, reference): program_processor, validate = program_processor @@ -339,12 +336,7 @@ def sum_shifted(inp0, inp1): @fendef(column_axis=KDim) def sum_shifted_fencil(out, inp0, inp1, k_size): - closure( - cartesian_domain(named_range(KDim, 1, k_size)), - sum_shifted, - out, - [inp0, inp1], - ) + closure(cartesian_domain(named_range(KDim, 1, k_size)), sum_shifted, out, [inp0, inp1]) def test_different_vertical_sizes(program_processor): @@ -357,13 +349,7 @@ def test_different_vertical_sizes(program_processor): ref = inp0.ndarray + inp1.ndarray[1:] run_processor( - sum_shifted_fencil, - program_processor, - out, - inp0, - inp1, - k_size, - offset_provider={"K": KDim}, + sum_shifted_fencil, program_processor, out, inp0, inp1, k_size, offset_provider={"K": KDim} ) if validate: @@ -377,12 +363,7 @@ def sum(inp0, inp1): @fendef(column_axis=KDim) def sum_fencil(out, inp0, inp1, k_size): - closure( - cartesian_domain(named_range(KDim, 0, k_size)), - sum, - out, - [inp0, inp1], - ) + closure(cartesian_domain(named_range(KDim, 0, k_size)), sum, out, [inp0, inp1]) @pytest.mark.uses_origin @@ -396,13 +377,7 @@ def test_different_vertical_sizes_with_origin(program_processor): ref = inp0.asnumpy() + inp1.asnumpy()[:-1] run_processor( - sum_fencil, - program_processor, - out, - inp0, - inp1, - k_size, - offset_provider={"K": KDim}, + sum_fencil, program_processor, out, inp0, inp1, k_size, offset_provider={"K": KDim} ) if validate: diff --git a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_fvm_nabla.py b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_fvm_nabla.py index 9ee364c014..d29ef68d4e 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_fvm_nabla.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_fvm_nabla.py @@ -61,18 +61,8 @@ def compute_zavgS(pp, S_M): @fendef -def compute_zavgS_fencil( - n_edges, - out, - pp, - S_M, -): - closure( - unstructured_domain(named_range(Edge, 0, n_edges)), - compute_zavgS, - out, - [pp, S_M], - ) +def compute_zavgS_fencil(n_edges, out, pp, S_M): + closure(unstructured_domain(named_range(Edge, 0, n_edges)), compute_zavgS, out, [pp, S_M]) @fundef @@ -116,15 +106,7 @@ def compute_pnabla2(pp, S_M, sign, vol): @fendef -def nabla( - n_nodes, - out, - pp, - S_MXX, - S_MYY, - sign, - vol, -): +def nabla(n_nodes, out, pp, S_MXX, S_MYY, sign, vol): closure( unstructured_domain(named_range(Vertex, 0, n_nodes)), pnabla, @@ -178,18 +160,8 @@ def test_compute_zavgS(program_processor, lift_mode): @fendef -def compute_zavgS2_fencil( - n_edges, - out, - pp, - S_M, -): - closure( - unstructured_domain(named_range(Edge, 0, n_edges)), - compute_zavgS2, - out, - [pp, S_M], - ) +def compute_zavgS2_fencil(n_edges, out, pp, S_M): + closure(unstructured_domain(named_range(Edge, 0, n_edges)), compute_zavgS2, out, [pp, S_M]) @pytest.mark.requires_atlas @@ -273,14 +245,7 @@ def test_nabla(program_processor, lift_mode): @fendef -def nabla2( - n_nodes, - out, - pp, - S, - sign, - vol, -): +def nabla2(n_nodes, out, pp, S, sign, vol): closure( unstructured_domain(named_range(Vertex, 0, n_nodes)), compute_pnabla2, diff --git a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_vertical_advection.py b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_vertical_advection.py index 46038832d1..820e9415bc 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_vertical_advection.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_vertical_advection.py @@ -88,9 +88,7 @@ def tridiag_reference(): def fen_solve_tridiag(i_size, j_size, k_size, a, b, c, d, x): closure( cartesian_domain( - named_range(IDim, 0, i_size), - named_range(JDim, 0, j_size), - named_range(KDim, 0, k_size), + named_range(IDim, 0, i_size), named_range(JDim, 0, j_size), named_range(KDim, 0, k_size) ), solve_tridiag, x, @@ -102,9 +100,7 @@ def fen_solve_tridiag(i_size, j_size, k_size, a, b, c, d, x): def fen_solve_tridiag2(i_size, j_size, k_size, a, b, c, d, x): closure( cartesian_domain( - named_range(IDim, 0, i_size), - named_range(JDim, 0, j_size), - named_range(KDim, 0, k_size), + named_range(IDim, 0, i_size), named_range(JDim, 0, j_size), named_range(KDim, 0, k_size) ), solve_tridiag2, x, diff --git a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_with_toy_connectivity.py b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_with_toy_connectivity.py index 0b9b639b08..714e568b8f 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_with_toy_connectivity.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_with_toy_connectivity.py @@ -247,9 +247,7 @@ def test_slice_sparse(program_processor, lift_mode): program_processor, inp, out=out, - offset_provider={ - "V2V": gtx.NeighborTableOffsetProvider(v2v_arr, Vertex, Vertex, 4), - }, + offset_provider={"V2V": gtx.NeighborTableOffsetProvider(v2v_arr, Vertex, Vertex, 4)}, lift_mode=lift_mode, ) @@ -274,9 +272,7 @@ def test_slice_twice_sparse(program_processor, lift_mode): program_processor, inp, out=out, - offset_provider={ - "V2V": gtx.NeighborTableOffsetProvider(v2v_arr, Vertex, Vertex, 4), - }, + offset_provider={"V2V": gtx.NeighborTableOffsetProvider(v2v_arr, Vertex, Vertex, 4)}, lift_mode=lift_mode, ) @@ -302,9 +298,7 @@ def test_shift_sliced_sparse(program_processor, lift_mode): program_processor, inp, out=out, - offset_provider={ - "V2V": gtx.NeighborTableOffsetProvider(v2v_arr, Vertex, Vertex, 4), - }, + offset_provider={"V2V": gtx.NeighborTableOffsetProvider(v2v_arr, Vertex, Vertex, 4)}, lift_mode=lift_mode, ) @@ -330,9 +324,7 @@ def test_slice_shifted_sparse(program_processor, lift_mode): program_processor, inp, out=out, - offset_provider={ - "V2V": gtx.NeighborTableOffsetProvider(v2v_arr, Vertex, Vertex, 4), - }, + offset_provider={"V2V": gtx.NeighborTableOffsetProvider(v2v_arr, Vertex, Vertex, 4)}, lift_mode=lift_mode, ) diff --git a/tests/next_tests/unit_tests/embedded_tests/test_common.py b/tests/next_tests/unit_tests/embedded_tests/test_common.py index 111622ac42..367ecbfdcf 100644 --- a/tests/next_tests/unit_tests/embedded_tests/test_common.py +++ b/tests/next_tests/unit_tests/embedded_tests/test_common.py @@ -71,31 +71,19 @@ def test_slice_range(rng, slce, expected): ([(I, (-2, 3))], NamedRange(I, UnitRange(-3, -2)), IndexError), ([(I, (-2, 3))], NamedIndex(I, 3), IndexError), ([(I, (-2, 3))], NamedRange(I, UnitRange(3, 4)), IndexError), - ( - [(I, (2, 5)), (J, (3, 6)), (K, (4, 7))], - 2, - [(J, (3, 6)), (K, (4, 7))], - ), + ([(I, (2, 5)), (J, (3, 6)), (K, (4, 7))], 2, [(J, (3, 6)), (K, (4, 7))]), ( [(I, (2, 5)), (J, (3, 6)), (K, (4, 7))], slice(2, 3), [(I, (4, 5)), (J, (3, 6)), (K, (4, 7))], ), - ( - [(I, (2, 5)), (J, (3, 6)), (K, (4, 7))], - NamedIndex(I, 2), - [(J, (3, 6)), (K, (4, 7))], - ), + ([(I, (2, 5)), (J, (3, 6)), (K, (4, 7))], NamedIndex(I, 2), [(J, (3, 6)), (K, (4, 7))]), ( [(I, (2, 5)), (J, (3, 6)), (K, (4, 7))], NamedRange(I, UnitRange(2, 3)), [(I, (2, 3)), (J, (3, 6)), (K, (4, 7))], ), - ( - [(I, (2, 5)), (J, (3, 6)), (K, (4, 7))], - NamedIndex(J, 3), - [(I, (2, 5)), (K, (4, 7))], - ), + ([(I, (2, 5)), (J, (3, 6)), (K, (4, 7))], NamedIndex(J, 3), [(I, (2, 5)), (K, (4, 7))]), ( [(I, (2, 5)), (J, (3, 6)), (K, (4, 7))], NamedRange(J, UnitRange(4, 5)), @@ -165,14 +153,8 @@ def test_iterate_domain(): (NamedRange(J, common.UnitRange(3, 6)), NamedRange(I, common.UnitRange(3, 5))), ], [slice(I(1), J(7)), IndexError], - [ - slice(I(1), None), - IndexError, - ], - [ - slice(None, K(8)), - IndexError, - ], + [slice(I(1), None), IndexError], + [slice(None, K(8)), IndexError], ], ) def test_slicing(slices, expected): diff --git a/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py b/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py index 375b654475..adf01bd613 100644 --- a/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py +++ b/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py @@ -61,9 +61,7 @@ def binary_arithmetic_op(request): yield request.param -@pytest.fixture( - params=[operator.xor, operator.and_, operator.or_], -) +@pytest.fixture(params=[operator.xor, operator.and_, operator.or_]) def binary_logical_op(request): yield request.param @@ -96,10 +94,7 @@ def _make_field_or_scalar( buffer = nd_array_implementation.asarray(lst, dtype=dtype) if domain is None: domain = _make_default_domain(buffer.shape) - return common._field( - buffer, - domain=domain, - ) + return common._field(buffer, domain=domain) def _np_asarray_or_scalar(value: Iterable | core_defs.Scalar, dtype=None): @@ -158,8 +153,7 @@ def test_where_builtin_different_domain(nd_array_implementation): domain=common.domain({D0: common.UnitRange(0, 2), D1: common.UnitRange(-1, 2)}), ) false_field = common._field( - nd_array_implementation.asarray(false_), - domain=common.domain({D1: common.UnitRange(-1, 3)}), + nd_array_implementation.asarray(false_), domain=common.domain({D1: common.UnitRange(-1, 3)}) ) expected = np.where(cond[np.newaxis, :], true_[:, 1:], false_[np.newaxis, 1:-1]) @@ -233,10 +227,7 @@ def test_binary_logical_ops(binary_logical_op, nd_array_implementation, lhs, rhs def test_unary_logical_ops(unary_logical_op, nd_array_implementation): - inp = [ - True, - False, - ] + inp = [True, False] expected = unary_logical_op(np.asarray(inp)) @@ -260,11 +251,7 @@ def test_unary_arithmetic_ops(unary_arithmetic_op, nd_array_implementation): @pytest.mark.parametrize( - "dims,expected_indices", - [ - ((D0,), (slice(5, 10), None)), - ((D1,), (None, slice(5, 10))), - ], + "dims,expected_indices", [((D0,), (slice(5, 10), None)), ((D1,), (None, slice(5, 10)))] ) def test_binary_operations_with_intersection(binary_arithmetic_op, dims, expected_indices): arr1 = np.arange(10) @@ -345,12 +332,7 @@ def test_remap_implementation(): ) e2v_conn = common._connectivity( np.arange(E_START, E_STOP), - domain=common.Domain( - dims=(E,), - ranges=[ - UnitRange(E_START, E_STOP), - ], - ), + domain=common.Domain(dims=(E,), ranges=[UnitRange(E_START, E_STOP)]), codomain=V, ) @@ -437,10 +419,7 @@ def test_field_broadcast(new_dims, field, expected_domain): @pytest.mark.parametrize( "domain_slice", - [ - (NamedRange(D0, UnitRange(0, 10)),), - common.Domain(dims=(D0,), ranges=(UnitRange(0, 10),)), - ], + [(NamedRange(D0, UnitRange(0, 10)),), common.Domain(dims=(D0,), ranges=(UnitRange(0, 10),))], ) def test_get_slices_with_named_indices_3d_to_1d(domain_slice): field_domain = common.Domain( @@ -472,18 +451,12 @@ def test_get_slices_invalid_type(): "domain_slice,expected_dimensions,expected_shape", [ ( - ( - NamedRange(D0, UnitRange(7, 9)), - NamedRange(D1, UnitRange(8, 10)), - ), + (NamedRange(D0, UnitRange(7, 9)), NamedRange(D1, UnitRange(8, 10))), (D0, D1, D2), (2, 2, 15), ), ( - ( - NamedRange(D0, UnitRange(7, 9)), - NamedRange(D2, UnitRange(12, 20)), - ), + (NamedRange(D0, UnitRange(7, 9)), NamedRange(D2, UnitRange(12, 20))), (D0, D1, D2), (2, 10, 8), ), @@ -491,14 +464,7 @@ def test_get_slices_invalid_type(): ((NamedIndex(D0, 8),), (D1, D2), (10, 15)), ((NamedIndex(D1, 9),), (D0, D2), (5, 15)), ((NamedIndex(D2, 11),), (D0, D1), (5, 10)), - ( - ( - NamedIndex(D0, 8), - NamedRange(D1, UnitRange(8, 10)), - ), - (D1, D2), - (2, 15), - ), + ((NamedIndex(D0, 8), NamedRange(D1, UnitRange(8, 10))), (D1, D2), (2, 15)), (NamedIndex(D0, 5), (D1, D2), (10, 15)), (NamedRange(D0, UnitRange(5, 7)), (D0, D1, D2), (2, 10, 15)), ], @@ -586,21 +552,13 @@ def test_absolute_indexing_value_return(): (5, 10), Domain(NamedRange(D0, UnitRange(5, 10)), NamedRange(D1, UnitRange(2, 12))), ), - ( - (Ellipsis, 1), - (10,), - Domain(NamedRange(D0, UnitRange(5, 15))), - ), + ((Ellipsis, 1), (10,), Domain(NamedRange(D0, UnitRange(5, 15)))), ( (slice(2, 3), slice(5, 7)), (1, 2), Domain(NamedRange(D0, UnitRange(7, 8)), NamedRange(D1, UnitRange(7, 9))), ), - ( - (slice(1, 2), 0), - (1,), - Domain(NamedRange(D0, UnitRange(6, 7))), - ), + ((slice(1, 2), 0), (1,), Domain(NamedRange(D0, UnitRange(6, 7)))), ], ) def test_relative_indexing_slice_2D(index, expected_shape, expected_domain): @@ -626,24 +584,21 @@ def test_relative_indexing_slice_2D(index, expected_shape, expected_domain): (slice(None),), (10, 15, 10), Domain( - dims=(D0, D1, D2), - ranges=(UnitRange(5, 15), UnitRange(10, 25), UnitRange(10, 20)), + dims=(D0, D1, D2), ranges=(UnitRange(5, 15), UnitRange(10, 25), UnitRange(10, 20)) ), ), ( (slice(None), slice(None), slice(None)), (10, 15, 10), Domain( - dims=(D0, D1, D2), - ranges=(UnitRange(5, 15), UnitRange(10, 25), UnitRange(10, 20)), + dims=(D0, D1, D2), ranges=(UnitRange(5, 15), UnitRange(10, 25), UnitRange(10, 20)) ), ), ( (slice(None)), (10, 15, 10), Domain( - dims=(D0, D1, D2), - ranges=(UnitRange(5, 15), UnitRange(10, 25), UnitRange(10, 20)), + dims=(D0, D1, D2), ranges=(UnitRange(5, 15), UnitRange(10, 25), UnitRange(10, 20)) ), ), ((0, Ellipsis, 0), (15,), Domain(dims=(D1,), ranges=(UnitRange(10, 25),))), @@ -651,8 +606,7 @@ def test_relative_indexing_slice_2D(index, expected_shape, expected_domain): Ellipsis, (10, 15, 10), Domain( - dims=(D0, D1, D2), - ranges=(UnitRange(5, 15), UnitRange(10, 25), UnitRange(10, 20)), + dims=(D0, D1, D2), ranges=(UnitRange(5, 15), UnitRange(10, 25), UnitRange(10, 20)) ), ), ], @@ -669,13 +623,7 @@ def test_relative_indexing_slice_3D(index, expected_shape, expected_domain): assert indexed_field.domain == expected_domain -@pytest.mark.parametrize( - "index, expected_value", - [ - ((1, 0), 10), - ((0, 1), 1), - ], -) +@pytest.mark.parametrize("index, expected_value", [((1, 0), 10), ((0, 1), 1)]) def test_relative_indexing_value_return(index, expected_value): domain = common.Domain(dims=(D0, D1), ranges=(UnitRange(5, 15), UnitRange(2, 12))) field = common._field(np.reshape(np.arange(100, dtype=int), (10, 10)), domain=domain) @@ -916,10 +864,7 @@ def test_connectivity_field_inverse_image_2d_domain_skip_values(): ([0, 1, 0], None), ([0, -1, 0], [(0, 3)]), ([[1, 1, 1], [1, 0, 0]], [(1, 2), (1, 3)]), - ( - [[1, 0, -1], [1, 0, 0]], - [(0, 2), (1, 3)], - ), + ([[1, 0, -1], [1, 0, 0]], [(0, 2), (1, 3)]), ], ) def test_hypercube(index_array, expected): diff --git a/tests/next_tests/unit_tests/ffront_tests/test_foast_to_itir.py b/tests/next_tests/unit_tests/ffront_tests/test_foast_to_itir.py index 9ebd991e36..2bb2c844a9 100644 --- a/tests/next_tests/unit_tests/ffront_tests/test_foast_to_itir.py +++ b/tests/next_tests/unit_tests/ffront_tests/test_foast_to_itir.py @@ -480,11 +480,9 @@ def reduction(edge_f: gtx.Field[[Edge], float64]): im.call("reduce")( "plus", im.deref(im.promote_to_const_iterator(im.literal(value="0", typename="float64"))), - ), + ) ) - )( - im.lifted_neighbors("V2E", "edge_f"), - ) + )(im.lifted_neighbors("V2E", "edge_f")) assert lowered.expr == reference @@ -515,26 +513,16 @@ def reduction(e1: gtx.Field[[Edge], float64], e2: gtx.Field[[Vertex, V2EDim], fl im.deref( im.promote_to_const_iterator(im.literal(value="0", typename="float64")) ), - ), + ) ) - )( - mapped, - ) + )(mapped) ) assert lowered.expr == reference def test_builtin_int_constructors(): - def int_constrs() -> ( - tuple[ - int32, - int32, - int64, - int32, - int64, - ] - ): + def int_constrs() -> tuple[int32, int32, int64, int32, int64]: return 1, int32(1), int64(1), int32("1"), int64("1") parsed = FieldOperatorParser.apply_to_function(int_constrs) @@ -552,17 +540,7 @@ def int_constrs() -> ( def test_builtin_float_constructors(): - def float_constrs() -> ( - tuple[ - float, - float, - float32, - float64, - float, - float32, - float64, - ] - ): + def float_constrs() -> tuple[float, float, float32, float64, float, float32, float64]: return ( 0.1, float(0.1), diff --git a/tests/next_tests/unit_tests/ffront_tests/test_func_to_foast.py b/tests/next_tests/unit_tests/ffront_tests/test_func_to_foast.py index e527d18e4c..de83741a5e 100644 --- a/tests/next_tests/unit_tests/ffront_tests/test_func_to_foast.py +++ b/tests/next_tests/unit_tests/ffront_tests/test_func_to_foast.py @@ -87,10 +87,7 @@ def test_mistyped_arg(): def mistyped(inp: gtx.Field): return inp - with pytest.raises( - ValueError, - match="Field type requires two arguments, got 0.", - ): + with pytest.raises(ValueError, match="Field type requires two arguments, got 0."): _ = FieldOperatorParser.apply_to_function(mistyped) @@ -103,8 +100,7 @@ def rettype(inp: gtx.Field[[TDim], float64]) -> gtx.Field[[TDim], float64]: parsed = FieldOperatorParser.apply_to_function(rettype) assert parsed.body.stmts[-1].value.type == ts.FieldType( - dims=[TDim], - dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT64, shape=None), + dims=[TDim], dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT64, shape=None) ) @@ -114,10 +110,7 @@ def test_invalid_syntax_no_return(): def no_return(inp: gtx.Field[[TDim], "float64"]): tmp = inp # noqa - with pytest.raises( - errors.DSLError, - match=".*return.*", - ): + with pytest.raises(errors.DSLError, match=".*return.*"): _ = FieldOperatorParser.apply_to_function(no_return) @@ -145,8 +138,7 @@ def copy_field(inp: gtx.Field[[TDim], "float64"]): parsed = FieldOperatorParser.apply_to_function(copy_field) assert parsed.body.annex.symtable[ssa.unique_name("tmp", 0)].type == ts.FieldType( - dims=[TDim], - dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT64, shape=None), + dims=[TDim], dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT64, shape=None) ) @@ -168,8 +160,7 @@ def power(inp: gtx.Field[[TDim], "float64"]): parsed = FieldOperatorParser.apply_to_function(power) assert parsed.body.stmts[-1].value.type == ts.FieldType( - dims=[TDim], - dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT64, shape=None), + dims=[TDim], dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT64, shape=None) ) @@ -180,8 +171,7 @@ def modulo(inp: gtx.Field[[TDim], "int32"]): parsed = FieldOperatorParser.apply_to_function(modulo) assert parsed.body.stmts[-1].value.type == ts.FieldType( - dims=[TDim], - dtype=ts.ScalarType(kind=ts.ScalarKind.INT32, shape=None), + dims=[TDim], dtype=ts.ScalarType(kind=ts.ScalarKind.INT32, shape=None) ) @@ -189,10 +179,7 @@ def test_boolean_and_op_unsupported(): def bool_and(a: gtx.Field[[TDim], "bool"], b: gtx.Field[[TDim], "bool"]): return a and b - with pytest.raises( - errors.UnsupportedPythonFeatureError, - match=r".*and.*or.*", - ): + with pytest.raises(errors.UnsupportedPythonFeatureError, match=r".*and.*or.*"): _ = FieldOperatorParser.apply_to_function(bool_and) @@ -200,10 +187,7 @@ def test_boolean_or_op_unsupported(): def bool_or(a: gtx.Field[[TDim], "bool"], b: gtx.Field[[TDim], "bool"]): return a or b - with pytest.raises( - errors.UnsupportedPythonFeatureError, - match=r".*and.*or.*", - ): + with pytest.raises(errors.UnsupportedPythonFeatureError, match=r".*and.*or.*"): _ = FieldOperatorParser.apply_to_function(bool_or) @@ -214,8 +198,7 @@ def bool_xor(a: gtx.Field[[TDim], "bool"], b: gtx.Field[[TDim], "bool"]): parsed = FieldOperatorParser.apply_to_function(bool_xor) assert parsed.body.stmts[-1].value.type == ts.FieldType( - dims=[TDim], - dtype=ts.ScalarType(kind=ts.ScalarKind.BOOL, shape=None), + dims=[TDim], dtype=ts.ScalarType(kind=ts.ScalarKind.BOOL, shape=None) ) @@ -226,8 +209,7 @@ def unary_tilde(a: gtx.Field[[TDim], "bool"]): parsed = FieldOperatorParser.apply_to_function(unary_tilde) assert parsed.body.stmts[-1].value.type == ts.FieldType( - dims=[TDim], - dtype=ts.ScalarType(kind=ts.ScalarKind.BOOL, shape=None), + dims=[TDim], dtype=ts.ScalarType(kind=ts.ScalarKind.BOOL, shape=None) ) @@ -241,9 +223,7 @@ def cast_scalar_temp(): def test_conditional_wrong_mask_type(): - def conditional_wrong_mask_type( - a: gtx.Field[[TDim], float64], - ) -> gtx.Field[[TDim], float64]: + def conditional_wrong_mask_type(a: gtx.Field[[TDim], float64]) -> gtx.Field[[TDim], float64]: return where(a, a, a) msg = r"expected a field with dtype 'bool'" @@ -253,9 +233,7 @@ def conditional_wrong_mask_type( def test_conditional_wrong_arg_type(): def conditional_wrong_arg_type( - mask: gtx.Field[[TDim], bool], - a: gtx.Field[[TDim], float32], - b: gtx.Field[[TDim], float64], + mask: gtx.Field[[TDim], bool], a: gtx.Field[[TDim], float32], b: gtx.Field[[TDim], float64] ) -> gtx.Field[[TDim], float64]: return where(mask, a, b) @@ -321,8 +299,7 @@ def astype_fieldop(a: gtx.Field[[TDim], "int64"]) -> gtx.Field[[TDim], float64]: parsed = FieldOperatorParser.apply_to_function(astype_fieldop) assert parsed.body.stmts[-1].value.type == ts.FieldType( - dims=[TDim], - dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT64, shape=None), + dims=[TDim], dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT64, shape=None) ) @@ -371,8 +348,7 @@ def wrong_return_type_annotation(a: gtx.Field[[ADim], float64]) -> gtx.Field[[BD return a with pytest.raises( - errors.DSLError, - match=r"Annotated return type does not match deduced return type", + errors.DSLError, match=r"Annotated return type does not match deduced return type" ): _ = FieldOperatorParser.apply_to_function(wrong_return_type_annotation) @@ -382,8 +358,7 @@ def empty_dims() -> gtx.Field[[], float]: return 1.0 with pytest.raises( - errors.DSLError, - match=r"Annotated return type does not match deduced return type", + errors.DSLError, match=r"Annotated return type does not match deduced return type" ): _ = FieldOperatorParser.apply_to_function(empty_dims) diff --git a/tests/next_tests/unit_tests/ffront_tests/test_func_to_foast_error_line_number.py b/tests/next_tests/unit_tests/ffront_tests/test_func_to_foast_error_line_number.py index 02dbfbfee2..5123e1d5b1 100644 --- a/tests/next_tests/unit_tests/ffront_tests/test_func_to_foast_error_line_number.py +++ b/tests/next_tests/unit_tests/ffront_tests/test_func_to_foast_error_line_number.py @@ -36,10 +36,7 @@ def test_invalid_syntax_error_empty_return(): def wrong_syntax(inp: gtx.Field[[TDim], float]): return # <-- this line triggers the syntax error - with pytest.raises( - f2f.errors.DSLError, - match=(r".*return.*"), - ) as exc_info: + with pytest.raises(f2f.errors.DSLError, match=(r".*return.*")) as exc_info: _ = f2f.FieldOperatorParser.apply_to_function(wrong_syntax) assert exc_info.value.location diff --git a/tests/next_tests/unit_tests/ffront_tests/test_func_to_past.py b/tests/next_tests/unit_tests/ffront_tests/test_func_to_past.py index 605048d834..dde7383cfa 100644 --- a/tests/next_tests/unit_tests/ffront_tests/test_func_to_past.py +++ b/tests/next_tests/unit_tests/ffront_tests/test_func_to_past.py @@ -49,8 +49,7 @@ def test_copy_parsing(copy_program_def): past_node = ProgramParser.apply_to_function(copy_program_def) field_type = ts.FieldType( - dims=[IDim], - dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT64, shape=None), + dims=[IDim], dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT64, shape=None) ) pattern_node = P( past.Program, @@ -76,8 +75,7 @@ def test_double_copy_parsing(double_copy_program_def): past_node = ProgramParser.apply_to_function(double_copy_program_def) field_type = ts.FieldType( - dims=[IDim], - dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT64, shape=None), + dims=[IDim], dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT64, shape=None) ) pattern_node = P( past.Program, @@ -111,10 +109,7 @@ def test_undefined_field_program(identity_def): def undefined_field_program(in_field: gtx.Field[[IDim], "float64"]): identity(in_field, out=out_field) # noqa: F821 [undefined-name] - with pytest.raises( - errors.DSLError, - match=(r"Undeclared or untyped symbol 'out_field'."), - ): + with pytest.raises(errors.DSLError, match=(r"Undeclared or untyped symbol 'out_field'.")): ProgramParser.apply_to_function(undefined_field_program) @@ -122,8 +117,7 @@ def test_copy_restrict_parsing(copy_restrict_program_def): past_node = ProgramParser.apply_to_function(copy_restrict_program_def) field_type = ts.FieldType( - dims=[IDim], - dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT64, shape=None), + dims=[IDim], dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT64, shape=None) ) slice_pattern_node = P( past.Slice, lower=P(past.Constant, value=1), upper=P(past.Constant, value=2) @@ -160,9 +154,7 @@ def test_domain_exception_1(identity_def): def domain_format_1_program(in_field: gtx.Field[[IDim], float64]): domain_format_1(in_field, out=in_field, domain=(0, 2)) - with pytest.raises( - errors.DSLError, - ) as exc_info: + with pytest.raises(errors.DSLError) as exc_info: ProgramParser.apply_to_function(domain_format_1_program) assert exc_info.match("Invalid call to 'domain_format_1'") @@ -179,9 +171,7 @@ def test_domain_exception_2(identity_def): def domain_format_2_program(in_field: gtx.Field[[IDim], float64]): domain_format_2(in_field, out=in_field, domain={IDim: (0, 1, 2)}) - with pytest.raises( - errors.DSLError, - ) as exc_info: + with pytest.raises(errors.DSLError) as exc_info: ProgramParser.apply_to_function(domain_format_2_program) assert exc_info.match("Invalid call to 'domain_format_2'") @@ -198,9 +188,7 @@ def test_domain_exception_3(identity_def): def domain_format_3_program(in_field: gtx.Field[[IDim], float64]): domain_format_3(in_field, domain={IDim: (0, 2)}) - with pytest.raises( - errors.DSLError, - ) as exc_info: + with pytest.raises(errors.DSLError) as exc_info: ProgramParser.apply_to_function(domain_format_3_program) assert exc_info.match("Invalid call to 'domain_format_3'") @@ -219,9 +207,7 @@ def domain_format_4_program(in_field: gtx.Field[[IDim], float64]): in_field, out=(in_field[0:1], (in_field[0:1], in_field[0:1])), domain={IDim: (0, 1)} ) - with pytest.raises( - errors.DSLError, - ) as exc_info: + with pytest.raises(errors.DSLError) as exc_info: ProgramParser.apply_to_function(domain_format_4_program) assert exc_info.match("Invalid call to 'domain_format_4'") @@ -238,9 +224,7 @@ def test_domain_exception_5(identity_def): def domain_format_5_program(in_field: gtx.Field[[IDim], float64]): domain_format_5(in_field, out=in_field, domain={IDim: ("1.0", 9.0)}) - with pytest.raises( - errors.DSLError, - ) as exc_info: + with pytest.raises(errors.DSLError) as exc_info: ProgramParser.apply_to_function(domain_format_5_program) assert exc_info.match("Invalid call to 'domain_format_5'") @@ -257,9 +241,7 @@ def test_domain_exception_6(identity_def): def domain_format_6_program(in_field: gtx.Field[[IDim], float64]): domain_format_6(in_field, out=in_field, domain={}) - with pytest.raises( - errors.DSLError, - ) as exc_info: + with pytest.raises(errors.DSLError) as exc_info: ProgramParser.apply_to_function(domain_format_6_program) assert exc_info.match("Invalid call to 'domain_format_6'") diff --git a/tests/next_tests/unit_tests/ffront_tests/test_past_to_itir.py b/tests/next_tests/unit_tests/ffront_tests/test_past_to_itir.py index 05947996c1..49c5b11b20 100644 --- a/tests/next_tests/unit_tests/ffront_tests/test_past_to_itir.py +++ b/tests/next_tests/unit_tests/ffront_tests/test_past_to_itir.py @@ -175,8 +175,7 @@ def inout_field_program(inout_field: gtx.Field[[IDim], "float64"]): identity(inout_field, out=inout_field) with pytest.raises( - ValueError, - match=(r"Call to function with field as input and output not allowed."), + ValueError, match=(r"Call to function with field as input and output not allowed.") ): ProgramLowering.apply( ProgramParser.apply_to_function(inout_field_program), @@ -186,9 +185,7 @@ def inout_field_program(inout_field: gtx.Field[[IDim], "float64"]): def test_invalid_call_sig_program(invalid_call_sig_program_def): - with pytest.raises( - errors.DSLError, - ) as exc_info: + with pytest.raises(errors.DSLError) as exc_info: ProgramLowering.apply( ProgramParser.apply_to_function(invalid_call_sig_program_def), function_definitions=[], diff --git a/tests/next_tests/unit_tests/iterator_tests/test_pretty_parser.py b/tests/next_tests/unit_tests/iterator_tests/test_pretty_parser.py index e1732d4142..b3a9ba8001 100644 --- a/tests/next_tests/unit_tests/iterator_tests/test_pretty_parser.py +++ b/tests/next_tests/unit_tests/iterator_tests/test_pretty_parser.py @@ -98,8 +98,7 @@ def test_bool_arithmetic(): def test_shift(): testee = "⟪Iₒ, 1ₒ⟫" expected = ir.FunCall( - fun=ir.SymRef(id="shift"), - args=[ir.OffsetLiteral(value="I"), ir.OffsetLiteral(value=1)], + fun=ir.SymRef(id="shift"), args=[ir.OffsetLiteral(value="I"), ir.OffsetLiteral(value=1)] ) actual = pparse(testee) assert actual == expected @@ -137,8 +136,7 @@ def test_named_range(): def test_cartesian_domain(): testee = "c⟨ x, y ⟩" expected = ir.FunCall( - fun=ir.SymRef(id="cartesian_domain"), - args=[ir.SymRef(id="x"), ir.SymRef(id="y")], + fun=ir.SymRef(id="cartesian_domain"), args=[ir.SymRef(id="x"), ir.SymRef(id="y")] ) actual = pparse(testee) assert actual == expected @@ -147,8 +145,7 @@ def test_cartesian_domain(): def test_unstructured_domain(): testee = "u⟨ x, y ⟩" expected = ir.FunCall( - fun=ir.SymRef(id="unstructured_domain"), - args=[ir.SymRef(id="x"), ir.SymRef(id="y")], + fun=ir.SymRef(id="unstructured_domain"), args=[ir.SymRef(id="x"), ir.SymRef(id="y")] ) actual = pparse(testee) assert actual == expected @@ -157,8 +154,7 @@ def test_unstructured_domain(): def test_if(): testee = "if x then y else z" expected = ir.FunCall( - fun=ir.SymRef(id="if_"), - args=[ir.SymRef(id="x"), ir.SymRef(id="y"), ir.SymRef(id="z")], + fun=ir.SymRef(id="if_"), args=[ir.SymRef(id="x"), ir.SymRef(id="y"), ir.SymRef(id="z")] ) actual = pparse(testee) assert actual == expected @@ -166,10 +162,7 @@ def test_if(): def test_fun_call(): testee = "f(x)" - expected = ir.FunCall( - fun=ir.SymRef(id="f"), - args=[ir.SymRef(id="x")], - ) + expected = ir.FunCall(fun=ir.SymRef(id="f"), args=[ir.SymRef(id="x")]) actual = pparse(testee) assert actual == expected @@ -177,8 +170,7 @@ def test_fun_call(): def test_lambda_call(): testee = "(λ(x) → x)(x)" expected = ir.FunCall( - fun=ir.Lambda(params=[ir.Sym(id="x")], expr=ir.SymRef(id="x")), - args=[ir.SymRef(id="x")], + fun=ir.Lambda(params=[ir.Sym(id="x")], expr=ir.SymRef(id="x")), args=[ir.SymRef(id="x")] ) actual = pparse(testee) assert actual == expected diff --git a/tests/next_tests/unit_tests/iterator_tests/test_pretty_printer.py b/tests/next_tests/unit_tests/iterator_tests/test_pretty_printer.py index 8e0806baa2..844c905e8e 100644 --- a/tests/next_tests/unit_tests/iterator_tests/test_pretty_printer.py +++ b/tests/next_tests/unit_tests/iterator_tests/test_pretty_printer.py @@ -34,14 +34,7 @@ def test_vmerge(): a = ["This is", "block ‘a’."] b = ["This is", "block ‘b’."] c = ["This is", "block ‘c’."] - expected = [ - "This is", - "block ‘a’.", - "This is", - "block ‘b’.", - "This is", - "block ‘c’.", - ] + expected = ["This is", "block ‘a’.", "This is", "block ‘b’.", "This is", "block ‘c’."] actual = PrettyPrinter()._vmerge(a, b, c) assert actual == expected @@ -55,12 +48,7 @@ def test_indent(): def test_cost(): assert PrettyPrinter()._cost(["This is a single line."]) < PrettyPrinter()._cost( - [ - "These are", - "multiple", - "short", - "lines.", - ] + ["These are", "multiple", "short", "lines."] ) assert PrettyPrinter()._cost(["This is a short line."]) < PrettyPrinter()._cost( [ @@ -69,11 +57,7 @@ def test_cost(): ] ) assert PrettyPrinter()._cost( - [ - "Equal length!", - "Equal length!", - "Equal length!", - ] + ["Equal length!", "Equal length!", "Equal length!"] ) < PrettyPrinter()._cost(["Unequal length.", "Short…", "Looooooooooooooooooong…"]) @@ -148,17 +132,11 @@ def test_associativity(): args=[ ir.FunCall( fun=ir.SymRef(id="plus"), - args=[ - ir.Literal(value="1", type="int64"), - ir.Literal(value="2", type="int64"), - ], + args=[ir.Literal(value="1", type="int64"), ir.Literal(value="2", type="int64")], ), ir.FunCall( fun=ir.SymRef(id="plus"), - args=[ - ir.Literal(value="3", type="int64"), - ir.Literal(value="4", type="int64"), - ], + args=[ir.Literal(value="3", type="int64"), ir.Literal(value="4", type="int64")], ), ], ) @@ -209,8 +187,7 @@ def test_bool_arithmetic(): def test_shift(): testee = ir.FunCall( - fun=ir.SymRef(id="shift"), - args=[ir.OffsetLiteral(value="I"), ir.OffsetLiteral(value=1)], + fun=ir.SymRef(id="shift"), args=[ir.OffsetLiteral(value="I"), ir.OffsetLiteral(value=1)] ) expected = "⟪Iₒ, 1ₒ⟫" actual = pformat(testee) @@ -246,8 +223,7 @@ def test_named_range(): def test_cartesian_domain(): testee = ir.FunCall( - fun=ir.SymRef(id="cartesian_domain"), - args=[ir.SymRef(id="x"), ir.SymRef(id="y")], + fun=ir.SymRef(id="cartesian_domain"), args=[ir.SymRef(id="x"), ir.SymRef(id="y")] ) expected = "c⟨ x, y ⟩" actual = pformat(testee) @@ -256,8 +232,7 @@ def test_cartesian_domain(): def test_unstructured_domain(): testee = ir.FunCall( - fun=ir.SymRef(id="unstructured_domain"), - args=[ir.SymRef(id="x"), ir.SymRef(id="y")], + fun=ir.SymRef(id="unstructured_domain"), args=[ir.SymRef(id="x"), ir.SymRef(id="y")] ) expected = "u⟨ x, y ⟩" actual = pformat(testee) @@ -266,8 +241,7 @@ def test_unstructured_domain(): def test_if_short(): testee = ir.FunCall( - fun=ir.SymRef(id="if_"), - args=[ir.SymRef(id="x"), ir.SymRef(id="y"), ir.SymRef(id="z")], + fun=ir.SymRef(id="if_"), args=[ir.SymRef(id="x"), ir.SymRef(id="y"), ir.SymRef(id="z")] ) expected = "if x then y else z" actual = pformat(testee) @@ -291,10 +265,7 @@ def test_if_long(): def test_fun_call(): - testee = ir.FunCall( - fun=ir.SymRef(id="f"), - args=[ir.SymRef(id="x")], - ) + testee = ir.FunCall(fun=ir.SymRef(id="f"), args=[ir.SymRef(id="x")]) expected = "f(x)" actual = pformat(testee) assert actual == expected @@ -302,8 +273,7 @@ def test_fun_call(): def test_lambda_call(): testee = ir.FunCall( - fun=ir.Lambda(params=[ir.Sym(id="x")], expr=ir.SymRef(id="x")), - args=[ir.SymRef(id="x")], + fun=ir.Lambda(params=[ir.Sym(id="x")], expr=ir.SymRef(id="x")), args=[ir.SymRef(id="x")] ) expected = "(λ(x) → x)(x)" actual = pformat(testee) diff --git a/tests/next_tests/unit_tests/iterator_tests/test_runtime_domain.py b/tests/next_tests/unit_tests/iterator_tests/test_runtime_domain.py index 73ad24f42b..53cab1c130 100644 --- a/tests/next_tests/unit_tests/iterator_tests/test_runtime_domain.py +++ b/tests/next_tests/unit_tests/iterator_tests/test_runtime_domain.py @@ -34,13 +34,9 @@ def foo(inp): def test_deduce_domain(): assert isinstance(_deduce_domain({}, {}), CartesianDomain) assert isinstance(_deduce_domain(UnstructuredDomain(), {}), UnstructuredDomain) + assert isinstance(_deduce_domain({}, {"foo": connectivity}), UnstructuredDomain) assert isinstance( - _deduce_domain({}, {"foo": connectivity}), - UnstructuredDomain, - ) - assert isinstance( - _deduce_domain(CartesianDomain([("I", range(1))]), {"foo": connectivity}), - CartesianDomain, + _deduce_domain(CartesianDomain([("I", range(1))]), {"foo": connectivity}), CartesianDomain ) @@ -50,15 +46,6 @@ def test_deduce_domain(): def test_embedded_error_on_wrong_domain(): dom = CartesianDomain([("I", range(1))]) - out = gtx.as_field( - [I], - np.zeros( - 1, - ), - ) + out = gtx.as_field([I], np.zeros(1)) with pytest.raises(RuntimeError, match="expected 'UnstructuredDomain'"): - foo[dom]( - gtx.as_field([I], np.zeros((1,))), - out=out, - offset_provider={"bar": connectivity}, - ) + foo[dom](gtx.as_field([I], np.zeros((1,))), out=out, offset_provider={"bar": connectivity}) diff --git a/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py b/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py index cacdb7b070..7beda20d31 100644 --- a/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py +++ b/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py @@ -113,7 +113,7 @@ def test_deref(): size=ti.TypeVar(idx=1), current_loc=ti.TypeVar(idx=2), defined_loc=ti.TypeVar(idx=2), - ), + ) ), ret=ti.Val(kind=ti.Value(), dtype=ti.TypeVar(idx=0), size=ti.TypeVar(idx=1)), ) @@ -132,12 +132,7 @@ def test_deref_call(): def test_lambda(): testee = ir.Lambda(params=[ir.Sym(id="x")], expr=ir.SymRef(id="x")) - expected = ti.FunctionType( - args=ti.Tuple.from_elems( - ti.TypeVar(idx=0), - ), - ret=ti.TypeVar(idx=0), - ) + expected = ti.FunctionType(args=ti.Tuple.from_elems(ti.TypeVar(idx=0)), ret=ti.TypeVar(idx=0)) inferred = ti.infer(testee) assert inferred == expected assert ti.pformat(inferred) == "(T₀) → T₀" @@ -154,10 +149,7 @@ def test_typed_lambda(): current_loc=ti.TypeVar(idx=1), defined_loc=ti.TypeVar(idx=2), ) - expected = ti.FunctionType( - args=ti.Tuple.from_elems(expected_val), - ret=expected_val, - ) + expected = ti.FunctionType(args=ti.Tuple.from_elems(expected_val), ret=expected_val) inferred = ti.infer(testee) assert inferred == expected assert ti.pformat(inferred) == "(It[T₁, T₂, float64⁰]) → It[T₁, T₂, float64⁰]" @@ -213,12 +205,7 @@ def test_if_call(): def test_not(): testee = ir.SymRef(id="not_") t = ti.Val(kind=ti.Value(), dtype=ti.Primitive(name="bool"), size=ti.TypeVar(idx=0)) - expected = ti.FunctionType( - args=ti.Tuple.from_elems( - t, - ), - ret=t, - ) + expected = ti.FunctionType(args=ti.Tuple.from_elems(t), ret=t) inferred = ti.infer(testee) assert inferred == expected assert ti.pformat(inferred) == "(bool⁰) → bool⁰" @@ -257,7 +244,7 @@ def test_lift(): defined_locs=ti.TypeVar(idx=3), ), ret=ti.Val(kind=ti.Value(), dtype=ti.TypeVar(idx=4), size=ti.TypeVar(idx=1)), - ), + ) ), ret=ti.FunctionType( args=ti.ValTuple( @@ -319,7 +306,7 @@ def test_lift_application(): size=ti.TypeVar(idx=1), current_loc=ti.TypeVar(idx=2), defined_loc=ti.TypeVar(idx=3), - ), + ) ), ret=ti.Val( kind=ti.Iterator(), @@ -409,7 +396,7 @@ def test_tuple_get_in_lambda(): others=ti.Tuple(front=ti.TypeVar(idx=2), others=ti.TypeVar(idx=3)), ), size=ti.TypeVar(idx=4), - ), + ) ), ret=ti.Val(kind=ti.TypeVar(idx=0), dtype=ti.TypeVar(idx=2), size=ti.TypeVar(idx=4)), ) @@ -537,7 +524,7 @@ def test_shift(): size=ti.TypeVar(idx=1), current_loc=ti.TypeVar(idx=2), defined_loc=ti.TypeVar(idx=3), - ), + ) ), ret=ti.Val( kind=ti.Iterator(), @@ -564,7 +551,7 @@ def test_shift_with_cartesian_offset_provider(): size=ti.TypeVar(idx=1), current_loc=ti.TypeVar(idx=2), defined_loc=ti.TypeVar(idx=3), - ), + ) ), ret=ti.Val( kind=ti.Iterator(), @@ -590,7 +577,7 @@ def test_partial_shift_with_cartesian_offset_provider(): size=ti.TypeVar(idx=1), current_loc=ti.TypeVar(idx=2), defined_loc=ti.TypeVar(idx=3), - ), + ) ), ret=ti.Val( kind=ti.Iterator(), @@ -618,7 +605,7 @@ def test_shift_with_unstructured_offset_provider(): size=ti.TypeVar(idx=1), current_loc=ti.Location(name="Vertex"), defined_loc=ti.TypeVar(idx=2), - ), + ) ), ret=ti.Val( kind=ti.Iterator(), @@ -655,7 +642,7 @@ def test_partial_shift_with_unstructured_offset_provider(): size=ti.TypeVar(idx=1), current_loc=ti.Location(name="Vertex"), defined_loc=ti.TypeVar(idx=2), - ), + ) ), ret=ti.Val( kind=ti.Iterator(), @@ -681,12 +668,7 @@ def test_partial_shift_with_unstructured_offset_provider(): def test_function_definition(): testee = ir.FunctionDefinition(id="f", params=[ir.Sym(id="x")], expr=ir.SymRef(id="x")) expected = ti.LetPolymorphic( - dtype=ti.FunctionType( - args=ti.Tuple.from_elems( - ti.TypeVar(idx=0), - ), - ret=ti.TypeVar(idx=0), - ), + dtype=ti.FunctionType(args=ti.Tuple.from_elems(ti.TypeVar(idx=0)), ret=ti.TypeVar(idx=0)) ) inferred = ti.infer(testee) assert inferred == expected @@ -761,7 +743,7 @@ def test_stencil_closure(): size=ti.Column(), current_loc=ti.ANYWHERE, defined_loc=ti.TypeVar(idx=1), - ), + ) ), ) inferred = ti.infer(testee) @@ -963,10 +945,7 @@ def test_fencil_definition_with_function_definitions(): ti.FunctionDefinitionType( name="f", fun=ti.FunctionType( - args=ti.Tuple.from_elems( - ti.TypeVar(idx=0), - ), - ret=ti.TypeVar(idx=0), + args=ti.Tuple.from_elems(ti.TypeVar(idx=0)), ret=ti.TypeVar(idx=0) ), ), ti.FunctionDefinitionType( @@ -979,7 +958,7 @@ def test_fencil_definition_with_function_definitions(): size=ti.TypeVar(idx=2), current_loc=ti.TypeVar(idx=3), defined_loc=ti.TypeVar(idx=3), - ), + ) ), ret=ti.Val(kind=ti.Value(), dtype=ti.TypeVar(idx=1), size=ti.TypeVar(idx=2)), ), diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_collapse_list_get.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_collapse_list_get.py index b025b7613d..b6463ba0d5 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_collapse_list_get.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_collapse_list_get.py @@ -36,10 +36,7 @@ def test_list_get_neighbors(): ir.FunCall( fun=ir.FunCall( fun=ir.SymRef(id="shift"), - args=[ - ir.OffsetLiteral(value="foo"), - ir.OffsetLiteral(value=42), - ], + args=[ir.OffsetLiteral(value="foo"), ir.OffsetLiteral(value=42)], ), args=[ir.SymRef(id="bar")], ) diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_cse.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_cse.py index fb7720f4d7..a2d0a170a0 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_cse.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_cse.py @@ -190,11 +190,7 @@ def test_if_eligible_extraction(): # if statement. # if ((a ∧ b) ∧ (a ∧ b)) then c else d - testee = im.if_( - im.and_(im.and_("a", "b"), im.and_("a", "b")), - "c", - "d", - ) + testee = im.if_(im.and_(im.and_("a", "b"), im.and_("a", "b")), "c", "d") # (λ(_cs_1) → if _cs_1 ∧ _cs_1 then c else d)(a ∧ b) expected = im.let("_cs_1", im.and_("a", "b"))(im.if_(im.and_("_cs_1", "_cs_1"), "c", "d")) @@ -212,17 +208,7 @@ def is_let(node: ir.Expr): return isinstance(node, ir.FunCall) and isinstance(node.fun, ir.Lambda) testee = im.plus( - im.let( - ( - "c", - im.let( - ("a", 1), - ("b", 2), - )(im.plus("a", "b")), - ), - ("d", 3), - )(im.plus("c", "d")), - 4, + im.let(("c", im.let(("a", 1), ("b", 2))(im.plus("a", "b"))), ("d", 3))(im.plus("c", "d")), 4 ) expected = textwrap.dedent( diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_fuse_maps.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_fuse_maps.py index 215757feef..fdd1fb6b9a 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_fuse_maps.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_fuse_maps.py @@ -182,11 +182,7 @@ def test_nested(): fun=ir.SymRef(id="multiplies"), args=[ _p_symref, - P( - ir.FunCall, - fun=ir.SymRef(id="divides"), - args=[_p_symref, _p_symref], - ), + P(ir.FunCall, fun=ir.SymRef(id="divides"), args=[_p_symref, _p_symref]), ], ), ], @@ -225,16 +221,8 @@ def test_multiple_maps_with_colliding_symbol_names(): ir.FunCall, fun=ir.SymRef(id="plus"), args=[ - P( - ir.FunCall, - fun=ir.SymRef(id="multiplies"), - args=[_p_symref, _p_symref], - ), - P( - ir.FunCall, - fun=ir.SymRef(id="multiplies"), - args=[_p_symref, _p_symref], - ), + P(ir.FunCall, fun=ir.SymRef(id="multiplies"), args=[_p_symref, _p_symref]), + P(ir.FunCall, fun=ir.SymRef(id="multiplies"), args=[_p_symref, _p_symref]), ], ), ), diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_global_tmps.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_global_tmps.py index 46ca02217f..0521b0414b 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_global_tmps.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_global_tmps.py @@ -254,17 +254,8 @@ def test_update_cartesian_domains(): ), ], ), - params=[ - im.sym("i"), - im.sym("j"), - im.sym("k"), - im.sym("inp"), - im.sym("out"), - ], - tmps=[ - Temporary(id="_gtmp_0"), - Temporary(id="_gtmp_1"), - ], + params=[im.sym("i"), im.sym("j"), im.sym("k"), im.sym("inp"), im.sym("out")], + tmps=[Temporary(id="_gtmp_0"), Temporary(id="_gtmp_1")], ) expected = copy.deepcopy(testee) assert expected.fencil.params.pop() == im.sym("_gtmp_auto_domain") @@ -306,10 +297,7 @@ def test_update_cartesian_domains(): im.literal("0", ir.INTEGER_INDEX_BUILTIN), im.literal("1", ir.INTEGER_INDEX_BUILTIN), ), - im.plus( - im.ref("i"), - ir.Literal(value="1", type=ir.INTEGER_INDEX_BUILTIN), - ), + im.plus(im.ref("i"), ir.Literal(value="1", type=ir.INTEGER_INDEX_BUILTIN)), ], ) ] @@ -340,10 +328,7 @@ def test_collect_tmps_info(): ir.Literal(value="0", type=ir.INTEGER_INDEX_BUILTIN), ir.FunCall( fun=im.ref("plus"), - args=[ - im.ref("i"), - ir.Literal(value="1", type=ir.INTEGER_INDEX_BUILTIN), - ], + args=[im.ref("i"), ir.Literal(value="1", type=ir.INTEGER_INDEX_BUILTIN)], ), ], ) @@ -378,10 +363,7 @@ def test_collect_tmps_info(): domain=tmp_domain, stencil=ir.Lambda( params=[ir.Sym(id="foo_inp")], - expr=ir.FunCall( - fun=im.ref("deref"), - args=[im.ref("foo_inp")], - ), + expr=ir.FunCall(fun=im.ref("deref"), args=[im.ref("foo_inp")]), ), output=im.ref("_gtmp_1"), inputs=[im.ref("inp")], @@ -430,17 +412,8 @@ def test_collect_tmps_info(): ), ], ), - params=[ - ir.Sym(id="i"), - ir.Sym(id="j"), - ir.Sym(id="k"), - ir.Sym(id="inp"), - ir.Sym(id="out"), - ], - tmps=[ - Temporary(id="_gtmp_0"), - Temporary(id="_gtmp_1"), - ], + params=[ir.Sym(id="i"), ir.Sym(id="j"), ir.Sym(id="k"), ir.Sym(id="inp"), ir.Sym(id="out")], + tmps=[Temporary(id="_gtmp_0"), Temporary(id="_gtmp_1")], ) expected = FencilWithTemporaries( fencil=testee.fencil, diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_inline_lambdas.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_inline_lambdas.py index bf26889882..714b60eca1 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_inline_lambdas.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_inline_lambdas.py @@ -72,9 +72,5 @@ def test_inline_lambda_args(): 3, ) ) - inlined = InlineLambdas.apply( - testee, - opcount_preserving=True, - force_inline_lambda_args=True, - ) + inlined = InlineLambdas.apply(testee, opcount_preserving=True, force_inline_lambda_args=True) assert inlined == expected diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_scan_eta_reduction.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_scan_eta_reduction.py index e86fb65863..5a9d3a676b 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_scan_eta_reduction.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_scan_eta_reduction.py @@ -34,8 +34,7 @@ def test_scan_eta_reduction(): testee = ir.Lambda( params=[ir.Sym(id="x"), ir.Sym(id="y")], expr=ir.FunCall( - fun=_make_scan("param_y", "param_x"), - args=[ir.SymRef(id="y"), ir.SymRef(id="x")], + fun=_make_scan("param_y", "param_x"), args=[ir.SymRef(id="y"), ir.SymRef(id="x")] ), ) expected = _make_scan("param_x", "param_y") diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_simple_inline_heuristic.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_simple_inline_heuristic.py index a236d793c5..685625e9e7 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_simple_inline_heuristic.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_simple_inline_heuristic.py @@ -45,10 +45,7 @@ def test_trivial(is_scan_context): # `↑(scan(λ(acc, it) → acc + ·↑(deref)(it)))(...)` where the inner lift should not be inlined. expected = not is_scan_context testee = ir.FunCall( - fun=ir.FunCall( - fun=ir.SymRef(id="lift"), - args=[ir.SymRef(id="deref")], - ), + fun=ir.FunCall(fun=ir.SymRef(id="lift"), args=[ir.SymRef(id="deref")]), args=[ir.SymRef(id="it")], ) assert expected == is_eligible_for_inlining(testee, is_scan_context) @@ -66,10 +63,7 @@ def test_scan_with_lifted_arg(scan): fun=ir.FunCall(fun=ir.SymRef(id="lift"), args=[scan]), args=[ ir.FunCall( - fun=ir.FunCall( - fun=ir.SymRef(id="lift"), - args=[ir.SymRef(id="deref")], - ), + fun=ir.FunCall(fun=ir.SymRef(id="lift"), args=[ir.SymRef(id="deref")]), args=[ir.SymRef(id="x")], ) ], diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_trace_shifts.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_trace_shifts.py index d301ec997d..1d1a2dc89d 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_trace_shifts.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_trace_shifts.py @@ -93,8 +93,7 @@ def test_neighbors(): testee = ir.StencilClosure( stencil=ir.Lambda( expr=ir.FunCall( - fun=ir.SymRef(id="neighbors"), - args=[ir.OffsetLiteral(value="O"), ir.SymRef(id="x")], + fun=ir.SymRef(id="neighbors"), args=[ir.OffsetLiteral(value="O"), ir.SymRef(id="x")] ), params=[ir.Sym(id="x")], ), @@ -102,14 +101,7 @@ def test_neighbors(): output=ir.SymRef(id="out"), domain=ir.FunCall(fun=ir.SymRef(id="cartesian_domain"), args=[]), ) - expected = { - "inp": { - ( - ir.OffsetLiteral(value="O"), - Sentinel.ALL_NEIGHBORS, - ) - } - } + expected = {"inp": {(ir.OffsetLiteral(value="O"), Sentinel.ALL_NEIGHBORS)}} actual = TraceShifts.apply(testee) assert actual == expected diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_unroll_reduce.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_unroll_reduce.py index 9f1462dd8a..ba4a91e6b5 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_unroll_reduce.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_unroll_reduce.py @@ -142,13 +142,7 @@ def _expected(red, dim, max_neighbors, has_skip_values, shifted_arg=0): red_fun, red_init = red.fun.args - elements = [ - ir.FunCall( - fun=ir.SymRef(id="list_get"), - args=[offset, arg], - ) - for arg in red.args - ] + elements = [ir.FunCall(fun=ir.SymRef(id="list_get"), args=[offset, arg]) for arg in red.args] step_expr = ir.FunCall(fun=red_fun, args=[acc] + elements) if has_skip_values: diff --git a/tests/next_tests/unit_tests/otf_tests/binding_tests/test_cpp_interface.py b/tests/next_tests/unit_tests/otf_tests/binding_tests/test_cpp_interface.py index 92bf54f009..3ff0d3f341 100644 --- a/tests/next_tests/unit_tests/otf_tests/binding_tests/test_cpp_interface.py +++ b/tests/next_tests/unit_tests/otf_tests/binding_tests/test_cpp_interface.py @@ -73,8 +73,7 @@ def function_buffer_example(): interface.Parameter( name="b_buf", type_=ts.FieldType( - dims=[gtx.Dimension("foo")], - dtype=ts.ScalarType(ts.ScalarKind.INT64), + dims=[gtx.Dimension("foo")], dtype=ts.ScalarType(ts.ScalarKind.INT64) ), ), ], @@ -151,9 +150,7 @@ def test_render_function_declaration_tuple(function_tuple_example): def test_render_function_call_tuple(function_tuple_example): rendered = format_source( - "cpp", - cpp.render_function_call(function_tuple_example, args=["get_arg_1()"]), - style="LLVM", + "cpp", cpp.render_function_call(function_tuple_example, args=["get_arg_1()"]), style="LLVM" ) expected = format_source("cpp", """example(get_arg_1())""", style="LLVM") assert rendered == expected diff --git a/tests/next_tests/unit_tests/otf_tests/compilation_tests/build_systems_tests/conftest.py b/tests/next_tests/unit_tests/otf_tests/compilation_tests/build_systems_tests/conftest.py index 0911576fd6..219d51918b 100644 --- a/tests/next_tests/unit_tests/otf_tests/compilation_tests/build_systems_tests/conftest.py +++ b/tests/next_tests/unit_tests/otf_tests/compilation_tests/build_systems_tests/conftest.py @@ -77,9 +77,7 @@ def make_program_source(name: str) -> stages.ProgramSource: return stages.ProgramSource( entry_point=entry_point, source_code=src, - library_deps=[ - interface.LibraryDependency("gridtools_cpu", "master"), - ], + library_deps=[interface.LibraryDependency("gridtools_cpu", "master")], language=languages.Cpp, language_settings=cpp_interface.CPP_DEFAULT, ) diff --git a/tests/next_tests/unit_tests/program_processor_tests/codegens_tests/gtfn_tests/test_gtfn_module.py b/tests/next_tests/unit_tests/program_processor_tests/codegens_tests/gtfn_tests/test_gtfn_module.py index 4e865452f6..be7a9ff81e 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/codegens_tests/gtfn_tests/test_gtfn_module.py +++ b/tests/next_tests/unit_tests/program_processor_tests/codegens_tests/gtfn_tests/test_gtfn_module.py @@ -56,10 +56,7 @@ def fencil_example(): ], ) IDim = gtx.Dimension("I") - params = [ - gtx.as_field([IDim], np.empty((1,), dtype=np.float32)), - np.float32(3.14), - ] + params = [gtx.as_field([IDim], np.empty((1,), dtype=np.float32)), np.float32(3.14)] return fencil, params diff --git a/tests/next_tests/unit_tests/program_processor_tests/codegens_tests/gtfn_tests/test_itir_to_gtfn_ir.py b/tests/next_tests/unit_tests/program_processor_tests/codegens_tests/gtfn_tests/test_itir_to_gtfn_ir.py index 1f0634bd87..43a0b45ce6 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/codegens_tests/gtfn_tests/test_itir_to_gtfn_ir.py +++ b/tests/next_tests/unit_tests/program_processor_tests/codegens_tests/gtfn_tests/test_itir_to_gtfn_ir.py @@ -19,13 +19,10 @@ def test_funcall_to_op(): testee = itir.FunCall( - fun=itir.SymRef(id="plus"), - args=[itir.SymRef(id="foo"), itir.SymRef(id="bar")], + fun=itir.SymRef(id="plus"), args=[itir.SymRef(id="foo"), itir.SymRef(id="bar")] ) expected = gtfn_ir.BinaryExpr( - op="+", - lhs=gtfn_ir.SymRef(id="foo"), - rhs=gtfn_ir.SymRef(id="bar"), + op="+", lhs=gtfn_ir.SymRef(id="foo"), rhs=gtfn_ir.SymRef(id="bar") ) actual = it2gtfn.GTFN_lowering( diff --git a/tests/next_tests/unit_tests/test_common.py b/tests/next_tests/unit_tests/test_common.py index 1aeb51cb30..44150f344e 100644 --- a/tests/next_tests/unit_tests/test_common.py +++ b/tests/next_tests/unit_tests/test_common.py @@ -247,12 +247,7 @@ def test_range_comparison(op, rng1, rng2, expected): @pytest.mark.parametrize( - "named_rng_like", - [ - (IDim, (2, 4)), - (IDim, range(2, 4)), - (IDim, UnitRange(2, 4)), - ], + "named_rng_like", [(IDim, (2, 4)), (IDim, range(2, 4)), (IDim, UnitRange(2, 4))] ) def test_named_range_like(named_rng_like): assert named_range(named_rng_like) == (IDim, UnitRange(2, 4)) @@ -373,11 +368,7 @@ def test_domain_slice_indexing(a_domain, slice_obj, expected): @pytest.mark.parametrize( - "index, expected_result", - [ - (JDim, (JDim, UnitRange(5, 15))), - (KDim, (KDim, UnitRange(20, 30))), - ], + "index, expected_result", [(JDim, (JDim, UnitRange(5, 15))), (KDim, (KDim, UnitRange(20, 30)))] ) def test_domain_dimension_indexing(a_domain, index, expected_result): result = a_domain[index] diff --git a/tests/next_tests/unit_tests/test_constructors.py b/tests/next_tests/unit_tests/test_constructors.py index 8d95c9951f..1f491f220e 100644 --- a/tests/next_tests/unit_tests/test_constructors.py +++ b/tests/next_tests/unit_tests/test_constructors.py @@ -121,8 +121,7 @@ def test_as_field(): def test_as_field_domain(): ref = np.random.rand(sizes[I] - 1, sizes[J] - 1).astype(gtx.float32) domain = common.Domain( - dims=(I, J), - ranges=(common.UnitRange(0, sizes[I] - 1), common.UnitRange(0, sizes[J] - 1)), + dims=(I, J), ranges=(common.UnitRange(0, sizes[I] - 1), common.UnitRange(0, sizes[J] - 1)) ) a = gtx.as_field(domain, ref) assert np.array_equal(a.ndarray, ref) @@ -137,18 +136,12 @@ def test_as_field_origin(): # check that `as_field()` domain is correct depending on data origin and domain itself def test_field_wrong_dims(): - with pytest.raises( - ValueError, - match=(r"Cannot construct 'Field' from array of shape"), - ): + with pytest.raises(ValueError, match=(r"Cannot construct 'Field' from array of shape")): gtx.as_field([I, J], np.random.rand(sizes[I]).astype(gtx.float32)) def test_field_wrong_domain(): - with pytest.raises( - ValueError, - match=(r"Cannot construct 'Field' from array of shape"), - ): + with pytest.raises(ValueError, match=(r"Cannot construct 'Field' from array of shape")): domain = common.Domain( dims=(I, J), ranges=(common.UnitRange(0, sizes[I] - 1), common.UnitRange(0, sizes[J] - 1)), @@ -157,16 +150,10 @@ def test_field_wrong_domain(): def test_field_wrong_origin(): - with pytest.raises( - ValueError, - match=(r"Origin keys {'J'} not in domain"), - ): + with pytest.raises(ValueError, match=(r"Origin keys {'J'} not in domain")): gtx.as_field([I], np.random.rand(sizes[I]).astype(gtx.float32), origin={"J": 0}) - with pytest.raises( - ValueError, - match=(r"Cannot specify origin for domain I"), - ): + with pytest.raises(ValueError, match=(r"Cannot specify origin for domain I")): gtx.as_field("I", np.random.rand(sizes[J]).astype(gtx.float32), origin={"J": 0}) diff --git a/tests/next_tests/unit_tests/type_system_tests/test_type_translation.py b/tests/next_tests/unit_tests/type_system_tests/test_type_translation.py index 5d9e945798..d768a620e2 100644 --- a/tests/next_tests/unit_tests/type_system_tests/test_type_translation.py +++ b/tests/next_tests/unit_tests/type_system_tests/test_type_translation.py @@ -109,14 +109,8 @@ def test_invalid_scalar_kind(): ), ), (typing.ForwardRef("float"), ts.ScalarType(kind=ts.ScalarKind.FLOAT64)), - ( - typing.Annotated[float, "foo"], - ts.ScalarType(kind=ts.ScalarKind.FLOAT64), - ), - ( - typing.Annotated["float", "foo", "bar"], - ts.ScalarType(kind=ts.ScalarKind.FLOAT64), - ), + (typing.Annotated[float, "foo"], ts.ScalarType(kind=ts.ScalarKind.FLOAT64)), + (typing.Annotated["float", "foo", "bar"], ts.ScalarType(kind=ts.ScalarKind.FLOAT64)), ( typing.Annotated[typing.ForwardRef("float"), "foo"], ts.ScalarType(kind=ts.ScalarKind.FLOAT64), diff --git a/tests/storage_tests/unit_tests/test_interface.py b/tests/storage_tests/unit_tests/test_interface.py index 9d68f13ad2..db3ea799dc 100644 --- a/tests/storage_tests/unit_tests/test_interface.py +++ b/tests/storage_tests/unit_tests/test_interface.py @@ -108,12 +108,7 @@ def dimensions_strategy(draw): if dimension < 3: mask_values += [False] * (3 - dimension) - mask = draw( - hyp_st.one_of( - hyp_st.just(None), - hyp_st.permutations(mask_values), - ) - ) + mask = draw(hyp_st.one_of(hyp_st.just(None), hyp_st.permutations(mask_values))) if mask is not None: select_dimensions = ["I", "J", "K"] + [str(d) for d in range(max(0, dimension - 3))] assert len(select_dimensions) == len(mask) @@ -389,10 +384,7 @@ def test_cpu_constructor_0d(alloc_fun, backend): assert isinstance(stor, np.ndarray) -@pytest.mark.parametrize( - "backend", - GPU_LAYOUTS, -) +@pytest.mark.parametrize("backend", GPU_LAYOUTS) def test_gpu_constructor(alloc_fun, backend): stor = alloc_fun(dtype=np.float64, aligned_index=(1, 2, 3), shape=(2, 4, 6), backend=backend) assert stor.shape == (2, 4, 6)