Skip to content

Commit

Permalink
Apply ruff format changes.
Browse files Browse the repository at this point in the history
  • Loading branch information
egparedes committed Apr 2, 2024
1 parent a9540aa commit 0ef0c5c
Show file tree
Hide file tree
Showing 44 changed files with 783 additions and 649 deletions.
534 changes: 278 additions & 256 deletions docs/user/next/workshop/exercises/helpers.py

Large diffs are not rendered by default.

10 changes: 6 additions & 4 deletions src/gt4py/cartesian/backend/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,9 +411,11 @@ def build_extension_module(

assert module_name == qualified_pyext_name

self.builder.with_backend_data({
"pyext_module_name": module_name,
"pyext_file_path": file_path,
})
self.builder.with_backend_data(
{
"pyext_module_name": module_name,
"pyext_file_path": file_path,
}
)

return module_name, file_path
78 changes: 44 additions & 34 deletions src/gt4py/cartesian/backend/dace_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,12 +86,14 @@ def _get_expansion_priority_cpu(node: StencilComputation):
expansion_priority = []
if node.has_splittable_regions():
expansion_priority.append(["Sections", "Stages", "I", "J", "K"])
expansion_priority.extend([
["TileJ", "TileI", "IMap", "JMap", "Sections", "K", "Stages"],
["TileJ", "TileI", "IMap", "JMap", "Sections", "Stages", "K"],
["TileJ", "TileI", "Sections", "Stages", "IMap", "JMap", "K"],
["TileJ", "TileI", "Sections", "K", "Stages", "JMap", "IMap"],
])
expansion_priority.extend(
[
["TileJ", "TileI", "IMap", "JMap", "Sections", "K", "Stages"],
["TileJ", "TileI", "IMap", "JMap", "Sections", "Stages", "K"],
["TileJ", "TileI", "Sections", "Stages", "IMap", "JMap", "K"],
["TileJ", "TileI", "Sections", "K", "Stages", "JMap", "IMap"],
]
)
return expansion_priority


Expand Down Expand Up @@ -487,16 +489,18 @@ def generate_tmp_allocs(self, sdfg):
threadlocal_fmt,
"}}",
]
res.extend([
fmt.format(
name=name,
sdfg_id=array_sdfg.sdfg_id,
dtype=array.dtype.ctype,
size=f"omp_max_threads * ({array.total_size})",
local_size=array.total_size,
)
for fmt in fmts
])
res.extend(
[
fmt.format(
name=name,
sdfg_id=array_sdfg.sdfg_id,
dtype=array.dtype.ctype,
size=f"omp_max_threads * ({array.total_size})",
local_size=array.total_size,
)
for fmt in fmts
]
)
return res

