Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[REFACTOR] Polish runtime #4729

Merged
merged 1 commit into from
Jan 17, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 19 additions & 1 deletion include/tvm/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
#include <vector>
#include <utility>
#include "expr.h"
#include "runtime/util.h"

namespace tvm {
namespace ir {
Expand Down Expand Up @@ -1677,6 +1676,25 @@ constexpr const char* tvm_fill_fragment = "tvm_fill_fragment";
*/
constexpr const char* tvm_store_matrix_sync = "tvm_store_matrix_sync";

/*! \brief The kind of structure field info used in intrinsic */
enum TVMStructFieldKind : int {
// array head address
kArrAddr,
kArrData,
kArrShape,
kArrStrides,
kArrNDim,
kArrTypeCode,
kArrTypeBits,
kArrTypeLanes,
kArrByteOffset,
kArrDeviceId,
kArrDeviceType,
kArrKindBound_,
// TVMValue field
kTVMValueContent,
kTVMValueKindBound_
};
} // namespace intrinsic

/*!
Expand Down
59 changes: 7 additions & 52 deletions include/tvm/ir/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,15 +49,7 @@ class BaseExprNode : public Object {
*/
class BaseExpr : public ObjectRef {
public:
/*! \brief Cosntructor */
BaseExpr() {}
/*!
* \brief Cosntructor from object ptr.
* \param ptr The object pointer.
*/
explicit BaseExpr(ObjectPtr<Object> ptr) : ObjectRef(ptr) {}
/*! \brief The container type. */
using ContainerType = BaseExprNode;
TVM_DEFINE_OBJECT_REF_METHODS(BaseExpr, ObjectRef, BaseExprNode);
};

/*!
Expand Down Expand Up @@ -100,13 +92,6 @@ class PrimExprNode : public BaseExprNode {
*/
class PrimExpr : public BaseExpr {
public:
/*! \brief Cosntructor */
PrimExpr() {}
/*!
* \brief Cosntructor from object ptr.
* \param ptr The object pointer.
*/
explicit PrimExpr(ObjectPtr<Object> ptr) : BaseExpr(ptr) {}
/*!
* \brief construct from integer.
* \param value The value to be constructed.
Expand All @@ -127,8 +112,8 @@ class PrimExpr : public BaseExpr {
DataType dtype() const {
return static_cast<const PrimExprNode*>(get())->dtype;
}
/*! \brief The container type. */
using ContainerType = PrimExprNode;

TVM_DEFINE_OBJECT_REF_METHODS(PrimExpr, BaseExpr, PrimExprNode);
};

/*!
Expand Down Expand Up @@ -156,29 +141,14 @@ class IntImmNode : public PrimExprNode {
*/
class IntImm : public PrimExpr {
public:
/*!
* \brief Constructor
*/
IntImm() {}
/*!
* \brief constructor from node.
*/
explicit IntImm(ObjectPtr<Object> node) : PrimExpr(node) {}
/*!
* \brief Constructor.
* \param dtype The data type of the value.
* \param value The internal value.
*/
TVM_DLL IntImm(DataType dtype, int64_t value);
/*!
* \brief Get pointer to the internal value.
* \return the content of the integer.
*/
const IntImmNode* operator->() const {
return static_cast<const IntImmNode*>(get());
}
/*! \brief type indicate the container type */
using ContainerType = IntImmNode;

TVM_DEFINE_OBJECT_REF_METHODS(IntImm, PrimExpr, IntImmNode);
};

/*!
Expand Down Expand Up @@ -206,29 +176,14 @@ class FloatImmNode : public PrimExprNode {
*/
class FloatImm : public PrimExpr {
public:
/*!
* \brief Constructor
*/
FloatImm() {}
/*!
* \brief constructor from node.
*/
explicit FloatImm(ObjectPtr<Object> node) : PrimExpr(node) {}
/*!
* \brief Constructor.
* \param dtype The data type of the value.
* \param value The internal value.
*/
TVM_DLL FloatImm(DataType dtype, double value);
/*!
* \brief Get pointer to the container.
* \return The pointer.
*/
const FloatImmNode* operator->() const {
return static_cast<const FloatImmNode*>(get());
}
/*! \brief type indicate the container type */
using ContainerType = FloatImmNode;

TVM_DEFINE_OBJECT_REF_METHODS(FloatImm, PrimExpr, FloatImmNode);
};

/*!
Expand Down
2 changes: 1 addition & 1 deletion include/tvm/runtime/c_backend_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
#ifndef TVM_RUNTIME_C_BACKEND_API_H_
#define TVM_RUNTIME_C_BACKEND_API_H_

#include "c_runtime_api.h"
#include <tvm/runtime/c_runtime_api.h>

#ifdef __cplusplus
extern "C" {
Expand Down
19 changes: 18 additions & 1 deletion include/tvm/runtime/data_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
#include <dmlc/logging.h>
#include <type_traits>


namespace tvm {
namespace runtime {
/*!
Expand Down Expand Up @@ -233,6 +232,24 @@ inline int GetVectorBytes(DataType dtype) {
return data_bits / 8;
}

/*!
* \brief Check whether type matches the given spec.
* \param t The type
* \param code The type code.
* \param bits The number of bits to be matched.
* \param lanes The number of lanes in the type.
*/
inline bool TypeMatch(DLDataType t, int code, int bits, int lanes = 1) {
return t.code == code && t.bits == bits && t.lanes == lanes;
}
/*!
* \brief Check whether two types are equal .
* \param lhs The left operand.
* \param rhs The right operand.
*/
inline bool TypeEqual(DLDataType lhs, DLDataType rhs) {
return lhs.code == rhs.code && lhs.bits == rhs.bits && lhs.lanes == rhs.lanes;
}
} // namespace runtime

