From 4e5e69216df9851ee7ad6f46548cfadfc75b063e Mon Sep 17 00:00:00 2001 From: Andrew Reusch Date: Wed, 13 May 2020 13:21:54 -0700 Subject: [PATCH] Fix JSON graph dumping. * Previously this function placed a JSON-escaped string containing the JSON-encoded graph. --- python/tvm/contrib/debugger/debug_result.py | 8 ++++---- tests/python/unittest/test_runtime_graph_debug.py | 13 +++++++++++-- 2 files changed, 15 insertions(+), 6 deletions(-) diff --git a/python/tvm/contrib/debugger/debug_result.py b/python/tvm/contrib/debugger/debug_result.py index 18920c60719e..b1fe1b62b8a9 100644 --- a/python/tvm/contrib/debugger/debug_result.py +++ b/python/tvm/contrib/debugger/debug_result.py @@ -53,9 +53,9 @@ def __init__(self, graph_json, dump_path): self._dump_path = dump_path self._output_tensor_list = [] self._time_list = [] - self._parse_graph(graph_json) + json_obj = self._parse_graph(graph_json) # dump the json information - self.dump_graph_json(graph_json) + self._dump_graph_json(json_obj) def _parse_graph(self, graph_json): """Parse and extract the JSON graph and update the nodes, shapes and dltype. @@ -70,12 +70,12 @@ def _parse_graph(self, graph_json): self._shapes_list = json_obj['attrs']['shape'] self._dtype_list = json_obj['attrs']['dltype'] self._update_graph_json() + return json_obj def _update_graph_json(self): """update the nodes_list with name, shape and data type, for temporarily storing the output. """ - nodes_len = len(self._nodes_list) for i in range(nodes_len): node = self._nodes_list[i] @@ -192,7 +192,7 @@ def node_to_events(node, times, starting_time): with open(os.path.join(self._dump_path, CHROME_TRACE_FILE_NAME), "w") as trace_f: json.dump(result, trace_f) - def dump_graph_json(self, graph): + def _dump_graph_json(self, graph): """Dump json formatted graph. Parameters diff --git a/tests/python/unittest/test_runtime_graph_debug.py b/tests/python/unittest/test_runtime_graph_debug.py index 658d9eb95ef9..ce47b16fc4d5 100644 --- a/tests/python/unittest/test_runtime_graph_debug.py +++ b/tests/python/unittest/test_runtime_graph_debug.py @@ -14,11 +14,11 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +import json import os import tvm from tvm import te import numpy as np -import json from tvm import rpc from tvm.contrib import util from tvm.contrib.debugger import debug_runtime as graph_runtime @@ -75,7 +75,16 @@ def check_verify(): assert(len(os.listdir(directory)) == 1) #verify the file name is proper - assert(os.path.exists(os.path.join(directory, GRAPH_DUMP_FILE_NAME))) + graph_dump_path = os.path.join(directory, GRAPH_DUMP_FILE_NAME) + assert(os.path.exists(graph_dump_path)) + + # verify the graph contains some expected keys + with open(graph_dump_path) as graph_f: + dumped_graph = json.load(graph_f) + + assert isinstance(dumped_graph, dict) + for k in ("nodes", "arg_nodes", "node_row_ptr", "heads", "attrs"): + assert k in dumped_graph, f"key {k} not in dumped graph {graph!r}" mod.run() #Verify the tensors are dumped