Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix JSON graph dumping. #5591

Merged
merged 1 commit into from
May 14, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions python/tvm/contrib/debugger/debug_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,9 @@ def __init__(self, graph_json, dump_path):
self._dump_path = dump_path
self._output_tensor_list = []
self._time_list = []
self._parse_graph(graph_json)
json_obj = self._parse_graph(graph_json)
# dump the json information
self.dump_graph_json(graph_json)
self._dump_graph_json(json_obj)

def _parse_graph(self, graph_json):
"""Parse and extract the JSON graph and update the nodes, shapes and dltype.
Expand All @@ -70,12 +70,12 @@ def _parse_graph(self, graph_json):
self._shapes_list = json_obj['attrs']['shape']
self._dtype_list = json_obj['attrs']['dltype']
self._update_graph_json()
return json_obj

def _update_graph_json(self):
"""update the nodes_list with name, shape and data type,
for temporarily storing the output.
"""

nodes_len = len(self._nodes_list)
for i in range(nodes_len):
node = self._nodes_list[i]
Expand Down Expand Up @@ -192,7 +192,7 @@ def node_to_events(node, times, starting_time):
with open(os.path.join(self._dump_path, CHROME_TRACE_FILE_NAME), "w") as trace_f:
json.dump(result, trace_f)

def dump_graph_json(self, graph):
def _dump_graph_json(self, graph):
"""Dump json formatted graph.

Parameters
Expand Down
13 changes: 11 additions & 2 deletions tests/python/unittest/test_runtime_graph_debug.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,11 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import json
import os
import tvm
from tvm import te
import numpy as np
import json
from tvm import rpc
from tvm.contrib import util
from tvm.contrib.debugger import debug_runtime as graph_runtime
Expand Down Expand Up @@ -75,7 +75,16 @@ def check_verify():
assert(len(os.listdir(directory)) == 1)

#verify the file name is proper
assert(os.path.exists(os.path.join(directory, GRAPH_DUMP_FILE_NAME)))
graph_dump_path = os.path.join(directory, GRAPH_DUMP_FILE_NAME)
assert(os.path.exists(graph_dump_path))

# verify the graph contains some expected keys
with open(graph_dump_path) as graph_f:
dumped_graph = json.load(graph_f)

assert isinstance(dumped_graph, dict)
for k in ("nodes", "arg_nodes", "node_row_ptr", "heads", "attrs"):
assert k in dumped_graph, f"key {k} not in dumped graph {graph!r}"

mod.run()
#Verify the tensors are dumped
Expand Down