Skip to content

Commit

Permalink
Reformat with ruff to reduce vertical space.
Browse files Browse the repository at this point in the history
  • Loading branch information
egparedes committed Apr 3, 2024
1 parent 0ef0c5c commit de13b08
Show file tree
Hide file tree
Showing 199 changed files with 729 additions and 2,777 deletions.
6 changes: 1 addition & 5 deletions docs/user/next/workshop/exercises/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,7 @@
)


def random_mask(
sizes,
*dims,
dtype=None,
) -> MutableLocatedField:
def random_mask(sizes, *dims, dtype=None) -> MutableLocatedField:
arr = np.full(shape=sizes, fill_value=False).flatten()
arr[: int(arr.size * 0.5)] = True
np.random.shuffle(arr)
Expand Down
8 changes: 1 addition & 7 deletions src/gt4py/__about__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,7 @@
from packaging import version as pkg_version


__all__ = [
"__author__",
"__copyright__",
"__license__",
"__version__",
"__version_info__",
]
__all__ = ["__author__", "__copyright__", "__license__", "__version__", "__version_info__"]


__author__: Final = "ETH Zurich and individual contributors"
Expand Down
4 changes: 1 addition & 3 deletions src/gt4py/_core/definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,9 +143,7 @@ def is_positive_integral_type(integral_type: type) -> TypeGuard[Type[PositiveInt
] # TODO(egparedes) figure out if PositiveIntegral can be made to work


def is_valid_tensor_shape(
value: Sequence[IntegralScalar],
) -> TypeGuard[TensorShape]:
def is_valid_tensor_shape(value: Sequence[IntegralScalar]) -> TypeGuard[TensorShape]:
return isinstance(value, collections.abc.Sequence) and all(
isinstance(v, numbers.Integral) and v > 0 for v in value
)
Expand Down
10 changes: 2 additions & 8 deletions src/gt4py/cartesian/backend/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,10 +276,7 @@ def check_options(self, options: gt_definitions.BuildOptions) -> None:
stacklevel=2,
)

