Skip to content

Commit

Permalink
add bfloat16 typeflag support (#4525)
Browse files Browse the repository at this point in the history
  • Loading branch information
ElaineBao authored and tqchen committed Dec 16, 2019
1 parent 62aac9f commit 8541e25
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 5 deletions.
15 changes: 10 additions & 5 deletions nnvm/include/nnvm/top/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -100,10 +100,14 @@ enum TypeFlag {
kInt32 = 4,
kInt8 = 5,
kInt64 = 6,
kInt16 = 7,
kUint16 = 8,
kUint32 = 9,
kUint64 = 10,
// kBool = 7,
// 7 is reserved for kBool, in order to keep consistency with MXNet TypeFlag defined in
// https://github.com/apache/incubator-mxnet/blob/master/3rdparty/mshadow/mshadow/base.h#L314
kInt16 = 8,
kUint16 = 9,
kUint32 = 10,
kUint64 = 11,
kBfloat16 = 12,
};

enum IndicatorRuleFlag {
Expand All @@ -125,7 +129,8 @@ enum IndicatorRuleFlag {
.add_enum("int8", kInt8) \
.add_enum("int16", kInt16) \
.add_enum("int32", kInt32) \
.add_enum("int64", kInt64)
.add_enum("int64", kInt64) \
.add_enum("bfloat16", kBfloat16)

struct CastParam : public dmlc::Parameter<CastParam> {
int dtype;
Expand Down
1 change: 1 addition & 0 deletions nnvm/src/pass/plan_memory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ static int GetDTypeSize(int type_flag) {
case kInt8:
return 1;
case kFloat16:
case kBfloat16:
case kInt16:
case kUint16:
return 2;
Expand Down

0 comments on commit 8541e25

Please sign in to comment.