Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make all zips explicitly strict or non-strict #850

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
Open
6 changes: 5 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -129,8 +129,12 @@ exclude = ["doc/", "pytensor/_version.py"]
docstring-code-format = true

[tool.ruff.lint]
select = ["C", "E", "F", "I", "UP", "W", "RUF", "PERF", "PTH", "ISC"]
select = ["B905", "C", "E", "F", "I", "UP", "W", "RUF", "PERF", "PTH", "ISC"]
ignore = ["C408", "C901", "E501", "E741", "RUF012", "PERF203", "ISC001"]
unfixable = [
# zip-strict: the auto-fix adds `strict=False` but we might want `strict=True` instead
"B905",
]


[tool.ruff.lint.isort]
Expand Down
32 changes: 21 additions & 11 deletions pytensor/compile/builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,15 +43,15 @@ def infer_shape(outs, inputs, input_shapes):
# TODO: ShapeFeature should live elsewhere
from pytensor.tensor.rewriting.shape import ShapeFeature

for inp, inp_shp in zip(inputs, input_shapes):
for inp, inp_shp in zip(inputs, input_shapes, strict=True):
if inp_shp is not None and len(inp_shp) != inp.type.ndim:
assert len(inp_shp) == inp.type.ndim

shape_feature = ShapeFeature()
shape_feature.on_attach(FunctionGraph([], []))

# Initialize shape_of with the input shapes
for inp, inp_shp in zip(inputs, input_shapes):
for inp, inp_shp in zip(inputs, input_shapes, strict=True):
shape_feature.set_shape(inp, inp_shp)

