Skip to content

Commit

Permalink
Use runtime::String
Browse files Browse the repository at this point in the history
  • Loading branch information
zhiics committed Apr 8, 2020
1 parent 89da63e commit c79efe7
Show file tree
Hide file tree
Showing 72 changed files with 349 additions and 273 deletions.
7 changes: 4 additions & 3 deletions include/tvm/ir/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -107,11 +107,12 @@ class PrimExpr : public BaseExpr {
* \param value The value to be constructed.
*/
TVM_DLL PrimExpr(float value); // NOLINT(*)

/*!
* \brief construct from string.
* \param str The value to be constructed.
* \brief construct from runtime String.
* \param value The value to be constructed.
*/
TVM_DLL PrimExpr(std::string str); // NOLINT(*)
TVM_DLL PrimExpr(runtime::String value); // NOLINT(*)

/*! \return the data type of this expression. */
DataType dtype() const {
Expand Down
11 changes: 6 additions & 5 deletions include/tvm/ir/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
#define TVM_IR_TRANSFORM_H_

#include <tvm/support/with.h>
#include <tvm/runtime/container.h>
#include <tvm/node/container.h>
#include <tvm/ir/error.h>
#include <tvm/ir/module.h>
Expand Down Expand Up @@ -95,9 +96,9 @@ class PassContextNode : public Object {
int fallback_device{static_cast<int>(kDLCPU)};

/*! \brief The list of required passes. */
Array<PrimExpr> required_pass;
Array<runtime::String> required_pass;
/*! \brief The list of disabled passes. */
Array<PrimExpr> disabled_pass;
Array<runtime::String> disabled_pass;

TraceFunc trace_func;

Expand Down Expand Up @@ -197,7 +198,7 @@ class PassInfoNode : public Object {
std::string name;

/*! \brief The passes that are required to perform the current pass. */
Array<PrimExpr> required;
Array<runtime::String> required;

PassInfoNode() = default;

Expand Down Expand Up @@ -226,7 +227,7 @@ class PassInfo : public ObjectRef {
*/
TVM_DLL PassInfo(int opt_level,
std::string name,
Array<PrimExpr> required);
Array<runtime::String> required);

TVM_DEFINE_OBJECT_REF_METHODS(PassInfo, ObjectRef, PassInfoNode);
};
Expand Down Expand Up @@ -346,7 +347,7 @@ Pass CreateModulePass(
const runtime::TypedPackedFunc<IRModule(IRModule, PassContext)>& pass_func,
int opt_level,
const std::string& name,
const Array<PrimExpr>& required);
const Array<runtime::String>& required);

} // namespace transform
} // namespace tvm
Expand Down
2 changes: 2 additions & 0 deletions include/tvm/node/node.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
#define TVM_NODE_NODE_H_

#include <tvm/runtime/c_runtime_api.h>
#include <tvm/runtime/container.h>
#include <tvm/runtime/object.h>
#include <tvm/runtime/memory.h>
#include <tvm/node/reflection.h>
Expand Down Expand Up @@ -62,6 +63,7 @@ using runtime::make_object;
using runtime::PackedFunc;
using runtime::TVMArgs;
using runtime::TVMRetValue;
using runtime::String;

} // namespace tvm
#endif // TVM_NODE_NODE_H_
5 changes: 3 additions & 2 deletions include/tvm/relay/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#ifndef TVM_RELAY_TRANSFORM_H_
#define TVM_RELAY_TRANSFORM_H_

#include <tvm/runtime/container.h>
#include <tvm/relay/attrs/transform.h>
#include <tvm/ir/transform.h>
#include <tvm/relay/expr.h>
Expand Down Expand Up @@ -59,7 +60,7 @@ TVM_DLL Pass CreateFunctionPass(const runtime::TypedPackedFunc<
Function(Function, IRModule, PassContext)>& pass_func,
int opt_level,
const std::string& name,
const tvm::Array<tvm::PrimExpr>& required);
const tvm::Array<runtime::String>& required);

