diff --git a/include/tvm/runtime/data_type.h b/include/tvm/runtime/data_type.h index b12938bd751a..cb817a89ab81 100644 --- a/include/tvm/runtime/data_type.h +++ b/include/tvm/runtime/data_type.h @@ -53,6 +53,7 @@ class DataType { kUInt = kDLUInt, kFloat = kDLFloat, kHandle = TVMArgTypeCode::kTVMOpaqueHandle, + kBFloat = kDLBfloat, kCustomBegin = 129 }; /*! \brief default constructor */ @@ -72,6 +73,9 @@ class DataType { data_.code = static_cast(code); data_.bits = static_cast(bits); data_.lanes = static_cast(lanes); + if (code == kBFloat) { + CHECK_EQ(bits, 16); + } } /*! \return The type code. */ int code() const { return static_cast(data_.code); } @@ -89,6 +93,8 @@ class DataType { bool is_float() const { return code() == DataType::kFloat; } /*! \return whether type is a float16 type. */ bool is_float16() const { return is_float() && bits() == 16; } + /*! \return whether type is a bfloat16 type. */ + bool is_bfloat16() const { return code() == DataType::kBFloat && bits() == 16; } /*! \return whether type is an int type. */ bool is_int() const { return code() == DataType::kInt; } /*! \return whether type is an uint type. */ @@ -283,6 +289,8 @@ inline const char* DLDataTypeCode2Str(DLDataTypeCode type_code) { return "float"; case DataType::kHandle: return "handle"; + case kDLBfloat: + return "bfloat"; default: LOG(FATAL) << "unknown type_code=" << static_cast(type_code); return ""; @@ -349,6 +357,9 @@ inline DLDataType String2DLDataType(std::string s) { t.bits = 1; t.lanes = 1; return t; + } else if (s.substr(0, 6) == "bfloat") { + t.code = DataType::kBFloat; + scan = s.c_str() + 6; } else if (s.substr(0, 6) == "custom") { t.code = ParseCustomDatatype(s, &scan); } else { diff --git a/include/tvm/tir/op.h b/include/tvm/tir/op.h index 71e9ac4c3e22..2948bb2cc20e 100644 --- a/include/tvm/tir/op.h +++ b/include/tvm/tir/op.h @@ -751,7 +751,7 @@ inline PrimExpr MakeConstScalar(DataType t, ValueType value) { return LargeUIntImm(t, static_cast(low), static_cast(high)); } } - if (t.is_float()) return FloatImm(t, static_cast(value)); + if (t.is_float() || t.is_bfloat16()) return FloatImm(t, static_cast(value)); // For now, we store const scalar values of custom datatypes within doubles; later, during the // datatypes lowering pass, we will lower the value to its true representation in the format // specified by the datatype. diff --git a/include/tvm/tir/transform.h b/include/tvm/tir/transform.h index a794c12b55ee..5e04838f7cd3 100644 --- a/include/tvm/tir/transform.h +++ b/include/tvm/tir/transform.h @@ -321,6 +321,13 @@ TVM_DLL Pass CombineContextCall(); */ TVM_DLL Pass NarrowDataType(int target_bits); +/*! + * \brief Legalize bf16 typed Ops. Add a cast to fp32 + * before Ops, then add a cast back to bf16. + * \return The pass. + */ +TVM_DLL Pass BF16Legalize(); + /*! * \brief Rewrite the pointer content type of arguments, * as well as Alloc internal to the function to use diff --git a/python/tvm/_ffi/runtime_ctypes.py b/python/tvm/_ffi/runtime_ctypes.py index 2e498e38cce8..a7bfb3278784 100644 --- a/python/tvm/_ffi/runtime_ctypes.py +++ b/python/tvm/_ffi/runtime_ctypes.py @@ -54,6 +54,7 @@ class DataTypeCode(object): UINT = 1 FLOAT = 2 HANDLE = 3 + BFLOAT = 4 class DataType(ctypes.Structure): @@ -65,7 +66,8 @@ class DataType(ctypes.Structure): DataTypeCode.INT : 'int', DataTypeCode.UINT : 'uint', DataTypeCode.FLOAT : 'float', - DataTypeCode.HANDLE : 'handle' + DataTypeCode.HANDLE : 'handle', + DataTypeCode.BFLOAT : 'bfloat' } def __init__(self, type_str): super(DataType, self).__init__() @@ -96,6 +98,9 @@ def __init__(self, type_str): self.type_code = DataTypeCode.HANDLE bits = 64 head = "" + elif head.startswith("bfloat"): + self.type_code = DataTypeCode.BFLOAT + head = head[6:] elif head.startswith("custom"): # pylint: disable=import-outside-toplevel import tvm.runtime._ffi_api diff --git a/python/tvm/driver/build_module.py b/python/tvm/driver/build_module.py index a19b097168c0..47e9a81076d4 100644 --- a/python/tvm/driver/build_module.py +++ b/python/tvm/driver/build_module.py @@ -176,6 +176,7 @@ def lower(sch, pass_list += [ tvm.tir.transform.InjectPrefetch(), tvm.tir.transform.StorageFlatten(64, instrument_bound_checkers), + tvm.tir.transform.BF16Legalize(), tvm.tir.transform.NarrowDataType(32), tvm.tir.transform.Simplify(), ] diff --git a/python/tvm/tir/transform/transform.py b/python/tvm/tir/transform/transform.py index a5af3537473f..86e7a33ad8cb 100644 --- a/python/tvm/tir/transform/transform.py +++ b/python/tvm/tir/transform/transform.py @@ -226,6 +226,56 @@ def RemoveNoOp(): """ return _ffi_api.RemoveNoOp() +def BF16Legalize(): + """Legalize bf16 typed Ops. + Runs BF16Promote, BF16CastElimination and BF16TypeLowering + + Returns + ------- + fpass : tvm.transform.Pass + The result pass + """ + return _ffi_api.BF16Legalize() + +def BF16Promote(): + """Promote bf16 to fp32. Add a cast to fp32 + before Ops, then add a cast back to bf16. + + Returns + ------- + fpass : tvm.transform.Pass + The result pass + """ + return _ffi_api.BF16Promote() + +def BF16CastElimination(): + """Eliminate verbose casting between fp32 and bf16 + Checks if the AST has the pattern: + castto32(castto16(some_fp32_op(...))) + The verbose casting is generated by BF16Promote for multiple + bf16 Ops in a row. e.g.: + X[i] + Y[i] + T[i] => + bf16((float32(bf16((float32(X[i]) + float32(Y[i])))) + float32(T[i]))) + After this pass: + bf16(float32(X[i]) + float32(Y[i]) + float32(T[i])) + + Returns + ------- + fpass : tvm.transform.Pass + The result pass + """ + return _ffi_api.BF16CastElimination() + +def BF16TypeLowering(): + """Replace all bf16 type with uint16. Also lower the casting + between fp32 and bf16 + + Returns + ------- + fpass : tvm.transform.Pass + The result pass + """ + return _ffi_api.BF16TypeLowering() def RewriteUnsafeSelect(): """Detect and rewrite unsafe select that contains memory access. diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc index 9d2a11c265dd..e796f49a8dd5 100644 --- a/src/driver/driver_api.cc +++ b/src/driver/driver_api.cc @@ -162,6 +162,7 @@ IRModule lower(te::Schedule sch, const Array& args, const std::strin pass_list.push_back(tir::transform::InjectPrefetch()); pass_list.push_back(tir::transform::StorageFlatten(64, instrument_bound_checkers)); // Phase 1 + pass_list.push_back(tir::transform::BF16Legalize()); pass_list.push_back(tir::transform::NarrowDataType(32)); pass_list.push_back(tir::transform::Simplify()); pass_list.push_back(tir::transform::LoopPartition()); diff --git a/src/tir/transforms/bf16_legalize.cc b/src/tir/transforms/bf16_legalize.cc new file mode 100644 index 000000000000..07f4775ded50 --- /dev/null +++ b/src/tir/transforms/bf16_legalize.cc @@ -0,0 +1,383 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file bf16_legalize.cc + * \brief legalize bf16 type by adding cast_to_fp32 + */ + +#include +#include +#include + +#include +#include + +#include "../../arith/ir_mutator_with_analyzer.h" +#include "../../arith/ir_visitor_with_analyzer.h" + +namespace tvm { +namespace tir { + +using arith::Analyzer; +using arith::IRMutatorWithAnalyzer; + +class BF16PromoteRewriter : public StmtExprMutator { + public: + BF16PromoteRewriter() {} + + Stmt operator()(Stmt s) { return VisitStmt(s); } + + std::tuple DoCast(PrimExpr orig_a, PrimExpr orig_b, bool* is_bfloat16) { + auto a = this->VisitExpr(orig_a); + auto b = this->VisitExpr(orig_b); + *is_bfloat16 = false; + if (a->dtype.is_bfloat16()) { + CHECK(b->dtype.is_bfloat16()); + *is_bfloat16 = true; + } else if (b->dtype.is_bfloat16()) { + CHECK(a->dtype.is_bfloat16()); + *is_bfloat16 = true; + } + + if (*is_bfloat16) { + DataType fp32ty(kDLFloat, 32, 1); + a = Cast(fp32ty, a); + b = Cast(fp32ty, b); + } + return std::make_tuple(a, b); + } + + PrimExpr VisitExpr_(const AddNode* op) final; + PrimExpr VisitExpr_(const SubNode* op) final; + PrimExpr VisitExpr_(const MulNode* op) final; + PrimExpr VisitExpr_(const DivNode* op) final; + PrimExpr VisitExpr_(const MinNode* op) final; + PrimExpr VisitExpr_(const MaxNode* op) final; + PrimExpr VisitExpr_(const LTNode* op) final; + PrimExpr VisitExpr_(const LENode* op) final; + PrimExpr VisitExpr_(const GTNode* op) final; + PrimExpr VisitExpr_(const GENode* op) final; +}; + +#define DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(OP, FUNC) \ + PrimExpr BF16PromoteRewriter::VisitExpr_(const OP* op) { \ + PrimExpr a, b; \ + bool is_bfloat16; \ + std::tie(a, b) = DoCast(op->a, op->b, &is_bfloat16); \ + if (a.same_as(op->a) && b.same_as(op->b)) { \ + return GetRef(op); \ + } else { \ + auto ret = FUNC(a, b); \ + if (!is_bfloat16) \ + return ret; \ + else \ + return Cast(DataType(kDLBfloat, 16, 1), ret); \ + } \ + } + +#define DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH_NO_CAST(OP, FUNC) \ + PrimExpr BF16PromoteRewriter::VisitExpr_(const OP* op) { \ + PrimExpr a, b; \ + bool is_bfloat16; \ + std::tie(a, b) = DoCast(op->a, op->b, &is_bfloat16); \ + if (a.same_as(op->a) && b.same_as(op->b)) { \ + return GetRef(op); \ + } else { \ + auto ret = FUNC(a, b); \ + return ret; \ + } \ + } + +DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(AddNode, operator+) +DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(SubNode, operator-) +DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(MulNode, operator*) +DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(DivNode, div) +DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(MinNode, min) +DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(MaxNode, max) +DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH_NO_CAST(LTNode, operator<) // NOLINT(*) +DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH_NO_CAST(LENode, operator<=) // NOLINT(*) +DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH_NO_CAST(GTNode, operator>) // NOLINT(*) +DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH_NO_CAST(GENode, operator>=) // NOLINT(*) + +/* + * Eliminate verbose casting between fp32 and bf16 + * Checks if the AST has the pattern: + * castto32(castto16(some_fp32_op(...))) + * The verbose casting is generated by BF16Promote for multiple + * bf16 Ops in a row. e.g.: + * X[i] + Y[i] + T[i] => + * bf16((float32(bf16((float32(X[i]) + float32(Y[i])))) + float32(T[i]))) + * After this pass: + * bf16(float32(X[i]) + float32(Y[i]) + float32(T[i])) + */ +class BF16CastEliminationRewriter : public StmtExprMutator { + public: + BF16CastEliminationRewriter() {} + + Stmt operator()(Stmt s) { return VisitStmt(s); } + + PrimExpr VisitExpr_(const CastNode* op) final { + auto op_val = StmtExprMutator::VisitExpr(op->value); + if (op->dtype.is_float() && op->dtype.bits() == 32) { + // if is cast_to_fp32, check if op->value is cast_to_fp16 + // and op->value->value is a float32 + if (auto innercast = op_val.as()) { + if (innercast->dtype.is_bfloat16() && innercast->value->dtype.is_float() && + innercast->value->dtype.bits() == 32) { + return innercast->value; + } + } + } + if (op->value.same_as(op_val)) return GetRef(op); + return Cast(op->dtype, op_val); + } +}; + +union FloatCaster { + uint32_t u32; + float f32; +}; + +uint16_t RoundToNearestEven(float src) { + if (std::isnan(src)) { + return UINT16_C(0x7FC0); + } else { + FloatCaster caster; + caster.f32 = src; + uint32_t rounding_bias = ((caster.u32 >> 16) & 1) + UINT32_C(0x7FFF); + return static_cast((caster.u32 + rounding_bias) >> 16); + } +} + +/* + * Lower the bf16 type to int16 + * Lower cast between bf16 and fp32 + * Lower bf16 FloatImm to int16 + */ +class BF16LowerRewriter : StmtExprMutator { + public: + BF16LowerRewriter() {} + + std::unordered_map buffer_remap; + std::unordered_map var_remap; + + Stmt operator()(Stmt s) { return VisitStmt(s); } + + PrimExpr VisitExpr_(const CastNode* op) final { + auto op_val = StmtExprMutator::VisitExpr(op->value); + if (op->value->dtype.is_bfloat16()) { + // if is cast_from_bf16, check if is to fp32 + CHECK(op->dtype.is_float() && op->dtype.bits() == 32); + auto uint32_dtype = DataType(kDLUInt, 32, op_val->dtype.lanes()); + auto uint32_v = Cast(uint32_dtype, op_val); + // to be endian invariant. + return Call(op->dtype, CallNode::reinterpret, {uint32_v << 16}, CallNode::PureIntrinsic); + + } else if (op->dtype.is_bfloat16()) { + // if is cast_to_bf16, check if op->value is fp32 + CHECK(op->value->dtype.is_float() && op->value->dtype.bits() == 32); + auto uint32_dtype = DataType(kDLUInt, 32, op_val->dtype.lanes()); + auto uint32_v = Call(uint32_dtype, CallNode::reinterpret, {op_val}, CallNode::PureIntrinsic); + auto uint16_dtype = DataType(kDLUInt, 16, op_val->dtype.lanes()); + /* the following TIR is equivalent to the C++ code below: + uint32_t rounding_bias = ((U32 >> 16) & 1) + UINT32_C(0x7FFF); + return static_cast((U32 + rounding_bias) >> 16);*/ + auto rounding_bias = ((uint32_v >> 16) & 1) + make_const(uint16_dtype, 0x7FFF); + // to be endian invariant. + return Cast(uint16_dtype, {(uint32_v + rounding_bias) >> 16}); + } + if (op->value.same_as(op_val)) return GetRef(op); + return Cast(op->dtype, op_val); + } + + PrimExpr VisitExpr_(const VarNode* op) final { + auto itr = var_remap.find(op); + if (itr != var_remap.end()) { + return itr->second; + } + if (op->dtype.is_bfloat16()) { + CHECK(!op->type_annotation.defined()); + auto ret = Var(op->name_hint, op->dtype); + var_remap[op] = ret; + return std::move(ret); + } + return StmtExprMutator::VisitExpr_(op); + } + + Stmt VisitStmt_(const AllocateNode* op) final { + Stmt node_holder; + const AllocateNode* newop; + if (op->dtype.is_bfloat16()) { + auto v = Allocate(op->buffer_var, DataType::UInt(16, op->dtype.lanes()), op->extents, + op->condition, op->body); + node_holder = v; + newop = static_cast(v.operator->()); + } else { + newop = op; + } + return StmtExprMutator::VisitStmt_(newop); + } + + Stmt VisitStmt_(const BufferStoreNode* op) final { + auto itr = buffer_remap.find(op->buffer.operator->()); + const BufferStoreNode* newop; + BufferStore newop_holder; + if (itr != buffer_remap.end()) { + newop_holder = BufferStore(itr->second, op->value, op->indices); + newop = newop_holder.operator->(); + } else { + newop = op; + } + return StmtExprMutator::VisitStmt_(newop); + } + + Stmt VisitStmt_(const AttrStmtNode* op) final { + const AttrStmtNode* newop = op; + Stmt newop_holder; + if (auto buffer = op->node.as()) { + auto itr = buffer_remap.find(buffer); + if (itr != buffer_remap.end()) { + newop_holder = AttrStmt(itr->second, op->attr_key, op->value, op->body); + newop = newop_holder.as(); + } + } else if (auto buffer = op->node.as()) { + auto itr = var_remap.find(buffer); + if (itr != var_remap.end()) { + newop_holder = AttrStmt(itr->second, op->attr_key, op->value, op->body); + newop = newop_holder.as(); + } + } + return StmtExprMutator::VisitStmt_(newop); + } + + Stmt VisitStmt_(const BufferRealizeNode* op) final { + auto itr = buffer_remap.find(op->buffer.operator->()); + const BufferRealizeNode* newop; + Stmt newop_holder; + if (itr != buffer_remap.end()) { + auto v = BufferRealize(itr->second, op->bounds, op->condition, op->body); + newop_holder = v; + newop = v.operator->(); + } else { + newop = op; + } + return StmtExprMutator::VisitStmt_(newop); + } + + PrimExpr VisitExpr_(const BufferLoadNode* op) final { + auto itr = buffer_remap.find(op->buffer.operator->()); + const BufferLoadNode* newop; + BufferLoad newop_holder; + if (itr != buffer_remap.end()) { + newop_holder = BufferLoad(itr->second, op->indices); + newop = newop_holder.operator->(); + } else { + newop = op; + } + return StmtExprMutator::VisitExpr_(newop); + } + + PrimExpr VisitExpr_(const LoadNode* op) final { + bool is_bf16 = false; + if (op->dtype.is_bfloat16()) { + is_bf16 = true; + } + PrimExpr index = this->VisitExpr(op->index); + PrimExpr predicate = this->VisitExpr(op->predicate); + if (index.same_as(op->index) && predicate.same_as(op->predicate) && !is_bf16) { + return GetRef(op); + } else { + return Load(is_bf16 ? DataType::UInt(16, op->dtype.lanes()) : op->dtype, op->buffer_var, + index, predicate); + } + } + + PrimExpr VisitExpr_(const FloatImmNode* op) final { + if (op->dtype.is_bfloat16()) { + return IntImm(DataType::UInt(16, op->dtype.lanes()), + RoundToNearestEven(static_cast(op->value))); + } + return StmtExprMutator::VisitExpr_(op); + } + + void AlterBuffers(PrimFuncNode* op) { + std::vector> changes; + for (auto& itr : op->buffer_map) { + auto oldbuf = itr.second; + if (oldbuf->dtype.is_bfloat16()) { + auto newbuf = Buffer(oldbuf->data, DataType::UInt(16, oldbuf->dtype.lanes()), oldbuf->shape, + oldbuf->strides, oldbuf->elem_offset, oldbuf->name, oldbuf->scope, + oldbuf->data_alignment, oldbuf->offset_factor, oldbuf->buffer_type); + buffer_remap[oldbuf.operator->()] = newbuf; + changes.emplace_back(itr.first, newbuf); + } + } + if (buffer_remap.size() != 0) { + op->buffer_map.assign(changes.begin(), changes.end()); + } + } +}; + +namespace transform { + +Pass BF16Promote() { + auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { + auto* n = f.CopyOnWrite(); + n->body = BF16PromoteRewriter()(std::move(n->body)); + return f; + }; + return CreatePrimFuncPass(pass_func, 0, "tir.BF16Promote", {}); +} + +TVM_REGISTER_GLOBAL("tir.transform.BF16Promote").set_body_typed(BF16Promote); + +Pass BF16CastElimination() { + auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { + auto* n = f.CopyOnWrite(); + n->body = BF16CastEliminationRewriter()(std::move(n->body)); + return f; + }; + return CreatePrimFuncPass(pass_func, 0, "tir.BF16CastElimination", {}); +} + +TVM_REGISTER_GLOBAL("tir.transform.BF16CastElimination").set_body_typed(BF16CastElimination); + +Pass BF16TypeLowering() { + auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { + auto* n = f.CopyOnWrite(); + BF16LowerRewriter lowerer; + lowerer.AlterBuffers(n); + n->body = lowerer(std::move(n->body)); + return f; + }; + return CreatePrimFuncPass(pass_func, 0, "tir.BF16TypeLowering", {}); +} + +TVM_REGISTER_GLOBAL("tir.transform.BF16TypeLowering").set_body_typed(BF16TypeLowering); + +Pass BF16Legalize() { + return Sequential({BF16Promote(), BF16CastElimination(), BF16TypeLowering()}, "tir.BF16Legalize"); +} + +TVM_REGISTER_GLOBAL("tir.transform.BF16Legalize").set_body_typed(BF16Legalize); + +} // namespace transform +} // namespace tir +} // namespace tvm diff --git a/tests/python/unittest/test_target_codegen_llvm.py b/tests/python/unittest/test_target_codegen_llvm.py index 1173b71ade6f..0b415b0de6ba 100644 --- a/tests/python/unittest/test_target_codegen_llvm.py +++ b/tests/python/unittest/test_target_codegen_llvm.py @@ -737,6 +737,53 @@ def _transform(f, *_): module(a_, b_, c_) tvm.testing.assert_allclose(c_.asnumpy(), (a_.asnumpy() * 2).astype('int32')) +def np_float2np_bf16(arr): + ''' Convert a numpy array of float to a numpy array + of bf16 in uint16''' + orig = arr.view('