Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

style: minimize vertical space style #1518

Merged
merged 1 commit into from
Apr 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 1 addition & 5 deletions docs/user/next/workshop/exercises/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,7 @@
)


def random_mask(
sizes,
*dims,
dtype=None,
) -> MutableLocatedField:
def random_mask(sizes, *dims, dtype=None) -> MutableLocatedField:
arr = np.full(shape=sizes, fill_value=False).flatten()
arr[: int(arr.size * 0.5)] = True
np.random.shuffle(arr)
Expand Down
8 changes: 1 addition & 7 deletions src/gt4py/__about__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,7 @@
from packaging import version as pkg_version


__all__ = [
"__author__",
"__copyright__",
"__license__",
"__version__",
"__version_info__",
]
__all__ = ["__author__", "__copyright__", "__license__", "__version__", "__version_info__"]


__author__: Final = "ETH Zurich and individual contributors"
Expand Down
4 changes: 1 addition & 3 deletions src/gt4py/_core/definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,9 +143,7 @@ def is_positive_integral_type(integral_type: type) -> TypeGuard[Type[PositiveInt
] # TODO(egparedes) figure out if PositiveIntegral can be made to work


def is_valid_tensor_shape(
value: Sequence[IntegralScalar],
) -> TypeGuard[TensorShape]:
def is_valid_tensor_shape(value: Sequence[IntegralScalar]) -> TypeGuard[TensorShape]:
return isinstance(value, collections.abc.Sequence) and all(
isinstance(v, numbers.Integral) and v > 0 for v in value
)
Expand Down
10 changes: 2 additions & 8 deletions src/gt4py/cartesian/backend/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,10 +276,7 @@ def check_options(self, options: gt_definitions.BuildOptions) -> None:
stacklevel=2,
)

