Skip to content

Commit

Permalink
Replace str "output" by a dummy Op in the clients of the FunctionGraph
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Jul 8, 2024
1 parent 7f623fe commit 9ba6d99
Show file tree
Hide file tree
Showing 18 changed files with 172 additions and 180 deletions.
17 changes: 8 additions & 9 deletions pytensor/compile/debugmode.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from pytensor.graph.basic import Variable, io_toposort
from pytensor.graph.destroyhandler import DestroyHandler
from pytensor.graph.features import AlreadyThere, BadOptimization
from pytensor.graph.fg import Output
from pytensor.graph.op import HasInnerGraph, Op
from pytensor.graph.utils import InconsistencyError, MethodNotDefined
from pytensor.link.basic import Container, LocalLinker
Expand Down Expand Up @@ -628,7 +629,9 @@ def _is_used_in_graph(fgraph, var):
True if `var` is used by another node in the graph.
"""
return not (fgraph.clients[var] == [("output", 1)] or fgraph.clients[var] == [])
return any(
client for client, _ in fgraph.clients[var] if not isinstance(client.op, Output)
)


def _check_strides_match(a, b, warn_err, op):
Expand Down Expand Up @@ -977,7 +980,7 @@ def _check_preallocated_output(
# disable memory checks in that mode, since they were already run.
try:
changed_inner_mode = False
if isinstance(getattr(node, "op", None), HasInnerGraph):
if isinstance(node.op, HasInnerGraph):
fn = node.op.fn
if not (fn and hasattr(fn, "maker") and hasattr(fn.maker, "mode")):
_logger.warning(f"Expected pytensor function not found in {node.op}.fn")
Expand Down Expand Up @@ -1132,18 +1135,14 @@ class _FunctionGraphEvent:

def __init__(self, kind, node, idx=None, reason=None):
self.kind = kind
if node == "output":
self.node = "output"
self.op = "output"
else:
self.node = node
self.op = node.op
self.node = node
self.op = node.op
self.idx = idx
self.reason = str(reason)

def __str__(self):
if self.kind == "change":
if self.op != "output":
if not isinstance(self.op, Output):
msg = str(len(self.node.inputs))
else:
msg = ""
Expand Down
33 changes: 15 additions & 18 deletions pytensor/compile/function/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,6 @@ def view_tree_set(fgraph, v, treeset):
"""
treeset.add(v)
for cl, v_input_pos_to_cl in fgraph.clients[v]:
if cl == "output":
continue
vmap = cl.op.view_map
dmap = cl.op.destroy_map
for opos, iposlist in chain(vmap.items(), dmap.items()):
Expand Down Expand Up @@ -1202,8 +1200,11 @@ def insert_deepcopy(fgraph, wrapped_inputs, wrapped_outputs):
has_destroyers_attr = hasattr(fgraph, "has_destroyers")

for i in range(len(fgraph.outputs)):
original_out = fgraph.outputs[i]
output_client = fgraph.get_output_client(i)

views_of_output_i = set()
view_tree_set(fgraph, alias_root(fgraph.outputs[i]), views_of_output_i)
view_tree_set(fgraph, alias_root(original_out), views_of_output_i)
copied = False
# do not allow outputs to be aliased
for j in range(i + 1, len(fgraph.outputs)):
Expand All @@ -1212,16 +1213,16 @@ def insert_deepcopy(fgraph, wrapped_inputs, wrapped_outputs):
if fgraph.outputs[j] in views_of_output_i:
if wrapped_outputs[i].borrow and wrapped_outputs[j].borrow:
fgraph.change_node_input(
"output", i, view_op(fgraph.outputs[i]), reason=reason
*output_client, view_op(original_out), reason=reason
)
else:
fgraph.change_node_input(
"output", i, deep_copy_op(fgraph.outputs[i]), reason=reason
*output_client, deep_copy_op(original_out), reason=reason
)
copied = True
break

