From baa7f3bf92ae0d3d0147cb3834439b299e02d76e Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Fri, 9 Oct 2020 11:44:09 -0700 Subject: [PATCH] [Diagnostics][Relay][InferType] Refactor InferType to work on whole module, and use new diagnostics. (#6274) * Refactor the type checker to use diagnostics Although this patch is very large and seemingly disjoint the fixes are required to get it working for the entire stack. I started with first changing InferType to use the diagnostics, these weren't yet in the pass manager so this required changes to module and module pass. InferType wasn't actually written correctly as a pass requring refactoring there, then in order to add spans to AST it required turning on AnnotateSpans which in term required changes to the parser, and module to make it possible to use the errors. These changes to parse and module required changes to diagnostics and InferType. Althought seemingly disconnected there are hidden cycles between the components which require simultaneous change in order to remove the old error reporting. A huge change due to this patch is that the module no longer implicitly type checks functions which are added. * Apply suggestions from code review Co-authored-by: Robert Kimball Co-authored-by: Junru Shao * Apply suggestions from code review Co-authored-by: Tristan Konolige * Clean up parser * CR feedback * Apply Bobs suggestions * Fix up Python interface for diagnostics * Fix test_ir_parser and formatting * Fix cpplint * Fix lint * Fix format * More lint * Fix format * Kill dead doc comment * Fix documentation comment * Rebase fixups * Add docs for type.h * Fix parser.cc * Fix unittests * Fix black * Skip previously typechecked functions * fix ACL * Fix numerous issues * Add repr method * Fix issue with Pytest, I am ready to cry * Fix the rest of tests * Kill dead code * Fix dignostic tests * Fix more tests * fix more tests (#11) * Fix diagnostic.py deinit bug * Fix deinit issue * Format * Tweak disabling of override * Format * Fix BYOC * Fix TensorArray stuff * Fix PyTorch * Format * Format Co-authored-by: Robert Kimball Co-authored-by: Junru Shao Co-authored-by: Tristan Konolige Co-authored-by: Cody Yu Co-authored-by: Zhi <5145158+zhiics@users.noreply.github.com> --- src/parser/diagnostic.h | 179 --------------------------- src/relay/op/nn/convolution.h | 1 + tests/python/relay/test_ir_parser.py | 1 + 3 files changed, 2 insertions(+), 179 deletions(-) delete mode 100644 src/parser/diagnostic.h diff --git a/src/parser/diagnostic.h b/src/parser/diagnostic.h deleted file mode 100644 index 085d1c4ea8fb..000000000000 --- a/src/parser/diagnostic.h +++ /dev/null @@ -1,179 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file diagnostic.h - * \brief A new diagnostic interface for TVM error reporting. - * - * A prototype of the new diagnostic reporting interface for TVM. - * - * Eventually we hope to promote this file to the top-level and - * replace the existing errors.h. - */ - -#ifndef TVM_PARSER_DIAGNOSTIC_H_ -#define TVM_PARSER_DIAGNOSTIC_H_ - -#include -#include -#include -#include - -#include -#include -#include -#include - -namespace tvm { -namespace parser { - -/*! \brief The diagnostic level, controls the printing of the message. */ -enum class DiagnosticLevel { - kBug, - kError, - kWarning, - kNote, - kHelp, -}; - -struct DiagnosticBuilder; - -/*! \brief A diagnostic message. */ -struct Diagnostic { - /*! \brief The level. */ - DiagnosticLevel level; - /*! \brief The span at which to report an error. */ - Span span; - /*! \brief The diagnostic message. */ - std::string message; - - Diagnostic(DiagnosticLevel level, Span span, const std::string& message) - : level(level), span(span), message(message) {} - - static DiagnosticBuilder Bug(Span span); - static DiagnosticBuilder Error(Span span); - static DiagnosticBuilder Warning(Span span); - static DiagnosticBuilder Note(Span span); - static DiagnosticBuilder Help(Span span); -}; - -/*! - * \brief A wrapper around std::stringstream to build a diagnostic. - * - * \code - * - * void ReportError(const Error& err); - * - * void Test(int number) { - * // Use error reporter to construct an error. - * ReportError(ErrorBuilder() << "This is an error number=" << number); - * } - * - * \endcode - */ -struct DiagnosticBuilder { - public: - /*! \brief The level. */ - DiagnosticLevel level; - - /*! \brief The source name. */ - SourceName source_name; - - /*! \brief The span of the diagnostic. */ - Span span; - - template - DiagnosticBuilder& operator<<(const T& val) { // NOLINT(*) - stream_ << val; - return *this; - } - - DiagnosticBuilder() : level(DiagnosticLevel::kError), source_name(), span(Span()) {} - - DiagnosticBuilder(const DiagnosticBuilder& builder) - : level(builder.level), source_name(builder.source_name), span(builder.span) {} - - DiagnosticBuilder(DiagnosticLevel level, Span span) : level(level), span(span) {} - - operator Diagnostic() { return Diagnostic(this->level, this->span, this->stream_.str()); } - - private: - std::stringstream stream_; - friend struct Diagnostic; -}; - -DiagnosticBuilder Diagnostic::Bug(Span span) { - return DiagnosticBuilder(DiagnosticLevel::kBug, span); -} - -DiagnosticBuilder Diagnostic::Error(Span span) { - return DiagnosticBuilder(DiagnosticLevel::kError, span); -} - -DiagnosticBuilder Diagnostic::Warning(Span span) { - return DiagnosticBuilder(DiagnosticLevel::kWarning, span); -} - -DiagnosticBuilder Diagnostic::Note(Span span) { - return DiagnosticBuilder(DiagnosticLevel::kNote, span); -} - -DiagnosticBuilder Diagnostic::Help(Span span) { - return DiagnosticBuilder(DiagnosticLevel::kHelp, span); -} - -/*! \brief A diagnostic context for recording errors against a source file. - * TODO(jroesch): convert source map and improve in follow up PR, the parser - * assumes a single global file for now. - */ -struct DiagnosticContext { - /*! \brief The source to report against. */ - Source source; - - /*! \brief The set of diagnostics to report. */ - std::vector diagnostics; - - explicit DiagnosticContext(const Source& source) : source(source) {} - - /*! \brief Emit a diagnostic. */ - void Emit(const Diagnostic& diagnostic) { diagnostics.push_back(diagnostic); } - - /*! \brief Emit a diagnostic. */ - void EmitFatal(const Diagnostic& diagnostic) { - diagnostics.push_back(diagnostic); - Render(std::cout); - } - - // TODO(jroesch): eventually modularize the rendering interface to provide control of how to - // format errors. - void Render(std::ostream& ostream) { - for (auto diagnostic : diagnostics) { - source.ReportAt(ostream, diagnostic.span, diagnostic.message); - } - - if (diagnostics.size()) { - LOG(FATAL) << "DiagnosticError: one or more error diagnostics were " - << "emitted, please check diagnostic render for output."; - } - } -}; - -} // namespace parser -} // namespace tvm -#endif // TVM_PARSER_DIAGNOSTIC_H_ diff --git a/src/relay/op/nn/convolution.h b/src/relay/op/nn/convolution.h index 9459b68c23af..cd334d7269ab 100644 --- a/src/relay/op/nn/convolution.h +++ b/src/relay/op/nn/convolution.h @@ -369,6 +369,7 @@ bool Conv3DRel(const Array& types, int num_inputs, const Attrs& attrs, << "Conv3D: shape of weight is inconsistent with channels, " << " channels=" << param->channels << " wshape=" << wshape; } + if (!dshape_ncdhw[1].as() && !wshape[1].as()) { CHECK(reporter->AssertEQ(indexdiv(dshape_ncdhw[1], param->groups), wshape[1])); } diff --git a/tests/python/relay/test_ir_parser.py b/tests/python/relay/test_ir_parser.py index 5b4f05202c98..c5217ba41bfd 100644 --- a/tests/python/relay/test_ir_parser.py +++ b/tests/python/relay/test_ir_parser.py @@ -883,6 +883,7 @@ def test_op_string_attr(): nn.conv2d(%x, %y, data_layout="NHWC", kernel_layout="HWIO") """ ) + assert isinstance(call.op, tvm.ir.Op) assert call.op.name == "nn.conv2d" assert call.attrs.data_layout == "NHWC"