@staticmethod
Expand Down Expand Up @@ -613,18 +617,22 @@ def generate_dace_args(self, stencil_ir: gtir.Stencil, sdfg: dace.SDFG) -> List[
# api field strides
fmt = "gt::sid::get_stride<{dim}>(gt::sid::get_strides(__{name}_sid))"

symbols.update({
f"__{name}_{dim}_stride": fmt.format(
dim=f"gt::stencil::dim::{dim.lower()}", name=name
)
for dim in dims
})
symbols.update({
f"__{name}_d{dim}_stride": fmt.format(
dim=f"gt::integral_constant<int, {3 + dim}>", name=name
)
for dim in range(data_ndim)
})
symbols.update(
{
f"__{name}_{dim}_stride": fmt.format(
dim=f"gt::stencil::dim::{dim.lower()}", name=name
)
for dim in dims
}
)
symbols.update(
{
f"__{name}_d{dim}_stride": fmt.format(
dim=f"gt::integral_constant<int, {3 + dim}>", name=name
)
for dim in range(data_ndim)
}
)

# api field pointers
fmt = """gt::sid::multi_shifted(
Expand Down Expand Up @@ -738,12 +746,14 @@ def apply(cls, stencil_ir: gtir.Stencil, sdfg: dace.SDFG, module_name: str, *, b

class DaCePyExtModuleGenerator(PyExtModuleGenerator):
def generate_imports(self):
return "\n".join([
*super().generate_imports().splitlines(),
"import dace",
"import copy",
"from gt4py.cartesian.backend.dace_stencil_object import DaCeStencilObject",
])
return "\n".join(
[
*super().generate_imports().splitlines(),
"import dace",
"import copy",
"from gt4py.cartesian.backend.dace_stencil_object import DaCeStencilObject",
]
)

def generate_base_class_name(self):
return "DaCeStencilObject"
Expand Down
14 changes: 8 additions & 6 deletions src/gt4py/cartesian/backend/numpy_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,14 @@ def generate_imports(self) -> str:
comp_pkg = (
self.builder.caching.module_prefix + "computation" + self.builder.caching.module_postfix
)
return "\n".join([
*super().generate_imports().splitlines(),
"import pathlib",
"from gt4py.cartesian.utils import make_module_from_file",
f'computation = make_module_from_file("{comp_pkg}", pathlib.Path(__file__).parent / "{comp_pkg}.py")',
])
return "\n".join(
[
*super().generate_imports().splitlines(),
"import pathlib",
"from gt4py.cartesian.utils import make_module_from_file",
f'computation = make_module_from_file("{comp_pkg}", pathlib.Path(__file__).parent / "{comp_pkg}.py")',
]
)

def generate_implementation(self) -> str:
params = [f"{p.name}={p.name}" for p in self.builder.gtir.params]
Expand Down
6 changes: 3 additions & 3 deletions src/gt4py/cartesian/caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,9 +259,9 @@ def is_cache_info_available_and_consistent(
and cache_info_ns.module_shash == module_shash
)
if validate_extra:
result &= all([
cache_info[key] == validate_extra[key] for key in validate_extra
])
result &= all(
[cache_info[key] == validate_extra[key] for key in validate_extra]
)
except Exception as err:
if not catch_exceptions:
raise err
Expand Down
30 changes: 17 additions & 13 deletions src/gt4py/cartesian/frontend/gtscript_frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -1520,9 +1520,9 @@ def visit_With(self, node: ast.With):

self.parsing_horizontal_region = True
intervals_dicts = self._visit_with_horizontal(node.items[0], loc)
all_stmts = gt_utils.flatten([
gtc_utils.listify(self.visit(stmt)) for stmt in node.body
])
all_stmts = gt_utils.flatten(
[gtc_utils.listify(self.visit(stmt)) for stmt in node.body]
)
self.parsing_horizontal_region = False
stmts = list(filter(lambda stmt: isinstance(stmt, nodes.Decl), all_stmts))
body_block = nodes.BlockStmt(
Expand All @@ -1536,10 +1536,12 @@ def visit_With(self, node: ast.With):
"The following variables are"
f"written before being referenced with an offset in a horizontal region: {', '.join(written_then_offset)}"
)
stmts.extend([
nodes.HorizontalIf(intervals=intervals_dict, body=body_block)
for intervals_dict in intervals_dicts
])
stmts.extend(
[
nodes.HorizontalIf(intervals=intervals_dict, body=body_block)
for intervals_dict in intervals_dicts
]
)
return stmts
else:
# If we find nested `with` blocks flatten them, i.e. transform
Expand Down Expand Up @@ -1902,12 +1904,14 @@ def resolve_external_symbols(
for name, accesses in resolved_imports.items():
if accesses:
for attr_name, attr_nodes in accesses.items():
resolved_values_list.append((
attr_name,
GTScriptParser.eval_external(
attr_name, context, nodes.Location.from_ast_node(attr_nodes[0])
),
))
resolved_values_list.append(
(
attr_name,
GTScriptParser.eval_external(
attr_name, context, nodes.Location.from_ast_node(attr_nodes[0])
),
)
)

elif not exhaustive:
resolved_values_list.append((name, GTScriptParser.eval_external(name, context)))
Expand Down
6 changes: 3 additions & 3 deletions src/gt4py/cartesian/gtc/dace/expansion/daceir_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,9 +437,9 @@ def visit_HorizontalExecution(
)
expansion_items = global_ctx.library_node.expansion_specification[stages_idx + 1 :]

iteration_ctx = iteration_ctx.push_axes_extents({
k: v for k, v in zip(dcir.Axis.dims_horizontal(), extent)
})
iteration_ctx = iteration_ctx.push_axes_extents(
{k: v for k, v in zip(dcir.Axis.dims_horizontal(), extent)}
)
iteration_ctx = iteration_ctx.push_expansion_items(expansion_items)

assert iteration_ctx.grid_subset == dcir.GridSubset.single_gridpoint()
Expand Down
10 changes: 6 additions & 4 deletions src/gt4py/cartesian/gtc/dace/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,10 +144,12 @@ def __init__(
for decl in declarations.values()
if isinstance(decl, oir.ScalarDecl)
}
self.symbol_mapping.update({
axis.domain_symbol(): dace.symbol(axis.domain_symbol(), dtype=dace.int32)
for axis in dcir.Axis.dims_horizontal()
})
self.symbol_mapping.update(
{
axis.domain_symbol(): dace.symbol(axis.domain_symbol(), dtype=dace.int32)
for axis in dcir.Axis.dims_horizontal()
}
)
self.access_infos = compute_dcir_access_infos(
oir_node,
oir_decls=declarations,
Expand Down
18 changes: 12 additions & 6 deletions src/gt4py/cartesian/gtc/dace/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,9 +129,12 @@ def visit_VerticalLoopSection(
k_grid = dcir.GridSubset.from_interval(grid_subset.intervals[dcir.Axis.K], dcir.Axis.K)
inner_infos = {name: info.apply_iteration(k_grid) for name, info in inner_infos.items()}

ctx.access_infos.update({
name: info.union(ctx.access_infos.get(name, info)) for name, info in inner_infos.items()
})
ctx.access_infos.update(
{
name: info.union(ctx.access_infos.get(name, info))
for name, info in inner_infos.items()
}
)

return ctx.access_infos

Expand Down Expand Up @@ -167,9 +170,12 @@ def visit_HorizontalExecution(

inner_infos = {name: info.apply_iteration(ij_grid) for name, info in inner_infos.items()}

ctx.access_infos.update({
name: info.union(ctx.access_infos.get(name, info)) for name, info in inner_infos.items()
})
ctx.access_infos.update(
{
name: info.union(ctx.access_infos.get(name, info))
for name, info in inner_infos.items()
}
)

return ctx.access_infos

Expand Down
6 changes: 3 additions & 3 deletions src/gt4py/cartesian/gtc/definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,9 +436,9 @@ def _apply(self, other, left_func, right_func=None):
raise ValueError("Incompatible instance '{obj}'".format(obj=other))

right_func = right_func or left_func
return type(self)([
tuple([left_func(a[0], b[0]), right_func(a[1], b[1])]) for a, b in zip(self, other)
])
return type(self)(
[tuple([left_func(a[0], b[0]), right_func(a[1], b[1])]) for a, b in zip(self, other)]
)

def _reduce(self, reduce_func, out_type=tuple):
return out_type([reduce_func(d[0], d[1]) for d in self])
Expand Down
10 changes: 6 additions & 4 deletions src/gt4py/cartesian/gtc/gtcpp/gtcpp_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,10 +104,12 @@ def visit_AccessorRef(
if accessor_ref.name in temp_decls and accessor_ref.data_index:
# Cannot use symtable. See https://github.com/GridTools/gt4py/issues/808
temp = temp_decls[accessor_ref.name]
data_index = "+".join([
f"{self.visit(index, in_data_index=True, **kwargs)}*{int(np.prod(temp.data_dims[i+1:], initial=1))}"
for i, index in enumerate(accessor_ref.data_index)
])
data_index = "+".join(
[
f"{self.visit(index, in_data_index=True, **kwargs)}*{int(np.prod(temp.data_dims[i+1:], initial=1))}"
for i, index in enumerate(accessor_ref.data_index)
]
)
return f"eval({accessor_ref.name}({i_offset}, {j_offset}, {k_offset}))[{data_index}]"
else:
data_index = "".join(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,9 +106,9 @@ class LocalTemporariesToScalars(TemporariesToScalarsBase):

def visit_Stencil(self, node: oir.Stencil, **kwargs: Any) -> oir.Stencil:
horizontal_executions = node.walk_values().if_isinstance(oir.HorizontalExecution)
temps_without_data_dims = set([
decl.name for decl in node.declarations if not decl.data_dims
])
temps_without_data_dims = set(
[decl.name for decl in node.declarations if not decl.data_dims]
)
counts: collections.Counter = sum(
(
collections.Counter(
Expand Down
22 changes: 12 additions & 10 deletions src/gt4py/cartesian/gtc/passes/oir_optimizations/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,16 +173,18 @@ class CartesianAccessCollection(GenericAccessCollection[CartesianAccess, Tuple[i

class GeneralAccessCollection(GenericAccessCollection[GeneralAccess, GeneralOffsetTuple]):
def cartesian_accesses(self) -> "AccessCollector.CartesianAccessCollection":
return AccessCollector.CartesianAccessCollection([
CartesianAccess(
field=acc.field,
offset=cast(Tuple[int, int, int], acc.offset),
data_index=acc.data_index,
is_write=acc.is_write,
)
for acc in self._ordered_accesses
if acc.offset[2] is not None
])
return AccessCollector.CartesianAccessCollection(
[
CartesianAccess(
field=acc.field,
offset=cast(Tuple[int, int, int], acc.offset),
data_index=acc.data_index,
is_write=acc.is_write,
)
for acc in self._ordered_accesses
if acc.offset[2] is not None
]
)

def has_variable_access(self) -> bool:
return any(acc.offset[2] is None for acc in self._ordered_accesses)
Expand Down
6 changes: 3 additions & 3 deletions src/gt4py/cartesian/testing/input_strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,9 +178,9 @@ def derived_shape_st(shape_st, extra: Sequence[Optional[int]]):
both shape and extra elements are summed together.
"""
return hyp_st.builds(
lambda shape: tuple([
d + e for d, e in itertools.zip_longest(shape, extra, fillvalue=0) if e is not None
]),
lambda shape: tuple(
[d + e for d, e in itertools.zip_longest(shape, extra, fillvalue=0) if e is not None]
),
shape_st,
)

Expand Down
6 changes: 3 additions & 3 deletions src/gt4py/cartesian/testing/suites.py
Original file line number Diff line number Diff line change
Expand Up @@ -501,9 +501,9 @@ def _run_test_implementation(cls, parameters_dict, implementation): # too compl
referenced_inputs = {
name: info for name, info in implementation.field_info.items() if info is not None
}
referenced_inputs.update({
name: info for name, info in implementation.parameter_info.items() if info is not None
})
referenced_inputs.update(
{name: info for name, info in implementation.parameter_info.items() if info is not None}
)

# set externals for validation method
for k, v in implementation.constants.items():
Expand Down
10 changes: 6 additions & 4 deletions src/gt4py/cartesian/utils/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,10 +206,12 @@ def classmethod_to_function(class_method, instance=None, owner=None, remove_cls_

def namespace_from_nested_dict(nested_dict):
assert isinstance(nested_dict, dict)
return types.SimpleNamespace(**{
key: namespace_from_nested_dict(value) if isinstance(value, dict) else value
for key, value in nested_dict.items()
})
return types.SimpleNamespace(
**{
key: namespace_from_nested_dict(value) if isinstance(value, dict) else value
for key, value in nested_dict.items()
}
)


def make_local_dir(dir_name, base_dir=None, *, mode=0o777, is_package=False, is_cache=False):
Expand Down
Loading

0 comments on commit 0ef0c5c

Please sign in to comment.