if not copied:
if not copied: # no-break
for input_j in all_graph_inputs:
# do not allow outputs to be aliased to an inputs (j), unless
# a) that j'th input has been 'destroyed' by
Expand All @@ -1239,33 +1240,29 @@ def insert_deepcopy(fgraph, wrapped_inputs, wrapped_outputs):
j = fgraph.inputs.index(input_j)
if wrapped_outputs[i].borrow and wrapped_inputs[j].borrow:
fgraph.change_node_input(
"output",
i,
view_op(fgraph.outputs[i]),
*output_client,
view_op(original_out),
reason=reason,
)
break
else:
fgraph.change_node_input(
"output",
i,
deep_copy_op(fgraph.outputs[i]),
*output_client,
deep_copy_op(original_out),
reason=reason,
)
break
elif wrapped_outputs[i].borrow:
fgraph.change_node_input(
"output",
i,
view_op(fgraph.outputs[i]),
*output_client,
view_op(original_out),
reason=reason,
)
break
else:
fgraph.change_node_input(
"output",
i,
deep_copy_op(fgraph.outputs[i]),
*output_client,
deep_copy_op(original_out),
reason=reason,
)
break
Expand Down
11 changes: 4 additions & 7 deletions pytensor/compile/profiling.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,20 +16,17 @@
import time
from collections import Counter, defaultdict
from contextlib import contextmanager
from typing import TYPE_CHECKING, Any
from typing import Any

import numpy as np

import pytensor
from pytensor.configdefaults import config
from pytensor.graph.basic import Apply, Constant, Variable
from pytensor.graph.fg import FunctionGraph, Output
from pytensor.link.utils import get_destroy_dependencies


if TYPE_CHECKING:
from pytensor.graph.fg import FunctionGraph


@contextmanager
def extended_open(filename, mode="r"):
if filename == "<stdout>":
Expand Down Expand Up @@ -1038,7 +1035,7 @@ def count_minimum_peak(node_list, fgraph, nodes_mem):
executable_nodes = set()
for var in fgraph.inputs:
for c, _ in fgraph.clients[var]:
if c != "output":
if not isinstance(c.op, Output):
deps = c.inputs + destroy_dependencies[c]
if all(compute_map[v][0] for v in deps):
executable_nodes.add(c)
Expand Down Expand Up @@ -1166,7 +1163,7 @@ def min_memory_generator(executable_nodes, viewed_by, view_of):

for var in node.outputs:
for c, _ in fgraph.clients[var]:
if c != "output":
if not isinstance(c.op, Output):
deps = c.inputs + destroy_dependencies[c]
if all(compute_map[v][0] for v in deps):
new_exec_nodes.add(c)
Expand Down
5 changes: 2 additions & 3 deletions pytensor/graph/destroyhandler.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from pytensor.configdefaults import config
from pytensor.graph.basic import Constant
from pytensor.graph.features import AlreadyThere, Bookkeeper
from pytensor.graph.fg import Output
from pytensor.graph.utils import InconsistencyError
from pytensor.misc.ordered_set import OrderedSet

Expand Down Expand Up @@ -401,8 +402,6 @@ def has_destroyers(protected_list):
def recursive_destroys_finder(protected_var):
# protected_var is the idx'th input of app.
for app, idx in fgraph.clients[protected_var]:
if app == "output":
continue
destroy_maps = app.op.destroy_map.values()
# If True means that the apply node, destroys the protected_var.
if idx in [dmap for sublist in destroy_maps for dmap in sublist]:
Expand Down Expand Up @@ -578,7 +577,7 @@ def on_change_input(self, fgraph, app, i, old_r, new_r, reason):
app.inputs[i] changed from old_r to new_r.
"""
if app == "output":
if isinstance(app.op, Output):
# app == 'output' is special key that means FunctionGraph is redefining which nodes are being
# considered 'outputs' of the graph.
pass
Expand Down
Loading

0 comments on commit 9ba6d99

Please sign in to comment.