def make_module(
self,
**kwargs: Any,
) -> Type["StencilObject"]:
def make_module(self, **kwargs: Any) -> Type["StencilObject"]:
build_info = self.builder.options.build_info
if build_info is not None:
start_time = time.perf_counter()
Expand Down Expand Up @@ -412,10 +409,7 @@ def build_extension_module(
assert module_name == qualified_pyext_name

self.builder.with_backend_data(
{
"pyext_module_name": module_name,
"pyext_file_path": file_path,
}
{"pyext_module_name": module_name, "pyext_file_path": file_path}
)

return module_name, file_path
13 changes: 3 additions & 10 deletions src/gt4py/cartesian/backend/cuda_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,7 @@ def visit_FieldDecl(self, node: cuir.FieldDecl, **kwargs):
sid_ndim = domain_ndim + data_ndim
if kwargs["external_arg"]:
return "py::object {name}, std::array<gt::int_t,{sid_ndim}> {name}_origin".format(
name=node.name,
sid_ndim=sid_ndim,
name=node.name, sid_ndim=sid_ndim
)
else:
return pybuffer_to_sid(
Expand All @@ -113,12 +112,7 @@ def visit_Program(self, node: cuir.Program, **kwargs):
assert "module_name" in kwargs
entry_params = self.visit(node.params, external_arg=True, **kwargs)
sid_params = self.visit(node.params, external_arg=False, **kwargs)
return self.generic_visit(
node,
entry_params=entry_params,
sid_params=sid_params,
**kwargs,
)
return self.generic_visit(node, entry_params=entry_params, sid_params=sid_params, **kwargs)

Program = bindings_main_template()

Expand Down Expand Up @@ -157,6 +151,5 @@ def generate(self) -> Type["StencilObject"]:

# Generate and return the Python wrapper class
return self.make_module(
pyext_module_name=pyext_module_name,
pyext_file_path=pyext_file_path,
pyext_module_name=pyext_module_name, pyext_file_path=pyext_file_path
)
15 changes: 4 additions & 11 deletions src/gt4py/cartesian/backend/dace_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,7 @@ def _serialize_sdfg(sdfg: dace.SDFG):

def _specialize_transient_strides(sdfg: dace.SDFG, layout_map):
repldict = replace_strides(
[array for array in sdfg.arrays.values() if array.transient],
layout_map,
[array for array in sdfg.arrays.values() if array.transient], layout_map
)
sdfg.replace_dict(repldict)
for state in sdfg.nodes():
Expand Down Expand Up @@ -221,9 +220,7 @@ def _sdfg_add_arrays_and_edges(
ranges = [
(o - max(0, e), o - max(0, e) + s - 1, 1)
for o, e, s in zip(
origin,
field_info[name].boundary.lower_indices,
inner_sdfg.arrays[name].shape,
origin, field_info[name].boundary.lower_indices, inner_sdfg.arrays[name].shape
)
]
ranges += [(0, d, 1) for d in field_info[name].data_dims]
Expand Down Expand Up @@ -786,8 +783,7 @@ def generate(self) -> Type["StencilObject"]:

# Generate and return the Python wrapper class
return self.make_module(
pyext_module_name=pyext_module_name,
pyext_file_path=pyext_file_path,
pyext_module_name=pyext_module_name, pyext_file_path=pyext_file_path
)


Expand Down Expand Up @@ -826,10 +822,7 @@ class DaceGPUBackend(BaseDaceBackend):
),
}
MODULE_GENERATOR_CLASS = DaCeCUDAPyExtModuleGenerator
options = {
**BaseGTBackend.GT_BACKEND_OPTS,
"device_sync": {"versioning": True, "type": bool},
}
options = {**BaseGTBackend.GT_BACKEND_OPTS, "device_sync": {"versioning": True, "type": bool}}

def generate_extension(self, **kwargs: Any) -> Tuple[str, str]:
return self.make_extension(stencil_ir=self.builder.gtir, uses_cuda=True)
8 changes: 2 additions & 6 deletions src/gt4py/cartesian/backend/dace_stencil_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,7 @@ def add_optional_fields(
if info.access == AccessKind.NONE and name in kwargs and name not in sdfg.arrays:
outer_array = kwargs[name]
sdfg.add_array(
name,
shape=outer_array.shape,
dtype=outer_array.dtype,
strides=outer_array.strides,
name, shape=outer_array.shape, dtype=outer_array.dtype, strides=outer_array.strides
)

for name, info in parameter_info.items():
Expand Down Expand Up @@ -194,8 +191,7 @@ def normalize_args(
name: (kwargs[name] if name in kwargs else next(args_iter)) for name in arg_names
}
arg_infos = _extract_array_infos(
field_args=args_as_kwargs,
device=backend_cls.storage_info["device"],
field_args=args_as_kwargs, device=backend_cls.storage_info["device"]
)

origin = DaCeStencilObject._normalize_origins(arg_infos, field_info, origin)
Expand Down
14 changes: 3 additions & 11 deletions src/gt4py/cartesian/backend/gtc_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,15 +59,10 @@ def pybuffer_to_sid(

sid_def = """gt::{as_sid}<{ctype}, {sid_ndim},
gt::integral_constant<int, {unique_index}>>({name})""".format(
name=name,
ctype=ctype,
unique_index=stride_kind_index,
sid_ndim=sid_ndim,
as_sid=as_sid,
name=name, ctype=ctype, unique_index=stride_kind_index, sid_ndim=sid_ndim, as_sid=as_sid
)
sid_def = "gt::sid::shift_sid_origin({sid_def}, {name}_origin)".format(
sid_def=sid_def,
name=name,
sid_def=sid_def, name=name
)
if domain_ndim != 3:
gt_dims = [
Expand Down Expand Up @@ -145,10 +140,7 @@ def __init__(self):
self.pyext_file_path = None

def __call__(
self,
args_data: ModuleData,
builder: Optional["StencilBuilder"] = None,
**kwargs: Any,
self, args_data: ModuleData, builder: Optional["StencilBuilder"] = None, **kwargs: Any
) -> str:
self.pyext_module_name = kwargs["pyext_module_name"]
self.pyext_file_path = kwargs["pyext_file_path"]
Expand Down
10 changes: 2 additions & 8 deletions src/gt4py/cartesian/backend/gtcpp_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,12 +115,7 @@ def visit_Program(self, node: gtcpp.Program, **kwargs):
assert "module_name" in kwargs
entry_params = self.visit(node.parameters, external_arg=True, **kwargs)
sid_params = self.visit(node.parameters, external_arg=False, **kwargs)
return self.generic_visit(
node,
entry_params=entry_params,
sid_params=sid_params,
**kwargs,
)
return self.generic_visit(node, entry_params=entry_params, sid_params=sid_params, **kwargs)

Program = bindings_main_template()

Expand Down Expand Up @@ -151,8 +146,7 @@ def generate(self) -> Type["StencilObject"]:

# Generate and return the Python wrapper class
return self.make_module(
pyext_module_name=pyext_module_name,
pyext_file_path=pyext_file_path,
pyext_module_name=pyext_module_name, pyext_file_path=pyext_file_path
)


Expand Down
5 changes: 1 addition & 4 deletions src/gt4py/cartesian/backend/module_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,10 +141,7 @@ def __init__(self, builder: Optional["StencilBuilder"] = None):
)

def __call__(
self,
args_data: ModuleData,
builder: Optional["StencilBuilder"] = None,
**kwargs: Any,
self, args_data: ModuleData, builder: Optional["StencilBuilder"] = None, **kwargs: Any
) -> str:
"""
Generate source code for a Python module containing a StencilObject.
Expand Down
7 changes: 1 addition & 6 deletions src/gt4py/cartesian/backend/numpy_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,12 +117,7 @@ def _make_npir(self) -> npir.Computation:
oir_pipeline = self.builder.options.backend_opts.get(
"oir_pipeline",
DefaultPipeline(
skip=[
IJCacheDetection,
KCacheDetection,
PruneKCacheFills,
PruneKCacheFlushes,
]
skip=[IJCacheDetection, KCacheDetection, PruneKCacheFills, PruneKCacheFlushes]
),
)
oir_node = oir_pipeline.run(base_oir)
Expand Down
12 changes: 2 additions & 10 deletions src/gt4py/cartesian/backend/pyext_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,11 +174,7 @@ def get_gt_pyext_build_opts(
# The following tells mypy to accept unpacking kwargs
@overload
def build_pybind_ext(
name: str,
sources: list,
build_path: str,
target_path: str,
**kwargs: str,
name: str, sources: list, build_path: str, target_path: str, **kwargs: str
) -> Tuple[str, str]: ...


Expand Down Expand Up @@ -283,11 +279,7 @@ def build_pybind_ext(
# The following tells mypy to accept unpacking kwargs
@overload
def build_pybind_cuda_ext(
name: str,
sources: list,
build_path: str,
target_path: str,
**kwargs: str,
name: str, sources: list, build_path: str, target_path: str, **kwargs: str
) -> Tuple[str, str]:
pass

Expand Down
15 changes: 3 additions & 12 deletions src/gt4py/cartesian/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,7 @@ class BackendChoice(click.Choice):
name = "backend"

def convert(
self,
value: str,
param: Optional[click.Parameter],
ctx: Optional[click.Context],
self, value: str, param: Optional[click.Parameter], ctx: Optional[click.Context]
) -> Type[CLIBackendMixin]:
"""Convert a CLI option argument to a backend."""
name = super().convert(value, param, ctx)
Expand Down Expand Up @@ -239,10 +236,7 @@ def write_computation_src(
self.reporter.echo(f"Writing source file: {file_path}")
file_path.write_text(content)

def generate_stencils(
self,
build_options: Optional[Dict[str, Any]] = None,
) -> None:
def generate_stencils(self, build_options: Optional[Dict[str, Any]] = None) -> None:
for proto_stencil in self.iterate_stencils():
self.reporter.echo(f"Building stencil {proto_stencil.builder.options.name}")
builder = proto_stencil.builder.with_backend(self.backend_cls.name)
Expand Down Expand Up @@ -314,8 +308,5 @@ def gen(
) -> None:
"""Generate stencils from gtscript modules or packages."""
GTScriptBuilder(
input_path=input_path,
output_path=output_path,
backend=backend,
silent=silent,
input_path=input_path, output_path=output_path, backend=backend, silent=silent
).generate_stencils(build_options=dict(options))
8 changes: 1 addition & 7 deletions src/gt4py/cartesian/frontend/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,4 @@
from .base import REGISTRY, Frontend, from_name, register


__all__ = [
"gtscript_frontend",
"REGISTRY",
"Frontend",
"from_name",
"register",
]
__all__ = ["gtscript_frontend", "REGISTRY", "Frontend", "from_name", "register"]
12 changes: 3 additions & 9 deletions src/gt4py/cartesian/frontend/defir_to_gtir.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,10 +196,7 @@ def visit_FieldRef(self, node: FieldRef, *, fields_decls: Dict[str, FieldDecl],
data_type = DataType.INT32
data_index = [ScalarLiteral(value=index, data_type=data_type)]
element_ref = FieldRef(
name=node.name,
offset=node.offset,
data_index=data_index,
loc=node.loc,
name=node.name, offset=node.offset, data_index=data_index, loc=node.loc
)
field_list.append(element_ref)
# matrix
Expand Down Expand Up @@ -398,9 +395,7 @@ def visit_ComputationBlock(self, node: ComputationBlock) -> gtir.VerticalLoop:
stmts.append(decl_or_stmt)
start, end = self.visit(node.interval)
interval = gtir.Interval(
start=start,
end=end,
loc=location_to_source_location(node.interval.loc),
start=start, end=end, loc=location_to_source_location(node.interval.loc)
)
return gtir.VerticalLoop(
interval=interval,
Expand Down Expand Up @@ -525,8 +520,7 @@ def make_bound_or_level(bound: AxisBound, level) -> Optional[common.AxisBound]:
}

return gtir.HorizontalRestriction(
mask=common.HorizontalMask(**axes),
body=self.visit(node.body),
mask=common.HorizontalMask(**axes), body=self.visit(node.body)
)

def visit_While(self, node: While) -> gtir.While:
Expand Down
10 changes: 3 additions & 7 deletions src/gt4py/cartesian/frontend/gtscript_frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -658,8 +658,7 @@ def _make_init_computations(
else:
stmts.append(
nodes.Assign(
target=nodes.FieldRef.at_center(name, axes=decl.axes),
value=init_values[name],
target=nodes.FieldRef.at_center(name, axes=decl.axes), value=init_values[name]
)
)

Expand Down Expand Up @@ -982,8 +981,7 @@ def _visit_computation_node(self, node: ast.With) -> nodes.ComputationBlock:
if intervals_dicts:
stmts = [
nodes.HorizontalIf(
intervals=intervals_dict,
body=nodes.BlockStmt(stmts=stmts, loc=loc),
intervals=intervals_dict, body=nodes.BlockStmt(stmts=stmts, loc=loc)
)
for intervals_dict in intervals_dicts
]
Expand All @@ -1010,9 +1008,7 @@ def visit_Constant(
elif isinstance(value, bool):
return nodes.Cast(
data_type=nodes.DataType.BOOL,
expr=nodes.BuiltinLiteral(
value=nodes.Builtin.from_value(value),
),
expr=nodes.BuiltinLiteral(value=nodes.Builtin.from_value(value)),
loc=nodes.Location.from_ast_node(node),
)
elif isinstance(value, numbers.Number):
Expand Down
10 changes: 2 additions & 8 deletions src/gt4py/cartesian/gtc/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,10 +406,7 @@ def _make_root_validator(impl: datamodels.RootValidator) -> datamodels.RootValid


def assign_stmt_dtype_validation(*, strict: bool) -> datamodels.RootValidator:
def _impl(
cls: Type[datamodels.DataModel],
instance: datamodels.DataModel,
) -> None:
def _impl(cls: Type[datamodels.DataModel], instance: datamodels.DataModel) -> None:
assert isinstance(instance, AssignStmt)
verify_and_get_common_dtype(cls, [instance.left, instance.right], strict=strict)

Expand Down Expand Up @@ -865,10 +862,7 @@ def data_type_to_typestr(dtype: DataType) -> str:
ComparisonOperator.EQ: "equal",
ComparisonOperator.NE: "not_equal",
},
LogicalOperator: {
LogicalOperator.AND: "logical_and",
LogicalOperator.OR: "logical_or",
},
LogicalOperator: {LogicalOperator.AND: "logical_and", LogicalOperator.OR: "logical_or"},
NativeFunction: {
NativeFunction.ABS: "abs",
NativeFunction.MIN: "minimum",
Expand Down
Loading
Loading