def make_module(
self,
**kwargs: Any,
) -> Type["StencilObject"]:
def make_module(self, **kwargs: Any) -> Type["StencilObject"]:
build_info = self.builder.options.build_info
if build_info is not None:
start_time = time.perf_counter()
Expand Down Expand Up @@ -412,10 +409,7 @@ def build_extension_module(
assert module_name == qualified_pyext_name

self.builder.with_backend_data(
{
"pyext_module_name": module_name,
"pyext_file_path": file_path,
}
{"pyext_module_name": module_name, "pyext_file_path": file_path}
)

return module_name, file_path
13 changes: 3 additions & 10 deletions src/gt4py/cartesian/backend/cuda_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,7 @@ def visit_FieldDecl(self, node: cuir.FieldDecl, **kwargs):
sid_ndim = domain_ndim + data_ndim
if kwargs["external_arg"]:
return "py::object {name}, std::array<gt::int_t,{sid_ndim}> {name}_origin".format(
name=node.name,
sid_ndim=sid_ndim,
name=node.name, sid_ndim=sid_ndim
)
else:
return pybuffer_to_sid(
Expand All @@ -113,12 +112,7 @@ def visit_Program(self, node: cuir.Program, **kwargs):
assert "module_name" in kwargs
entry_params = self.visit(node.params, external_arg=True, **kwargs)
sid_params = self.visit(node.params, external_arg=False, **kwargs)
return self.generic_visit(
node,
entry_params=entry_params,
sid_params=sid_params,
**kwargs,
)
return self.generic_visit(node, entry_params=entry_params, sid_params=sid_params, **kwargs)

Program = bindings_main_template()

Expand Down Expand Up @@ -157,6 +151,5 @@ def generate(self) -> Type["StencilObject"]:

# Generate and return the Python wrapper class
return self.make_module(
pyext_module_name=pyext_module_name,
pyext_file_path=pyext_file_path,
pyext_module_name=pyext_module_name, pyext_file_path=pyext_file_path
)
15 changes: 4 additions & 11 deletions src/gt4py/cartesian/backend/dace_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,7 @@ def _serialize_sdfg(sdfg: dace.SDFG):

def _specialize_transient_strides(sdfg: dace.SDFG, layout_map):
repldict = replace_strides(
[array for array in sdfg.arrays.values() if array.transient],
layout_map,
[array for array in sdfg.arrays.values() if array.transient], layout_map
)
sdfg.replace_dict(repldict)
for state in sdfg.nodes():
Expand Down Expand Up @@ -221,9 +220,7 @@ def _sdfg_add_arrays_and_edges(
ranges = [
(o - max(0, e), o - max(0, e) + s - 1, 1)
for o, e, s in zip(
origin,
field_info[name].boundary.lower_indices,
inner_sdfg.arrays[name].shape,
origin, field_info[name].boundary.lower_indices, inner_sdfg.arrays[name].shape
)
]
ranges += [(0, d, 1) for d in field_info[name].data_dims]
Expand Down Expand Up @@ -786,8 +783,7 @@ def generate(self) -> Type["StencilObject"]:

# Generate and return the Python wrapper class
return self.make_module(
pyext_module_name=pyext_module_name,
pyext_file_path=pyext_file_path,
pyext_module_name=pyext_module_name, pyext_file_path=pyext_file_path
)


Expand Down Expand Up @@ -826,10 +822,7 @@ class DaceGPUBackend(BaseDaceBackend):
),
}
MODULE_GENERATOR_CLASS = DaCeCUDAPyExtModuleGenerator
options = {
**BaseGTBackend.GT_BACKEND_OPTS,
"device_sync": {"versioning": True, "type": bool},
}
options = {**BaseGTBackend.GT_BACKEND_OPTS, "device_sync": {"versioning": True, "type": bool}}

def generate_extension(self, **kwargs: Any) -> Tuple[str, str]:
return self.make_extension(stencil_ir=self.builder.gtir, uses_cuda=True)
8 changes: 2 additions & 6 deletions src/gt4py/cartesian/backend/dace_stencil_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,7 @@ def add_optional_fields(
if info.access == AccessKind.NONE and name in kwargs and name not in sdfg.arrays:
outer_array = kwargs[name]
sdfg.add_array(
name,
shape=outer_array.shape,
dtype=outer_array.dtype,
strides=outer_array.strides,
name, shape=outer_array.shape, dtype=outer_array.dtype, strides=outer_array.strides
)

for name, info in parameter_info.items():
Expand Down Expand Up @@ -194,8 +191,7 @@ def normalize_args(
name: (kwargs[name] if name in kwargs else next(args_iter)) for name in arg_names
}
arg_infos = _extract_array_infos(
field_args=args_as_kwargs,
device=backend_cls.storage_info["device"],
field_args=args_as_kwargs, device=backend_cls.storage_info["device"]
)

origin = DaCeStencilObject._normalize_origins(arg_infos, field_info, origin)
Expand Down
14 changes: 3 additions & 11 deletions src/gt4py/cartesian/backend/gtc_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,15 +59,10 @@ def pybuffer_to_sid(

sid_def = """gt::{as_sid}<{ctype}, {sid_ndim},
gt::integral_constant<int, {unique_index}>>({name})""".format(
name=name,
ctype=ctype,
unique_index=stride_kind_index,
sid_ndim=sid_ndim,
as_sid=as_sid,
name=name, ctype=ctype, unique_index=stride_kind_index, sid_ndim=sid_ndim, as_sid=as_sid
)
sid_def = "gt::sid::shift_sid_origin({sid_def}, {name}_origin)".format(
sid_def=sid_def,
name=name,
sid_def=sid_def, name=name
)
if domain_ndim != 3:
gt_dims = [
Expand Down Expand Up @@ -145,10 +140,7 @@ def __init__(self):
self.pyext_file_path = None

def __call__(
self,
args_data: ModuleData,
builder: Optional["StencilBuilder"] = None,
**kwargs: Any,
self, args_data: ModuleData, builder: Optional["StencilBuilder"] = None, **kwargs: Any
) -> str:
self.pyext_module_name = kwargs["pyext_module_name"]
self.pyext_file_path = kwargs["pyext_file_path"]
Expand Down
10 changes: 2 additions & 8 deletions src/gt4py/cartesian/backend/gtcpp_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,12 +115,7 @@ def visit_Program(self, node: gtcpp.Program, **kwargs):
assert "module_name" in kwargs
entry_params = self.visit(node.parameters, external_arg=True, **kwargs)
sid_params = self.visit(node.parameters, external_arg=False, **kwargs)
return self.generic_visit(
node,
entry_params=entry_params,
sid_params=sid_params,
**kwargs,
)
return self.generic_visit(node, entry_params=entry_params, sid_params=sid_params, **kwargs)

Program = bindings_main_template()

Expand Down Expand Up @@ -151,8 +146,7 @@ def generate(self) -> Type["StencilObject"]:

# Generate and return the Python wrapper class
return self.make_module(
pyext_module_name=pyext_module_name,
pyext_file_path=pyext_file_path,
pyext_module_name=pyext_module_name, pyext_file_path=pyext_file_path
)


Expand Down
5 changes: 1 addition & 4 deletions src/gt4py/cartesian/backend/module_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,10 +141,7 @@ def __init__(self, builder: Optional["StencilBuilder"] = None):
)

def __call__(
self,
args_data: ModuleData,
builder: Optional["StencilBuilder"] = None,
**kwargs: Any,
self, args_data: ModuleData, builder: Optional["StencilBuilder"] = None, **kwargs: Any
) -> str:
"""
Generate source code for a Python module containing a StencilObject.
Expand Down
7 changes: 1 addition & 6 deletions src/gt4py/cartesian/backend/numpy_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,12 +117,7 @@ def _make_npir(self) -> npir.Computation:
oir_pipeline = self.builder.options.backend_opts.get(
"oir_pipeline",
DefaultPipeline(
skip=[
IJCacheDetection,
KCacheDetection,
PruneKCacheFills,
PruneKCacheFlushes,
]
skip=[IJCacheDetection, KCacheDetection, PruneKCacheFills, PruneKCacheFlushes]
),
)
oir_node = oir_pipeline.run(base_oir)
Expand Down
12 changes: 2 additions & 10 deletions src/gt4py/cartesian/backend/pyext_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,11 +174,7 @@ def get_gt_pyext_build_opts(
# The following tells mypy to accept unpacking kwargs
@overload
def build_pybind_ext(
name: str,
sources: list,
build_path: str,
target_path: str,
**kwargs: str,
name: str, sources: list, build_path: str, target_path: str, **kwargs: str
) -> Tuple[str, str]: ...


Expand Down Expand Up @@ -283,11 +279,7 @@ def build_pybind_ext(
# The following tells mypy to accept unpacking kwargs
@overload
def build_pybind_cuda_ext(
name: str,
sources: list,
build_path: str,
target_path: str,
**kwargs: str,
name: str, sources: list, build_path: str, target_path: str, **kwargs: str
) -> Tuple[str, str]:
pass

Expand Down
15 changes: 3 additions & 12 deletions src/gt4py/cartesian/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,7 @@ class BackendChoice(click.Choice):
name = "backend"

def convert(
self,
value: str,
param: Optional[click.Parameter],
ctx: Optional[click.Context],
self, value: str, param: Optional[click.Parameter], ctx: Optional[click.Context]
) -> Type[CLIBackendMixin]:
"""Convert a CLI option argument to a backend."""
name = super().convert(value, param, ctx)
Expand Down Expand Up @@ -239,10 +236,7 @@ def write_computation_src(
self.reporter.echo(f"Writing source file: {file_path}")
file_path.write_text(content)

def generate_stencils(
self,
build_options: Optional[Dict[str, Any]] = None,
) -> None:
def generate_stencils(self, build_options: Optional[Dict[str, Any]] = None) -> None:
for proto_stencil in self.iterate_stencils():
self.reporter.echo(f"Building stencil {proto_stencil.builder.options.name}")
builder = proto_stencil.builder.with_backend(self.backend_cls.name)
Expand Down Expand Up @@ -314,8 +308,5 @@ def gen(
) -> None:
"""Generate stencils from gtscript modules or packages."""
GTScriptBuilder(
input_path=input_path,
output_path=output_path,
backend=backend,
silent=silent,
input_path=input_path, output_path=output_path, backend=backend, silent=silent
).generate_stencils(build_options=dict(options))
8 changes: 1 addition & 7 deletions src/gt4py/cartesian/frontend/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,4 @@
from .base import REGISTRY, Frontend, from_name, register


__all__ = [
"gtscript_frontend",
"REGISTRY",
"Frontend",
"from_name",
"register",
]
__all__ = ["gtscript_frontend", "REGISTRY", "Frontend", "from_name", "register"]
12 changes: 3 additions & 9 deletions src/gt4py/cartesian/frontend/defir_to_gtir.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,10 +196,7 @@ def visit_FieldRef(self, node: FieldRef, *, fields_decls: Dict[str, FieldDecl],
data_type = DataType.INT32
data_index = [ScalarLiteral(value=index, data_type=data_type)]
element_ref = FieldRef(
name=node.name,
offset=node.offset,
data_index=data_index,
loc=node.loc,
name=node.name, offset=node.offset, data_index=data_index, loc=node.loc
)
field_list.append(element_ref)
# matrix
Expand Down Expand Up @@ -398,9 +395,7 @@ def visit_ComputationBlock(self, node: ComputationBlock) -> gtir.VerticalLoop:
stmts.append(decl_or_stmt)
start, end = self.visit(node.interval)
interval = gtir.Interval(
start=start,
end=end,
loc=location_to_source_location(node.interval.loc),
start=start, end=end, loc=location_to_source_location(node.interval.loc)
)
return gtir.VerticalLoop(
interval=interval,
Expand Down Expand Up @@ -525,8 +520,7 @@ def make_bound_or_level(bound: AxisBound, level) -> Optional[common.AxisBound]:
}

return gtir.HorizontalRestriction(
mask=common.HorizontalMask(**axes),
body=self.visit(node.body),
mask=common.HorizontalMask(**axes), body=self.visit(node.body)
)

def visit_While(self, node: While) -> gtir.While:
Expand Down
10 changes: 3 additions & 7 deletions src/gt4py/cartesian/frontend/gtscript_frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -658,8 +658,7 @@ def _make_init_computations(
else:
stmts.append(
nodes.Assign(
target=nodes.FieldRef.at_center(name, axes=decl.axes),
value=init_values[name],
target=nodes.FieldRef.at_center(name, axes=decl.axes), value=init_values[name]
)
)

Expand Down Expand Up @@ -982,8 +981,7 @@ def _visit_computation_node(self, node: ast.With) -> nodes.ComputationBlock:
if intervals_dicts:
stmts = [
nodes.HorizontalIf(
intervals=intervals_dict,
body=nodes.BlockStmt(stmts=stmts, loc=loc),
intervals=intervals_dict, body=nodes.BlockStmt(stmts=stmts, loc=loc)
)
for intervals_dict in intervals_dicts
]
Expand All @@ -1010,9 +1008,7 @@ def visit_Constant(
elif isinstance(value, bool):
return nodes.Cast(
data_type=nodes.DataType.BOOL,
expr=nodes.BuiltinLiteral(
value=nodes.Builtin.from_value(value),
),
expr=nodes.BuiltinLiteral(value=nodes.Builtin.from_value(value)),
loc=nodes.Location.from_ast_node(node),
)
elif isinstance(value, numbers.Number):
Expand Down
10 changes: 2 additions & 8 deletions src/gt4py/cartesian/gtc/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,10 +406,7 @@ def _make_root_validator(impl: datamodels.RootValidator) -> datamodels.RootValid


def assign_stmt_dtype_validation(*, strict: bool) -> datamodels.RootValidator:
def _impl(
cls: Type[datamodels.DataModel],
instance: datamodels.DataModel,
) -> None:
def _impl(cls: Type[datamodels.DataModel], instance: datamodels.DataModel) -> None:
assert isinstance(instance, AssignStmt)
verify_and_get_common_dtype(cls, [instance.left, instance.right], strict=strict)

Expand Down Expand Up @@ -865,10 +862,7 @@ def data_type_to_typestr(dtype: DataType) -> str:
ComparisonOperator.EQ: "equal",
ComparisonOperator.NE: "not_equal",
},
LogicalOperator: {
LogicalOperator.AND: "logical_and",
LogicalOperator.OR: "logical_or",
},
LogicalOperator: {LogicalOperator.AND: "logical_and", LogicalOperator.OR: "logical_or"},
NativeFunction: {
NativeFunction.ABS: "abs",
NativeFunction.MIN: "minimum",
Expand Down
Loading

0 comments on commit de13b08

Please sign in to comment.