Skip to content

Commit

Permalink
[Diagnostics] Add environment variable for controlling top-level prin…
Browse files Browse the repository at this point in the history
…ting and fix issue with pretty printing/parsing roundtrip. (#6874)

* Update Parser in order to handle the NMS code

* Add support for displaying traces optionally

* WIP

* Fix

* Fix error reporting in parser and clean up __init__.py due to CR

* Format

* Quick fix for If

* Fix format

* Fix lint
  • Loading branch information
jroesch authored Dec 3, 2020
1 parent 42583d6 commit 8daa97e
Show file tree
Hide file tree
Showing 3 changed files with 95 additions and 31 deletions.
21 changes: 17 additions & 4 deletions python/tvm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,15 +68,28 @@
from .contrib import rocm as _rocm, nvcc as _nvcc, sdaccel as _sdaccel


def _should_print_backtrace():
in_pytest = "PYTEST_CURRENT_TEST" in os.environ
tvm_backtrace = os.environ.get("TVM_BACKTRACE", "0")

try:
tvm_backtrace = bool(int(tvm_backtrace))
except ValueError:
raise ValueError(
f"invalid value for TVM_BACKTRACE `{tvm_backtrace}`, please set to 0 or 1."
)

return in_pytest or tvm_backtrace


def tvm_wrap_excepthook(exception_hook):
"""Wrap given excepthook with TVM additional work."""

def wrapper(exctype, value, trbk):
"""Clean subprocesses when TVM is interrupted."""
in_pytest = "PYTEST_CURRENT_TEST" in os.environ

if exctype is error.DiagnosticError and not in_pytest:
pass
if exctype is error.DiagnosticError and not _should_print_backtrace():
# TODO(@jroesch): consider moving to C++?
print("note: run with `TVM_BACKTRACE=1` environment variable to display a backtrace.")
else:
exception_hook(exctype, value, trbk)

Expand Down
91 changes: 64 additions & 27 deletions src/parser/parser.cc
Original file line number Diff line number Diff line change
Expand Up @@ -605,30 +605,43 @@ class Parser {
return ast;
}

struct MetaRef {
std::string type_key;
uint64_t node_index;
Span span;
MetaRef(std::string type_key, uint64_t node_index, Span span)
: type_key(type_key), node_index(node_index), span(span) {}
};

MetaRef MetaRefFromToken(const Token& tok) {
Call ref = Downcast<Call>(tok->data);
auto attrs = ref->attrs.as<MetaRefAttrs>();
auto type_key = attrs->node_type_key;
auto index = attrs->node_index;
return MetaRef(type_key, index, ref->span);
}

/*! \brief Parse a meta reference of the form `meta[type_key][node_index]`.
* For example `meta[relay.Constant][0]` references the first constant, `meta[relay.Constant][1]`
* the second, and so on.
*/
ObjectRef ParseMetaRef() {
auto meta_ref = Match(TokenType::kMetaReference);
Call ref = Downcast<Call>(meta_ref->data);
auto attrs = ref->attrs.as<MetaRefAttrs>();
auto type_key = attrs->node_type_key;
auto index = attrs->node_index;
auto it = this->meta_table.find(type_key);
auto meta_ref_tok = Match(TokenType::kMetaReference);
auto meta_ref = MetaRefFromToken(meta_ref_tok);
auto it = this->meta_table.find(meta_ref.type_key);
if (it != this->meta_table.end()) {
auto nodes = (*it).second;
if (index < nodes.size()) {
return nodes[index];
if (meta_ref.node_index < nodes.size()) {
return nodes[meta_ref.node_index];
} else {
this->diag_ctx.Emit(Diagnostic::Error(meta_ref->span)
<< "the node index `" << index << "` is out of bounds for `" << type_key
<< "`");
this->diag_ctx.Emit(Diagnostic::Error(meta_ref.span)
<< "the node index `" << meta_ref.node_index
<< "` is out of bounds for `" << meta_ref.type_key << "`");
return ObjectRef();
}
} else {
this->diag_ctx.Emit(Diagnostic::Error(meta_ref->span)
<< "no entry in the meta table for `" << type_key << "`");
this->diag_ctx.Emit(Diagnostic::Error(meta_ref.span)
<< "no entry in the meta table for `" << meta_ref.type_key << "`");
return ObjectRef();
}
}
Expand Down Expand Up @@ -922,10 +935,7 @@ class Parser {
exprs.push_back(ParseMatch(is_total));
break;
}
case TokenType::kIf: {
exprs.push_back(ParseIf());
break;
}

// %x ...
case TokenType::kGraph:
if (Lookahead(2)->token_type == TokenType::kEqual) {
Expand Down Expand Up @@ -1344,6 +1354,10 @@ class Parser {
Match(TokenType::kIdentifier);
return ObjectRef();
}
if (id == "None") {
Match(TokenType::kIdentifier);
return Optional<ObjectRef>();
}
}
}
default:
Expand Down Expand Up @@ -1372,7 +1386,7 @@ class Parser {
ICHECK(op.defined()) << "the operator must be defined";

DLOG(INFO) << "Parser::ParseCallArgs";
Map<String, ObjectRef> raw_attrs;
Attrs attrs;
std::string op_key;
bool is_op = false;

Expand All @@ -1388,21 +1402,40 @@ class Parser {
[&] {
auto is_ident = Lookahead(1)->token_type == TokenType::kIdentifier;
auto next_is_equal = Lookahead(2)->token_type == TokenType::kEqual;

if (is_op && is_ident && next_is_equal) {
raw_attrs = ParseAttrs();
auto is_pretty_attrs = is_ident && next_is_equal;
auto is_meta_next = Lookahead(1)->token_type == TokenType::kMetaReference;
// TODO(@jroesch): might not handle trailing comma
auto last_meta = Lookahead(2)->token_type == TokenType::kCloseParen;
auto is_meta_attrs = is_meta_next && last_meta;

if (is_op && (is_pretty_attrs || is_meta_attrs)) {
if (is_meta_attrs) {
auto meta_ref = ParseMetaRef();
if (meta_ref.as<BaseAttrsNode>()) {
attrs = Downcast<Attrs>(meta_ref);
} else {
// Not awesome parsing code here.
this->pos--;
return false;
}
} else {
auto raw_attrs = ParseAttrs();
auto attr_obj = tvm::ReflectionVTable::Global()->CreateObject(op_key, raw_attrs);
ICHECK(attr_obj.defined());
attrs = Downcast<Attrs>(attr_obj);
}
return true;
}

return false;
});

Attrs attrs;

if (is_op && op_key.size()) {
auto attr_obj = tvm::ReflectionVTable::Global()->CreateObject(op_key, raw_attrs);
ICHECK(attr_obj.defined());
attrs = Downcast<Attrs>(attr_obj);
if (!attrs.defined()) {
if (is_op && op_key.size()) {
auto attr_obj = tvm::ReflectionVTable::Global()->CreateObject(op_key, {});
ICHECK(attr_obj.defined());
attrs = Downcast<Attrs>(attr_obj);
}
}

// TODO(@jroesch): in a secondary pass adjust spans.
Expand Down Expand Up @@ -1527,6 +1560,10 @@ class Parser {
ICHECK(e->span.defined()) << "function spans must be defined.\n" << e;
return e;
}
case TokenType::kIf: {
Expr e = ParseIf();
return e;
}
case TokenType::kRef: {
Consume(TokenType::kRef);
Match(TokenType::kOpenParen);
Expand Down
14 changes: 14 additions & 0 deletions tests/python/relay/test_ir_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -875,6 +875,20 @@ def @example() {
parse_module(program)


def test_parse_if_in_binding():
program = """
def @example(%b: bool) {
%0 = if (%b) {
1
} else {
0
};
%0
}
"""
parse_module(program)


def test_op_string_attr():
call = parse_text(
"""
Expand Down

0 comments on commit 8daa97e

Please sign in to comment.