Skip to content

Commit

Permalink
Fix JSON graph dumping. (#5591)
Browse files Browse the repository at this point in the history
* Previously this function placed a JSON-escaped string containing
   the JSON-encoded graph.
  • Loading branch information
areusch authored May 14, 2020
1 parent dc9b557 commit 482e341
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 6 deletions.
8 changes: 4 additions & 4 deletions python/tvm/contrib/debugger/debug_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,9 @@ def __init__(self, graph_json, dump_path):
self._dump_path = dump_path
self._output_tensor_list = []
self._time_list = []
self._parse_graph(graph_json)
json_obj = self._parse_graph(graph_json)
# dump the json information
self.dump_graph_json(graph_json)
self._dump_graph_json(json_obj)

def _parse_graph(self, graph_json):
"""Parse and extract the JSON graph and update the nodes, shapes and dltype.
Expand All @@ -70,12 +70,12 @@ def _parse_graph(self, graph_json):
self._shapes_list = json_obj['attrs']['shape']
self._dtype_list = json_obj['attrs']['dltype']
self._update_graph_json()
return json_obj

def _update_graph_json(self):
"""update the nodes_list with name, shape and data type,
for temporarily storing the output.
"""

nodes_len = len(self._nodes_list)
for i in range(nodes_len):
node = self._nodes_list[i]
Expand Down Expand Up @@ -192,7 +192,7 @@ def node_to_events(node, times, starting_time):
with open(os.path.join(self._dump_path, CHROME_TRACE_FILE_NAME), "w") as trace_f:
json.dump(result, trace_f)

def dump_graph_json(self, graph):
def _dump_graph_json(self, graph):
"""Dump json formatted graph.
Parameters
Expand Down
13 changes: 11 additions & 2 deletions tests/python/unittest/test_runtime_graph_debug.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,11 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import json
import os
import tvm
from tvm import te
import numpy as np
import json
from tvm import rpc
from tvm.contrib import util
from tvm.contrib.debugger import debug_runtime as graph_runtime
Expand Down Expand Up @@ -75,7 +75,16 @@ def check_verify():
assert(len(os.listdir(directory)) == 1)

#verify the file name is proper
assert(os.path.exists(os.path.join(directory, GRAPH_DUMP_FILE_NAME)))
graph_dump_path = os.path.join(directory, GRAPH_DUMP_FILE_NAME)
assert(os.path.exists(graph_dump_path))

# verify the graph contains some expected keys
with open(graph_dump_path) as graph_f:
dumped_graph = json.load(graph_f)

assert isinstance(dumped_graph, dict)
for k in ("nodes", "arg_nodes", "node_row_ptr", "heads", "attrs"):
assert k in dumped_graph, f"key {k} not in dumped graph {graph!r}"

mod.run()
#Verify the tensors are dumped
Expand Down

0 comments on commit 482e341

Please sign in to comment.