Skip to content

Commit

Permalink
[REFACTOR] Polish runtime (#4729)
Browse files Browse the repository at this point in the history
- Remove operator bool from base object ref macro
  - Raitionale: operator bool can be dangerous for sub-classes
    that also overloads other operators(e.g. ==).
  - If bool is still needed, use explicit operator bool.
- Use absolute include when necessary
- Move type related util to data_type
- Isolate stackvm code from compiler
  • Loading branch information
tqchen committed Jan 17, 2020
1 parent eaa2380 commit b171cf1
Show file tree
Hide file tree
Showing 27 changed files with 144 additions and 205 deletions.
20 changes: 19 additions & 1 deletion include/tvm/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
#include <vector>
#include <utility>
#include "expr.h"
#include "runtime/util.h"

namespace tvm {
namespace ir {
Expand Down Expand Up @@ -1677,6 +1676,25 @@ constexpr const char* tvm_fill_fragment = "tvm_fill_fragment";
*/
constexpr const char* tvm_store_matrix_sync = "tvm_store_matrix_sync";

/*! \brief The kind of structure field info used in intrinsic */
enum TVMStructFieldKind : int {
// array head address
kArrAddr,
kArrData,
kArrShape,
kArrStrides,
kArrNDim,
kArrTypeCode,
kArrTypeBits,
kArrTypeLanes,
kArrByteOffset,
kArrDeviceId,
kArrDeviceType,
kArrKindBound_,
// TVMValue field
kTVMValueContent,
kTVMValueKindBound_
};
} // namespace intrinsic

/*!
Expand Down
59 changes: 7 additions & 52 deletions include/tvm/ir/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,15 +49,7 @@ class BaseExprNode : public Object {
*/
class BaseExpr : public ObjectRef {
public:
/*! \brief Cosntructor */
BaseExpr() {}
/*!
* \brief Cosntructor from object ptr.
* \param ptr The object pointer.
*/
explicit BaseExpr(ObjectPtr<Object> ptr) : ObjectRef(ptr) {}
/*! \brief The container type. */
using ContainerType = BaseExprNode;
TVM_DEFINE_OBJECT_REF_METHODS(BaseExpr, ObjectRef, BaseExprNode);
};

/*!
Expand Down Expand Up @@ -100,13 +92,6 @@ class PrimExprNode : public BaseExprNode {
*/
class PrimExpr : public BaseExpr {
public:
/*! \brief Cosntructor */
PrimExpr() {}
/*!
* \brief Cosntructor from object ptr.
* \param ptr The object pointer.
*/
explicit PrimExpr(ObjectPtr<Object> ptr) : BaseExpr(ptr) {}
/*!
* \brief construct from integer.
* \param value The value to be constructed.
Expand All @@ -127,8 +112,8 @@ class PrimExpr : public BaseExpr {
DataType dtype() const {
return static_cast<const PrimExprNode*>(get())->dtype;
}
/*! \brief The container type. */
using ContainerType = PrimExprNode;

TVM_DEFINE_OBJECT_REF_METHODS(PrimExpr, BaseExpr, PrimExprNode);
};

/*!
Expand Down Expand Up @@ -156,29 +141,14 @@ class IntImmNode : public PrimExprNode {
*/
class IntImm : public PrimExpr {
public:
/*!
* \brief Constructor
*/
IntImm() {}
/*!
* \brief constructor from node.
*/
explicit IntImm(ObjectPtr<Object> node) : PrimExpr(node) {}
/*!
* \brief Constructor.
* \param dtype The data type of the value.
* \param value The internal value.
*/
TVM_DLL IntImm(DataType dtype, int64_t value);
/*!
* \brief Get pointer to the internal value.
* \return the content of the integer.
*/
const IntImmNode* operator->() const {
return static_cast<const IntImmNode*>(get());
}
/*! \brief type indicate the container type */
using ContainerType = IntImmNode;

TVM_DEFINE_OBJECT_REF_METHODS(IntImm, PrimExpr, IntImmNode);
};

/*!
Expand Down Expand Up @@ -206,29 +176,14 @@ class FloatImmNode : public PrimExprNode {
*/
class FloatImm : public PrimExpr {
public:
/*!
* \brief Constructor
*/
FloatImm() {}
/*!
* \brief constructor from node.
*/
explicit FloatImm(ObjectPtr<Object> node) : PrimExpr(node) {}
/*!
* \brief Constructor.
* \param dtype The data type of the value.
* \param value The internal value.
*/
TVM_DLL FloatImm(DataType dtype, double value);
/*!
* \brief Get pointer to the container.
* \return The pointer.
*/
const FloatImmNode* operator->() const {
return static_cast<const FloatImmNode*>(get());
}
/*! \brief type indicate the container type */
using ContainerType = FloatImmNode;

TVM_DEFINE_OBJECT_REF_METHODS(FloatImm, PrimExpr, FloatImmNode);
};

/*!
Expand Down
2 changes: 1 addition & 1 deletion include/tvm/runtime/c_backend_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
#ifndef TVM_RUNTIME_C_BACKEND_API_H_
#define TVM_RUNTIME_C_BACKEND_API_H_

#include "c_runtime_api.h"
#include <tvm/runtime/c_runtime_api.h>

#ifdef __cplusplus
extern "C" {
Expand Down
19 changes: 18 additions & 1 deletion include/tvm/runtime/data_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
#include <dmlc/logging.h>
#include <type_traits>


namespace tvm {
namespace runtime {
/*!
Expand Down Expand Up @@ -233,6 +232,24 @@ inline int GetVectorBytes(DataType dtype) {
return data_bits / 8;
}

/*!
* \brief Check whether type matches the given spec.
* \param t The type
* \param code The type code.
* \param bits The number of bits to be matched.
* \param lanes The number of lanes in the type.
*/
inline bool TypeMatch(DLDataType t, int code, int bits, int lanes = 1) {
return t.code == code && t.bits == bits && t.lanes == lanes;
}
/*!
* \brief Check whether two types are equal .
* \param lhs The left operand.
* \param rhs The right operand.
*/
inline bool TypeEqual(DLDataType lhs, DLDataType rhs) {
return lhs.code == rhs.code && lhs.bits == rhs.bits && lhs.lanes == rhs.lanes;
}
} // namespace runtime

using DataType = runtime::DataType;
Expand Down
4 changes: 2 additions & 2 deletions include/tvm/runtime/device_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@
#ifndef TVM_RUNTIME_DEVICE_API_H_
#define TVM_RUNTIME_DEVICE_API_H_

#include <tvm/runtime/c_runtime_api.h>
#include <tvm/runtime/packed_func.h>
#include <string>
#include "packed_func.h"
#include "c_runtime_api.h"

namespace tvm {
namespace runtime {
Expand Down
2 changes: 1 addition & 1 deletion include/tvm/runtime/memory.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,10 @@
#ifndef TVM_RUNTIME_MEMORY_H_
#define TVM_RUNTIME_MEMORY_H_

#include <tvm/runtime/object.h>
#include <cstdlib>
#include <utility>
#include <type_traits>
#include "object.h"

namespace tvm {
namespace runtime {
Expand Down
3 changes: 0 additions & 3 deletions include/tvm/runtime/object.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
#include <string>
#include <utility>


/*!
* \brief Whether or not use atomic reference counter.
* If the reference counter is not atomic,
Expand Down Expand Up @@ -715,7 +714,6 @@ struct ObjectEqual {
const ObjectName* operator->() const { \
return static_cast<const ObjectName*>(data_.get()); \
} \
operator bool() const { return data_ != nullptr; } \
using ContainerType = ObjectName;

/*
Expand All @@ -734,7 +732,6 @@ struct ObjectEqual {
ObjectName* operator->() const { \
return static_cast<ObjectName*>(data_.get()); \
} \
operator bool() const { return data_ != nullptr; } \
using ContainerType = ObjectName;

/*!
Expand Down
4 changes: 2 additions & 2 deletions include/tvm/runtime/serializer.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@

#include <dmlc/io.h>
#include <dmlc/serializer.h>
#include "c_runtime_api.h"
#include "ndarray.h"
#include <tvm/runtime/c_runtime_api.h>
#include <tvm/runtime/ndarray.h>

namespace dmlc {
namespace serializer {
Expand Down
79 changes: 0 additions & 79 deletions include/tvm/runtime/util.h

This file was deleted.

26 changes: 24 additions & 2 deletions src/codegen/stackvm/codegen_stackvm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,28 @@ namespace codegen {

using namespace ir;

// map struct field kind to runtime variants
// We keep two separate enums to ensure runtime/compiler isolation.
StackVM::StructFieldKind MapFieldKind(int64_t kind) {
auto val = static_cast<intrinsic::TVMStructFieldKind>(kind);
switch (val) {
case intrinsic::kArrData: return StackVM::kArrData;
case intrinsic::kArrShape: return StackVM::kArrShape;
case intrinsic::kArrAddr: return StackVM::kArrAddr;
case intrinsic::kArrStrides: return StackVM::kArrStrides;
case intrinsic::kArrNDim: return StackVM::kArrNDim;
case intrinsic::kArrTypeCode: return StackVM::kArrTypeCode;
case intrinsic::kArrTypeBits: return StackVM::kArrTypeBits;
case intrinsic::kArrTypeLanes: return StackVM::kArrTypeLanes;
case intrinsic::kArrByteOffset: return StackVM::kArrByteOffset;
case intrinsic::kArrDeviceId: return StackVM::kArrDeviceId;
case intrinsic::kArrDeviceType: return StackVM::kArrDeviceType;
case intrinsic::kTVMValueContent: return StackVM::kTVMValueContent;
default: LOG(FATAL) << "Do not know how to map field " << kind;
}
return StackVM::kArrData;
}

StackVM CodeGenStackVM::Compile(LoweredFunc f) {
for (size_t i = 0; i < f->args.size(); ++i) {
Var v = f->args[i];
Expand Down Expand Up @@ -163,7 +185,7 @@ void CodeGenStackVM::VisitExpr_(const CallNode* op) {
vm_.code.push_back(code);
code.v_int = index->value;
vm_.code.push_back(code);
code.v_int = kind;
code.v_int = MapFieldKind(kind);
vm_.code.push_back(code);
} else if (op->is_intrinsic(intrinsic::tvm_call_packed_lowered)) {
CHECK_GE(op->args.size(), 5U);
Expand Down Expand Up @@ -431,7 +453,7 @@ void CodeGenStackVM::VisitStmt_(const EvaluateNode *ev) {
vm_.code.push_back(code);
code.v_int = index->value;
vm_.code.push_back(code);
code.v_int = op->args[2].as<IntImmNode>()->value;
code.v_int = MapFieldKind(op->args[2].as<IntImmNode>()->value);
vm_.code.push_back(code);
} else {
this->Push(ev->value);
Expand Down
4 changes: 2 additions & 2 deletions src/pass/hoist_if_then_else.cc
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ std::pair<Stmt, Stmt> RemoveIf(const Stmt& for_stmt, const Stmt& if_stmt) {

then_for = IRTransform(for_stmt, nullptr, replace_then_case,
{PrimExpr("IfThenElse")});
if (if_stmt.as<IfThenElseNode>()->else_case) {
if (if_stmt.as<IfThenElseNode>()->else_case.defined()) {
else_for = IRTransform(for_stmt, nullptr, replace_else_case,
{PrimExpr("IfThenElse")});
}
Expand Down Expand Up @@ -221,7 +221,7 @@ void IfThenElseHoist::SelectCandidates(const Stmt& stmt) {
for2if_map_[for_stmt.get()].push_back(head);
const IfThenElseNode* if_node = head.as<IfThenElseNode>();
tracker.push(if_node->then_case);
if (if_node->else_case) {
if (if_node->else_case.defined()) {
tracker.push(if_node->else_case);
}

Expand Down
Loading

0 comments on commit b171cf1

Please sign in to comment.