/*! \brief Remove expressions which does not effect the program result.
*
Expand Down Expand Up @@ -355,7 +356,7 @@ TVM_DLL Pass Inline();
*
* \return The pass.
*/
TVM_DLL Pass RemoveUnusedFunctions(Array<tvm::PrimExpr> entry_functions);
TVM_DLL Pass RemoveUnusedFunctions(Array<runtime::String> entry_functions);

} // namespace transform

Expand Down
6 changes: 3 additions & 3 deletions include/tvm/target/target.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,11 @@ class TargetNode : public Object {
/*! \brief The warp size that should be used by the LowerThreadAllreduce pass */
int thread_warp_size = 1;
/*! \brief Keys for this target */
Array<PrimExpr> keys_array;
Array<runtime::String> keys_array;
/*! \brief Options for this target */
Array<PrimExpr> options_array;
Array<runtime::String> options_array;
/*! \brief Collection of imported libs */
Array<PrimExpr> libs_array;
Array<runtime::String> libs_array;

/*! \return the full device string to pass to codegen::Build */
TVM_DLL const std::string& str() const;
Expand Down
4 changes: 2 additions & 2 deletions include/tvm/tir/stmt_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -326,15 +326,15 @@ class StmtExprMutator :
* won't do further recursion.
* \param postorder The function called after recursive mutation.
* The recursive mutation result is passed to postorder for further mutation.
* \param only_enable List of StringImm.
* \param only_enable List of runtime::String.
* If it is empty, all IRNode will call preorder/postorder
* If it is not empty, preorder/postorder will only be called
* when the IRNode's type key is in the list.
*/
TVM_DLL Stmt IRTransform(Stmt node,
const runtime::PackedFunc& preorder,
const runtime::PackedFunc& postorder,
const Array<PrimExpr>& only_enable = {});
const Array<runtime::String>& only_enable = {});

/*!
* \brief recursively visit the ir in post DFS order node, apply fvisit
Expand Down
4 changes: 2 additions & 2 deletions include/tvm/tir/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ TVM_DLL Pass CreatePrimFuncPass(const runtime::TypedPackedFunc<
PrimFunc(PrimFunc, IRModule, PassContext)>& pass_func,
int opt_level,
const std::string& name,
const tvm::Array<tvm::PrimExpr>& required);
const tvm::Array<runtime::String>& required);

/*!
* \brief Transform the high-level PrimFunc to a low-level version
Expand Down Expand Up @@ -100,7 +100,7 @@ TVM_DLL Pass MakePackedAPI(int num_unpacked_args);
*
* \return The pass.
*/
TVM_DLL Pass RemapThreadAxis(Map<PrimExpr, IterVar> axis_map);
TVM_DLL Pass RemapThreadAxis(Map<runtime::String, IterVar> axis_map);