def local_traverse(out):
Expand Down Expand Up @@ -108,7 +108,9 @@ def construct_nominal_fgraph(

replacements = dict(
zip(
inputs + implicit_shared_inputs, dummy_inputs + dummy_implicit_shared_inputs
inputs + implicit_shared_inputs,
dummy_inputs + dummy_implicit_shared_inputs,
strict=True,
)
)

Expand Down Expand Up @@ -138,7 +140,7 @@ def construct_nominal_fgraph(
NominalVariable(n, var.type) for n, var in enumerate(local_inputs)
)

fgraph.replace_all(zip(local_inputs, nominal_local_inputs))
fgraph.replace_all(zip(local_inputs, nominal_local_inputs, strict=True))

for i, inp in enumerate(fgraph.inputs):
nom_inp = nominal_local_inputs[i]
Expand Down Expand Up @@ -562,7 +564,9 @@ def lop_overrides(inps, grads):
# compute non-overriding downsteam grads from upstreams grads
# it's normal some input may be disconnected, thus the 'ignore'
wrt = [
lin for lin, gov in zip(inner_inputs, custom_input_grads) if gov is None
lin
for lin, gov in zip(inner_inputs, custom_input_grads, strict=True)
if gov is None
]
default_input_grads = fn_grad(wrt=wrt) if wrt else []
input_grads = self._combine_list_overrides(
Expand Down Expand Up @@ -653,7 +657,7 @@ def _build_and_cache_rop_op(self):
f = [
output
for output, custom_output_grad in zip(
inner_outputs, custom_output_grads
inner_outputs, custom_output_grads, strict=True
)
if custom_output_grad is None
]
Expand Down Expand Up @@ -733,18 +737,24 @@ def make_node(self, *inputs):

non_shared_inputs = [
inp_t.filter_variable(inp)
for inp, inp_t in zip(non_shared_inputs, self.input_types)
for inp, inp_t in zip(non_shared_inputs, self.input_types, strict=True)
]

new_shared_inputs = inputs[num_expected_inps:]
inner_and_input_shareds = list(zip(self.shared_inputs, new_shared_inputs))
inner_and_input_shareds = list(
zip(self.shared_inputs, new_shared_inputs, strict=True)
)

if not all(inp_s == inn_s for inn_s, inp_s in inner_and_input_shareds):
# The shared variables are not equal to the original shared
# variables, so we construct a new `Op` that uses the new shared
# variables instead.
replace = dict(
zip(self.inner_inputs[num_expected_inps:], new_shared_inputs)
zip(
self.inner_inputs[num_expected_inps:],
new_shared_inputs,
strict=True,
)
)

# If the new shared variables are inconsistent with the inner-graph,
Expand Down Expand Up @@ -811,7 +821,7 @@ def infer_shape(self, fgraph, node, shapes):
# each shape call. PyTensor optimizer will clean this up later, but this
# will make extra work for the optimizer.

repl = dict(zip(self.inner_inputs, node.inputs))
repl = dict(zip(self.inner_inputs, node.inputs, strict=True))
clone_out_shapes = [s for s in out_shapes if isinstance(s, tuple)]
cloned = clone_replace(sum(clone_out_shapes, ()), replace=repl)
ret = []
Expand Down Expand Up @@ -853,5 +863,5 @@ def clone(self):
def perform(self, node, inputs, outputs):
variables = self.fn(*inputs)
assert len(variables) == len(outputs)
for output, variable in zip(outputs, variables):
for output, variable in zip(outputs, variables, strict=True):
output[0] = variable
16 changes: 9 additions & 7 deletions pytensor/compile/debugmode.py
Original file line number Diff line number Diff line change
Expand Up @@ -865,7 +865,7 @@ def _get_preallocated_maps(
# except if broadcastable, or for dimensions above
# config.DebugMode__check_preallocated_output_ndim
buf_shape = []
for s, b in zip(r_vals[r].shape, r.broadcastable):
for s, b in zip(r_vals[r].shape, r.broadcastable, strict=True):
if b or ((r.ndim - len(buf_shape)) > check_ndim):
buf_shape.append(s)
else:
Expand Down Expand Up @@ -943,7 +943,7 @@ def _get_preallocated_maps(
r_shape_diff = shape_diff[: r.ndim]
new_buf_shape = [
max((s + sd), 0)
for s, sd in zip(r_vals[r].shape, r_shape_diff)
for s, sd in zip(r_vals[r].shape, r_shape_diff, strict=True)
]
new_buf = np.empty(new_buf_shape, dtype=r.type.dtype)
new_buf[...] = np.asarray(def_val).astype(r.type.dtype)
Expand Down Expand Up @@ -1575,7 +1575,7 @@ def f():
# try:
# compute the value of all variables
for i, (thunk_py, thunk_c, node) in enumerate(
zip(thunks_py, thunks_c, order)
zip(thunks_py, thunks_c, order, strict=True)
):
_logger.debug(f"{i} - starting node {i} {node}")

Expand Down Expand Up @@ -1855,7 +1855,7 @@ def thunk():
assert s[0] is None

# store our output variables to their respective storage lists
for output, storage in zip(fgraph.outputs, output_storage):
for output, storage in zip(fgraph.outputs, output_storage, strict=True):
storage[0] = r_vals[output]

# transfer all inputs back to their respective storage lists
Expand Down Expand Up @@ -1931,11 +1931,11 @@ def deco():
f,
[
Container(input, storage, readonly=False)
for input, storage in zip(fgraph.inputs, input_storage)
for input, storage in zip(fgraph.inputs, input_storage, strict=True)
],
[
Container(output, storage, readonly=True)
for output, storage in zip(fgraph.outputs, output_storage)
for output, storage in zip(fgraph.outputs, output_storage, strict=True)
],
thunks_py,
order,
Expand Down Expand Up @@ -2122,7 +2122,9 @@ def __init__(

no_borrow = [
output
for output, spec in zip(fgraph.outputs, outputs + additional_outputs)
for output, spec in zip(
fgraph.outputs, outputs + additional_outputs, strict=True
)
if not spec.borrow
]
if no_borrow:
Expand Down
6 changes: 3 additions & 3 deletions pytensor/compile/function/pfunc.py
Original file line number Diff line number Diff line change
Expand Up @@ -603,7 +603,7 @@ def construct_pfunc_ins_and_outs(

new_inputs = []

for i, iv in zip(inputs, input_variables):
for i, iv in zip(inputs, input_variables, strict=True):
new_i = copy(i)
new_i.variable = iv

Expand Down Expand Up @@ -637,13 +637,13 @@ def construct_pfunc_ins_and_outs(
assert len(fgraph.inputs) == len(inputs)
assert len(fgraph.outputs) == len(outputs)

for fg_inp, inp in zip(fgraph.inputs, inputs):
for fg_inp, inp in zip(fgraph.inputs, inputs, strict=True):
if fg_inp != getattr(inp, "variable", inp):
raise ValueError(
f"`fgraph`'s input does not match the provided input: {fg_inp}, {inp}"
)

for fg_out, out in zip(fgraph.outputs, outputs):
for fg_out, out in zip(fgraph.outputs, outputs, strict=True):
if fg_out != getattr(out, "variable", out):
raise ValueError(
f"`fgraph`'s output does not match the provided output: {fg_out}, {out}"
Expand Down
34 changes: 20 additions & 14 deletions pytensor/compile/function/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ def std_fgraph(
fgraph.attach_feature(
Supervisor(
input
for spec, input in zip(input_specs, fgraph.inputs)
for spec, input in zip(input_specs, fgraph.inputs, strict=True)
if not (
spec.mutable
or (hasattr(fgraph, "destroyers") and fgraph.has_destroyers([input]))
Expand Down Expand Up @@ -442,7 +442,7 @@ def __init__(
# this loop works by modifying the elements (as variable c) of
# self.input_storage inplace.
for i, ((input, indices, sinputs), (required, refeed, value)) in enumerate(
zip(self.indices, defaults)
zip(self.indices, defaults, strict=True)
):
if indices is None:
# containers is being used as a stack. Here we pop off
Expand Down Expand Up @@ -671,7 +671,7 @@ def checkSV(sv_ori, sv_rpl):
else:
outs = list(map(SymbolicOutput, fg_cpy.outputs))

for out_ori, out_cpy in zip(maker.outputs, outs):
for out_ori, out_cpy in zip(maker.outputs, outs, strict=False):
out_cpy.borrow = out_ori.borrow

# swap SharedVariable
Expand All @@ -684,7 +684,7 @@ def checkSV(sv_ori, sv_rpl):
raise ValueError(f"SharedVariable: {sv.name} not found")

# Swap SharedVariable in fgraph and In instances
for index, (i, in_v) in enumerate(zip(ins, fg_cpy.inputs)):
for index, (i, in_v) in enumerate(zip(ins, fg_cpy.inputs, strict=True)):
# Variables in maker.inputs are defined by user, therefore we
# use them to make comparison and do the mapping.
# Otherwise we don't touch them.
Expand All @@ -708,7 +708,7 @@ def checkSV(sv_ori, sv_rpl):

# Delete update if needed
rev_update_mapping = {v: k for k, v in fg_cpy.update_mapping.items()}
for n, (inp, in_var) in enumerate(zip(ins, fg_cpy.inputs)):
for n, (inp, in_var) in enumerate(zip(ins, fg_cpy.inputs, strict=True)):
inp.variable = in_var
if not delete_updates and inp.update is not None:
out_idx = rev_update_mapping[n]
Expand Down Expand Up @@ -768,7 +768,11 @@ def checkSV(sv_ori, sv_rpl):
).create(input_storage, storage_map=new_storage_map)

for in_ori, in_cpy, ori, cpy in zip(
maker.inputs, f_cpy.maker.inputs, self.input_storage, f_cpy.input_storage
maker.inputs,
f_cpy.maker.inputs,
self.input_storage,
f_cpy.input_storage,
strict=True,
):
# Share immutable ShareVariable and constant input's storage
swapped = swap is not None and in_ori.variable in swap
Expand Down Expand Up @@ -999,7 +1003,7 @@ def __call__(self, *args, **kwargs):
# output reference from the internal storage cells
if getattr(self.vm, "allow_gc", False):
for o_container, o_variable in zip(
self.output_storage, self.maker.fgraph.outputs
self.output_storage, self.maker.fgraph.outputs, strict=True
):
if o_variable.owner is not None:
# this node is the variable of computation
Expand All @@ -1009,7 +1013,7 @@ def __call__(self, *args, **kwargs):
if getattr(self.vm, "need_update_inputs", True):
# Update the inputs that have an update function
for input, storage in reversed(
list(zip(self.maker.expanded_inputs, input_storage))
list(zip(self.maker.expanded_inputs, input_storage, strict=True))
):
if input.update is not None:
storage.data = outputs.pop()
Expand Down Expand Up @@ -1040,7 +1044,7 @@ def __call__(self, *args, **kwargs):
assert len(self.output_keys) == len(outputs)

if output_subset is None:
return dict(zip(self.output_keys, outputs))
return dict(zip(self.output_keys, outputs, strict=True))
else:
return {
self.output_keys[index]: outputs[index]
Expand Down Expand Up @@ -1108,7 +1112,7 @@ def _pickle_Function(f):
input_storage = []

for (input, indices, inputs), (required, refeed, default) in zip(
f.indices, f.defaults
f.indices, f.defaults, strict=True
):
input_storage.append(ins[0])
del ins[0]
Expand Down Expand Up @@ -1150,7 +1154,7 @@ def _constructor_Function(maker, input_storage, inputs_data, trust_input=False):

f = maker.create(input_storage)
assert len(f.input_storage) == len(inputs_data)
for container, x in zip(f.input_storage, inputs_data):
for container, x in zip(f.input_storage, inputs_data, strict=True):
assert (
(container.data is x)
or (isinstance(x, np.ndarray) and (container.data == x).all())
Expand Down Expand Up @@ -1184,7 +1188,7 @@ def insert_deepcopy(fgraph, wrapped_inputs, wrapped_outputs):
reason = "insert_deepcopy"
updated_fgraph_inputs = {
fgraph_i
for i, fgraph_i in zip(wrapped_inputs, fgraph.inputs)
for i, fgraph_i in zip(wrapped_inputs, fgraph.inputs, strict=True)
if getattr(i, "update", False)
}

Expand Down Expand Up @@ -1521,7 +1525,9 @@ def __init__(
# return the internal storage pointer.
no_borrow = [
output
for output, spec in zip(fgraph.outputs, outputs + found_updates)
for output, spec in zip(
fgraph.outputs, outputs + found_updates, strict=True
)
if not spec.borrow
]

Expand Down Expand Up @@ -1590,7 +1596,7 @@ def create(self, input_storage=None, storage_map=None):
# defaults lists.
assert len(self.indices) == len(input_storage)
for i, ((input, indices, subinputs), input_storage_i) in enumerate(
zip(self.indices, input_storage)
zip(self.indices, input_storage, strict=True)
):
# Replace any default value given as a variable by its
# container. Note that this makes sense only in the
Expand Down
4 changes: 2 additions & 2 deletions pytensor/d3viz/formatting.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,14 +244,14 @@ def format_map(m):
ext_inputs = [self.__node_id(x) for x in node.inputs]
int_inputs = [gf.__node_id(x) for x in node.op.inner_inputs]
assert len(ext_inputs) == len(int_inputs)
h = format_map(zip(ext_inputs, int_inputs))
h = format_map(zip(ext_inputs, int_inputs, strict=True))
pd_node.get_attributes()["subg_map_inputs"] = h

# Outputs mapping
ext_outputs = [self.__node_id(x) for x in node.outputs]
int_outputs = [gf.__node_id(x) for x in node.op.inner_outputs]
assert len(ext_outputs) == len(int_outputs)
h = format_map(zip(int_outputs, ext_outputs))
h = format_map(zip(int_outputs, ext_outputs, strict=True))
pd_node.get_attributes()["subg_map_outputs"] = h

return graph
Expand Down
Loading
Loading