using DataType = runtime::DataType;
Expand Down
4 changes: 2 additions & 2 deletions include/tvm/runtime/device_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@
#ifndef TVM_RUNTIME_DEVICE_API_H_
#define TVM_RUNTIME_DEVICE_API_H_

#include <tvm/runtime/c_runtime_api.h>
#include <tvm/runtime/packed_func.h>
#include <string>
#include "packed_func.h"
#include "c_runtime_api.h"

namespace tvm {
namespace runtime {
Expand Down
2 changes: 1 addition & 1 deletion include/tvm/runtime/memory.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,10 @@
#ifndef TVM_RUNTIME_MEMORY_H_
#define TVM_RUNTIME_MEMORY_H_

#include <tvm/runtime/object.h>
#include <cstdlib>
#include <utility>
#include <type_traits>
#include "object.h"

namespace tvm {
namespace runtime {
Expand Down
3 changes: 0 additions & 3 deletions include/tvm/runtime/object.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
#include <string>
#include <utility>


/*!
* \brief Whether or not use atomic reference counter.
* If the reference counter is not atomic,
Expand Down Expand Up @@ -715,7 +714,6 @@ struct ObjectEqual {
const ObjectName* operator->() const { \
return static_cast<const ObjectName*>(data_.get()); \
} \
operator bool() const { return data_ != nullptr; } \
using ContainerType = ObjectName;

/*
Expand All @@ -734,7 +732,6 @@ struct ObjectEqual {
ObjectName* operator->() const { \
return static_cast<ObjectName*>(data_.get()); \
} \
operator bool() const { return data_ != nullptr; } \
using ContainerType = ObjectName;

/*!
Expand Down
4 changes: 2 additions & 2 deletions include/tvm/runtime/serializer.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@

#include <dmlc/io.h>
#include <dmlc/serializer.h>
#include "c_runtime_api.h"
#include "ndarray.h"
#include <tvm/runtime/c_runtime_api.h>
#include <tvm/runtime/ndarray.h>

namespace dmlc {
namespace serializer {
Expand Down
79 changes: 0 additions & 79 deletions include/tvm/runtime/util.h

This file was deleted.

26 changes: 24 additions & 2 deletions src/codegen/stackvm/codegen_stackvm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,28 @@ namespace codegen {

using namespace ir;

// map struct field kind to runtime variants
// We keep two separate enums to ensure runtime/compiler isolation.
StackVM::StructFieldKind MapFieldKind(int64_t kind) {
auto val = static_cast<intrinsic::TVMStructFieldKind>(kind);
switch (val) {
case intrinsic::kArrData: return StackVM::kArrData;
case intrinsic::kArrShape: return StackVM::kArrShape;
case intrinsic::kArrAddr: return StackVM::kArrAddr;
case intrinsic::kArrStrides: return StackVM::kArrStrides;
case intrinsic::kArrNDim: return StackVM::kArrNDim;
case intrinsic::kArrTypeCode: return StackVM::kArrTypeCode;
case intrinsic::kArrTypeBits: return StackVM::kArrTypeBits;
case intrinsic::kArrTypeLanes: return StackVM::kArrTypeLanes;
case intrinsic::kArrByteOffset: return StackVM::kArrByteOffset;
case intrinsic::kArrDeviceId: return StackVM::kArrDeviceId;
case intrinsic::kArrDeviceType: return StackVM::kArrDeviceType;
case intrinsic::kTVMValueContent: return StackVM::kTVMValueContent;
default: LOG(FATAL) << "Do not know how to map field " << kind;
}
return StackVM::kArrData;
}

StackVM CodeGenStackVM::Compile(LoweredFunc f) {
for (size_t i = 0; i < f->args.size(); ++i) {
Var v = f->args[i];
Expand Down Expand Up @@ -163,7 +185,7 @@ void CodeGenStackVM::VisitExpr_(const CallNode* op) {
vm_.code.push_back(code);
code.v_int = index->value;
vm_.code.push_back(code);
code.v_int = kind;
code.v_int = MapFieldKind(kind);
vm_.code.push_back(code);
} else if (op->is_intrinsic(intrinsic::tvm_call_packed_lowered)) {
CHECK_GE(op->args.size(), 5U);
Expand Down Expand Up @@ -431,7 +453,7 @@ void CodeGenStackVM::VisitStmt_(const EvaluateNode *ev) {
vm_.code.push_back(code);
code.v_int = index->value;
vm_.code.push_back(code);
code.v_int = op->args[2].as<IntImmNode>()->value;
code.v_int = MapFieldKind(op->args[2].as<IntImmNode>()->value);
vm_.code.push_back(code);
} else {
this->Push(ev->value);
Expand Down
4 changes: 2 additions & 2 deletions src/pass/hoist_if_then_else.cc
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ std::pair<Stmt, Stmt> RemoveIf(const Stmt& for_stmt, const Stmt& if_stmt) {

then_for = IRTransform(for_stmt, nullptr, replace_then_case,
{PrimExpr("IfThenElse")});
if (if_stmt.as<IfThenElseNode>()->else_case) {
if (if_stmt.as<IfThenElseNode>()->else_case.defined()) {
else_for = IRTransform(for_stmt, nullptr, replace_else_case,
{PrimExpr("IfThenElse")});
}
Expand Down Expand Up @@ -221,7 +221,7 @@ void IfThenElseHoist::SelectCandidates(const Stmt& stmt) {
for2if_map_[for_stmt.get()].push_back(head);
const IfThenElseNode* if_node = head.as<IfThenElseNode>();
tracker.push(if_node->then_case);
if (if_node->else_case) {
if (if_node->else_case.defined()) {
tracker.push(if_node->else_case);
}

Expand Down
Loading