Skip to content

Commit

Permalink
reapply update in parser.onnx.model.py that was lost in rebase
Browse files Browse the repository at this point in the history
  • Loading branch information
RobinGeens committed Sep 6, 2024
1 parent e3802f1 commit d06c278
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 12 deletions.
24 changes: 15 additions & 9 deletions stream/parser/onnx/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from stream.parser.onnx.pooling import PoolingParser
from stream.parser.onnx.reshape import ReshapeParser
from stream.parser.onnx.simd import SimdParser
from stream.parser.onnx.softmax import SoftmaxParser
from stream.parser.onnx.transpose import TransposeParser
from stream.workload.node import Node
from stream.workload.onnx_workload import ONNXWorkload
Expand All @@ -38,13 +39,14 @@ class ONNXModelParser:
"AveragePool": PoolingParser,
"GlobalMaxPool": PoolingParser,
"GlobalAveragePool": PoolingParser,
"Reshape": ReshapeParser,
"Flatten": FlattenParser,
"Gather": GatherParser,
"Add": SimdParser,
"Mul": SimdParser,
"Transpose": TransposeParser,
"Softmax": SoftmaxParser,
"LpNormalization": LpNormalizationParser,
"Gather": GatherParser,
"Transpose": TransposeParser,
"Reshape": ReshapeParser,
"Flatten": FlattenParser,
"Concat": ConcatParser,
}

Expand Down Expand Up @@ -102,13 +104,13 @@ def parse_workload_from_onnx_model_and_mapping(self):

# Workload Graph
workload = ONNXWorkload()

for node_id, node in enumerate(self.onnx_model.graph.node):
node_id = 0
for node in self.onnx_model.graph.node:
# If this node has no inputs, don't take it into consideration (e.g. Constant operator has no inputs)
if not node.input:
continue

nodes_inputs[node_id] = node.input
nodes_outputs[node_id] = node.output

parser_class = self.get_parser_class(node)
parser = parser_class(
Expand All @@ -121,8 +123,12 @@ def parse_workload_from_onnx_model_and_mapping(self):
)

logger.info("Parsed %s node %s.", node.op_type, node.name)
node_obj: Node = parser.run()
workload.add(node_id, node_obj)
for node_obj in parser.run():
# Parsers that yield multiple nodes increment the node id internally, so we must keep count here.
workload.add(node_id, node_obj)
node_id += 1

nodes_outputs[node_id - 1] = node.output

logger.info(
"Created ONNXWorkload graph with %i nodes and %i edges.",
Expand Down
4 changes: 1 addition & 3 deletions stream/workload/onnx_workload.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
from typing import Any, TypeVar
from typing import Any

from stream.utils import DiGraphWrapper
from stream.workload.computation_node import ComputationNode
from stream.workload.node import Node

T = TypeVar("T", bound=Node)


class ONNXWorkload(DiGraphWrapper[Node]):
"""Represents a Workload Graph"""
Expand Down

0 comments on commit d06c278

Please sign in to comment.