/*!
Expand Down
3 changes: 3 additions & 0 deletions python/tvm/autotvm/task/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import numpy as np

from tvm import target as _target
from tvm import runtime
from tvm.ir import container
from tvm.tir import expr
from tvm.te import tensor, placeholder
Expand Down Expand Up @@ -55,6 +56,8 @@ def _encode(x):
return x
if isinstance(x, (expr.StringImm, expr.IntImm, expr.FloatImm)):
return x.value
if isinstance(x, runtime.container.String):
return str(x)
if x is None:
return None
raise RuntimeError('Do not support type "%s" in argument. Consider to use'
Expand Down
3 changes: 1 addition & 2 deletions python/tvm/relay/backend/graph_runtime_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,7 @@ def codegen(self, func):
lowered_func = self._get_irmodule()
param_names = self._list_params_name()
params = {}
for name in param_names:
key = name.value
for key in param_names:
arr = self._get_param_by_name(key)
param = empty(arr.shape, dtype=arr.dtype, ctx=arr.ctx)
arr.copyto(param)
Expand Down
63 changes: 54 additions & 9 deletions python/tvm/runtime/container.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,9 @@
# under the License.
"""Runtime container structures."""
import tvm._ffi

from tvm._ffi.base import string_types
from tvm.runtime import Object, ObjectTypes
from tvm.runtime import _ffi_api

def getitem_helper(obj, elem_getter, length, idx):
"""Helper function to implement a pythonic getitem function.
Expand Down Expand Up @@ -75,18 +76,19 @@ def __init__(self, tag, fields):
for f in fields:
assert isinstance(f, ObjectTypes), "Expect object or " \
"tvm NDArray type, but received : {0}".format(type(f))
self.__init_handle_by_constructor__(_ADT, tag, *fields)
self.__init_handle_by_constructor__(_ffi_api.ADT, tag,
*fields)

@property
def tag(self):
return _GetADTTag(self)
return _ffi_api.GetADTTag(self)

def __getitem__(self, idx):
return getitem_helper(
self, _GetADTFields, len(self), idx)
self, _ffi_api.GetADTFields, len(self), idx)

def __len__(self):
return _GetADTSize(self)
return _ffi_api.GetADTSize(self)


def tuple_object(fields=None):
Expand All @@ -106,7 +108,7 @@ def tuple_object(fields=None):
for f in fields:
assert isinstance(f, ObjectTypes), "Expect object or tvm " \
"NDArray type, but received : {0}".format(type(f))
return _Tuple(*fields)
return _ffi_api.Tuple(*fields)


@tvm._ffi.register_object("runtime.String")
Expand All @@ -115,7 +117,7 @@ class String(Object):
Parameters
----------
string : Str
string : str
The string used to construct a runtime String object
Returns
Expand All @@ -124,7 +126,50 @@ class String(Object):
The created object.
"""
def __init__(self, string):
self.__init_handle_by_constructor__(_String, string)
self.__init_handle_by_constructor__(_ffi_api.String, string)

def __str__(self):
return _ffi_api.GetStdString(self)

def __len__(self):
return _ffi_api.GetStringSize(self)

def __hash__(self):
return _ffi_api.StringHash(self)

def __eq__(self, other):
if isinstance(other, string_types):
return self.__str__() == other

if not isinstance(other, String):
return False

return _ffi_api.CompareString(self, other) == 0

def __ne__(self, other):
return not self.__eq__(other)

def __gt__(self, other):
return _ffi_api.CompareString(self, other) > 0

def __lt__(self, other):
return _ffi_api.CompareString(self, other) < 0

def __getitem__(self, key):
return self.__str__()[key]

def startswith(self, string):
"""Check if the runtime string starts with a given string
Parameters
----------
string : str
The provided string
tvm._ffi._init_api("tvm.runtime.container")
Returns
-------
ret : boolean
Return true if the runtime string starts with the given string,
otherwise, false.
"""
return self.__str__().startswith(string)
4 changes: 2 additions & 2 deletions python/tvm/runtime/object_generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from numbers import Number, Integral
from tvm._ffi.base import string_types

from . import _ffi_node_api
from . import _ffi_node_api, _ffi_api
from .object import ObjectBase, _set_class_object_generic
from .ndarray import NDArrayBase
from .packed_func import PackedFuncBase, convert_to_tvm_func
Expand Down Expand Up @@ -56,7 +56,7 @@ def convert_to_object(value):
if isinstance(value, Number):
return const(value)
if isinstance(value, string_types):
return _ffi_node_api.String(value)
return _ffi_api.String(value)
if isinstance(value, (list, tuple)):
value = [convert_to_object(x) for x in value]
return _ffi_node_api.Array(*value)
Expand Down
10 changes: 5 additions & 5 deletions python/tvm/target/target.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,26 +48,26 @@ def __new__(cls):
@property
def keys(self):
if not self._keys:
self._keys = [k.value for k in self.keys_array]
self._keys = [str(k) for k in self.keys_array]
return self._keys

@property
def options(self):
if not self._options:
self._options = [o.value for o in self.options_array]
self._options = [str(o) for o in self.options_array]
return self._options

@property
def libs(self):
if not self._libs:
self._libs = [l.value for l in self.libs_array]
self._libs = [str(l) for l in self.libs_array]
return self._libs

@property
def model(self):
for opt in self.options_array:
if opt.value.startswith('-model='):
return opt.value[7:]
if opt.startswith('-model='):
return opt[7:]
return 'unknown'

@property
Expand Down
8 changes: 4 additions & 4 deletions src/autotvm/touch_extractor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -252,9 +252,9 @@ void GetItervarFeature(Stmt stmt, bool take_log, Array<Array<Array<PrimExpr> > >
for (auto var : vars) {
Array<Array<PrimExpr> > feature_row;
ItervarFeature &fea = touch_analyzer.itervar_map[var];
feature_row.push_back(Array<PrimExpr>{std::string("_itervar_"), var});
feature_row.push_back(Array<PrimExpr>{tvm::tir::StringImmNode::make("_itervar_"), var});

Array<PrimExpr> attr{std::string("_attr_"),
Array<PrimExpr> attr{tvm::tir::StringImmNode::make("_attr_"),
FloatImm(DataType::Float(32), trans(fea.length)),
IntImm(DataType::Int(32), fea.nest_level),
FloatImm(DataType::Float(32), trans(fea.topdown_product)),
Expand All @@ -267,7 +267,7 @@ void GetItervarFeature(Stmt stmt, bool take_log, Array<Array<Array<PrimExpr> > >
feature_row.push_back(attr);

// arithmetic
feature_row.push_back(Array<PrimExpr>{std::string("_arith_"),
feature_row.push_back(Array<PrimExpr>{tvm::tir::StringImmNode::make("_arith_"),
FloatImm(DataType::Float(32), trans(fea.add_ct)),
FloatImm(DataType::Float(32), trans(fea.mul_ct)),
FloatImm(DataType::Float(32), trans(fea.div_ct)),
Expand All @@ -282,7 +282,7 @@ void GetItervarFeature(Stmt stmt, bool take_log, Array<Array<Array<PrimExpr> > >
for (auto k : bufs) {
TouchPattern &v = fea.touch_feature[k];
feature_row.push_back(
Array<PrimExpr>{k,
Array<PrimExpr>{tvm::tir::StringImmNode::make(k),
FloatImm(DataType::Float(32), trans(v.stride)),
FloatImm(DataType::Float(32), trans(v.mod)),
FloatImm(DataType::Float(32), trans(v.count)),
Expand Down
2 changes: 1 addition & 1 deletion src/ir/attrs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ void DictAttrsNode::InitByPackedArgs(
if (val.IsObjectRef<ObjectRef>()) {
dict.Set(key, val.operator ObjectRef());
} else if (val.type_code() == kTVMStr) {
dict.Set(key, PrimExpr(val.operator std::string()));
dict.Set(key, runtime::String(val.operator std::string()));
} else {
dict.Set(key, val.operator PrimExpr());
}
Expand Down
7 changes: 5 additions & 2 deletions src/ir/expr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@ PrimExpr::PrimExpr(int32_t value)
PrimExpr::PrimExpr(float value)
: PrimExpr(FloatImm(DataType::Float(32), value)) {}

PrimExpr::PrimExpr(std::string str)
: PrimExpr(tir::StringImmNode::make(str)) {}
PrimExpr::PrimExpr(runtime::String value)
: PrimExpr(tir::StringImmNode::make(value)) {}

PrimExpr PrimExpr::FromObject_(ObjectPtr<Object> ptr) {
using runtime::ObjectTypeChecker;
Expand All @@ -51,6 +51,9 @@ PrimExpr PrimExpr::FromObject_(ObjectPtr<Object> ptr) {
if (ptr->IsInstance<te::TensorNode>()) {
return te::Tensor(ptr)();
}
if (ptr->IsInstance<runtime::StringObj>()) {
return tir::StringImmNode::make(runtime::String(ptr));
}
CHECK(ObjectTypeChecker<PrimExpr>::Check(ptr.get()))
<< "Expect type " << ObjectTypeChecker<PrimExpr>::TypeName()
<< " but get " << ptr->GetTypeKey();
Expand Down
Loading

0 comments on commit c79efe7

Please sign in to comment.