Skip to content

Commit

Permalink
Fix
Browse files Browse the repository at this point in the history
  • Loading branch information
jroesch committed Sep 4, 2020
1 parent 5a030f2 commit d9bcbe6
Show file tree
Hide file tree
Showing 12 changed files with 139 additions and 60 deletions.
3 changes: 2 additions & 1 deletion include/tvm/parser/source_map.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,14 +48,15 @@ class SourceNode : public Object {
SourceName source_name;

/*! \brief The raw source. */
std::string source;
String source;

/*! \brief A mapping of line breaks into the raw source. */
std::vector<std::pair<int, int>> line_map;

// override attr visitor
void VisitAttrs(AttrVisitor* v) {
v->Visit("source_name", &source_name);
v->Visit("source", &source);
}

static constexpr const char* _type_key = "Source";
Expand Down
14 changes: 14 additions & 0 deletions include/tvm/runtime/container.h
Original file line number Diff line number Diff line change
Expand Up @@ -1279,6 +1279,20 @@ class String : public ObjectRef {
*/
bool empty() const { return size() == 0; }

/*!
* \brief Read an element.
* \param pos The position at which to read the character.
*
* \return The char at position
*/
char at(size_t pos) const {
if (pos < size()) {
return data()[pos];
} else {
throw std::out_of_range("tvm::String index out of bounds");
}
}

/*!
* \brief Return the data pointer
*
Expand Down
1 change: 1 addition & 0 deletions python/tvm/ir/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,4 @@
from .container import Array, Map

from . import transform
from . import diagnostic
60 changes: 60 additions & 0 deletions python/tvm/ir/diagnostic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
The diagnostic interface to TVM, uses for reporting and rendering
diagnostic information about the compiler. This module exposes
three key abstractions a Diagnostic, the DiagnosticContext,
and the DiagnosticRenderer.
"""
import tvm._ffi
from .. import get_global_func, register_func, Object

def get_default_renderer():
return get_global_func("diagnostics.DefaultRenderer")

def set_default_renderer(render_func):

def _render_factory():
return DiagnosticRenderer(render_func)

register_func("diagnostics.DefaultRenderer", _render_factory, override=True)

@tvm._ffi.register_object("Diagnostic")
class Diagnostic(Object):
"""A single diagnostic object from TVM."""
pass

# TODO: use ffi_api pattern
_mk_renderer = get_global_func("DiagnosticRenderer")

# DiagnosticRenderer -> DiagnosticContext -> void
# This is the method on the renderer object.
render_method = get_global_func("DiagnosticRendererRender")

# Register the diagnostic renderer.
@tvm._ffi.register_object("DiagnosticRenderer")
class DiagnosticRenderer(Object):
def __init_(self, render_func):
self.__init_handle_by_constructor__(_mk_renderer, render_func)

def render(self, ctx):
return render_method(self, ctx)

# Register the diagnostic context.
@tvm._ffi.register_object("DiagnosticContext")
class DiagnosticContext(Object):
pass
2 changes: 1 addition & 1 deletion python/tvm/relay/backend/graph_runtime_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def __getitem__(self, item):

def __iter__(self):
warnings.warn(
"legacy graph runtime behaviour of producing json / lib / params will be "
"legacy graph runtime behavior of producing json / lib / params will be "
"removed in the next release ",
DeprecationWarning, 2)
return self
Expand Down
51 changes: 38 additions & 13 deletions src/parser/parser.cc
Original file line number Diff line number Diff line change
Expand Up @@ -867,11 +867,13 @@ class Parser {
// Stack should only grow proportionally to the number of
// nested scopes.
// Parses `{` expression `}`.
auto block = Bracket<Expr>(TokenType::kLCurly, TokenType::kRCurly, [&]() {
PushScope();
auto expr = ParseExpr();
PopScopes(1);
return expr;
auto block = WithSpan<Expr>([&]() {
return Bracket<Expr>(TokenType::kLCurly, TokenType::kRCurly, [&]() {
PushScope();
auto expr = ParseExpr();
PopScopes(1);
return expr;
});
});
exprs.push_back(block);
break;
Expand Down Expand Up @@ -926,9 +928,9 @@ class Parser {
CHECK_GE(exprs.size(), 1);

if (exprs.size() == 1) {
ICHECK(exprs[0]->span.defined())
<< "parser must set expression spans.\n"
<< exprs[0];
// ICHECK(exprs[0].defined() && exprs[0]->span.defined())
// << "parser must set expression spans.\n"
// << exprs[0];
return exprs[0];
} else {
auto body = exprs.back();
Expand All @@ -938,7 +940,7 @@ class Parser {
ICHECK(value->span.defined())
<< "parser must set expression spans.";
exprs.pop_back();
body = relay::Let(Var("", IncompleteType()), value, body);
body = relay::Let(Var("", IncompleteType()), value, body, value->span.Merge(body->span));
}
ICHECK(body->span.defined())
<< "parser must set expression spans.";
Expand Down Expand Up @@ -1203,7 +1205,7 @@ class Parser {

Expr ParseExprBinOp() {
DLOG(INFO) << "Parser::ParseExprBinOp";
return ConsumeWhitespace<Expr>([this] {
return WithSpan<Expr>([this] {
// We must parse at least one expression, the default
// case is that there is no operator and we will fall
// through.
Expand Down Expand Up @@ -1252,6 +1254,8 @@ class Parser {
exprs.pop_back();
Expr left = exprs.back();
exprs.pop_back();
CHECK(new_op.op.defined())
<< "a call op must be set " << new_op.op;
exprs.push_back(relay::Call(new_op.op, {left, right}, Attrs(), {}, left->span.Merge(right->span)));
}

Expand All @@ -1266,6 +1270,8 @@ class Parser {
exprs.pop_back();
Expr left = exprs.back();
exprs.pop_back();
CHECK(new_op.op.defined())
<< "a call op must be set " << new_op.op;
exprs.push_back(relay::Call(new_op.op, {left, right}, Attrs(), {}, left->span.Merge(right->span)));
}

Expand Down Expand Up @@ -1333,6 +1339,8 @@ class Parser {
}

Expr ParseCallArgs(Expr op) {
CHECK(op.defined()) << "the operator must be defined";

try {
DLOG(INFO) << "Parser::ParseCallArgs";
Map<String, ObjectRef> raw_attrs;
Expand Down Expand Up @@ -1519,6 +1527,26 @@ class Parser {
Match(TokenType::kCloseParen);
return static_cast<Expr>(RefCreate(expr));
}
case TokenType::kRefRead: {
return WithSpan<Expr>([&]() {
Consume(TokenType::kRefRead);
Match(TokenType::kOpenParen);
auto ref = ParseExpr();
Match(TokenType::kCloseParen);
return static_cast<Expr>(RefRead(ref));
});
}
case TokenType::kRefWrite: {
return WithSpan<Expr>([&]() {
Consume(TokenType::kRefWrite);
Match(TokenType::kOpenParen);
auto ref = ParseExpr();
Match(TokenType::kComma);
auto value = ParseExpr();
Match(TokenType::kCloseParen);
return static_cast<Expr>(RefWrite(ref, value));
});
}
case TokenType::kOpenParen: {
Span sp = next->span;
Consume(TokenType::kOpenParen);
Expand Down Expand Up @@ -1778,16 +1806,13 @@ Parser InitParser(const std::string& file_name, const std::string& file_content)
IRModule ParseModule(std::string file_name, std::string file_content) {
DLOG(INFO) << "ParseModule";
auto parser = InitParser(file_name, file_content);
std::cout << "Befofre of the parsering";
auto mod = parser.ParseModule();
std::cout << "Outside of the parsering";
ICHECK(mod.defined())
<< "The parser must return a non-null module.";
// NB(@jroesch): it is very important that we render any errors before we procede
// if there were any errors which allow the parser to procede we must render them
// here.
parser.diag_ctx.Render();
std::cout << AsText(parser.module) << std::endl;
SpanCheck()(parser.module);
auto infer_type = tvm::relay::transform::InferType();
ICHECK(infer_type.defined())
Expand Down
6 changes: 4 additions & 2 deletions src/parser/source_map.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@ Source::Source(SourceName src_name, std::string source) {
int length = 0;
n->line_map.push_back({index, length});
// NB(@jroesch):
for (auto c : n->source) {
std::string source_str = n->source;
for (auto c : source_str) {
DLOG(INFO) << "char=" << c;
if (c == '\n') {
// Record the length of the line.
Expand Down Expand Up @@ -70,7 +71,8 @@ tvm::String Source::GetLine(int line) {
int line_start = range.first;
int line_length = range.second;
DLOG(INFO) << "Source::GetLine: line_start=" << line_start << " line_length=" << line_length;
auto line_text = (*this)->source.substr(line_start, line_length);
// TODO(@jroesch): expose substring on tvm::String.
auto line_text = std::string((*this)->source).substr(line_start, line_length);
DLOG(INFO) << "Source::GetLine: line_text=" << line_text;
return line_text;
}
Expand Down
14 changes: 7 additions & 7 deletions src/parser/tokenizer.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,12 +67,12 @@ bool IsIdentLetter(char c) { return '_' == c || ('a' <= c && c <= 'z') || ('A' <
bool IsIdent(char c) { return IsIdentLetter(c) || IsDigit(c); }

static std::unordered_map<std::string, TokenType> KEYWORD_TABLE = {
{"let", TokenType::kLet}, {"fn", TokenType::kFn},
{"def", TokenType::kDefn}, {"if", TokenType::kIf},
{"else", TokenType::kElse}, {"type", TokenType::kTypeDef},
{"match", TokenType::kMatch}, {"extern", TokenType::kExtern},
{"free_var", TokenType::kFreeVar},{"ref", TokenType::kRef},
{"ref_read", TokenType::kRefRead},{"ref_write", TokenType::kRefWrite}};
{"let", TokenType::kLet}, {"fn", TokenType::kFn},
{"def", TokenType::kDefn}, {"if", TokenType::kIf},
{"else", TokenType::kElse}, {"type", TokenType::kTypeDef},
{"match", TokenType::kMatch}, {"extern", TokenType::kExtern},
{"free_var", TokenType::kFreeVar}, {"ref", TokenType::kRef},
{"ref_read", TokenType::kRefRead}, {"ref_write", TokenType::kRefWrite}};

struct Tokenizer {
DiagnosticContext diag_ctx;
Expand All @@ -82,7 +82,7 @@ struct Tokenizer {
int col;
int line;
char next_char;
const std::string& source;
String source;
std::vector<Token> tokens;

char Next() {
Expand Down
2 changes: 1 addition & 1 deletion src/relay/op/nn/convolution.h
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ bool Conv1DRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
template <typename AttrType>
bool Conv2DRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(types.size(), 3);
ICHECK_EQ(types.size(), 3);
const auto* data = types[0].as<TensorTypeNode>();
const auto* weight = types[1].as<TensorTypeNode>();
if (data == nullptr) return false;
Expand Down
34 changes: 5 additions & 29 deletions tests/python/relay/test_diagnostic.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,42 +5,18 @@
from tvm.parser import SpanCheck
from tvm.parser import AnnotateSpans
from tvm.runtime import Object
from tvm.ir.diagnostic import get_default_renderer, set_default_renderer

# This grabs the default renderer.
default_renderer = get_global_func("diagnostics.DefaultRenderer")

# This is the constructor for the renderer object.
mk_renderer = get_global_func("DiagnosticRenderer")

# This is the method on the renderer object.
render_method = get_global_func("DiagnosticRendererRender")

# Register the diagnostic.
@tvm._ffi.register_object("Diagnostic")
class Diagnostic(Object):
pass

# Register the diagnostic renderer.
@tvm._ffi.register_object("DiagnosticRenderer")
class DiagnosticRenderer(Object):
def render(self, ctx):
return render_method(self, ctx)

# Register the diagnostic context.
@tvm._ffi.register_object("DiagnosticContext")
class DiagnosticContext(Object):
pass

def the_testing_renderer():
std_out = default_renderer()
std_out = get_default_renderer()

def _renderer(diag_ctx):
std_out.render(diag_ctx)
# import pdb; pdb.set_trace()
file.open(....A).write(json)
breakpoint()

return mk_renderer(_renderer)

register_func("diagnostics.DefaultRenderer", the_testing_renderer, override=True)
set_default_renderer(the_testing_renderer)

def test_span_check():
data = relay.var('data', shape=(10, 1, 1, 1))
Expand Down
10 changes: 5 additions & 5 deletions tests/python/relay/test_ir_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,7 +417,7 @@ def test_ifelse_scope():
"""
if (True) {
let %x = ();
()
(t_
} else {
%x
}
Expand All @@ -429,7 +429,8 @@ def test_ref():
#[version = "0.0.5"]
def @main(%x: float32) {
%0 = ref(%x);
read_ref(%0)
ref_write(%0, 1f);
ref_read(%0)
}
"""
tvm.parser.parse(program)
Expand Down Expand Up @@ -977,6 +978,5 @@ def @example() {
parse_module(program)

if __name__ == "__main__":
test_ref()
# import sys
# pytest.main(sys.argv)
import sys
pytest.main(sys.argv)
2 changes: 1 addition & 1 deletion tests/python/relay/test_pass_simplify_expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def before():
y = relay.reshape(y, newshape=(1, 16, -1))
y = relay.reshape(y, newshape=(4, 8, -1, 16))
y = relay.reverse_reshape(y, newshape=(32, 0, -1))
return relay.Function([x, w], y)
return relay.Function([x, w], y).with_attr("foo": 1)

def expected():
x = relay.var("x", shape=(1, 16, 16, 16), dtype="float32")
Expand Down

0 comments on commit d9bcbe6

Please sign in to comment.