From 6ca0e30051947f64ba1aad1df2309760b34acd12 Mon Sep 17 00:00:00 2001 From: Menooker Date: Thu, 30 Apr 2020 15:21:21 +0800 Subject: [PATCH 01/43] add bf16 --- include/tvm/runtime/c_runtime_api.h | 1 + include/tvm/runtime/data_type.h | 10 + include/tvm/tir/op.h | 2 +- include/tvm/tir/transform.h | 8 + python/tvm/driver/build_module.py | 1 + python/tvm/tir/transform/transform.py | 40 ++++ src/driver/driver_api.cc | 1 + src/target/llvm/codegen_llvm.cc | 19 +- src/tir/transforms/bf16_legalize.cc | 186 ++++++++++++++++++ .../unittest/test_target_codegen_llvm.py | 45 +++++ .../test_tir_transform_bf16_legalize.py | 139 +++++++++++++ 11 files changed, 450 insertions(+), 2 deletions(-) create mode 100644 src/tir/transforms/bf16_legalize.cc create mode 100644 tests/python/unittest/test_tir_transform_bf16_legalize.py diff --git a/include/tvm/runtime/c_runtime_api.h b/include/tvm/runtime/c_runtime_api.h index bb38ad8a84df..736b5eff2332 100644 --- a/include/tvm/runtime/c_runtime_api.h +++ b/include/tvm/runtime/c_runtime_api.h @@ -114,6 +114,7 @@ typedef enum { kTVMNNVMLast = 20U, // The following section of code is used for non-reserved types. kTVMExtReserveEnd = 64U, + kTVMBFloat = 65U, kTVMExtEnd = 128U, // The rest of the space is used for custom, user-supplied datatypes kTVMCustomBegin = 129U, diff --git a/include/tvm/runtime/data_type.h b/include/tvm/runtime/data_type.h index a10b83fd321b..eb8284cf4ae1 100644 --- a/include/tvm/runtime/data_type.h +++ b/include/tvm/runtime/data_type.h @@ -46,6 +46,7 @@ class DataType { kUInt = kDLUInt, kFloat = kDLFloat, kHandle = TVMTypeCode::kTVMOpaqueHandle, + kBFloat = kTVMBFloat, }; /*! \brief default constructor */ DataType() {} @@ -81,6 +82,10 @@ class DataType { bool is_float() const { return code() == DataType::kFloat; } /*! \return whether type is a float16 type. */ bool is_float16() const { return is_float() && bits() == 16; } + /*! \return whether type is a bfloat type. */ + bool is_bfloat() const { return code() == DataType::kBFloat; } + /*! \return whether type is a bfloat16 type. */ + bool is_bf16() const { return code() == DataType::kBFloat && bits() == 16; } /*! \return whether type is an int type. */ bool is_int() const { return code() == DataType::kInt; } /*! \return whether type is an uint type. */ @@ -297,6 +302,8 @@ inline const char* TypeCode2Str(int type_code) { return "Object"; case kTVMObjectRValueRefArg: return "ObjectRValueRefArg"; + case kTVMBFloat: + return "bf"; default: LOG(FATAL) << "unknown type_code=" << static_cast(type_code); return ""; @@ -363,6 +370,9 @@ inline DLDataType String2DLDataType(std::string s) { t.bits = 1; t.lanes = 1; return t; + } else if (s.substr(0, 2) == "bf") { + t.code = kTVMBFloat; + scan = s.c_str() + 2; } else if (s.substr(0, 6) == "custom") { t.code = ParseCustomDatatype(s, &scan); } else { diff --git a/include/tvm/tir/op.h b/include/tvm/tir/op.h index 5884942ebef1..d2c934bf73e8 100644 --- a/include/tvm/tir/op.h +++ b/include/tvm/tir/op.h @@ -735,7 +735,7 @@ inline PrimExpr MakeConstScalar(DataType t, ValueType value) { return LargeUIntImm(t, static_cast(low), static_cast(high)); } } - if (t.is_float()) return FloatImm(t, static_cast(value)); + if (t.is_float() || t.is_bf16()) return FloatImm(t, static_cast(value)); // For now, we store const scalar values of custom datatypes within doubles; later, during the // datatypes lowering pass, we will lower the value to its true representation in the format // specified by the datatype. diff --git a/include/tvm/tir/transform.h b/include/tvm/tir/transform.h index 13e1e2510e29..50fc31f20dc5 100644 --- a/include/tvm/tir/transform.h +++ b/include/tvm/tir/transform.h @@ -330,6 +330,14 @@ TVM_DLL Pass CombineContextCall(); */ TVM_DLL Pass NarrowDataType(int target_bits); + +/*! + * \brief Legalize bf16 typed Ops. Add a cast to fp32 + * before Ops, then add a cast back to bf16. + * \return The pass. + */ +TVM_DLL Pass BF16Legalize(); + /*! * \brief Rewrite the pointer content type of arguments, * as well as Alloc internal to the function to use diff --git a/python/tvm/driver/build_module.py b/python/tvm/driver/build_module.py index 216cad992d98..c8cb06167d4a 100644 --- a/python/tvm/driver/build_module.py +++ b/python/tvm/driver/build_module.py @@ -175,6 +175,7 @@ def lower(sch, pass_list += [ tvm.tir.transform.InjectPrefetch(), tvm.tir.transform.StorageFlatten(64, cfg.instrument_bound_checkers), + tvm.tir.transform.BF16Legalize(), tvm.tir.transform.NarrowDataType(32), tvm.tir.transform.Simplify(), ] diff --git a/python/tvm/tir/transform/transform.py b/python/tvm/tir/transform/transform.py index 6d797f8772ec..3112d5327194 100644 --- a/python/tvm/tir/transform/transform.py +++ b/python/tvm/tir/transform/transform.py @@ -255,6 +255,46 @@ def RemoveNoOp(): """ return _ffi_api.RemoveNoOp() +def BF16Legalize(): + """Legalize bf16 typed Ops. + Runs BF16Promote and BF16CastElimination + + Returns + ------- + fpass : tvm.transform.Pass + The result pass + """ + return _ffi_api.BF16Legalize() + +def BF16Promote(): + """Promote bf16 to fp32. Add a cast to fp32 + before Ops, then add a cast back to bf16. + + Returns + ------- + fpass : tvm.transform.Pass + The result pass + """ + return _ffi_api.BF16Promote() + +def BF16CastElimination(): + """ + Eliminate verbose casting between fp32 and bf16 + Checks if the AST has the pattern: + castto32(castto16(some_fp32_op(...))) + The verbose casting is generated by BF16Promote for multiple + bf16 Ops in a row. e.g.: + X[i] + Y[i] + T[i] => + bf16((float32(bf16((float32(X[i]) + float32(Y[i])))) + float32(T[i]))) + After this pass: + bf16(float32(X[i]) + float32(Y[i]) + float32(T[i])) + + Returns + ------- + fpass : tvm.transform.Pass + The result pass + """ + return _ffi_api.BF16CastElimination() def RewriteUnsafeSelect(): """Detect and rewrite unsafe select that contains memory access. diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc index cdd9d5441b25..d3d4387b0ebb 100644 --- a/src/driver/driver_api.cc +++ b/src/driver/driver_api.cc @@ -149,6 +149,7 @@ IRModule lower(te::Schedule sch, const Array& args, const std::strin pass_list.push_back(tir::transform::InjectPrefetch()); pass_list.push_back(tir::transform::StorageFlatten(64, config->instrument_bound_checkers)); // Phase 1 + pass_list.push_back(tir::transform::BF16Legalize()); pass_list.push_back(tir::transform::NarrowDataType(32)); pass_list.push_back(tir::transform::Simplify()); pass_list.push_back(tir::transform::LoopPartition(config->partition_const_loop)); diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc index f664532b2dc1..9bb0aadcc150 100644 --- a/src/target/llvm/codegen_llvm.cc +++ b/src/target/llvm/codegen_llvm.cc @@ -309,6 +309,9 @@ llvm::Type* CodeGenLLVM::DTypeToLLVMType(const DataType& dtype) const { default: LOG(FATAL) << "do not support " << dtype; } + } else if (dtype.is_bfloat()) { + CHECK_EQ(dtype.bits(), 16); + etype = llvm::Type::getInt16Ty(*ctx_); } if (dtype.lanes() != 1) { return llvm::VectorType::get(etype, dtype.lanes()); @@ -561,6 +564,20 @@ llvm::Value* CodeGenLLVM::CreateCast(DataType from, DataType to, llvm::Value* va if (value->getType() == target) return value; if (to.is_handle()) { return builder_->CreateBitCast(value, target); + } else if (to.is_float() && from.is_bfloat()) { + CHECK_EQ(from.bits(), 16); + CHECK_EQ(to.bits(), 32); + auto v = builder_->CreateZExt(value, builder_->getInt32Ty()); + if (module_->getDataLayout().isLittleEndian()) + v = builder_->CreateShl(v, 16); + return builder_->CreateBitCast(v, target); + } else if (to.is_bfloat() && from.is_float()) { + CHECK_EQ(to.bits(), 16); + CHECK_EQ(from.bits(), 32); + auto v = builder_->CreateBitCast(value, builder_->getInt32Ty()); + if (module_->getDataLayout().isLittleEndian()) + v = builder_->CreateLShr(v, 16); + return builder_->CreateTrunc(v, target); } else if (to.is_uint() && to.bits() == 1) { if (from.is_float()) { llvm::Constant* zero = llvm::ConstantFP::get(DTypeToLLVMType(from), 0.); @@ -906,7 +923,7 @@ DEFINE_CODEGEN_BINARY_OP(Mul); llvm::Value* CodeGenLLVM::Create##Op(DataType t, llvm::Value* a, llvm::Value* b) { \ if (t.is_int()) { \ return builder_->CreateICmpS##Op(a, b); \ - } else if (t.is_uint()) { \ + } else if (t.is_uint() || t.is_bfloat()) { \ return builder_->CreateICmpU##Op(a, b); \ } else { \ CHECK(t.is_float()); \ diff --git a/src/tir/transforms/bf16_legalize.cc b/src/tir/transforms/bf16_legalize.cc new file mode 100644 index 000000000000..1cb276b80d82 --- /dev/null +++ b/src/tir/transforms/bf16_legalize.cc @@ -0,0 +1,186 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file narrow_datatype.cc + * \brief narrow the datatype of indexing vars + */ + +#include +#include +#include +#include +#include "../../arith/ir_mutator_with_analyzer.h" +#include "../../arith/ir_visitor_with_analyzer.h" +#include + +namespace tvm { +namespace tir { + +using arith::Analyzer; +using arith::IRMutatorWithAnalyzer; + +class BF16PromoteRewriter : public StmtExprMutator { + public: + explicit BF16PromoteRewriter() {} + + Stmt operator()(Stmt s) { + return VisitStmt(s); + } + + std::tuple DoCast(PrimExpr orig_a, + PrimExpr orig_b, bool& is_bf16) { + auto a = this->VisitExpr(orig_a); + auto b = this->VisitExpr(orig_b); + is_bf16 = false; + if (a->dtype.is_bf16()) { + CHECK(b->dtype.is_bf16()); + is_bf16 = true; + } else if (b->dtype.is_bf16()) { + CHECK(a->dtype.is_bf16()); + is_bf16 = true; + } + + if (is_bf16) { + DataType fp32ty(kDLFloat, 32, 1); + a = CastNode::make(fp32ty, a); + b = CastNode::make(fp32ty, b); + } + return std::make_tuple(a, b); + } + + PrimExpr VisitExpr_(const AddNode* op) final; + PrimExpr VisitExpr_(const SubNode* op) final; + PrimExpr VisitExpr_(const MulNode* op) final; + PrimExpr VisitExpr_(const DivNode* op) final; + PrimExpr VisitExpr_(const MinNode* op) final; + PrimExpr VisitExpr_(const MaxNode* op) final; + PrimExpr VisitExpr_(const LTNode* op) final; + PrimExpr VisitExpr_(const LENode* op) final; + PrimExpr VisitExpr_(const GTNode* op) final; + PrimExpr VisitExpr_(const GENode* op) final; +}; + + +#define DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(OP, FUNC) \ + PrimExpr BF16PromoteRewriter::VisitExpr_(const OP* op) { \ + PrimExpr a, b; \ + bool is_bf16; \ + std::tie(a, b) = DoCast(op->a, op->b, is_bf16); \ + if (a.same_as(op->a) && \ + b.same_as(op->b)) { \ + return GetRef(op); \ + } else { \ + auto ret = FUNC(a, b); \ + if (!is_bf16) \ + return ret; \ + else \ + return CastNode::make(DataType(kTVMBFloat, 16, 1), ret); \ + } \ + } + +DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(AddNode, operator+) +DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(SubNode, operator-) +DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(MulNode, operator*) +DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(DivNode, div) +DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(MinNode, min) +DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(MaxNode, max) +DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(LTNode, operator <) +DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(LENode, operator<=) +DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(GTNode, operator >) +DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(GENode, operator>=) + +/* + * Eliminate verbose casting between fp32 and bf16 + * Checks if the AST has the pattern: + * castto32(castto16(some_fp32_op(...))) + * The verbose casting is generated by BF16Promote for multiple + * bf16 Ops in a row. e.g.: + * X[i] + Y[i] + T[i] => + * bf16((float32(bf16((float32(X[i]) + float32(Y[i])))) + float32(T[i]))) + * After this pass: + * bf16(float32(X[i]) + float32(Y[i]) + float32(T[i])) +*/ +class BF16CastEliminationRewriter : public StmtExprMutator { + public: + explicit BF16CastEliminationRewriter() {} + + Stmt operator()(Stmt s) { + return VisitStmt(s); + } + + PrimExpr VisitExpr_(const CastNode* op) { + auto op_val = StmtExprMutator::VisitExpr(op->value); + if (op->dtype.is_float() && op->dtype.bits() == 32) { + // if is cast_to_fp32, check if op->value is cast_to_fp16 + // and op->value->value is a float32 + if (auto innercast = op_val.as()) { + if (innercast->dtype.is_bf16() + && innercast->value->dtype.is_float() + && innercast->value->dtype.bits() == 32) { + return innercast->value; + } + } + } + if (op->value.same_as(op_val)) + return GetRef(op); + return CastNode::make(op->dtype, op_val); + } +}; + + +namespace transform { + +Pass BF16Promote() { + auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { + auto* n = f.CopyOnWrite(); + n->body = BF16PromoteRewriter()(std::move(n->body)); + return f; + }; + return CreatePrimFuncPass( + pass_func, 0, "tir.BF16Promote", {}); +} + +TVM_REGISTER_GLOBAL("tir.transform.BF16Promote") +.set_body_typed(BF16Promote); + +Pass BF16CastElimination() { + auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { + auto* n = f.CopyOnWrite(); + n->body = BF16CastEliminationRewriter()(std::move(n->body)); + return f; + }; + return CreatePrimFuncPass( + pass_func, 0, "tir.BF16CastElimination", {}); +} + +TVM_REGISTER_GLOBAL("tir.transform.BF16CastElimination") +.set_body_typed(BF16CastElimination); + +Pass BF16Legalize() { + return Sequential({BF16Promote(), BF16CastElimination()}, + "tir.BF16Legalize"); +} + +TVM_REGISTER_GLOBAL("tir.transform.BF16Legalize") +.set_body_typed(BF16Legalize); + +} // namespace transform +} // namespace tir +} // namespace tvm diff --git a/tests/python/unittest/test_target_codegen_llvm.py b/tests/python/unittest/test_target_codegen_llvm.py index c6591721d247..98155d28abb2 100644 --- a/tests/python/unittest/test_target_codegen_llvm.py +++ b/tests/python/unittest/test_target_codegen_llvm.py @@ -710,6 +710,50 @@ def _transform(f, *_): module(a_, b_, c_) tvm.testing.assert_allclose(c_.asnumpy(), (a_.asnumpy() * 2).astype('int32')) +import sys +import struct + +def float2bf16(v): + ba = bytearray(struct.pack("f", v)) + if sys.byteorder=='little': + return struct.unpack('h',ba[2:])[0] + else: + return struct.unpack('h',ba[0:2])[0] + +def bf162float(v): + ba = struct.pack("h", v) + if sys.byteorder=='little': + return struct.unpack('f', b"\0\0" + ba)[0] + else: + return struct.unpack('f', ba + b"\0\0")[0] + +def bf16_cast_and_cast_back(v): + return bf162float(float2bf16(v)) + +def test_llvm_bf16(): + np.random.seed(122) + A = te.placeholder((6, )) + B = te.placeholder((6, )) + a = te.compute((6, ), lambda x: topi.cast(A[x], 'bf16'), 'A') + b = te.compute((6, ), lambda x: topi.cast(B[x], 'bf16'), 'B') + c = te.compute((6, ), lambda x: a[x] + b[x]) + d = te.compute((6, ), lambda x: topi.cast(c[x], 'float'), 'D') + sch = te.create_schedule(d.op) + module = tvm.build(sch, [A, B, d]) + + npa = np.random.rand(6).astype('float32') + npb = np.random.rand(6).astype('float32') + res = [0] * len(npa) + for i in range(len(npa)): + va = bf16_cast_and_cast_back(npa[i]) + vb = bf16_cast_and_cast_back(npb[i]) + res[i] = bf16_cast_and_cast_back(va + vb) + a_ = tvm.nd.array(npa) + b_ = tvm.nd.array(npb) + c_ = tvm.nd.array(np.zeros((6,), dtype='float32')) + module(a_, b_, c_) + tvm.testing.assert_allclose(c_.asnumpy(), res) + if __name__ == "__main__": test_multiple_func() test_llvm_large_uintimm() @@ -732,3 +776,4 @@ def _transform(f, *_): test_llvm_fp_math() test_dwarf_debug_information() test_llvm_shuffle() + test_llvm_bf16() diff --git a/tests/python/unittest/test_tir_transform_bf16_legalize.py b/tests/python/unittest/test_tir_transform_bf16_legalize.py new file mode 100644 index 000000000000..4df43b2e213e --- /dev/null +++ b/tests/python/unittest/test_tir_transform_bf16_legalize.py @@ -0,0 +1,139 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import tvm +import topi +from tvm import te +from tvm.tir import const + + +def lower_stmt(sche, params, passfunc): + func = tvm.driver.build_module.form_irmodule(sche, params, "main", None)["main"] + func = passfunc()( + tvm.IRModule.from_expr(func))["main"] + stmt = func.body + return stmt + +def to32(v): + return topi.cast(v, 'float') +def to16(v): + return topi.cast(v, 'bf16') + +def test_promote(): + def runpass(op, passfunc): + a = te.placeholder((100,), dtype='bf16') + b = te.placeholder((100,), dtype='bf16') + c = te.compute((100,), lambda i: op(a[i], b[i])) + s = te.create_schedule(c.op) + return lower_stmt(s, [a, b, c], passfunc) + + def get_promoted(op): + a = te.placeholder((100,), dtype='bf16') + b = te.placeholder((100,), dtype='bf16') + c = te.compute((100,), lambda i: + topi.cast(op(topi.cast(a[i],'float'), + topi.cast(b[i],'float')), 'bf16') + ) + s = te.create_schedule(c.op) + func = tvm.driver.build_module.form_irmodule(s, [a,b,c], "main", None)["main"] + return func.body + + def test_promoted(op): + stmt = runpass(op, tvm.tir.transform.BF16Promote) + tvm.ir.assert_structural_equal(stmt, get_promoted(op)) + test_promoted(topi.add) + test_promoted(topi.subtract) + test_promoted(topi.multiply) + test_promoted(topi.divide) + +def test_eliminate(): + def get_eliminated(): + a = te.placeholder((100,), dtype='bf16') + b = te.placeholder((100,), dtype='bf16') + c = te.compute((100,), lambda i: to16( + topi.add( + to32( + to16( + topi.add( + to32(a[i]), + to32(b[i]), + ) + ) + ), + to32( + to16( + topi.add( + to32(a[i]), + to32(b[i]), + ) + ) + ) + ) + )) + s = te.create_schedule(c.op) + stmt = lower_stmt(s, [a, b, c], tvm.tir.transform.BF16CastElimination) + return stmt + + def get_target(): + a = te.placeholder((100,), dtype='bf16') + b = te.placeholder((100,), dtype='bf16') + c = te.compute((100,), lambda i: to16( + topi.add(topi.add( + to32(a[i]), + to32(b[i]), + ), + topi.add( + to32(a[i]), + to32(b[i]), + ) + ) + )) + s = te.create_schedule(c.op) + func = tvm.driver.build_module.form_irmodule(s, [a,b,c], "main", None)["main"] + return func.body + + tvm.ir.assert_structural_equal(get_eliminated(), get_target()) + +def test_legalize(): + def check(fcompute_before, fcompute_after): + a = te.placeholder((100,), dtype='bf16') + b = te.placeholder((100,), dtype='bf16') + c = te.compute((100,), fcompute_before(a,b)) + s = te.create_schedule(c.op) + stmt = lower_stmt(s, [a, b, c], tvm.tir.transform.BF16Legalize) + + a = te.placeholder((100,), dtype='bf16') + b = te.placeholder((100,), dtype='bf16') + c = te.compute((100,), fcompute_after(a,b)) + s = te.create_schedule(c.op) + func = tvm.driver.build_module.form_irmodule(s, [a,b,c], "main", None)["main"] + tvm.ir.assert_structural_equal(stmt, func.body) + + def orig1(a,b): + return lambda i: a[i]+b[i]+a[99-i]+b[99-i] + def after1(a,b): + return lambda i: to16(to32(a[i])+to32(b[i])+to32(a[99-i])+to32(b[99-i])) + def orig1(a,b): + return lambda i: a[i]*b[i]+a[99-i]*b[99-i]+a[i] + def after1(a,b): + return lambda i: to16(to32(a[i])*to32(b[i])+to32(a[99-i])*to32(b[99-i])+to32(a[i])) + + check(orig1, after1) + +if __name__ == "__main__": + test_promote() + test_eliminate() + test_legalize() \ No newline at end of file From 3fba6848da864b15a676bf2ea293ea8163023a4f Mon Sep 17 00:00:00 2001 From: Menooker Date: Wed, 6 May 2020 13:51:14 +0800 Subject: [PATCH 02/43] add bf16 in DataType (py) --- python/tvm/_ffi/runtime_ctypes.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/python/tvm/_ffi/runtime_ctypes.py b/python/tvm/_ffi/runtime_ctypes.py index 0d6e5ac18fb3..adf4439c343e 100644 --- a/python/tvm/_ffi/runtime_ctypes.py +++ b/python/tvm/_ffi/runtime_ctypes.py @@ -58,7 +58,8 @@ class DataType(ctypes.Structure): 0 : 'int', 1 : 'uint', 2 : 'float', - 4 : 'handle' + 4 : 'handle', + 65: 'bf' } def __init__(self, type_str): super(DataType, self).__init__() @@ -85,6 +86,9 @@ def __init__(self, type_str): elif head.startswith("float"): self.type_code = 2 head = head[5:] + elif head.startswith("bf"): + self.type_code = 65 + head = head[2:] elif head.startswith("handle"): self.type_code = 4 bits = 64 From 48e7e9461096ccc41d3bb9f38215aaffc6ac8049 Mon Sep 17 00:00:00 2001 From: Menooker Date: Thu, 7 May 2020 10:33:42 +0800 Subject: [PATCH 03/43] ndarray of bf16 --- python/tvm/runtime/ndarray.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/python/tvm/runtime/ndarray.py b/python/tvm/runtime/ndarray.py index 9f5f0f685e8d..6508ed990fa5 100644 --- a/python/tvm/runtime/ndarray.py +++ b/python/tvm/runtime/ndarray.py @@ -138,7 +138,10 @@ def copyfrom(self, source_array): if source_array.shape != shape: raise ValueError("array shape do not match the shape of NDArray {0} vs {1}".format( source_array.shape, shape)) - source_array = np.ascontiguousarray(source_array, dtype=dtype) + if dtype == 'bf16': + source_array = np.ascontiguousarray(source_array, dtype='uint16') + else: + source_array = np.ascontiguousarray(source_array, dtype=dtype) assert source_array.flags['C_CONTIGUOUS'] data = source_array.ctypes.data_as(ctypes.c_void_p) nbytes = ctypes.c_size_t(source_array.size * source_array.dtype.itemsize) @@ -167,7 +170,10 @@ def asnumpy(self): shape = shape + (t.lanes,) t.lanes = 1 dtype = str(t) - np_arr = np.empty(shape, dtype=dtype) + if dtype == 'bf16': + np_arr = np.empty(shape, dtype='uint16') + else: + np_arr = np.empty(shape, dtype=dtype) assert np_arr.flags['C_CONTIGUOUS'] data = np_arr.ctypes.data_as(ctypes.c_void_p) nbytes = ctypes.c_size_t(np_arr.size * np_arr.dtype.itemsize) From 17ef57bdfc6e351ff09330bbb447ef5964890f7c Mon Sep 17 00:00:00 2001 From: Menooker Date: Thu, 7 May 2020 12:07:47 +0800 Subject: [PATCH 04/43] do not cast back for compare op --- src/tir/transforms/bf16_legalize.cc | 24 +++++++++++++++++++----- 1 file changed, 19 insertions(+), 5 deletions(-) diff --git a/src/tir/transforms/bf16_legalize.cc b/src/tir/transforms/bf16_legalize.cc index 1cb276b80d82..cc5de98e8633 100644 --- a/src/tir/transforms/bf16_legalize.cc +++ b/src/tir/transforms/bf16_legalize.cc @@ -79,7 +79,7 @@ class BF16PromoteRewriter : public StmtExprMutator { #define DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(OP, FUNC) \ - PrimExpr BF16PromoteRewriter::VisitExpr_(const OP* op) { \ + PrimExpr BF16PromoteRewriter::VisitExpr_(const OP* op) { \ PrimExpr a, b; \ bool is_bf16; \ std::tie(a, b) = DoCast(op->a, op->b, is_bf16); \ @@ -95,16 +95,30 @@ class BF16PromoteRewriter : public StmtExprMutator { } \ } +#define DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH_NO_CAST(OP, FUNC) \ + PrimExpr BF16PromoteRewriter::VisitExpr_(const OP* op) { \ + PrimExpr a, b; \ + bool is_bf16; \ + std::tie(a, b) = DoCast(op->a, op->b, is_bf16); \ + if (a.same_as(op->a) && \ + b.same_as(op->b)) { \ + return GetRef(op); \ + } else { \ + auto ret = FUNC(a, b); \ + return ret; \ + } \ + } + DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(AddNode, operator+) DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(SubNode, operator-) DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(MulNode, operator*) DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(DivNode, div) DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(MinNode, min) DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(MaxNode, max) -DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(LTNode, operator <) -DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(LENode, operator<=) -DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(GTNode, operator >) -DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(GENode, operator>=) +DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH_NO_CAST(LTNode, operator <) +DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH_NO_CAST(LENode, operator<=) +DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH_NO_CAST(GTNode, operator >) +DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH_NO_CAST(GENode, operator>=) /* * Eliminate verbose casting between fp32 and bf16 From 96f30195cee15e6441d204f37db19e9349998e4c Mon Sep 17 00:00:00 2001 From: Menooker Date: Thu, 7 May 2020 15:16:29 +0800 Subject: [PATCH 05/43] const gen --- src/target/llvm/codegen_llvm.cc | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc index 9bb0aadcc150..99592a8d7b49 100644 --- a/src/target/llvm/codegen_llvm.cc +++ b/src/target/llvm/codegen_llvm.cc @@ -887,6 +887,15 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const IntImmNode* op) { } llvm::Value* CodeGenLLVM::VisitExpr_(const FloatImmNode* op) { + if (op->dtype.is_bf16()) { + auto fp = float(op->value); + auto p = reinterpret_cast(&fp); + #if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__ + return this->builder_->getInt16(p[0]); + #else + return this->builder_->getInt16(p[1]); + #endif + } return llvm::ConstantFP::get(DTypeToLLVMType(op->dtype), op->value); } From 4aeff415dca4d9e1ebfbc7fd67acfe256bf2ef21 Mon Sep 17 00:00:00 2001 From: Menooker Date: Thu, 7 May 2020 16:17:41 +0800 Subject: [PATCH 06/43] more precise --- src/target/llvm/codegen_llvm.cc | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc index 99592a8d7b49..614eba199433 100644 --- a/src/target/llvm/codegen_llvm.cc +++ b/src/target/llvm/codegen_llvm.cc @@ -575,8 +575,11 @@ llvm::Value* CodeGenLLVM::CreateCast(DataType from, DataType to, llvm::Value* va CHECK_EQ(to.bits(), 16); CHECK_EQ(from.bits(), 32); auto v = builder_->CreateBitCast(value, builder_->getInt32Ty()); - if (module_->getDataLayout().isLittleEndian()) - v = builder_->CreateLShr(v, 16); + auto bias = builder_->CreateLShr(v, 16); + bias = builder_->CreateAnd(bias, builder_->getInt32(1)); + bias = builder_->CreateAdd(bias, builder_->getInt32(0x7fff)); + v = builder_->CreateAdd(v, bias); + v = builder_->CreateLShr(v, 16); return builder_->CreateTrunc(v, target); } else if (to.is_uint() && to.bits() == 1) { if (from.is_float()) { From c551a3d80d359c9166f32adc1f6b0f2451746cdf Mon Sep 17 00:00:00 2001 From: Menooker Date: Fri, 8 May 2020 09:16:58 +0800 Subject: [PATCH 07/43] update test --- .../unittest/test_target_codegen_llvm.py | 68 +++++++++---------- 1 file changed, 33 insertions(+), 35 deletions(-) diff --git a/tests/python/unittest/test_target_codegen_llvm.py b/tests/python/unittest/test_target_codegen_llvm.py index 98155d28abb2..e848f2d9d4e7 100644 --- a/tests/python/unittest/test_target_codegen_llvm.py +++ b/tests/python/unittest/test_target_codegen_llvm.py @@ -710,51 +710,50 @@ def _transform(f, *_): module(a_, b_, c_) tvm.testing.assert_allclose(c_.asnumpy(), (a_.asnumpy() * 2).astype('int32')) -import sys -import struct - -def float2bf16(v): - ba = bytearray(struct.pack("f", v)) - if sys.byteorder=='little': - return struct.unpack('h',ba[2:])[0] - else: - return struct.unpack('h',ba[0:2])[0] - -def bf162float(v): - ba = struct.pack("h", v) - if sys.byteorder=='little': - return struct.unpack('f', b"\0\0" + ba)[0] - else: - return struct.unpack('f', ba + b"\0\0")[0] - -def bf16_cast_and_cast_back(v): - return bf162float(float2bf16(v)) +def np_float2np_bf16(arr): + ''' Convert a numpy array of float to a numpy array + of bf16 in uint16''' + orig = arr.view(' Date: Sat, 9 May 2020 15:57:09 +0800 Subject: [PATCH 08/43] enable vectorization --- src/target/llvm/codegen_llvm.cc | 33 +++++++++++++++++++++++++++------ 1 file changed, 27 insertions(+), 6 deletions(-) diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc index 614eba199433..dd134b61cc63 100644 --- a/src/target/llvm/codegen_llvm.cc +++ b/src/target/llvm/codegen_llvm.cc @@ -558,6 +558,22 @@ void CodeGenLLVM::CreateSerialFor(llvm::Value* begin, llvm::Value* end, llvm::Va builder_->SetInsertPoint(for_end); } +static llvm::Value* GetInt32VectorOrScalar( + llvm::IRBuilder& builder, + uint32_t v, + int lanes) { + if (lanes == 1) { + return builder.getInt32(v); + } else { + std::vector consts; + for (int i = 0; i < lanes; i++) { + consts.emplace_back(builder.getInt32(v)); + } + return llvm::ConstantVector::get(consts); + } +} + // cast operatpr llvm::Value* CodeGenLLVM::CreateCast(DataType from, DataType to, llvm::Value* value) { llvm::Type* target = DTypeToLLVMType(to); @@ -567,17 +583,22 @@ llvm::Value* CodeGenLLVM::CreateCast(DataType from, DataType to, llvm::Value* va } else if (to.is_float() && from.is_bfloat()) { CHECK_EQ(from.bits(), 16); CHECK_EQ(to.bits(), 32); - auto v = builder_->CreateZExt(value, builder_->getInt32Ty()); - if (module_->getDataLayout().isLittleEndian()) - v = builder_->CreateShl(v, 16); + llvm::Type* extended_type = (from.lanes() != 1) ? + static_cast(builder_->getInt32Ty()) : + llvm::VectorType::get(builder_->getInt32Ty(), from.lanes()); + auto v = builder_->CreateZExt(value, extended_type); + v = builder_->CreateShl(v, 16); return builder_->CreateBitCast(v, target); } else if (to.is_bfloat() && from.is_float()) { CHECK_EQ(to.bits(), 16); CHECK_EQ(from.bits(), 32); - auto v = builder_->CreateBitCast(value, builder_->getInt32Ty()); + llvm::Type* extended_type = (from.lanes() != 1) ? + static_cast(builder_->getInt32Ty()) : + llvm::VectorType::get(builder_->getInt32Ty(), to.lanes()); + auto v = builder_->CreateBitCast(value, extended_type); auto bias = builder_->CreateLShr(v, 16); - bias = builder_->CreateAnd(bias, builder_->getInt32(1)); - bias = builder_->CreateAdd(bias, builder_->getInt32(0x7fff)); + bias = builder_->CreateAnd(bias, GetInt32VectorOrScalar(*builder_, 1, to.lanes())); + bias = builder_->CreateAdd(bias, GetInt32VectorOrScalar(*builder_, 0x7fff, to.lanes())); v = builder_->CreateAdd(v, bias); v = builder_->CreateLShr(v, 16); return builder_->CreateTrunc(v, target); From 3c5c0f473a6ccdeb972e1890d5df361d8fe22c26 Mon Sep 17 00:00:00 2001 From: Menooker Date: Sat, 9 May 2020 16:08:03 +0800 Subject: [PATCH 09/43] correct vectorize --- src/target/llvm/codegen_llvm.cc | 4 ++-- .../unittest/test_target_codegen_llvm.py | 20 ++++++++++--------- 2 files changed, 13 insertions(+), 11 deletions(-) diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc index dd134b61cc63..657f908c948b 100644 --- a/src/target/llvm/codegen_llvm.cc +++ b/src/target/llvm/codegen_llvm.cc @@ -583,7 +583,7 @@ llvm::Value* CodeGenLLVM::CreateCast(DataType from, DataType to, llvm::Value* va } else if (to.is_float() && from.is_bfloat()) { CHECK_EQ(from.bits(), 16); CHECK_EQ(to.bits(), 32); - llvm::Type* extended_type = (from.lanes() != 1) ? + llvm::Type* extended_type = (from.lanes() == 1) ? static_cast(builder_->getInt32Ty()) : llvm::VectorType::get(builder_->getInt32Ty(), from.lanes()); auto v = builder_->CreateZExt(value, extended_type); @@ -592,7 +592,7 @@ llvm::Value* CodeGenLLVM::CreateCast(DataType from, DataType to, llvm::Value* va } else if (to.is_bfloat() && from.is_float()) { CHECK_EQ(to.bits(), 16); CHECK_EQ(from.bits(), 32); - llvm::Type* extended_type = (from.lanes() != 1) ? + llvm::Type* extended_type = (from.lanes() == 1) ? static_cast(builder_->getInt32Ty()) : llvm::VectorType::get(builder_->getInt32Ty(), to.lanes()); auto v = builder_->CreateBitCast(value, extended_type); diff --git a/tests/python/unittest/test_target_codegen_llvm.py b/tests/python/unittest/test_target_codegen_llvm.py index e848f2d9d4e7..d254aca68a3f 100644 --- a/tests/python/unittest/test_target_codegen_llvm.py +++ b/tests/python/unittest/test_target_codegen_llvm.py @@ -733,27 +733,29 @@ def np_bf16_cast_and_cast_back(arr): ''' Convert a numpy array of float to bf16 and cast back''' return np_bf162np_float(np_float2np_bf16(arr)) -def test_llvm_bf16(): +def test_llvm_bf16(do_vectorize): np.random.seed(122) - A = te.placeholder((6, ), dtype='bf16') - B = te.placeholder((6, ), dtype='bf16') - d = te.compute((6, ), lambda x: A[x] + B[x]) + A = te.placeholder((32, ), dtype='bf16') + B = te.placeholder((32, ), dtype='bf16') + d = te.compute((32, ), lambda x: A[x] + B[x]) sch = te.create_schedule(d.op) + if do_vectorize: + sch[d].vectorize(d.op.axis[0]) module = tvm.build(sch, [A, B, d]) - - npa = np.random.rand(6).astype('float32') - npb = np.random.rand(6).astype('float32') + npa = np.random.rand(32).astype('float32') + npb = np.random.rand(32).astype('float32') va = np_bf16_cast_and_cast_back(npa) vb = np_bf16_cast_and_cast_back(npb) res = np_bf16_cast_and_cast_back(va + vb) a_ = np_float2tvm_bf16(npa) b_ = np_float2tvm_bf16(npb) - c_ = tvm.nd.empty((6,), 'bf16') + c_ = tvm.nd.empty((32,), 'bf16') module(a_, b_, c_) tvm.testing.assert_allclose(np_bf162np_float(c_.asnumpy()), res) if __name__ == "__main__": - test_llvm_bf16() + test_llvm_bf16(do_vectorize=True) + test_llvm_bf16(do_vectorize=False) test_multiple_func() test_llvm_large_uintimm() test_llvm_import() From c978b9e9658c2618cd0053f1a8e2258e650f9085 Mon Sep 17 00:00:00 2001 From: Menooker Date: Fri, 15 May 2020 16:35:45 +0800 Subject: [PATCH 10/43] linter changes --- src/target/llvm/codegen_llvm.cc | 12 ++++++------ src/tir/transforms/bf16_legalize.cc | 20 ++++++++++---------- 2 files changed, 16 insertions(+), 16 deletions(-) diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc index 657f908c948b..7b6853ef20ff 100644 --- a/src/target/llvm/codegen_llvm.cc +++ b/src/target/llvm/codegen_llvm.cc @@ -560,15 +560,15 @@ void CodeGenLLVM::CreateSerialFor(llvm::Value* begin, llvm::Value* end, llvm::Va static llvm::Value* GetInt32VectorOrScalar( llvm::IRBuilder& builder, + llvm::IRBuilderDefaultInserter>* builder, uint32_t v, int lanes) { if (lanes == 1) { - return builder.getInt32(v); + return builder->getInt32(v); } else { std::vector consts; for (int i = 0; i < lanes; i++) { - consts.emplace_back(builder.getInt32(v)); + consts.emplace_back(builder->getInt32(v)); } return llvm::ConstantVector::get(consts); } @@ -597,8 +597,8 @@ llvm::Value* CodeGenLLVM::CreateCast(DataType from, DataType to, llvm::Value* va llvm::VectorType::get(builder_->getInt32Ty(), to.lanes()); auto v = builder_->CreateBitCast(value, extended_type); auto bias = builder_->CreateLShr(v, 16); - bias = builder_->CreateAnd(bias, GetInt32VectorOrScalar(*builder_, 1, to.lanes())); - bias = builder_->CreateAdd(bias, GetInt32VectorOrScalar(*builder_, 0x7fff, to.lanes())); + bias = builder_->CreateAnd(bias, GetInt32VectorOrScalar(builder_, 1, to.lanes())); + bias = builder_->CreateAdd(bias, GetInt32VectorOrScalar(builder_, 0x7fff, to.lanes())); v = builder_->CreateAdd(v, bias); v = builder_->CreateLShr(v, 16); return builder_->CreateTrunc(v, target); @@ -912,7 +912,7 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const IntImmNode* op) { llvm::Value* CodeGenLLVM::VisitExpr_(const FloatImmNode* op) { if (op->dtype.is_bf16()) { - auto fp = float(op->value); + auto fp = static_cast(op->value); auto p = reinterpret_cast(&fp); #if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__ return this->builder_->getInt16(p[0]); diff --git a/src/tir/transforms/bf16_legalize.cc b/src/tir/transforms/bf16_legalize.cc index cc5de98e8633..b7a3070bd58a 100644 --- a/src/tir/transforms/bf16_legalize.cc +++ b/src/tir/transforms/bf16_legalize.cc @@ -26,9 +26,9 @@ #include #include #include + #include "../../arith/ir_mutator_with_analyzer.h" #include "../../arith/ir_visitor_with_analyzer.h" -#include namespace tvm { namespace tir { @@ -38,23 +38,23 @@ using arith::IRMutatorWithAnalyzer; class BF16PromoteRewriter : public StmtExprMutator { public: - explicit BF16PromoteRewriter() {} + BF16PromoteRewriter() {} Stmt operator()(Stmt s) { return VisitStmt(s); } std::tuple DoCast(PrimExpr orig_a, - PrimExpr orig_b, bool& is_bf16) { + PrimExpr orig_b, bool* is_bf16) { auto a = this->VisitExpr(orig_a); auto b = this->VisitExpr(orig_b); - is_bf16 = false; + *is_bf16 = false; if (a->dtype.is_bf16()) { CHECK(b->dtype.is_bf16()); - is_bf16 = true; + *is_bf16 = true; } else if (b->dtype.is_bf16()) { CHECK(a->dtype.is_bf16()); - is_bf16 = true; + *is_bf16 = true; } if (is_bf16) { @@ -82,7 +82,7 @@ class BF16PromoteRewriter : public StmtExprMutator { PrimExpr BF16PromoteRewriter::VisitExpr_(const OP* op) { \ PrimExpr a, b; \ bool is_bf16; \ - std::tie(a, b) = DoCast(op->a, op->b, is_bf16); \ + std::tie(a, b) = DoCast(op->a, op->b, &is_bf16); \ if (a.same_as(op->a) && \ b.same_as(op->b)) { \ return GetRef(op); \ @@ -99,7 +99,7 @@ class BF16PromoteRewriter : public StmtExprMutator { PrimExpr BF16PromoteRewriter::VisitExpr_(const OP* op) { \ PrimExpr a, b; \ bool is_bf16; \ - std::tie(a, b) = DoCast(op->a, op->b, is_bf16); \ + std::tie(a, b) = DoCast(op->a, op->b, &is_bf16); \ if (a.same_as(op->a) && \ b.same_as(op->b)) { \ return GetRef(op); \ @@ -133,7 +133,7 @@ DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH_NO_CAST(GENode, operator>=) */ class BF16CastEliminationRewriter : public StmtExprMutator { public: - explicit BF16CastEliminationRewriter() {} + BF16CastEliminationRewriter() {} Stmt operator()(Stmt s) { return VisitStmt(s); @@ -145,7 +145,7 @@ class BF16CastEliminationRewriter : public StmtExprMutator { // if is cast_to_fp32, check if op->value is cast_to_fp16 // and op->value->value is a float32 if (auto innercast = op_val.as()) { - if (innercast->dtype.is_bf16() + if (innercast->dtype.is_bf16() && innercast->value->dtype.is_float() && innercast->value->dtype.bits() == 32) { return innercast->value; From ef6f410ac5231e08bb6be2bdab1c6a7235c20038 Mon Sep 17 00:00:00 2001 From: Menooker Date: Fri, 15 May 2020 16:43:25 +0800 Subject: [PATCH 11/43] linter --- src/target/llvm/codegen_llvm.cc | 10 ++-- src/tir/transforms/bf16_legalize.cc | 74 ++++++++++++++--------------- 2 files changed, 39 insertions(+), 45 deletions(-) diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc index 7b6853ef20ff..9143d32b6525 100644 --- a/src/target/llvm/codegen_llvm.cc +++ b/src/target/llvm/codegen_llvm.cc @@ -559,10 +559,8 @@ void CodeGenLLVM::CreateSerialFor(llvm::Value* begin, llvm::Value* end, llvm::Va } static llvm::Value* GetInt32VectorOrScalar( - llvm::IRBuilder* builder, - uint32_t v, - int lanes) { + llvm::IRBuilder* builder, uint32_t v, + int lanes) { if (lanes == 1) { return builder->getInt32(v); } else { @@ -597,8 +595,8 @@ llvm::Value* CodeGenLLVM::CreateCast(DataType from, DataType to, llvm::Value* va llvm::VectorType::get(builder_->getInt32Ty(), to.lanes()); auto v = builder_->CreateBitCast(value, extended_type); auto bias = builder_->CreateLShr(v, 16); - bias = builder_->CreateAnd(bias, GetInt32VectorOrScalar(builder_, 1, to.lanes())); - bias = builder_->CreateAdd(bias, GetInt32VectorOrScalar(builder_, 0x7fff, to.lanes())); + bias = builder_->CreateAnd(bias, GetInt32VectorOrScalar(builder_.get(), 1, to.lanes())); + bias = builder_->CreateAdd(bias, GetInt32VectorOrScalar(builder_.get(), 0x7fff, to.lanes())); v = builder_->CreateAdd(v, bias); v = builder_->CreateLShr(v, 16); return builder_->CreateTrunc(v, target); diff --git a/src/tir/transforms/bf16_legalize.cc b/src/tir/transforms/bf16_legalize.cc index b7a3070bd58a..a18e8c593d77 100644 --- a/src/tir/transforms/bf16_legalize.cc +++ b/src/tir/transforms/bf16_legalize.cc @@ -22,9 +22,10 @@ * \brief narrow the datatype of indexing vars */ +#include #include #include -#include + #include #include "../../arith/ir_mutator_with_analyzer.h" @@ -44,8 +45,7 @@ class BF16PromoteRewriter : public StmtExprMutator { return VisitStmt(s); } - std::tuple DoCast(PrimExpr orig_a, - PrimExpr orig_b, bool* is_bf16) { + std::tuple DoCast(PrimExpr orig_a, PrimExpr orig_b, bool* is_bf16) { auto a = this->VisitExpr(orig_a); auto b = this->VisitExpr(orig_b); *is_bf16 = false; @@ -77,36 +77,33 @@ class BF16PromoteRewriter : public StmtExprMutator { PrimExpr VisitExpr_(const GENode* op) final; }; - -#define DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(OP, FUNC) \ - PrimExpr BF16PromoteRewriter::VisitExpr_(const OP* op) { \ - PrimExpr a, b; \ - bool is_bf16; \ - std::tie(a, b) = DoCast(op->a, op->b, &is_bf16); \ - if (a.same_as(op->a) && \ - b.same_as(op->b)) { \ - return GetRef(op); \ - } else { \ - auto ret = FUNC(a, b); \ - if (!is_bf16) \ - return ret; \ - else \ - return CastNode::make(DataType(kTVMBFloat, 16, 1), ret); \ - } \ +#define DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(OP, FUNC) \ + PrimExpr BF16PromoteRewriter::VisitExpr_(const OP* op) { \ + PrimExpr a, b; \ + bool is_bf16; \ + std::tie(a, b) = DoCast(op->a, op->b, &is_bf16); \ + if (a.same_as(op->a) && b.same_as(op->b)) { \ + return GetRef(op); \ + } else { \ + auto ret = FUNC(a, b); \ + if (!is_bf16) \ + return ret; \ + else \ + return CastNode::make(DataType(kTVMBFloat, 16, 1), ret); \ + } \ } -#define DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH_NO_CAST(OP, FUNC) \ - PrimExpr BF16PromoteRewriter::VisitExpr_(const OP* op) { \ - PrimExpr a, b; \ - bool is_bf16; \ - std::tie(a, b) = DoCast(op->a, op->b, &is_bf16); \ - if (a.same_as(op->a) && \ - b.same_as(op->b)) { \ - return GetRef(op); \ - } else { \ - auto ret = FUNC(a, b); \ - return ret; \ - } \ +#define DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH_NO_CAST(OP, FUNC) \ + PrimExpr BF16PromoteRewriter::VisitExpr_(const OP* op) { \ + PrimExpr a, b; \ + bool is_bf16; \ + std::tie(a, b) = DoCast(op->a, op->b, &is_bf16); \ + if (a.same_as(op->a) && b.same_as(op->b)) { \ + return GetRef(op); \ + } else { \ + auto ret = FUNC(a, b); \ + return ret; \ + } \ } DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(AddNode, operator+) @@ -142,15 +139,14 @@ class BF16CastEliminationRewriter : public StmtExprMutator { PrimExpr VisitExpr_(const CastNode* op) { auto op_val = StmtExprMutator::VisitExpr(op->value); if (op->dtype.is_float() && op->dtype.bits() == 32) { - // if is cast_to_fp32, check if op->value is cast_to_fp16 - // and op->value->value is a float32 - if (auto innercast = op_val.as()) { - if (innercast->dtype.is_bf16() - && innercast->value->dtype.is_float() - && innercast->value->dtype.bits() == 32) { - return innercast->value; - } + // if is cast_to_fp32, check if op->value is cast_to_fp16 + // and op->value->value is a float32 + if (auto innercast = op_val.as()) { + if (innercast->dtype.is_bf16() && innercast->value->dtype.is_float() && + innercast->value->dtype.bits() == 32) { + return innercast->value; } + } } if (op->value.same_as(op_val)) return GetRef(op); From 92d014a5161e436b400dbdc62fd585ce554faafc Mon Sep 17 00:00:00 2001 From: Menooker Date: Fri, 15 May 2020 16:52:55 +0800 Subject: [PATCH 12/43] linter --- include/tvm/tir/transform.h | 1 - src/target/llvm/codegen_llvm.cc | 22 ++++++------- src/tir/transforms/bf16_legalize.cc | 48 +++++++++++------------------ 3 files changed, 29 insertions(+), 42 deletions(-) diff --git a/include/tvm/tir/transform.h b/include/tvm/tir/transform.h index 50fc31f20dc5..1a43799fc313 100644 --- a/include/tvm/tir/transform.h +++ b/include/tvm/tir/transform.h @@ -330,7 +330,6 @@ TVM_DLL Pass CombineContextCall(); */ TVM_DLL Pass NarrowDataType(int target_bits); - /*! * \brief Legalize bf16 typed Ops. Add a cast to fp32 * before Ops, then add a cast back to bf16. diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc index 9143d32b6525..63e6f935175c 100644 --- a/src/target/llvm/codegen_llvm.cc +++ b/src/target/llvm/codegen_llvm.cc @@ -581,18 +581,18 @@ llvm::Value* CodeGenLLVM::CreateCast(DataType from, DataType to, llvm::Value* va } else if (to.is_float() && from.is_bfloat()) { CHECK_EQ(from.bits(), 16); CHECK_EQ(to.bits(), 32); - llvm::Type* extended_type = (from.lanes() == 1) ? - static_cast(builder_->getInt32Ty()) : - llvm::VectorType::get(builder_->getInt32Ty(), from.lanes()); + llvm::Type* extended_type = (from.lanes() == 1) + ? static_cast(builder_->getInt32Ty()) + : llvm::VectorType::get(builder_->getInt32Ty(), from.lanes()); auto v = builder_->CreateZExt(value, extended_type); v = builder_->CreateShl(v, 16); return builder_->CreateBitCast(v, target); } else if (to.is_bfloat() && from.is_float()) { CHECK_EQ(to.bits(), 16); CHECK_EQ(from.bits(), 32); - llvm::Type* extended_type = (from.lanes() == 1) ? - static_cast(builder_->getInt32Ty()) : - llvm::VectorType::get(builder_->getInt32Ty(), to.lanes()); + llvm::Type* extended_type = (from.lanes() == 1) + ? static_cast(builder_->getInt32Ty()) + : llvm::VectorType::get(builder_->getInt32Ty(), to.lanes()); auto v = builder_->CreateBitCast(value, extended_type); auto bias = builder_->CreateLShr(v, 16); bias = builder_->CreateAnd(bias, GetInt32VectorOrScalar(builder_.get(), 1, to.lanes())); @@ -912,11 +912,11 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const FloatImmNode* op) { if (op->dtype.is_bf16()) { auto fp = static_cast(op->value); auto p = reinterpret_cast(&fp); - #if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__ - return this->builder_->getInt16(p[0]); - #else - return this->builder_->getInt16(p[1]); - #endif +#if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__ + return this->builder_->getInt16(p[0]); +#else + return this->builder_->getInt16(p[1]); +#endif } return llvm::ConstantFP::get(DTypeToLLVMType(op->dtype), op->value); } diff --git a/src/tir/transforms/bf16_legalize.cc b/src/tir/transforms/bf16_legalize.cc index a18e8c593d77..44c58c118565 100644 --- a/src/tir/transforms/bf16_legalize.cc +++ b/src/tir/transforms/bf16_legalize.cc @@ -41,26 +41,24 @@ class BF16PromoteRewriter : public StmtExprMutator { public: BF16PromoteRewriter() {} - Stmt operator()(Stmt s) { - return VisitStmt(s); - } + Stmt operator()(Stmt s) { return VisitStmt(s); } std::tuple DoCast(PrimExpr orig_a, PrimExpr orig_b, bool* is_bf16) { auto a = this->VisitExpr(orig_a); auto b = this->VisitExpr(orig_b); *is_bf16 = false; if (a->dtype.is_bf16()) { - CHECK(b->dtype.is_bf16()); - *is_bf16 = true; + CHECK(b->dtype.is_bf16()); + *is_bf16 = true; } else if (b->dtype.is_bf16()) { - CHECK(a->dtype.is_bf16()); - *is_bf16 = true; + CHECK(a->dtype.is_bf16()); + *is_bf16 = true; } if (is_bf16) { - DataType fp32ty(kDLFloat, 32, 1); - a = CastNode::make(fp32ty, a); - b = CastNode::make(fp32ty, b); + DataType fp32ty(kDLFloat, 32, 1); + a = CastNode::make(fp32ty, a); + b = CastNode::make(fp32ty, b); } return std::make_tuple(a, b); } @@ -112,9 +110,9 @@ DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(MulNode, operator*) DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(DivNode, div) DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(MinNode, min) DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(MaxNode, max) -DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH_NO_CAST(LTNode, operator <) +DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH_NO_CAST(LTNode, operator<) DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH_NO_CAST(LENode, operator<=) -DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH_NO_CAST(GTNode, operator >) +DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH_NO_CAST(GTNode, operator>) DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH_NO_CAST(GENode, operator>=) /* @@ -132,9 +130,7 @@ class BF16CastEliminationRewriter : public StmtExprMutator { public: BF16CastEliminationRewriter() {} - Stmt operator()(Stmt s) { - return VisitStmt(s); - } + Stmt operator()(Stmt s) { return VisitStmt(s); } PrimExpr VisitExpr_(const CastNode* op) { auto op_val = StmtExprMutator::VisitExpr(op->value); @@ -148,13 +144,11 @@ class BF16CastEliminationRewriter : public StmtExprMutator { } } } - if (op->value.same_as(op_val)) - return GetRef(op); + if (op->value.same_as(op_val)) return GetRef(op); return CastNode::make(op->dtype, op_val); } }; - namespace transform { Pass BF16Promote() { @@ -163,12 +157,10 @@ Pass BF16Promote() { n->body = BF16PromoteRewriter()(std::move(n->body)); return f; }; - return CreatePrimFuncPass( - pass_func, 0, "tir.BF16Promote", {}); + return CreatePrimFuncPass(pass_func, 0, "tir.BF16Promote", {}); } -TVM_REGISTER_GLOBAL("tir.transform.BF16Promote") -.set_body_typed(BF16Promote); +TVM_REGISTER_GLOBAL("tir.transform.BF16Promote").set_body_typed(BF16Promote); Pass BF16CastElimination() { auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { @@ -176,20 +168,16 @@ Pass BF16CastElimination() { n->body = BF16CastEliminationRewriter()(std::move(n->body)); return f; }; - return CreatePrimFuncPass( - pass_func, 0, "tir.BF16CastElimination", {}); + return CreatePrimFuncPass(pass_func, 0, "tir.BF16CastElimination", {}); } -TVM_REGISTER_GLOBAL("tir.transform.BF16CastElimination") -.set_body_typed(BF16CastElimination); +TVM_REGISTER_GLOBAL("tir.transform.BF16CastElimination").set_body_typed(BF16CastElimination); Pass BF16Legalize() { - return Sequential({BF16Promote(), BF16CastElimination()}, - "tir.BF16Legalize"); + return Sequential({BF16Promote(), BF16CastElimination()}, "tir.BF16Legalize"); } -TVM_REGISTER_GLOBAL("tir.transform.BF16Legalize") -.set_body_typed(BF16Legalize); +TVM_REGISTER_GLOBAL("tir.transform.BF16Legalize").set_body_typed(BF16Legalize); } // namespace transform } // namespace tir From e23f33c722ff7bc6c4a55afa0c7d5338f72d397c Mon Sep 17 00:00:00 2001 From: Menooker Date: Fri, 15 May 2020 16:56:04 +0800 Subject: [PATCH 13/43] linter --- src/tir/transforms/bf16_legalize.cc | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/tir/transforms/bf16_legalize.cc b/src/tir/transforms/bf16_legalize.cc index 44c58c118565..ba682de1763f 100644 --- a/src/tir/transforms/bf16_legalize.cc +++ b/src/tir/transforms/bf16_legalize.cc @@ -110,10 +110,10 @@ DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(MulNode, operator*) DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(DivNode, div) DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(MinNode, min) DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(MaxNode, max) -DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH_NO_CAST(LTNode, operator<) -DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH_NO_CAST(LENode, operator<=) -DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH_NO_CAST(GTNode, operator>) -DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH_NO_CAST(GENode, operator>=) +DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH_NO_CAST(LTNode, operator<) // NOLINT(*) +DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH_NO_CAST(LENode, operator<=) // NOLINT(*) +DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH_NO_CAST(GTNode, operator>) // NOLINT(*) +DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH_NO_CAST(GENode, operator>=) // NOLINT(*) /* * Eliminate verbose casting between fp32 and bf16 From a245d410b7e5970b613b3bf5bbc08cacacb96726 Mon Sep 17 00:00:00 2001 From: Menooker Date: Fri, 15 May 2020 20:20:42 +0800 Subject: [PATCH 14/43] Update bf16_legalize.cc --- src/tir/transforms/bf16_legalize.cc | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/tir/transforms/bf16_legalize.cc b/src/tir/transforms/bf16_legalize.cc index ba682de1763f..a6b2609bbdc7 100644 --- a/src/tir/transforms/bf16_legalize.cc +++ b/src/tir/transforms/bf16_legalize.cc @@ -110,10 +110,10 @@ DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(MulNode, operator*) DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(DivNode, div) DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(MinNode, min) DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(MaxNode, max) -DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH_NO_CAST(LTNode, operator<) // NOLINT(*) -DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH_NO_CAST(LENode, operator<=) // NOLINT(*) -DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH_NO_CAST(GTNode, operator>) // NOLINT(*) -DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH_NO_CAST(GENode, operator>=) // NOLINT(*) +DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH_NO_CAST(LTNode, operator<) // NOLINT(*) +DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH_NO_CAST(LENode, operator<=) // NOLINT(*) +DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH_NO_CAST(GTNode, operator>) // NOLINT(*) +DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH_NO_CAST(GENode, operator>=) // NOLINT(*) /* * Eliminate verbose casting between fp32 and bf16 From 17c7084ed3871af50c947503436cf3957154a794 Mon Sep 17 00:00:00 2001 From: Menooker Date: Fri, 15 May 2020 20:23:44 +0800 Subject: [PATCH 15/43] Update bf16_legalize.cc --- src/tir/transforms/bf16_legalize.cc | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/tir/transforms/bf16_legalize.cc b/src/tir/transforms/bf16_legalize.cc index a6b2609bbdc7..88ac79a8767e 100644 --- a/src/tir/transforms/bf16_legalize.cc +++ b/src/tir/transforms/bf16_legalize.cc @@ -18,8 +18,8 @@ */ /*! - * \file narrow_datatype.cc - * \brief narrow the datatype of indexing vars + * \file bf16_legalize.cc + * \brief legalize bf16 type by adding cast_to_fp32 */ #include @@ -125,7 +125,7 @@ DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH_NO_CAST(GENode, operator>=) // NOLINT(* * bf16((float32(bf16((float32(X[i]) + float32(Y[i])))) + float32(T[i]))) * After this pass: * bf16(float32(X[i]) + float32(Y[i]) + float32(T[i])) -*/ + */ class BF16CastEliminationRewriter : public StmtExprMutator { public: BF16CastEliminationRewriter() {} From d51bf9bf114e6e7d5217f0dd769c464102125491 Mon Sep 17 00:00:00 2001 From: Menooker Date: Fri, 15 May 2020 20:25:22 +0800 Subject: [PATCH 16/43] Update bf16_legalize.cc --- src/tir/transforms/bf16_legalize.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/tir/transforms/bf16_legalize.cc b/src/tir/transforms/bf16_legalize.cc index 88ac79a8767e..3862531027de 100644 --- a/src/tir/transforms/bf16_legalize.cc +++ b/src/tir/transforms/bf16_legalize.cc @@ -19,7 +19,7 @@ /*! * \file bf16_legalize.cc - * \brief legalize bf16 type by adding cast_to_fp32 + * \brief legalize bf16 type by adding cast_to_fp32 */ #include From cbb1e5bc216071a5bc1cae225e3ceb1d5347df7a Mon Sep 17 00:00:00 2001 From: Menooker Date: Fri, 15 May 2020 20:49:17 +0800 Subject: [PATCH 17/43] Update transform.py --- python/tvm/tir/transform/transform.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/python/tvm/tir/transform/transform.py b/python/tvm/tir/transform/transform.py index 3112d5327194..204a09799e67 100644 --- a/python/tvm/tir/transform/transform.py +++ b/python/tvm/tir/transform/transform.py @@ -256,7 +256,7 @@ def RemoveNoOp(): return _ffi_api.RemoveNoOp() def BF16Legalize(): - """Legalize bf16 typed Ops. + """Legalize bf16 typed Ops. Runs BF16Promote and BF16CastElimination Returns @@ -278,16 +278,15 @@ def BF16Promote(): return _ffi_api.BF16Promote() def BF16CastElimination(): - """ - Eliminate verbose casting between fp32 and bf16 + """Eliminate verbose casting between fp32 and bf16 Checks if the AST has the pattern: castto32(castto16(some_fp32_op(...))) The verbose casting is generated by BF16Promote for multiple bf16 Ops in a row. e.g.: - X[i] + Y[i] + T[i] => - bf16((float32(bf16((float32(X[i]) + float32(Y[i])))) + float32(T[i]))) + X[i] + Y[i] + T[i] => + bf16((float32(bf16((float32(X[i]) + float32(Y[i])))) + float32(T[i]))) After this pass: - bf16(float32(X[i]) + float32(Y[i]) + float32(T[i])) + bf16(float32(X[i]) + float32(Y[i]) + float32(T[i])) Returns ------- From 680eccee47a7094554f501ee760970b3b66e295c Mon Sep 17 00:00:00 2001 From: Menooker Date: Fri, 15 May 2020 22:14:25 +0800 Subject: [PATCH 18/43] fix --- src/tir/transforms/bf16_legalize.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/tir/transforms/bf16_legalize.cc b/src/tir/transforms/bf16_legalize.cc index 3862531027de..41f3a43e2ecb 100644 --- a/src/tir/transforms/bf16_legalize.cc +++ b/src/tir/transforms/bf16_legalize.cc @@ -55,7 +55,7 @@ class BF16PromoteRewriter : public StmtExprMutator { *is_bf16 = true; } - if (is_bf16) { + if (*is_bf16) { DataType fp32ty(kDLFloat, 32, 1); a = CastNode::make(fp32ty, a); b = CastNode::make(fp32ty, b); From b4b9d42dc8e5ae403c5ab95a30d66a8e36f9c6f3 Mon Sep 17 00:00:00 2001 From: Menooker Date: Sat, 16 May 2020 10:55:12 +0800 Subject: [PATCH 19/43] Update test_target_codegen_llvm.py --- .../unittest/test_target_codegen_llvm.py | 46 ++++++++++--------- 1 file changed, 24 insertions(+), 22 deletions(-) diff --git a/tests/python/unittest/test_target_codegen_llvm.py b/tests/python/unittest/test_target_codegen_llvm.py index d254aca68a3f..7f6b52486726 100644 --- a/tests/python/unittest/test_target_codegen_llvm.py +++ b/tests/python/unittest/test_target_codegen_llvm.py @@ -733,29 +733,30 @@ def np_bf16_cast_and_cast_back(arr): ''' Convert a numpy array of float to bf16 and cast back''' return np_bf162np_float(np_float2np_bf16(arr)) -def test_llvm_bf16(do_vectorize): - np.random.seed(122) - A = te.placeholder((32, ), dtype='bf16') - B = te.placeholder((32, ), dtype='bf16') - d = te.compute((32, ), lambda x: A[x] + B[x]) - sch = te.create_schedule(d.op) - if do_vectorize: - sch[d].vectorize(d.op.axis[0]) - module = tvm.build(sch, [A, B, d]) - npa = np.random.rand(32).astype('float32') - npb = np.random.rand(32).astype('float32') - va = np_bf16_cast_and_cast_back(npa) - vb = np_bf16_cast_and_cast_back(npb) - res = np_bf16_cast_and_cast_back(va + vb) - a_ = np_float2tvm_bf16(npa) - b_ = np_float2tvm_bf16(npb) - c_ = tvm.nd.empty((32,), 'bf16') - module(a_, b_, c_) - tvm.testing.assert_allclose(np_bf162np_float(c_.asnumpy()), res) - +def test_llvm_bf16(): + def dotest(do_vectorize): + np.random.seed(122) + A = te.placeholder((32, ), dtype='bf16') + B = te.placeholder((32, ), dtype='bf16') + d = te.compute((32, ), lambda x: A[x] + B[x]) + sch = te.create_schedule(d.op) + if do_vectorize: + sch[d].vectorize(d.op.axis[0]) + module = tvm.build(sch, [A, B, d]) + npa = np.random.rand(32).astype('float32') + npb = np.random.rand(32).astype('float32') + va = np_bf16_cast_and_cast_back(npa) + vb = np_bf16_cast_and_cast_back(npb) + res = np_bf16_cast_and_cast_back(va + vb) + a_ = np_float2tvm_bf16(npa) + b_ = np_float2tvm_bf16(npb) + c_ = tvm.nd.empty((32,), 'bf16') + module(a_, b_, c_) + tvm.testing.assert_allclose(np_bf162np_float(c_.asnumpy()), res) + dotest(true) + dotest(false) + if __name__ == "__main__": - test_llvm_bf16(do_vectorize=True) - test_llvm_bf16(do_vectorize=False) test_multiple_func() test_llvm_large_uintimm() test_llvm_import() @@ -777,3 +778,4 @@ def test_llvm_bf16(do_vectorize): test_llvm_fp_math() test_dwarf_debug_information() test_llvm_shuffle() + test_llvm_bf16() From a899ef77eda9cd9504fe753213e3dd0c79af288d Mon Sep 17 00:00:00 2001 From: Menooker Date: Sat, 16 May 2020 11:53:08 +0800 Subject: [PATCH 20/43] Update test_target_codegen_llvm.py --- tests/python/unittest/test_target_codegen_llvm.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/python/unittest/test_target_codegen_llvm.py b/tests/python/unittest/test_target_codegen_llvm.py index 7f6b52486726..5253bd50e67c 100644 --- a/tests/python/unittest/test_target_codegen_llvm.py +++ b/tests/python/unittest/test_target_codegen_llvm.py @@ -753,8 +753,8 @@ def dotest(do_vectorize): c_ = tvm.nd.empty((32,), 'bf16') module(a_, b_, c_) tvm.testing.assert_allclose(np_bf162np_float(c_.asnumpy()), res) - dotest(true) - dotest(false) + dotest(True) + dotest(False) if __name__ == "__main__": test_multiple_func() From 3523f00d0efcaa77c79ced2785504eddff12c046 Mon Sep 17 00:00:00 2001 From: Menooker Date: Sat, 16 May 2020 13:09:45 +0800 Subject: [PATCH 21/43] Update transform.py --- python/tvm/tir/transform/transform.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/tir/transform/transform.py b/python/tvm/tir/transform/transform.py index 204a09799e67..b826f1302407 100644 --- a/python/tvm/tir/transform/transform.py +++ b/python/tvm/tir/transform/transform.py @@ -280,7 +280,7 @@ def BF16Promote(): def BF16CastElimination(): """Eliminate verbose casting between fp32 and bf16 Checks if the AST has the pattern: - castto32(castto16(some_fp32_op(...))) + castto32(castto16(some_fp32_op(...))) The verbose casting is generated by BF16Promote for multiple bf16 Ops in a row. e.g.: X[i] + Y[i] + T[i] => From bac724766f7dea6153d07b740e111f1988db7399 Mon Sep 17 00:00:00 2001 From: Menooker Date: Wed, 20 May 2020 14:22:34 +0800 Subject: [PATCH 22/43] bf16 => bfloat16 --- include/tvm/runtime/data_type.h | 13 +++---- include/tvm/tir/op.h | 2 +- python/tvm/_ffi/runtime_ctypes.py | 6 ++-- python/tvm/runtime/ndarray.py | 4 +-- src/target/llvm/codegen_llvm.cc | 13 +++---- src/tir/transforms/bf16_legalize.cc | 30 ++++++++-------- .../unittest/test_target_codegen_llvm.py | 8 ++--- .../test_tir_transform_bf16_legalize.py | 35 ++++++++++--------- 8 files changed, 55 insertions(+), 56 deletions(-) diff --git a/include/tvm/runtime/data_type.h b/include/tvm/runtime/data_type.h index eb8284cf4ae1..20eb061a012c 100644 --- a/include/tvm/runtime/data_type.h +++ b/include/tvm/runtime/data_type.h @@ -65,6 +65,9 @@ class DataType { data_.code = static_cast(code); data_.bits = static_cast(bits); data_.lanes = static_cast(lanes); + if (code == kBFloat) { + CHECK_EQ(bits, 16); + } } /*! \return The type code. */ int code() const { return static_cast(data_.code); } @@ -82,10 +85,8 @@ class DataType { bool is_float() const { return code() == DataType::kFloat; } /*! \return whether type is a float16 type. */ bool is_float16() const { return is_float() && bits() == 16; } - /*! \return whether type is a bfloat type. */ - bool is_bfloat() const { return code() == DataType::kBFloat; } /*! \return whether type is a bfloat16 type. */ - bool is_bf16() const { return code() == DataType::kBFloat && bits() == 16; } + bool is_bfloat16() const { return code() == DataType::kBFloat && bits() == 16; } /*! \return whether type is an int type. */ bool is_int() const { return code() == DataType::kInt; } /*! \return whether type is an uint type. */ @@ -303,7 +304,7 @@ inline const char* TypeCode2Str(int type_code) { case kTVMObjectRValueRefArg: return "ObjectRValueRefArg"; case kTVMBFloat: - return "bf"; + return "bfloat"; default: LOG(FATAL) << "unknown type_code=" << static_cast(type_code); return ""; @@ -370,9 +371,9 @@ inline DLDataType String2DLDataType(std::string s) { t.bits = 1; t.lanes = 1; return t; - } else if (s.substr(0, 2) == "bf") { + } else if (s.substr(0, 6) == "bfloat") { t.code = kTVMBFloat; - scan = s.c_str() + 2; + scan = s.c_str() + 6; } else if (s.substr(0, 6) == "custom") { t.code = ParseCustomDatatype(s, &scan); } else { diff --git a/include/tvm/tir/op.h b/include/tvm/tir/op.h index d2c934bf73e8..6b06a4431d09 100644 --- a/include/tvm/tir/op.h +++ b/include/tvm/tir/op.h @@ -735,7 +735,7 @@ inline PrimExpr MakeConstScalar(DataType t, ValueType value) { return LargeUIntImm(t, static_cast(low), static_cast(high)); } } - if (t.is_float() || t.is_bf16()) return FloatImm(t, static_cast(value)); + if (t.is_float() || t.is_bfloat16()) return FloatImm(t, static_cast(value)); // For now, we store const scalar values of custom datatypes within doubles; later, during the // datatypes lowering pass, we will lower the value to its true representation in the format // specified by the datatype. diff --git a/python/tvm/_ffi/runtime_ctypes.py b/python/tvm/_ffi/runtime_ctypes.py index adf4439c343e..c6a6e09206d3 100644 --- a/python/tvm/_ffi/runtime_ctypes.py +++ b/python/tvm/_ffi/runtime_ctypes.py @@ -59,7 +59,7 @@ class DataType(ctypes.Structure): 1 : 'uint', 2 : 'float', 4 : 'handle', - 65: 'bf' + 65: 'bfloat' } def __init__(self, type_str): super(DataType, self).__init__() @@ -86,9 +86,9 @@ def __init__(self, type_str): elif head.startswith("float"): self.type_code = 2 head = head[5:] - elif head.startswith("bf"): + elif head.startswith("bfloat"): self.type_code = 65 - head = head[2:] + head = head[6:] elif head.startswith("handle"): self.type_code = 4 bits = 64 diff --git a/python/tvm/runtime/ndarray.py b/python/tvm/runtime/ndarray.py index 6508ed990fa5..502543a65320 100644 --- a/python/tvm/runtime/ndarray.py +++ b/python/tvm/runtime/ndarray.py @@ -138,7 +138,7 @@ def copyfrom(self, source_array): if source_array.shape != shape: raise ValueError("array shape do not match the shape of NDArray {0} vs {1}".format( source_array.shape, shape)) - if dtype == 'bf16': + if dtype == 'bfloat16': source_array = np.ascontiguousarray(source_array, dtype='uint16') else: source_array = np.ascontiguousarray(source_array, dtype=dtype) @@ -170,7 +170,7 @@ def asnumpy(self): shape = shape + (t.lanes,) t.lanes = 1 dtype = str(t) - if dtype == 'bf16': + if dtype == 'bfloat16': np_arr = np.empty(shape, dtype='uint16') else: np_arr = np.empty(shape, dtype=dtype) diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc index 63e6f935175c..d7d7dd3eb92f 100644 --- a/src/target/llvm/codegen_llvm.cc +++ b/src/target/llvm/codegen_llvm.cc @@ -309,8 +309,7 @@ llvm::Type* CodeGenLLVM::DTypeToLLVMType(const DataType& dtype) const { default: LOG(FATAL) << "do not support " << dtype; } - } else if (dtype.is_bfloat()) { - CHECK_EQ(dtype.bits(), 16); + } else if (dtype.is_bfloat16()) { etype = llvm::Type::getInt16Ty(*ctx_); } if (dtype.lanes() != 1) { @@ -578,8 +577,7 @@ llvm::Value* CodeGenLLVM::CreateCast(DataType from, DataType to, llvm::Value* va if (value->getType() == target) return value; if (to.is_handle()) { return builder_->CreateBitCast(value, target); - } else if (to.is_float() && from.is_bfloat()) { - CHECK_EQ(from.bits(), 16); + } else if (to.is_float() && from.is_bfloat16()) { CHECK_EQ(to.bits(), 32); llvm::Type* extended_type = (from.lanes() == 1) ? static_cast(builder_->getInt32Ty()) @@ -587,8 +585,7 @@ llvm::Value* CodeGenLLVM::CreateCast(DataType from, DataType to, llvm::Value* va auto v = builder_->CreateZExt(value, extended_type); v = builder_->CreateShl(v, 16); return builder_->CreateBitCast(v, target); - } else if (to.is_bfloat() && from.is_float()) { - CHECK_EQ(to.bits(), 16); + } else if (to.is_bfloat16() && from.is_float()) { CHECK_EQ(from.bits(), 32); llvm::Type* extended_type = (from.lanes() == 1) ? static_cast(builder_->getInt32Ty()) @@ -909,7 +906,7 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const IntImmNode* op) { } llvm::Value* CodeGenLLVM::VisitExpr_(const FloatImmNode* op) { - if (op->dtype.is_bf16()) { + if (op->dtype.is_bfloat16()) { auto fp = static_cast(op->value); auto p = reinterpret_cast(&fp); #if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__ @@ -954,7 +951,7 @@ DEFINE_CODEGEN_BINARY_OP(Mul); llvm::Value* CodeGenLLVM::Create##Op(DataType t, llvm::Value* a, llvm::Value* b) { \ if (t.is_int()) { \ return builder_->CreateICmpS##Op(a, b); \ - } else if (t.is_uint() || t.is_bfloat()) { \ + } else if (t.is_uint() || t.is_bfloat16()) { \ return builder_->CreateICmpU##Op(a, b); \ } else { \ CHECK(t.is_float()); \ diff --git a/src/tir/transforms/bf16_legalize.cc b/src/tir/transforms/bf16_legalize.cc index 41f3a43e2ecb..e68e91b9c659 100644 --- a/src/tir/transforms/bf16_legalize.cc +++ b/src/tir/transforms/bf16_legalize.cc @@ -43,19 +43,19 @@ class BF16PromoteRewriter : public StmtExprMutator { Stmt operator()(Stmt s) { return VisitStmt(s); } - std::tuple DoCast(PrimExpr orig_a, PrimExpr orig_b, bool* is_bf16) { + std::tuple DoCast(PrimExpr orig_a, PrimExpr orig_b, bool* is_bfloat16) { auto a = this->VisitExpr(orig_a); auto b = this->VisitExpr(orig_b); - *is_bf16 = false; - if (a->dtype.is_bf16()) { - CHECK(b->dtype.is_bf16()); - *is_bf16 = true; - } else if (b->dtype.is_bf16()) { - CHECK(a->dtype.is_bf16()); - *is_bf16 = true; + *is_bfloat16 = false; + if (a->dtype.is_bfloat16()) { + CHECK(b->dtype.is_bfloat16()); + *is_bfloat16 = true; + } else if (b->dtype.is_bfloat16()) { + CHECK(a->dtype.is_bfloat16()); + *is_bfloat16 = true; } - if (*is_bf16) { + if (*is_bfloat16) { DataType fp32ty(kDLFloat, 32, 1); a = CastNode::make(fp32ty, a); b = CastNode::make(fp32ty, b); @@ -78,13 +78,13 @@ class BF16PromoteRewriter : public StmtExprMutator { #define DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(OP, FUNC) \ PrimExpr BF16PromoteRewriter::VisitExpr_(const OP* op) { \ PrimExpr a, b; \ - bool is_bf16; \ - std::tie(a, b) = DoCast(op->a, op->b, &is_bf16); \ + bool is_bfloat16; \ + std::tie(a, b) = DoCast(op->a, op->b, &is_bfloat16); \ if (a.same_as(op->a) && b.same_as(op->b)) { \ return GetRef(op); \ } else { \ auto ret = FUNC(a, b); \ - if (!is_bf16) \ + if (!is_bfloat16) \ return ret; \ else \ return CastNode::make(DataType(kTVMBFloat, 16, 1), ret); \ @@ -94,8 +94,8 @@ class BF16PromoteRewriter : public StmtExprMutator { #define DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH_NO_CAST(OP, FUNC) \ PrimExpr BF16PromoteRewriter::VisitExpr_(const OP* op) { \ PrimExpr a, b; \ - bool is_bf16; \ - std::tie(a, b) = DoCast(op->a, op->b, &is_bf16); \ + bool is_bfloat16; \ + std::tie(a, b) = DoCast(op->a, op->b, &is_bfloat16); \ if (a.same_as(op->a) && b.same_as(op->b)) { \ return GetRef(op); \ } else { \ @@ -138,7 +138,7 @@ class BF16CastEliminationRewriter : public StmtExprMutator { // if is cast_to_fp32, check if op->value is cast_to_fp16 // and op->value->value is a float32 if (auto innercast = op_val.as()) { - if (innercast->dtype.is_bf16() && innercast->value->dtype.is_float() && + if (innercast->dtype.is_bfloat16() && innercast->value->dtype.is_float() && innercast->value->dtype.bits() == 32) { return innercast->value; } diff --git a/tests/python/unittest/test_target_codegen_llvm.py b/tests/python/unittest/test_target_codegen_llvm.py index 5253bd50e67c..9b5649b84699 100644 --- a/tests/python/unittest/test_target_codegen_llvm.py +++ b/tests/python/unittest/test_target_codegen_llvm.py @@ -721,7 +721,7 @@ def np_float2tvm_bf16(arr): ''' Convert a numpy array of float to a TVM array of bf16''' nparr = np_float2np_bf16(arr) - return tvm.nd.empty(nparr.shape, 'bf16').copyfrom(nparr) + return tvm.nd.empty(nparr.shape, 'bfloat16').copyfrom(nparr) def np_bf162np_float(arr): ''' Convert a numpy array of bf16 (uint16) to a numpy array @@ -736,8 +736,8 @@ def np_bf16_cast_and_cast_back(arr): def test_llvm_bf16(): def dotest(do_vectorize): np.random.seed(122) - A = te.placeholder((32, ), dtype='bf16') - B = te.placeholder((32, ), dtype='bf16') + A = te.placeholder((32, ), dtype='bfloat16') + B = te.placeholder((32, ), dtype='bfloat16') d = te.compute((32, ), lambda x: A[x] + B[x]) sch = te.create_schedule(d.op) if do_vectorize: @@ -750,7 +750,7 @@ def dotest(do_vectorize): res = np_bf16_cast_and_cast_back(va + vb) a_ = np_float2tvm_bf16(npa) b_ = np_float2tvm_bf16(npb) - c_ = tvm.nd.empty((32,), 'bf16') + c_ = tvm.nd.empty((32,), 'bfloat16') module(a_, b_, c_) tvm.testing.assert_allclose(np_bf162np_float(c_.asnumpy()), res) dotest(True) diff --git a/tests/python/unittest/test_tir_transform_bf16_legalize.py b/tests/python/unittest/test_tir_transform_bf16_legalize.py index 4df43b2e213e..1161f38de962 100644 --- a/tests/python/unittest/test_tir_transform_bf16_legalize.py +++ b/tests/python/unittest/test_tir_transform_bf16_legalize.py @@ -30,22 +30,22 @@ def lower_stmt(sche, params, passfunc): def to32(v): return topi.cast(v, 'float') def to16(v): - return topi.cast(v, 'bf16') + return topi.cast(v, 'bfloat16') def test_promote(): def runpass(op, passfunc): - a = te.placeholder((100,), dtype='bf16') - b = te.placeholder((100,), dtype='bf16') + a = te.placeholder((100,), dtype='bfloat16') + b = te.placeholder((100,), dtype='bfloat16') c = te.compute((100,), lambda i: op(a[i], b[i])) s = te.create_schedule(c.op) return lower_stmt(s, [a, b, c], passfunc) def get_promoted(op): - a = te.placeholder((100,), dtype='bf16') - b = te.placeholder((100,), dtype='bf16') + a = te.placeholder((100,), dtype='bfloat16') + b = te.placeholder((100,), dtype='bfloat16') c = te.compute((100,), lambda i: topi.cast(op(topi.cast(a[i],'float'), - topi.cast(b[i],'float')), 'bf16') + topi.cast(b[i],'float')), 'bfloat16') ) s = te.create_schedule(c.op) func = tvm.driver.build_module.form_irmodule(s, [a,b,c], "main", None)["main"] @@ -61,8 +61,8 @@ def test_promoted(op): def test_eliminate(): def get_eliminated(): - a = te.placeholder((100,), dtype='bf16') - b = te.placeholder((100,), dtype='bf16') + a = te.placeholder((100,), dtype='bfloat16') + b = te.placeholder((100,), dtype='bfloat16') c = te.compute((100,), lambda i: to16( topi.add( to32( @@ -88,8 +88,8 @@ def get_eliminated(): return stmt def get_target(): - a = te.placeholder((100,), dtype='bf16') - b = te.placeholder((100,), dtype='bf16') + a = te.placeholder((100,), dtype='bfloat16') + b = te.placeholder((100,), dtype='bfloat16') c = te.compute((100,), lambda i: to16( topi.add(topi.add( to32(a[i]), @@ -109,14 +109,14 @@ def get_target(): def test_legalize(): def check(fcompute_before, fcompute_after): - a = te.placeholder((100,), dtype='bf16') - b = te.placeholder((100,), dtype='bf16') + a = te.placeholder((100,), dtype='bfloat16') + b = te.placeholder((100,), dtype='bfloat16') c = te.compute((100,), fcompute_before(a,b)) s = te.create_schedule(c.op) stmt = lower_stmt(s, [a, b, c], tvm.tir.transform.BF16Legalize) - a = te.placeholder((100,), dtype='bf16') - b = te.placeholder((100,), dtype='bf16') + a = te.placeholder((100,), dtype='bfloat16') + b = te.placeholder((100,), dtype='bfloat16') c = te.compute((100,), fcompute_after(a,b)) s = te.create_schedule(c.op) func = tvm.driver.build_module.form_irmodule(s, [a,b,c], "main", None)["main"] @@ -126,14 +126,15 @@ def orig1(a,b): return lambda i: a[i]+b[i]+a[99-i]+b[99-i] def after1(a,b): return lambda i: to16(to32(a[i])+to32(b[i])+to32(a[99-i])+to32(b[99-i])) - def orig1(a,b): + def orig2(a,b): return lambda i: a[i]*b[i]+a[99-i]*b[99-i]+a[i] - def after1(a,b): + def after2(a,b): return lambda i: to16(to32(a[i])*to32(b[i])+to32(a[99-i])*to32(b[99-i])+to32(a[i])) check(orig1, after1) + check(orig2, after2) if __name__ == "__main__": test_promote() test_eliminate() - test_legalize() \ No newline at end of file + test_legalize() From 7224acf30758a22b8659497cd3c64ae99d3b0ca8 Mon Sep 17 00:00:00 2001 From: Menooker Date: Wed, 20 May 2020 14:30:57 +0800 Subject: [PATCH 23/43] fix linter problem --- src/target/llvm/codegen_llvm.cc | 2 +- src/tir/transforms/bf16_legalize.cc | 10 +++++----- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc index d7d7dd3eb92f..35e5765798d0 100644 --- a/src/target/llvm/codegen_llvm.cc +++ b/src/target/llvm/codegen_llvm.cc @@ -951,7 +951,7 @@ DEFINE_CODEGEN_BINARY_OP(Mul); llvm::Value* CodeGenLLVM::Create##Op(DataType t, llvm::Value* a, llvm::Value* b) { \ if (t.is_int()) { \ return builder_->CreateICmpS##Op(a, b); \ - } else if (t.is_uint() || t.is_bfloat16()) { \ + } else if (t.is_uint() || t.is_bfloat16()) { \ return builder_->CreateICmpU##Op(a, b); \ } else { \ CHECK(t.is_float()); \ diff --git a/src/tir/transforms/bf16_legalize.cc b/src/tir/transforms/bf16_legalize.cc index e68e91b9c659..0c79597f8e25 100644 --- a/src/tir/transforms/bf16_legalize.cc +++ b/src/tir/transforms/bf16_legalize.cc @@ -78,13 +78,13 @@ class BF16PromoteRewriter : public StmtExprMutator { #define DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(OP, FUNC) \ PrimExpr BF16PromoteRewriter::VisitExpr_(const OP* op) { \ PrimExpr a, b; \ - bool is_bfloat16; \ - std::tie(a, b) = DoCast(op->a, op->b, &is_bfloat16); \ + bool is_bfloat16; \ + std::tie(a, b) = DoCast(op->a, op->b, &is_bfloat16); \ if (a.same_as(op->a) && b.same_as(op->b)) { \ return GetRef(op); \ } else { \ auto ret = FUNC(a, b); \ - if (!is_bfloat16) \ + if (!is_bfloat16) \ return ret; \ else \ return CastNode::make(DataType(kTVMBFloat, 16, 1), ret); \ @@ -94,8 +94,8 @@ class BF16PromoteRewriter : public StmtExprMutator { #define DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH_NO_CAST(OP, FUNC) \ PrimExpr BF16PromoteRewriter::VisitExpr_(const OP* op) { \ PrimExpr a, b; \ - bool is_bfloat16; \ - std::tie(a, b) = DoCast(op->a, op->b, &is_bfloat16); \ + bool is_bfloat16; \ + std::tie(a, b) = DoCast(op->a, op->b, &is_bfloat16); \ if (a.same_as(op->a) && b.same_as(op->b)) { \ return GetRef(op); \ } else { \ From b36f5f4792b40ccd75081003cef2b130a1d9f0ae Mon Sep 17 00:00:00 2001 From: Menooker Date: Sun, 24 May 2020 17:41:10 +0800 Subject: [PATCH 24/43] TIR legalize --- python/tvm/tir/transform/transform.py | 13 +- src/target/llvm/codegen_llvm.cc | 49 +---- src/tir/transforms/bf16_legalize.cc | 187 +++++++++++++++++- .../unittest/test_target_codegen_llvm.py | 5 +- 4 files changed, 203 insertions(+), 51 deletions(-) diff --git a/python/tvm/tir/transform/transform.py b/python/tvm/tir/transform/transform.py index b826f1302407..1e3114b2485d 100644 --- a/python/tvm/tir/transform/transform.py +++ b/python/tvm/tir/transform/transform.py @@ -257,7 +257,7 @@ def RemoveNoOp(): def BF16Legalize(): """Legalize bf16 typed Ops. - Runs BF16Promote and BF16CastElimination + Runs BF16Promote, BF16CastElimination and BF16TypeLowering Returns ------- @@ -295,6 +295,17 @@ def BF16CastElimination(): """ return _ffi_api.BF16CastElimination() +def BF16TypeLowering(): + """Replace all bf16 type with uint16. Also lower the casting + between fp32 and bf16 + + Returns + ------- + fpass : tvm.transform.Pass + The result pass + """ + return _ffi_api.BF16TypeLowering() + def RewriteUnsafeSelect(): """Detect and rewrite unsafe select that contains memory access. diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc index 35e5765798d0..cecafde1030b 100644 --- a/src/target/llvm/codegen_llvm.cc +++ b/src/target/llvm/codegen_llvm.cc @@ -309,8 +309,6 @@ llvm::Type* CodeGenLLVM::DTypeToLLVMType(const DataType& dtype) const { default: LOG(FATAL) << "do not support " << dtype; } - } else if (dtype.is_bfloat16()) { - etype = llvm::Type::getInt16Ty(*ctx_); } if (dtype.lanes() != 1) { return llvm::VectorType::get(etype, dtype.lanes()); @@ -557,46 +555,12 @@ void CodeGenLLVM::CreateSerialFor(llvm::Value* begin, llvm::Value* end, llvm::Va builder_->SetInsertPoint(for_end); } -static llvm::Value* GetInt32VectorOrScalar( - llvm::IRBuilder* builder, uint32_t v, - int lanes) { - if (lanes == 1) { - return builder->getInt32(v); - } else { - std::vector consts; - for (int i = 0; i < lanes; i++) { - consts.emplace_back(builder->getInt32(v)); - } - return llvm::ConstantVector::get(consts); - } -} - // cast operatpr llvm::Value* CodeGenLLVM::CreateCast(DataType from, DataType to, llvm::Value* value) { llvm::Type* target = DTypeToLLVMType(to); if (value->getType() == target) return value; if (to.is_handle()) { return builder_->CreateBitCast(value, target); - } else if (to.is_float() && from.is_bfloat16()) { - CHECK_EQ(to.bits(), 32); - llvm::Type* extended_type = (from.lanes() == 1) - ? static_cast(builder_->getInt32Ty()) - : llvm::VectorType::get(builder_->getInt32Ty(), from.lanes()); - auto v = builder_->CreateZExt(value, extended_type); - v = builder_->CreateShl(v, 16); - return builder_->CreateBitCast(v, target); - } else if (to.is_bfloat16() && from.is_float()) { - CHECK_EQ(from.bits(), 32); - llvm::Type* extended_type = (from.lanes() == 1) - ? static_cast(builder_->getInt32Ty()) - : llvm::VectorType::get(builder_->getInt32Ty(), to.lanes()); - auto v = builder_->CreateBitCast(value, extended_type); - auto bias = builder_->CreateLShr(v, 16); - bias = builder_->CreateAnd(bias, GetInt32VectorOrScalar(builder_.get(), 1, to.lanes())); - bias = builder_->CreateAdd(bias, GetInt32VectorOrScalar(builder_.get(), 0x7fff, to.lanes())); - v = builder_->CreateAdd(v, bias); - v = builder_->CreateLShr(v, 16); - return builder_->CreateTrunc(v, target); } else if (to.is_uint() && to.bits() == 1) { if (from.is_float()) { llvm::Constant* zero = llvm::ConstantFP::get(DTypeToLLVMType(from), 0.); @@ -906,15 +870,6 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const IntImmNode* op) { } llvm::Value* CodeGenLLVM::VisitExpr_(const FloatImmNode* op) { - if (op->dtype.is_bfloat16()) { - auto fp = static_cast(op->value); - auto p = reinterpret_cast(&fp); -#if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__ - return this->builder_->getInt16(p[0]); -#else - return this->builder_->getInt16(p[1]); -#endif - } return llvm::ConstantFP::get(DTypeToLLVMType(op->dtype), op->value); } @@ -951,7 +906,7 @@ DEFINE_CODEGEN_BINARY_OP(Mul); llvm::Value* CodeGenLLVM::Create##Op(DataType t, llvm::Value* a, llvm::Value* b) { \ if (t.is_int()) { \ return builder_->CreateICmpS##Op(a, b); \ - } else if (t.is_uint() || t.is_bfloat16()) { \ + } else if (t.is_uint()) { \ return builder_->CreateICmpU##Op(a, b); \ } else { \ CHECK(t.is_float()); \ @@ -1330,4 +1285,4 @@ void CodeGenLLVM::VisitStmt_(const EvaluateNode* op) { MakeValue(op->value); } } // namespace codegen } // namespace tvm -#endif // TVM_LLVM_VERSION +#endif // TVM_LLVM_VERSION \ No newline at end of file diff --git a/src/tir/transforms/bf16_legalize.cc b/src/tir/transforms/bf16_legalize.cc index 0c79597f8e25..51231cbb2413 100644 --- a/src/tir/transforms/bf16_legalize.cc +++ b/src/tir/transforms/bf16_legalize.cc @@ -27,6 +27,7 @@ #include #include +#include #include "../../arith/ir_mutator_with_analyzer.h" #include "../../arith/ir_visitor_with_analyzer.h" @@ -149,6 +150,177 @@ class BF16CastEliminationRewriter : public StmtExprMutator { } }; +// implementation from +// https://github.com/pytorch/pytorch/blob/master/c10/util/BFloat16.h +inline uint16_t round_to_nearest_even(float src) { +#if defined(_MSC_VER) + if (isnan(src)) { +#else + if (std::isnan(src)) { +#endif + return UINT16_C(0x7FC0); + } else { + union { + uint32_t U32; + float F32; + }; + + F32 = src; + uint32_t rounding_bias = ((U32 >> 16) & 1) + UINT32_C(0x7FFF); + return static_cast((U32 + rounding_bias) >> 16); + } +} + +/* + * Lower the bf16 type to int16 + * Lower cast between bf16 and fp32 + * Lower bf16 FloatImm to int16 + */ +class BF16LowerRewriter : StmtExprMutator { + public: + BF16LowerRewriter() {} + + std::unordered_map buffer_remap; + std::unordered_map var_remap; + + Stmt operator()(Stmt s) { return VisitStmt(s); } + + PrimExpr VisitExpr_(const CastNode* op) { + auto op_val = StmtExprMutator::VisitExpr(op->value); + if (op->value->dtype.is_bfloat16()) { + // if is cast_from_bf16, check if is to fp32 + CHECK(op->dtype.is_float() && op->dtype.bits() == 32); + auto uint32_dtype = DataType(kDLUInt, 32, op_val->dtype.lanes()); + auto uint32_v = CastNode::make(uint32_dtype, op_val); + return CallNode::make(op->dtype, CallNode::reinterpret, {uint32_v << 16}, + CallNode::PureIntrinsic); + + } else if (op->dtype.is_bfloat16()) { + // if is cast_to_bf16, check if op->value is fp32 + CHECK(op->value->dtype.is_float() && op->value->dtype.bits() == 32); + auto uint32_dtype = DataType(kDLUInt, 32, op_val->dtype.lanes()); + auto uint32_v = + CallNode::make(uint32_dtype, CallNode::reinterpret, {op_val}, CallNode::PureIntrinsic); + auto uint16_dtype = DataType(kDLUInt, 16, op_val->dtype.lanes()); + // uint32_t rounding_bias = ((U32 >> 16) & 1) + UINT32_C(0x7FFF); + // return static_cast((U32 + rounding_bias) >> 16); + auto rounding_bias = ((uint32_v >> 16) & 1) + make_const(uint16_dtype, 0x7FFF); + return CastNode::make(uint16_dtype, {(uint32_v + rounding_bias) >> 16}); + } + if (op->value.same_as(op_val)) return GetRef(op); + return CastNode::make(op->dtype, op_val); + } + + PrimExpr VisitExpr_(const VarNode* op) { + auto itr = var_remap.find(op); + if (itr != var_remap.end()) { + return itr->second; + } + if (op->dtype.is_bfloat16()) { + CHECK(!op->type_annotation.defined()); + auto ret = Var(op->name_hint, op->dtype); + var_remap[op] = ret; + return ret; + } + return StmtExprMutator::VisitExpr_(op); + } + + Stmt VisitStmt_(const AllocateNode* op) { + Stmt node_holder; + const AllocateNode* newop; + if (op->dtype.is_bfloat16()) { + auto v = AllocateNode::make(op->buffer_var, DataType::UInt(16, op->dtype.lanes()), + op->extents, op->condition, op->body); + node_holder = v; + newop = static_cast(v.operator->()); + } else { + newop = op; + } + return StmtExprMutator::VisitStmt_(newop); + } + + Stmt VisitStmt_(const BufferStoreNode* op) { + auto itr = buffer_remap.find(op->buffer.operator->()); + const BufferStoreNode* newop; + BufferStore newop_holder; + if (itr != buffer_remap.end()) { + newop_holder = BufferStore(itr->second, op->value, op->indices); + newop = newop_holder.operator->(); + } else { + newop = op; + } + return StmtExprMutator::VisitStmt_(newop); + } + + Stmt VisitStmt_(const BufferRealizeNode* op) { + auto itr = buffer_remap.find(op->buffer.operator->()); + const BufferRealizeNode* newop; + Stmt newop_holder; + if (itr != buffer_remap.end()) { + auto v = BufferRealize(itr->second, op->bounds, op->condition, op->body); + newop_holder = v; + newop = v.operator->(); + } else { + newop = op; + } + return StmtExprMutator::VisitStmt_(newop); + } + + PrimExpr VisitExpr_(const BufferLoadNode* op) override { + auto itr = buffer_remap.find(op->buffer.operator->()); + const BufferLoadNode* newop; + BufferLoad newop_holder; + if (itr != buffer_remap.end()) { + newop_holder = BufferLoad(itr->second, op->indices); + newop = newop_holder.operator->(); + } else { + newop = op; + } + return StmtExprMutator::VisitExpr_(newop); + } + + PrimExpr VisitExpr_(const LoadNode* op) override { + bool is_bf16 = false; + if (op->dtype.is_bfloat16()) { + is_bf16 = true; + } + PrimExpr index = this->VisitExpr(op->index); + PrimExpr predicate = this->VisitExpr(op->predicate); + if (index.same_as(op->index) && predicate.same_as(op->predicate) && !is_bf16) { + return GetRef(op); + } else { + return LoadNode::make(is_bf16 ? DataType::UInt(16, op->dtype.lanes()) : op->dtype, + op->buffer_var, index, predicate); + } + } + + PrimExpr VisitExpr_(const FloatImmNode* op) override { + if (op->dtype.is_bfloat16()) { + return IntImm(DataType::UInt(16, op->dtype.lanes()), + round_to_nearest_even(static_cast(op->value))); + } + return StmtExprMutator::VisitExpr_(op); + } + + void alter_buffers(PrimFuncNode* op) { + std::vector> changes; + for (auto& itr : op->buffer_map) { + auto oldbuf = itr.second; + if (oldbuf->dtype.is_bfloat16()) { + auto newbuf = + BufferNode::make(oldbuf->data, DataType::UInt(16, oldbuf->dtype.lanes()), oldbuf->shape, + oldbuf->strides, oldbuf->elem_offset, oldbuf->name, oldbuf->scope, + oldbuf->data_alignment, oldbuf->offset_factor, oldbuf->buffer_type); + buffer_remap[oldbuf.operator->()] = newbuf; + changes.emplace_back(itr.first, newbuf); + } + } + if (buffer_remap.size() != 0) { + op->buffer_map.assign(changes.begin(), changes.end()); + } + } + }; + namespace transform { Pass BF16Promote() { @@ -173,8 +345,21 @@ Pass BF16CastElimination() { TVM_REGISTER_GLOBAL("tir.transform.BF16CastElimination").set_body_typed(BF16CastElimination); +Pass BF16TypeLowering() { + auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { + auto* n = f.CopyOnWrite(); + BF16LowerRewriter lowerer; + lowerer.alter_buffers(n); + n->body = lowerer(std::move(n->body)); + return f; + }; + return CreatePrimFuncPass(pass_func, 0, "tir.BF16TypeLowering", {}); +} + +TVM_REGISTER_GLOBAL("tir.transform.BF16TypeLowering").set_body_typed(BF16TypeLowering); + Pass BF16Legalize() { - return Sequential({BF16Promote(), BF16CastElimination()}, "tir.BF16Legalize"); + return Sequential({BF16Promote(), BF16CastElimination(), BF16TypeLowering()}, "tir.BF16Legalize"); } TVM_REGISTER_GLOBAL("tir.transform.BF16Legalize").set_body_typed(BF16Legalize); diff --git a/tests/python/unittest/test_target_codegen_llvm.py b/tests/python/unittest/test_target_codegen_llvm.py index 9b5649b84699..a7abe12d9b35 100644 --- a/tests/python/unittest/test_target_codegen_llvm.py +++ b/tests/python/unittest/test_target_codegen_llvm.py @@ -721,7 +721,7 @@ def np_float2tvm_bf16(arr): ''' Convert a numpy array of float to a TVM array of bf16''' nparr = np_float2np_bf16(arr) - return tvm.nd.empty(nparr.shape, 'bfloat16').copyfrom(nparr) + return tvm.nd.empty(nparr.shape, 'uint16').copyfrom(nparr) def np_bf162np_float(arr): ''' Convert a numpy array of bf16 (uint16) to a numpy array @@ -740,6 +740,7 @@ def dotest(do_vectorize): B = te.placeholder((32, ), dtype='bfloat16') d = te.compute((32, ), lambda x: A[x] + B[x]) sch = te.create_schedule(d.op) + print(tvm.lower(sch, [A,B,d])) if do_vectorize: sch[d].vectorize(d.op.axis[0]) module = tvm.build(sch, [A, B, d]) @@ -750,7 +751,7 @@ def dotest(do_vectorize): res = np_bf16_cast_and_cast_back(va + vb) a_ = np_float2tvm_bf16(npa) b_ = np_float2tvm_bf16(npb) - c_ = tvm.nd.empty((32,), 'bfloat16') + c_ = tvm.nd.empty((32,), 'uint16') module(a_, b_, c_) tvm.testing.assert_allclose(np_bf162np_float(c_.asnumpy()), res) dotest(True) From 92968d031a536c07ae4217814b5bf95f83700bf4 Mon Sep 17 00:00:00 2001 From: Menooker Date: Sun, 24 May 2020 18:51:18 +0800 Subject: [PATCH 25/43] pass test --- .../test_tir_transform_bf16_legalize.py | 44 ++++++++++++++----- 1 file changed, 32 insertions(+), 12 deletions(-) diff --git a/tests/python/unittest/test_tir_transform_bf16_legalize.py b/tests/python/unittest/test_tir_transform_bf16_legalize.py index 1161f38de962..ed72c9123747 100644 --- a/tests/python/unittest/test_tir_transform_bf16_legalize.py +++ b/tests/python/unittest/test_tir_transform_bf16_legalize.py @@ -27,10 +27,6 @@ def lower_stmt(sche, params, passfunc): stmt = func.body return stmt -def to32(v): - return topi.cast(v, 'float') -def to16(v): - return topi.cast(v, 'bfloat16') def test_promote(): def runpass(op, passfunc): @@ -60,6 +56,10 @@ def test_promoted(op): test_promoted(topi.divide) def test_eliminate(): + def to32(v): + return topi.cast(v, 'float') + def to16(v): + return topi.cast(v, 'bfloat16') def get_eliminated(): a = te.placeholder((100,), dtype='bfloat16') b = te.placeholder((100,), dtype='bfloat16') @@ -104,23 +104,43 @@ def get_target(): s = te.create_schedule(c.op) func = tvm.driver.build_module.form_irmodule(s, [a,b,c], "main", None)["main"] return func.body - tvm.ir.assert_structural_equal(get_eliminated(), get_target()) def test_legalize(): + def to32(v): + uint32_v = topi.cast(v, "uint32") + uint32_v = tvm.tir.call_pure_intrin("uint32", "shift_left", uint32_v, tvm.tir.const(16, "uint32")) + return tvm.tir.call_pure_intrin("float32", "reinterpret", uint32_v) + def to16(v): + uint32_v = tvm.tir.call_pure_intrin("uint32", "reinterpret", v) + rounding_bias = tvm.tir.call_pure_intrin("uint32", "shift_right", uint32_v, tvm.tir.const(16, "uint32")) + rounding_bias = tvm.tir.call_pure_intrin("uint32", "bitwise_and", rounding_bias, tvm.tir.const(1, "uint32")) + rounding_bias = rounding_bias + tvm.tir.const(0x7FFF, "uint16") + uint32_v = uint32_v + rounding_bias + uint32_v = tvm.tir.call_pure_intrin("uint32", "shift_right", uint32_v, tvm.tir.const(16, "uint32")) + return topi.cast(uint32_v, 'uint16') + def check(fcompute_before, fcompute_after): - a = te.placeholder((100,), dtype='bfloat16') - b = te.placeholder((100,), dtype='bfloat16') - c = te.compute((100,), fcompute_before(a,b)) + a = te.placeholder((100,), dtype='bfloat16', name = 'A') + b = te.placeholder((100,), dtype='bfloat16', name = 'B') + c = te.compute((100,), fcompute_before(a,b), name = 'C') s = te.create_schedule(c.op) stmt = lower_stmt(s, [a, b, c], tvm.tir.transform.BF16Legalize) - a = te.placeholder((100,), dtype='bfloat16') - b = te.placeholder((100,), dtype='bfloat16') - c = te.compute((100,), fcompute_after(a,b)) + a = te.placeholder((100,), dtype='uint16', name = 'A') + b = te.placeholder((100,), dtype='uint16', name = 'B') + c = te.compute((100,), fcompute_after(a,b), name = 'C') s = te.create_schedule(c.op) func = tvm.driver.build_module.form_irmodule(s, [a,b,c], "main", None)["main"] - tvm.ir.assert_structural_equal(stmt, func.body) + + stmt_str = str(stmt) + func_str = str(func.body) + + stmt_str = stmt_str[stmt_str.find("realize_scope"):] + func_str = func_str[func_str.find("realize_scope"):] + + assert(func_str == stmt_str) + # tvm.ir.assert_structural_equal(stmt, func.body) def orig1(a,b): return lambda i: a[i]+b[i]+a[99-i]+b[99-i] From 177f99a8ce7f95e927cabeaf18684ade0c8ebc73 Mon Sep 17 00:00:00 2001 From: Menooker Date: Sun, 24 May 2020 19:00:17 +0800 Subject: [PATCH 26/43] linter --- src/target/llvm/codegen_llvm.cc | 2 +- src/tir/transforms/bf16_legalize.cc | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc index cecafde1030b..f664532b2dc1 100644 --- a/src/target/llvm/codegen_llvm.cc +++ b/src/target/llvm/codegen_llvm.cc @@ -1285,4 +1285,4 @@ void CodeGenLLVM::VisitStmt_(const EvaluateNode* op) { MakeValue(op->value); } } // namespace codegen } // namespace tvm -#endif // TVM_LLVM_VERSION \ No newline at end of file +#endif // TVM_LLVM_VERSION diff --git a/src/tir/transforms/bf16_legalize.cc b/src/tir/transforms/bf16_legalize.cc index 51231cbb2413..0fc040eaddbc 100644 --- a/src/tir/transforms/bf16_legalize.cc +++ b/src/tir/transforms/bf16_legalize.cc @@ -319,7 +319,7 @@ class BF16LowerRewriter : StmtExprMutator { op->buffer_map.assign(changes.begin(), changes.end()); } } - }; +}; namespace transform { From 27c0ab247dadc5a846f9f233888f8654d495fd8f Mon Sep 17 00:00:00 2001 From: Menooker Date: Sun, 24 May 2020 19:05:06 +0800 Subject: [PATCH 27/43] linter --- src/tir/transforms/bf16_legalize.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/tir/transforms/bf16_legalize.cc b/src/tir/transforms/bf16_legalize.cc index 0fc040eaddbc..fb9aad9eafca 100644 --- a/src/tir/transforms/bf16_legalize.cc +++ b/src/tir/transforms/bf16_legalize.cc @@ -26,8 +26,8 @@ #include #include -#include #include +#include #include "../../arith/ir_mutator_with_analyzer.h" #include "../../arith/ir_visitor_with_analyzer.h" From 3dd2a71f58e9987f7e32dcdc870c0426b277ac8e Mon Sep 17 00:00:00 2001 From: Menooker Date: Sun, 24 May 2020 20:31:01 +0800 Subject: [PATCH 28/43] linter --- src/tir/transforms/bf16_legalize.cc | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/src/tir/transforms/bf16_legalize.cc b/src/tir/transforms/bf16_legalize.cc index fb9aad9eafca..2090efa9291c 100644 --- a/src/tir/transforms/bf16_legalize.cc +++ b/src/tir/transforms/bf16_legalize.cc @@ -133,7 +133,7 @@ class BF16CastEliminationRewriter : public StmtExprMutator { Stmt operator()(Stmt s) { return VisitStmt(s); } - PrimExpr VisitExpr_(const CastNode* op) { + PrimExpr VisitExpr_(const CastNode* op) final { auto op_val = StmtExprMutator::VisitExpr(op->value); if (op->dtype.is_float() && op->dtype.bits() == 32) { // if is cast_to_fp32, check if op->value is cast_to_fp16 @@ -185,7 +185,7 @@ class BF16LowerRewriter : StmtExprMutator { Stmt operator()(Stmt s) { return VisitStmt(s); } - PrimExpr VisitExpr_(const CastNode* op) { + PrimExpr VisitExpr_(const CastNode* op) final { auto op_val = StmtExprMutator::VisitExpr(op->value); if (op->value->dtype.is_bfloat16()) { // if is cast_from_bf16, check if is to fp32 @@ -211,7 +211,7 @@ class BF16LowerRewriter : StmtExprMutator { return CastNode::make(op->dtype, op_val); } - PrimExpr VisitExpr_(const VarNode* op) { + PrimExpr VisitExpr_(const VarNode* op) final { auto itr = var_remap.find(op); if (itr != var_remap.end()) { return itr->second; @@ -220,12 +220,12 @@ class BF16LowerRewriter : StmtExprMutator { CHECK(!op->type_annotation.defined()); auto ret = Var(op->name_hint, op->dtype); var_remap[op] = ret; - return ret; + return std::move(ret); } return StmtExprMutator::VisitExpr_(op); } - Stmt VisitStmt_(const AllocateNode* op) { + Stmt VisitStmt_(const AllocateNode* op) final { Stmt node_holder; const AllocateNode* newop; if (op->dtype.is_bfloat16()) { @@ -239,7 +239,7 @@ class BF16LowerRewriter : StmtExprMutator { return StmtExprMutator::VisitStmt_(newop); } - Stmt VisitStmt_(const BufferStoreNode* op) { + Stmt VisitStmt_(const BufferStoreNode* op) final { auto itr = buffer_remap.find(op->buffer.operator->()); const BufferStoreNode* newop; BufferStore newop_holder; @@ -252,7 +252,7 @@ class BF16LowerRewriter : StmtExprMutator { return StmtExprMutator::VisitStmt_(newop); } - Stmt VisitStmt_(const BufferRealizeNode* op) { + Stmt VisitStmt_(const BufferRealizeNode* op) final { auto itr = buffer_remap.find(op->buffer.operator->()); const BufferRealizeNode* newop; Stmt newop_holder; @@ -266,7 +266,7 @@ class BF16LowerRewriter : StmtExprMutator { return StmtExprMutator::VisitStmt_(newop); } - PrimExpr VisitExpr_(const BufferLoadNode* op) override { + PrimExpr VisitExpr_(const BufferLoadNode* op) override final { auto itr = buffer_remap.find(op->buffer.operator->()); const BufferLoadNode* newop; BufferLoad newop_holder; @@ -279,7 +279,7 @@ class BF16LowerRewriter : StmtExprMutator { return StmtExprMutator::VisitExpr_(newop); } - PrimExpr VisitExpr_(const LoadNode* op) override { + PrimExpr VisitExpr_(const LoadNode* op) override final { bool is_bf16 = false; if (op->dtype.is_bfloat16()) { is_bf16 = true; @@ -294,7 +294,7 @@ class BF16LowerRewriter : StmtExprMutator { } } - PrimExpr VisitExpr_(const FloatImmNode* op) override { + PrimExpr VisitExpr_(const FloatImmNode* op) override final { if (op->dtype.is_bfloat16()) { return IntImm(DataType::UInt(16, op->dtype.lanes()), round_to_nearest_even(static_cast(op->value))); From c5463769b1fc2c6cda7b715774472678187030e0 Mon Sep 17 00:00:00 2001 From: Menooker Date: Sun, 24 May 2020 21:09:48 +0800 Subject: [PATCH 29/43] fix AttrStmtNode --- src/tir/transforms/bf16_legalize.cc | 25 ++++++++++++++++--- .../test_tir_transform_bf16_legalize.py | 10 +------- 2 files changed, 23 insertions(+), 12 deletions(-) diff --git a/src/tir/transforms/bf16_legalize.cc b/src/tir/transforms/bf16_legalize.cc index 2090efa9291c..4eb5179de9f1 100644 --- a/src/tir/transforms/bf16_legalize.cc +++ b/src/tir/transforms/bf16_legalize.cc @@ -252,6 +252,25 @@ class BF16LowerRewriter : StmtExprMutator { return StmtExprMutator::VisitStmt_(newop); } + Stmt VisitStmt_(const AttrStmtNode* op) final { + const AttrStmtNode* newop = op; + Stmt newop_holder; + if (auto buffer = op->node.as()) { + auto itr = buffer_remap.find(buffer); + if (itr != buffer_remap.end()) { + newop_holder = AttrStmtNode::make(itr->second, op->attr_key, op->value, op->body); + newop = newop_holder.as(); + } + } else if (auto buffer = op->node.as()) { + auto itr = var_remap.find(buffer); + if (itr != var_remap.end()) { + newop_holder = AttrStmtNode::make(itr->second, op->attr_key, op->value, op->body); + newop = newop_holder.as(); + } + } + return StmtExprMutator::VisitStmt_(newop); + } + Stmt VisitStmt_(const BufferRealizeNode* op) final { auto itr = buffer_remap.find(op->buffer.operator->()); const BufferRealizeNode* newop; @@ -266,7 +285,7 @@ class BF16LowerRewriter : StmtExprMutator { return StmtExprMutator::VisitStmt_(newop); } - PrimExpr VisitExpr_(const BufferLoadNode* op) override final { + PrimExpr VisitExpr_(const BufferLoadNode* op) final { auto itr = buffer_remap.find(op->buffer.operator->()); const BufferLoadNode* newop; BufferLoad newop_holder; @@ -279,7 +298,7 @@ class BF16LowerRewriter : StmtExprMutator { return StmtExprMutator::VisitExpr_(newop); } - PrimExpr VisitExpr_(const LoadNode* op) override final { + PrimExpr VisitExpr_(const LoadNode* op) final { bool is_bf16 = false; if (op->dtype.is_bfloat16()) { is_bf16 = true; @@ -294,7 +313,7 @@ class BF16LowerRewriter : StmtExprMutator { } } - PrimExpr VisitExpr_(const FloatImmNode* op) override final { + PrimExpr VisitExpr_(const FloatImmNode* op) final { if (op->dtype.is_bfloat16()) { return IntImm(DataType::UInt(16, op->dtype.lanes()), round_to_nearest_even(static_cast(op->value))); diff --git a/tests/python/unittest/test_tir_transform_bf16_legalize.py b/tests/python/unittest/test_tir_transform_bf16_legalize.py index ed72c9123747..f71396ebda93 100644 --- a/tests/python/unittest/test_tir_transform_bf16_legalize.py +++ b/tests/python/unittest/test_tir_transform_bf16_legalize.py @@ -132,15 +132,7 @@ def check(fcompute_before, fcompute_after): c = te.compute((100,), fcompute_after(a,b), name = 'C') s = te.create_schedule(c.op) func = tvm.driver.build_module.form_irmodule(s, [a,b,c], "main", None)["main"] - - stmt_str = str(stmt) - func_str = str(func.body) - - stmt_str = stmt_str[stmt_str.find("realize_scope"):] - func_str = func_str[func_str.find("realize_scope"):] - - assert(func_str == stmt_str) - # tvm.ir.assert_structural_equal(stmt, func.body) + tvm.ir.assert_structural_equal(stmt, func.body) def orig1(a,b): return lambda i: a[i]+b[i]+a[99-i]+b[99-i] From 01906c53b9636cb05872c82f79f8e7913d1f64dc Mon Sep 17 00:00:00 2001 From: Menooker Date: Sun, 24 May 2020 21:39:02 +0800 Subject: [PATCH 30/43] msvc compile --- src/tir/transforms/bf16_legalize.cc | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/tir/transforms/bf16_legalize.cc b/src/tir/transforms/bf16_legalize.cc index 4eb5179de9f1..d6cae504bc72 100644 --- a/src/tir/transforms/bf16_legalize.cc +++ b/src/tir/transforms/bf16_legalize.cc @@ -153,11 +153,7 @@ class BF16CastEliminationRewriter : public StmtExprMutator { // implementation from // https://github.com/pytorch/pytorch/blob/master/c10/util/BFloat16.h inline uint16_t round_to_nearest_even(float src) { -#if defined(_MSC_VER) - if (isnan(src)) { -#else if (std::isnan(src)) { -#endif return UINT16_C(0x7FC0); } else { union { From 817f3021e2a9fa1c956a1e70689e8e73e0aafc81 Mon Sep 17 00:00:00 2001 From: Menooker Date: Fri, 29 May 2020 20:20:38 +0800 Subject: [PATCH 31/43] comments and notes --- src/tir/transforms/bf16_legalize.cc | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/tir/transforms/bf16_legalize.cc b/src/tir/transforms/bf16_legalize.cc index d6cae504bc72..6375e92ebb6b 100644 --- a/src/tir/transforms/bf16_legalize.cc +++ b/src/tir/transforms/bf16_legalize.cc @@ -188,6 +188,7 @@ class BF16LowerRewriter : StmtExprMutator { CHECK(op->dtype.is_float() && op->dtype.bits() == 32); auto uint32_dtype = DataType(kDLUInt, 32, op_val->dtype.lanes()); auto uint32_v = CastNode::make(uint32_dtype, op_val); + // to be endian invariant. return CallNode::make(op->dtype, CallNode::reinterpret, {uint32_v << 16}, CallNode::PureIntrinsic); @@ -198,9 +199,11 @@ class BF16LowerRewriter : StmtExprMutator { auto uint32_v = CallNode::make(uint32_dtype, CallNode::reinterpret, {op_val}, CallNode::PureIntrinsic); auto uint16_dtype = DataType(kDLUInt, 16, op_val->dtype.lanes()); - // uint32_t rounding_bias = ((U32 >> 16) & 1) + UINT32_C(0x7FFF); - // return static_cast((U32 + rounding_bias) >> 16); + /* the following TIR is equivalent to the C++ code below: + uint32_t rounding_bias = ((U32 >> 16) & 1) + UINT32_C(0x7FFF); + return static_cast((U32 + rounding_bias) >> 16);*/ auto rounding_bias = ((uint32_v >> 16) & 1) + make_const(uint16_dtype, 0x7FFF); + // to be endian invariant. return CastNode::make(uint16_dtype, {(uint32_v + rounding_bias) >> 16}); } if (op->value.same_as(op_val)) return GetRef(op); From ae67413b47559ac1d3da430e2dfa7d30c7fc155a Mon Sep 17 00:00:00 2001 From: Menooker Date: Fri, 29 May 2020 20:32:20 +0800 Subject: [PATCH 32/43] linter --- src/tir/transforms/bf16_legalize.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/tir/transforms/bf16_legalize.cc b/src/tir/transforms/bf16_legalize.cc index 6375e92ebb6b..efdb5320dd4d 100644 --- a/src/tir/transforms/bf16_legalize.cc +++ b/src/tir/transforms/bf16_legalize.cc @@ -188,7 +188,7 @@ class BF16LowerRewriter : StmtExprMutator { CHECK(op->dtype.is_float() && op->dtype.bits() == 32); auto uint32_dtype = DataType(kDLUInt, 32, op_val->dtype.lanes()); auto uint32_v = CastNode::make(uint32_dtype, op_val); - // to be endian invariant. + // to be endian invariant. return CallNode::make(op->dtype, CallNode::reinterpret, {uint32_v << 16}, CallNode::PureIntrinsic); @@ -203,7 +203,7 @@ class BF16LowerRewriter : StmtExprMutator { uint32_t rounding_bias = ((U32 >> 16) & 1) + UINT32_C(0x7FFF); return static_cast((U32 + rounding_bias) >> 16);*/ auto rounding_bias = ((uint32_v >> 16) & 1) + make_const(uint16_dtype, 0x7FFF); - // to be endian invariant. + // to be endian invariant. return CastNode::make(uint16_dtype, {(uint32_v + rounding_bias) >> 16}); } if (op->value.same_as(op_val)) return GetRef(op); From bf1b747a96c499f79a1a2205ea163ad8c8206d86 Mon Sep 17 00:00:00 2001 From: Menooker Date: Mon, 1 Jun 2020 10:17:51 +0800 Subject: [PATCH 33/43] Code style, use kDLBfloat --- include/tvm/runtime/c_runtime_api.h | 4 ++-- include/tvm/runtime/data_type.h | 6 +++--- python/tvm/_ffi/_cython/base.pxi | 3 ++- python/tvm/_ffi/runtime_ctypes.py | 15 ++++++++------- src/tir/transforms/bf16_legalize.cc | 6 +++--- 5 files changed, 18 insertions(+), 16 deletions(-) diff --git a/include/tvm/runtime/c_runtime_api.h b/include/tvm/runtime/c_runtime_api.h index 736b5eff2332..db00d8f7e07c 100644 --- a/include/tvm/runtime/c_runtime_api.h +++ b/include/tvm/runtime/c_runtime_api.h @@ -94,7 +94,7 @@ typedef enum { // The next few fields are extension types // that is used by TVM API calls. kTVMOpaqueHandle = 3U, - kTVMNullptr = 4U, + // 4 is for kDLBfloat kTVMDataType = 5U, kTVMContext = 6U, kTVMDLTensorHandle = 7U, @@ -112,9 +112,9 @@ typedef enum { kTVMExtBegin = 15U, kTVMNNVMFirst = 16U, kTVMNNVMLast = 20U, + kTVMNullptr = 21U, // The following section of code is used for non-reserved types. kTVMExtReserveEnd = 64U, - kTVMBFloat = 65U, kTVMExtEnd = 128U, // The rest of the space is used for custom, user-supplied datatypes kTVMCustomBegin = 129U, diff --git a/include/tvm/runtime/data_type.h b/include/tvm/runtime/data_type.h index 20eb061a012c..1dfaf6c888d3 100644 --- a/include/tvm/runtime/data_type.h +++ b/include/tvm/runtime/data_type.h @@ -45,8 +45,8 @@ class DataType { kInt = kDLInt, kUInt = kDLUInt, kFloat = kDLFloat, + kBFloat = kDLBfloat, kHandle = TVMTypeCode::kTVMOpaqueHandle, - kBFloat = kTVMBFloat, }; /*! \brief default constructor */ DataType() {} @@ -303,7 +303,7 @@ inline const char* TypeCode2Str(int type_code) { return "Object"; case kTVMObjectRValueRefArg: return "ObjectRValueRefArg"; - case kTVMBFloat: + case kDLBfloat: return "bfloat"; default: LOG(FATAL) << "unknown type_code=" << static_cast(type_code); @@ -372,7 +372,7 @@ inline DLDataType String2DLDataType(std::string s) { t.lanes = 1; return t; } else if (s.substr(0, 6) == "bfloat") { - t.code = kTVMBFloat; + t.code = kDLBfloat; scan = s.c_str() + 6; } else if (s.substr(0, 6) == "custom") { t.code = ParseCustomDatatype(s, &scan); diff --git a/python/tvm/_ffi/_cython/base.pxi b/python/tvm/_ffi/_cython/base.pxi index 0da66ac2e034..d753f4a48eca 100644 --- a/python/tvm/_ffi/_cython/base.pxi +++ b/python/tvm/_ffi/_cython/base.pxi @@ -27,7 +27,7 @@ cdef enum TVMTypeCode: kUInt = 1 kFloat = 2 kTVMOpaqueHandle = 3 - kTVMNullptr = 4 + kBFloat = 4 kTVMDataType = 5 kTVMContext = 6 kTVMDLTensorHandle = 7 @@ -39,6 +39,7 @@ cdef enum TVMTypeCode: kTVMNDArrayHandle = 13 kTVMObjectRefArg = 14 kTVMExtBegin = 15 + kTVMNullptr = 21 cdef extern from "tvm/runtime/c_runtime_api.h": ctypedef struct DLDataType: diff --git a/python/tvm/_ffi/runtime_ctypes.py b/python/tvm/_ffi/runtime_ctypes.py index c6a6e09206d3..d310f874ecf3 100644 --- a/python/tvm/_ffi/runtime_ctypes.py +++ b/python/tvm/_ffi/runtime_ctypes.py @@ -29,7 +29,8 @@ class TypeCode(object): UINT = 1 FLOAT = 2 HANDLE = 3 - NULL = 4 + BFLOAT = 4 + NULL = 21 TVM_TYPE = 5 TVM_CONTEXT = 6 DLTENSOR_HANDLE = 7 @@ -58,8 +59,8 @@ class DataType(ctypes.Structure): 0 : 'int', 1 : 'uint', 2 : 'float', - 4 : 'handle', - 65: 'bfloat' + 3 : 'handle', + 4 : 'bfloat' } def __init__(self, type_str): super(DataType, self).__init__() @@ -86,13 +87,13 @@ def __init__(self, type_str): elif head.startswith("float"): self.type_code = 2 head = head[5:] - elif head.startswith("bfloat"): - self.type_code = 65 - head = head[6:] elif head.startswith("handle"): - self.type_code = 4 + self.type_code = 3 bits = 64 head = "" + elif head.startswith("bfloat"): + self.type_code = 4 + head = head[6:] elif head.startswith("custom"): # pylint: disable=import-outside-toplevel import tvm.runtime._ffi_api diff --git a/src/tir/transforms/bf16_legalize.cc b/src/tir/transforms/bf16_legalize.cc index efdb5320dd4d..a8ee4a890545 100644 --- a/src/tir/transforms/bf16_legalize.cc +++ b/src/tir/transforms/bf16_legalize.cc @@ -88,7 +88,7 @@ class BF16PromoteRewriter : public StmtExprMutator { if (!is_bfloat16) \ return ret; \ else \ - return CastNode::make(DataType(kTVMBFloat, 16, 1), ret); \ + return CastNode::make(DataType(kDLBfloat, 16, 1), ret); \ } \ } @@ -320,7 +320,7 @@ class BF16LowerRewriter : StmtExprMutator { return StmtExprMutator::VisitExpr_(op); } - void alter_buffers(PrimFuncNode* op) { + void AlterBuffers(PrimFuncNode* op) { std::vector> changes; for (auto& itr : op->buffer_map) { auto oldbuf = itr.second; @@ -367,7 +367,7 @@ Pass BF16TypeLowering() { auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { auto* n = f.CopyOnWrite(); BF16LowerRewriter lowerer; - lowerer.alter_buffers(n); + lowerer.AlterBuffers(n); n->body = lowerer(std::move(n->body)); return f; }; From b1c1951c8ff8789c28e61c9feb4f254484fd9abd Mon Sep 17 00:00:00 2001 From: Menooker Date: Mon, 1 Jun 2020 10:22:46 +0800 Subject: [PATCH 34/43] format --- src/tir/transforms/bf16_legalize.cc | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/src/tir/transforms/bf16_legalize.cc b/src/tir/transforms/bf16_legalize.cc index a8ee4a890545..31015398c63e 100644 --- a/src/tir/transforms/bf16_legalize.cc +++ b/src/tir/transforms/bf16_legalize.cc @@ -76,20 +76,20 @@ class BF16PromoteRewriter : public StmtExprMutator { PrimExpr VisitExpr_(const GENode* op) final; }; -#define DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(OP, FUNC) \ - PrimExpr BF16PromoteRewriter::VisitExpr_(const OP* op) { \ - PrimExpr a, b; \ - bool is_bfloat16; \ - std::tie(a, b) = DoCast(op->a, op->b, &is_bfloat16); \ - if (a.same_as(op->a) && b.same_as(op->b)) { \ - return GetRef(op); \ - } else { \ - auto ret = FUNC(a, b); \ - if (!is_bfloat16) \ - return ret; \ - else \ +#define DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(OP, FUNC) \ + PrimExpr BF16PromoteRewriter::VisitExpr_(const OP* op) { \ + PrimExpr a, b; \ + bool is_bfloat16; \ + std::tie(a, b) = DoCast(op->a, op->b, &is_bfloat16); \ + if (a.same_as(op->a) && b.same_as(op->b)) { \ + return GetRef(op); \ + } else { \ + auto ret = FUNC(a, b); \ + if (!is_bfloat16) \ + return ret; \ + else \ return CastNode::make(DataType(kDLBfloat, 16, 1), ret); \ - } \ + } \ } #define DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH_NO_CAST(OP, FUNC) \ From 86bcc16ccbcb9af1cd69862216a6d1b7ef53a438 Mon Sep 17 00:00:00 2001 From: Menooker Date: Mon, 1 Jun 2020 10:29:14 +0800 Subject: [PATCH 35/43] update dlpack --- 3rdparty/dlpack | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/3rdparty/dlpack b/3rdparty/dlpack index 0acb731e0e43..3ec04430e89a 160000 --- a/3rdparty/dlpack +++ b/3rdparty/dlpack @@ -1 +1 @@ -Subproject commit 0acb731e0e43d15deee27b66f10e4c5b4e667913 +Subproject commit 3ec04430e89a6834e5a1b99471f415fa939bf642 From 7612f9d3d816591c1227239255dde1a22bb87e15 Mon Sep 17 00:00:00 2001 From: Menooker Date: Fri, 5 Jun 2020 12:42:06 +0800 Subject: [PATCH 36/43] change back nullptr typecode --- include/tvm/runtime/c_runtime_api.h | 3 +-- python/tvm/_ffi/_cython/base.pxi | 2 +- python/tvm/_ffi/runtime_ctypes.py | 2 +- 3 files changed, 3 insertions(+), 4 deletions(-) diff --git a/include/tvm/runtime/c_runtime_api.h b/include/tvm/runtime/c_runtime_api.h index db00d8f7e07c..bb38ad8a84df 100644 --- a/include/tvm/runtime/c_runtime_api.h +++ b/include/tvm/runtime/c_runtime_api.h @@ -94,7 +94,7 @@ typedef enum { // The next few fields are extension types // that is used by TVM API calls. kTVMOpaqueHandle = 3U, - // 4 is for kDLBfloat + kTVMNullptr = 4U, kTVMDataType = 5U, kTVMContext = 6U, kTVMDLTensorHandle = 7U, @@ -112,7 +112,6 @@ typedef enum { kTVMExtBegin = 15U, kTVMNNVMFirst = 16U, kTVMNNVMLast = 20U, - kTVMNullptr = 21U, // The following section of code is used for non-reserved types. kTVMExtReserveEnd = 64U, kTVMExtEnd = 128U, diff --git a/python/tvm/_ffi/_cython/base.pxi b/python/tvm/_ffi/_cython/base.pxi index d753f4a48eca..54cd127f7d79 100644 --- a/python/tvm/_ffi/_cython/base.pxi +++ b/python/tvm/_ffi/_cython/base.pxi @@ -27,6 +27,7 @@ cdef enum TVMTypeCode: kUInt = 1 kFloat = 2 kTVMOpaqueHandle = 3 + kTVMNullptr = 4 kBFloat = 4 kTVMDataType = 5 kTVMContext = 6 @@ -39,7 +40,6 @@ cdef enum TVMTypeCode: kTVMNDArrayHandle = 13 kTVMObjectRefArg = 14 kTVMExtBegin = 15 - kTVMNullptr = 21 cdef extern from "tvm/runtime/c_runtime_api.h": ctypedef struct DLDataType: diff --git a/python/tvm/_ffi/runtime_ctypes.py b/python/tvm/_ffi/runtime_ctypes.py index d310f874ecf3..11d2d406155a 100644 --- a/python/tvm/_ffi/runtime_ctypes.py +++ b/python/tvm/_ffi/runtime_ctypes.py @@ -30,7 +30,7 @@ class TypeCode(object): FLOAT = 2 HANDLE = 3 BFLOAT = 4 - NULL = 21 + NULL = 4 TVM_TYPE = 5 TVM_CONTEXT = 6 DLTENSOR_HANDLE = 7 From 1b85a007bc018bc4c2b85e99cc31875062989f87 Mon Sep 17 00:00:00 2001 From: Menooker Date: Fri, 5 Jun 2020 12:52:12 +0800 Subject: [PATCH 37/43] remove python runtime type for bf16 --- python/tvm/_ffi/runtime_ctypes.py | 1 - python/tvm/runtime/ndarray.py | 10 ++-------- 2 files changed, 2 insertions(+), 9 deletions(-) diff --git a/python/tvm/_ffi/runtime_ctypes.py b/python/tvm/_ffi/runtime_ctypes.py index 7dc3d26a8bab..074a69410867 100644 --- a/python/tvm/_ffi/runtime_ctypes.py +++ b/python/tvm/_ffi/runtime_ctypes.py @@ -29,7 +29,6 @@ class ArgTypeCode(object): UINT = 1 FLOAT = 2 HANDLE = 3 - BFLOAT = 4 NULL = 4 TVM_TYPE = 5 TVM_CONTEXT = 6 diff --git a/python/tvm/runtime/ndarray.py b/python/tvm/runtime/ndarray.py index d80c84b01846..060673dc19c6 100644 --- a/python/tvm/runtime/ndarray.py +++ b/python/tvm/runtime/ndarray.py @@ -138,10 +138,7 @@ def copyfrom(self, source_array): if source_array.shape != shape: raise ValueError("array shape do not match the shape of NDArray {0} vs {1}".format( source_array.shape, shape)) - if dtype == 'bfloat16': - source_array = np.ascontiguousarray(source_array, dtype='uint16') - else: - source_array = np.ascontiguousarray(source_array, dtype=dtype) + source_array = np.ascontiguousarray(source_array, dtype=dtype) assert source_array.flags['C_CONTIGUOUS'] data = source_array.ctypes.data_as(ctypes.c_void_p) nbytes = ctypes.c_size_t(source_array.size * source_array.dtype.itemsize) @@ -170,10 +167,7 @@ def asnumpy(self): shape = shape + (t.lanes,) t.lanes = 1 dtype = str(t) - if dtype == 'bfloat16': - np_arr = np.empty(shape, dtype='uint16') - else: - np_arr = np.empty(shape, dtype=dtype) + np_arr = np.empty(shape, dtype=dtype) assert np_arr.flags['C_CONTIGUOUS'] data = np_arr.ctypes.data_as(ctypes.c_void_p) nbytes = ctypes.c_size_t(np_arr.size * np_arr.dtype.itemsize) From c0cb1efcd6057c44f6e088050be40b14569e6513 Mon Sep 17 00:00:00 2001 From: Menooker Date: Fri, 12 Jun 2020 09:24:08 +0800 Subject: [PATCH 38/43] fix code style of RoundToNearestEven --- src/tir/transforms/bf16_legalize.cc | 23 +++++++++++------------ 1 file changed, 11 insertions(+), 12 deletions(-) diff --git a/src/tir/transforms/bf16_legalize.cc b/src/tir/transforms/bf16_legalize.cc index 31015398c63e..6522806a0e71 100644 --- a/src/tir/transforms/bf16_legalize.cc +++ b/src/tir/transforms/bf16_legalize.cc @@ -150,20 +150,19 @@ class BF16CastEliminationRewriter : public StmtExprMutator { } }; -// implementation from -// https://github.com/pytorch/pytorch/blob/master/c10/util/BFloat16.h -inline uint16_t round_to_nearest_even(float src) { +union FloatCaster { + uint32_t u32; + float f32; +}; + +uint16_t RoundToNearestEven(float src) { if (std::isnan(src)) { return UINT16_C(0x7FC0); } else { - union { - uint32_t U32; - float F32; - }; - - F32 = src; - uint32_t rounding_bias = ((U32 >> 16) & 1) + UINT32_C(0x7FFF); - return static_cast((U32 + rounding_bias) >> 16); + FloatCaster caster; + caster.f32 = src; + uint32_t rounding_bias = ((caster.u32 >> 16) & 1) + UINT32_C(0x7FFF); + return static_cast((caster.u32 + rounding_bias) >> 16); } } @@ -315,7 +314,7 @@ class BF16LowerRewriter : StmtExprMutator { PrimExpr VisitExpr_(const FloatImmNode* op) final { if (op->dtype.is_bfloat16()) { return IntImm(DataType::UInt(16, op->dtype.lanes()), - round_to_nearest_even(static_cast(op->value))); + RoundToNearestEven(static_cast(op->value))); } return StmtExprMutator::VisitExpr_(op); } From a6341cbb7374d487c0acf341fe915c78931690f5 Mon Sep 17 00:00:00 2001 From: Menooker Date: Fri, 12 Jun 2020 19:59:03 +0800 Subject: [PATCH 39/43] merge newest master --- src/tir/transforms/bf16_legalize.cc | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/src/tir/transforms/bf16_legalize.cc b/src/tir/transforms/bf16_legalize.cc index 6522806a0e71..78c8a909d002 100644 --- a/src/tir/transforms/bf16_legalize.cc +++ b/src/tir/transforms/bf16_legalize.cc @@ -58,8 +58,8 @@ class BF16PromoteRewriter : public StmtExprMutator { if (*is_bfloat16) { DataType fp32ty(kDLFloat, 32, 1); - a = CastNode::make(fp32ty, a); - b = CastNode::make(fp32ty, b); + a = Cast(fp32ty, a); + b = Cast(fp32ty, b); } return std::make_tuple(a, b); } @@ -88,7 +88,7 @@ class BF16PromoteRewriter : public StmtExprMutator { if (!is_bfloat16) \ return ret; \ else \ - return CastNode::make(DataType(kDLBfloat, 16, 1), ret); \ + return Cast(DataType(kDLBfloat, 16, 1), ret); \ } \ } @@ -146,7 +146,7 @@ class BF16CastEliminationRewriter : public StmtExprMutator { } } if (op->value.same_as(op_val)) return GetRef(op); - return CastNode::make(op->dtype, op_val); + return Cast(op->dtype, op_val); } }; @@ -186,9 +186,9 @@ class BF16LowerRewriter : StmtExprMutator { // if is cast_from_bf16, check if is to fp32 CHECK(op->dtype.is_float() && op->dtype.bits() == 32); auto uint32_dtype = DataType(kDLUInt, 32, op_val->dtype.lanes()); - auto uint32_v = CastNode::make(uint32_dtype, op_val); + auto uint32_v = Cast(uint32_dtype, op_val); // to be endian invariant. - return CallNode::make(op->dtype, CallNode::reinterpret, {uint32_v << 16}, + return Call(op->dtype, CallNode::reinterpret, {uint32_v << 16}, CallNode::PureIntrinsic); } else if (op->dtype.is_bfloat16()) { @@ -196,17 +196,17 @@ class BF16LowerRewriter : StmtExprMutator { CHECK(op->value->dtype.is_float() && op->value->dtype.bits() == 32); auto uint32_dtype = DataType(kDLUInt, 32, op_val->dtype.lanes()); auto uint32_v = - CallNode::make(uint32_dtype, CallNode::reinterpret, {op_val}, CallNode::PureIntrinsic); + Call(uint32_dtype, CallNode::reinterpret, {op_val}, CallNode::PureIntrinsic); auto uint16_dtype = DataType(kDLUInt, 16, op_val->dtype.lanes()); /* the following TIR is equivalent to the C++ code below: uint32_t rounding_bias = ((U32 >> 16) & 1) + UINT32_C(0x7FFF); return static_cast((U32 + rounding_bias) >> 16);*/ auto rounding_bias = ((uint32_v >> 16) & 1) + make_const(uint16_dtype, 0x7FFF); // to be endian invariant. - return CastNode::make(uint16_dtype, {(uint32_v + rounding_bias) >> 16}); + return Cast(uint16_dtype, {(uint32_v + rounding_bias) >> 16}); } if (op->value.same_as(op_val)) return GetRef(op); - return CastNode::make(op->dtype, op_val); + return Cast(op->dtype, op_val); } PrimExpr VisitExpr_(const VarNode* op) final { @@ -227,7 +227,7 @@ class BF16LowerRewriter : StmtExprMutator { Stmt node_holder; const AllocateNode* newop; if (op->dtype.is_bfloat16()) { - auto v = AllocateNode::make(op->buffer_var, DataType::UInt(16, op->dtype.lanes()), + auto v = Allocate(op->buffer_var, DataType::UInt(16, op->dtype.lanes()), op->extents, op->condition, op->body); node_holder = v; newop = static_cast(v.operator->()); @@ -256,13 +256,13 @@ class BF16LowerRewriter : StmtExprMutator { if (auto buffer = op->node.as()) { auto itr = buffer_remap.find(buffer); if (itr != buffer_remap.end()) { - newop_holder = AttrStmtNode::make(itr->second, op->attr_key, op->value, op->body); + newop_holder = AttrStmt(itr->second, op->attr_key, op->value, op->body); newop = newop_holder.as(); } } else if (auto buffer = op->node.as()) { auto itr = var_remap.find(buffer); if (itr != var_remap.end()) { - newop_holder = AttrStmtNode::make(itr->second, op->attr_key, op->value, op->body); + newop_holder = AttrStmt(itr->second, op->attr_key, op->value, op->body); newop = newop_holder.as(); } } @@ -306,7 +306,7 @@ class BF16LowerRewriter : StmtExprMutator { if (index.same_as(op->index) && predicate.same_as(op->predicate) && !is_bf16) { return GetRef(op); } else { - return LoadNode::make(is_bf16 ? DataType::UInt(16, op->dtype.lanes()) : op->dtype, + return Load(is_bf16 ? DataType::UInt(16, op->dtype.lanes()) : op->dtype, op->buffer_var, index, predicate); } } From 656e3e484000df957bb7db69146a2e57280019e1 Mon Sep 17 00:00:00 2001 From: Menooker Date: Fri, 12 Jun 2020 20:04:40 +0800 Subject: [PATCH 40/43] format --- src/tir/transforms/bf16_legalize.cc | 42 ++++++++++++++--------------- 1 file changed, 20 insertions(+), 22 deletions(-) diff --git a/src/tir/transforms/bf16_legalize.cc b/src/tir/transforms/bf16_legalize.cc index 78c8a909d002..d201abac8fc4 100644 --- a/src/tir/transforms/bf16_legalize.cc +++ b/src/tir/transforms/bf16_legalize.cc @@ -76,20 +76,20 @@ class BF16PromoteRewriter : public StmtExprMutator { PrimExpr VisitExpr_(const GENode* op) final; }; -#define DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(OP, FUNC) \ - PrimExpr BF16PromoteRewriter::VisitExpr_(const OP* op) { \ - PrimExpr a, b; \ - bool is_bfloat16; \ - std::tie(a, b) = DoCast(op->a, op->b, &is_bfloat16); \ - if (a.same_as(op->a) && b.same_as(op->b)) { \ - return GetRef(op); \ - } else { \ - auto ret = FUNC(a, b); \ - if (!is_bfloat16) \ - return ret; \ - else \ - return Cast(DataType(kDLBfloat, 16, 1), ret); \ - } \ +#define DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(OP, FUNC) \ + PrimExpr BF16PromoteRewriter::VisitExpr_(const OP* op) { \ + PrimExpr a, b; \ + bool is_bfloat16; \ + std::tie(a, b) = DoCast(op->a, op->b, &is_bfloat16); \ + if (a.same_as(op->a) && b.same_as(op->b)) { \ + return GetRef(op); \ + } else { \ + auto ret = FUNC(a, b); \ + if (!is_bfloat16) \ + return ret; \ + else \ + return Cast(DataType(kDLBfloat, 16, 1), ret); \ + } \ } #define DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH_NO_CAST(OP, FUNC) \ @@ -188,15 +188,13 @@ class BF16LowerRewriter : StmtExprMutator { auto uint32_dtype = DataType(kDLUInt, 32, op_val->dtype.lanes()); auto uint32_v = Cast(uint32_dtype, op_val); // to be endian invariant. - return Call(op->dtype, CallNode::reinterpret, {uint32_v << 16}, - CallNode::PureIntrinsic); + return Call(op->dtype, CallNode::reinterpret, {uint32_v << 16}, CallNode::PureIntrinsic); } else if (op->dtype.is_bfloat16()) { // if is cast_to_bf16, check if op->value is fp32 CHECK(op->value->dtype.is_float() && op->value->dtype.bits() == 32); auto uint32_dtype = DataType(kDLUInt, 32, op_val->dtype.lanes()); - auto uint32_v = - Call(uint32_dtype, CallNode::reinterpret, {op_val}, CallNode::PureIntrinsic); + auto uint32_v = Call(uint32_dtype, CallNode::reinterpret, {op_val}, CallNode::PureIntrinsic); auto uint16_dtype = DataType(kDLUInt, 16, op_val->dtype.lanes()); /* the following TIR is equivalent to the C++ code below: uint32_t rounding_bias = ((U32 >> 16) & 1) + UINT32_C(0x7FFF); @@ -227,8 +225,8 @@ class BF16LowerRewriter : StmtExprMutator { Stmt node_holder; const AllocateNode* newop; if (op->dtype.is_bfloat16()) { - auto v = Allocate(op->buffer_var, DataType::UInt(16, op->dtype.lanes()), - op->extents, op->condition, op->body); + auto v = Allocate(op->buffer_var, DataType::UInt(16, op->dtype.lanes()), op->extents, + op->condition, op->body); node_holder = v; newop = static_cast(v.operator->()); } else { @@ -306,8 +304,8 @@ class BF16LowerRewriter : StmtExprMutator { if (index.same_as(op->index) && predicate.same_as(op->predicate) && !is_bf16) { return GetRef(op); } else { - return Load(is_bf16 ? DataType::UInt(16, op->dtype.lanes()) : op->dtype, - op->buffer_var, index, predicate); + return Load(is_bf16 ? DataType::UInt(16, op->dtype.lanes()) : op->dtype, op->buffer_var, + index, predicate); } } From 25c811c85e5e2d998b0fa7d604cb8370c1adb210 Mon Sep 17 00:00:00 2001 From: Menooker Date: Sat, 13 Jun 2020 11:23:51 +0800 Subject: [PATCH 41/43] pylint on test --- .../test_tir_transform_bf16_legalize.py | 84 +++++++++++-------- 1 file changed, 50 insertions(+), 34 deletions(-) diff --git a/tests/python/unittest/test_tir_transform_bf16_legalize.py b/tests/python/unittest/test_tir_transform_bf16_legalize.py index f71396ebda93..77a06022ac70 100644 --- a/tests/python/unittest/test_tir_transform_bf16_legalize.py +++ b/tests/python/unittest/test_tir_transform_bf16_legalize.py @@ -17,11 +17,11 @@ import tvm import topi from tvm import te -from tvm.tir import const def lower_stmt(sche, params, passfunc): - func = tvm.driver.build_module.form_irmodule(sche, params, "main", None)["main"] + func = tvm.driver.build_module.form_irmodule( + sche, params, "main", None)["main"] func = passfunc()( tvm.IRModule.from_expr(func))["main"] stmt = func.body @@ -35,16 +35,17 @@ def runpass(op, passfunc): c = te.compute((100,), lambda i: op(a[i], b[i])) s = te.create_schedule(c.op) return lower_stmt(s, [a, b, c], passfunc) - + def get_promoted(op): a = te.placeholder((100,), dtype='bfloat16') b = te.placeholder((100,), dtype='bfloat16') c = te.compute((100,), lambda i: - topi.cast(op(topi.cast(a[i],'float'), - topi.cast(b[i],'float')), 'bfloat16') - ) + topi.cast(op(topi.cast(a[i], 'float'), + topi.cast(b[i], 'float')), 'bfloat16') + ) s = te.create_schedule(c.op) - func = tvm.driver.build_module.form_irmodule(s, [a,b,c], "main", None)["main"] + func = tvm.driver.build_module.form_irmodule( + s, [a, b, c], "main", None)["main"] return func.body def test_promoted(op): @@ -55,11 +56,14 @@ def test_promoted(op): test_promoted(topi.multiply) test_promoted(topi.divide) + def test_eliminate(): def to32(v): return topi.cast(v, 'float') + def to16(v): return topi.cast(v, 'bfloat16') + def get_eliminated(): a = te.placeholder((100,), dtype='bfloat16') b = te.placeholder((100,), dtype='bfloat16') @@ -92,60 +96,72 @@ def get_target(): b = te.placeholder((100,), dtype='bfloat16') c = te.compute((100,), lambda i: to16( topi.add(topi.add( - to32(a[i]), - to32(b[i]), - ), - topi.add( - to32(a[i]), - to32(b[i]), + to32(a[i]), + to32(b[i]), + ), + topi.add( + to32(a[i]), + to32(b[i]), + ) ) - ) )) s = te.create_schedule(c.op) - func = tvm.driver.build_module.form_irmodule(s, [a,b,c], "main", None)["main"] + func = tvm.driver.build_module.form_irmodule( + s, [a, b, c], "main", None)["main"] return func.body tvm.ir.assert_structural_equal(get_eliminated(), get_target()) + def test_legalize(): def to32(v): uint32_v = topi.cast(v, "uint32") - uint32_v = tvm.tir.call_pure_intrin("uint32", "shift_left", uint32_v, tvm.tir.const(16, "uint32")) + uint32_v = tvm.tir.call_pure_intrin( + "uint32", "shift_left", uint32_v, tvm.tir.const(16, "uint32")) return tvm.tir.call_pure_intrin("float32", "reinterpret", uint32_v) + def to16(v): uint32_v = tvm.tir.call_pure_intrin("uint32", "reinterpret", v) - rounding_bias = tvm.tir.call_pure_intrin("uint32", "shift_right", uint32_v, tvm.tir.const(16, "uint32")) - rounding_bias = tvm.tir.call_pure_intrin("uint32", "bitwise_and", rounding_bias, tvm.tir.const(1, "uint32")) + rounding_bias = tvm.tir.call_pure_intrin( + "uint32", "shift_right", uint32_v, tvm.tir.const(16, "uint32")) + rounding_bias = tvm.tir.call_pure_intrin( + "uint32", "bitwise_and", rounding_bias, tvm.tir.const(1, "uint32")) rounding_bias = rounding_bias + tvm.tir.const(0x7FFF, "uint16") uint32_v = uint32_v + rounding_bias - uint32_v = tvm.tir.call_pure_intrin("uint32", "shift_right", uint32_v, tvm.tir.const(16, "uint32")) + uint32_v = tvm.tir.call_pure_intrin( + "uint32", "shift_right", uint32_v, tvm.tir.const(16, "uint32")) return topi.cast(uint32_v, 'uint16') def check(fcompute_before, fcompute_after): - a = te.placeholder((100,), dtype='bfloat16', name = 'A') - b = te.placeholder((100,), dtype='bfloat16', name = 'B') - c = te.compute((100,), fcompute_before(a,b), name = 'C') + a = te.placeholder((100,), dtype='bfloat16', name='A') + b = te.placeholder((100,), dtype='bfloat16', name='B') + c = te.compute((100,), fcompute_before(a, b), name='C') s = te.create_schedule(c.op) stmt = lower_stmt(s, [a, b, c], tvm.tir.transform.BF16Legalize) - a = te.placeholder((100,), dtype='uint16', name = 'A') - b = te.placeholder((100,), dtype='uint16', name = 'B') - c = te.compute((100,), fcompute_after(a,b), name = 'C') + a = te.placeholder((100,), dtype='uint16', name='A') + b = te.placeholder((100,), dtype='uint16', name='B') + c = te.compute((100,), fcompute_after(a, b), name='C') s = te.create_schedule(c.op) - func = tvm.driver.build_module.form_irmodule(s, [a,b,c], "main", None)["main"] + func = tvm.driver.build_module.form_irmodule( + s, [a, b, c], "main", None)["main"] tvm.ir.assert_structural_equal(stmt, func.body) - def orig1(a,b): - return lambda i: a[i]+b[i]+a[99-i]+b[99-i] - def after1(a,b): - return lambda i: to16(to32(a[i])+to32(b[i])+to32(a[99-i])+to32(b[99-i])) - def orig2(a,b): - return lambda i: a[i]*b[i]+a[99-i]*b[99-i]+a[i] - def after2(a,b): - return lambda i: to16(to32(a[i])*to32(b[i])+to32(a[99-i])*to32(b[99-i])+to32(a[i])) + def orig1(a, b): + return lambda i: a[i] + b[i] + a[99-i] + b[99-i] + + def after1(a, b): + return lambda i: to16(to32(a[i]) + to32(b[i] ) + to32(a[99 - i]) + to32(b[99 - i])) + + def orig2(a, b): + return lambda i: a[i] * b[i] + a[99 - i] * b[99 - i] + a[i] + + def after2(a, b): + return lambda i: to16(to32(a[i]) * to32(b[i]) + to32(a[99 - i]) * to32(b[99 - i]) + to32(a[i])) check(orig1, after1) check(orig2, after2) + if __name__ == "__main__": test_promote() test_eliminate() From 7a72ea99aa0c9aafc70fbd05641523fa5ea8d475 Mon Sep 17 00:00:00 2001 From: Menooker Date: Mon, 15 Jun 2020 13:32:36 +0800 Subject: [PATCH 42/43] make it run on newest master --- src/tir/transforms/bf16_legalize.cc | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/tir/transforms/bf16_legalize.cc b/src/tir/transforms/bf16_legalize.cc index d201abac8fc4..07f4775ded50 100644 --- a/src/tir/transforms/bf16_legalize.cc +++ b/src/tir/transforms/bf16_legalize.cc @@ -322,8 +322,7 @@ class BF16LowerRewriter : StmtExprMutator { for (auto& itr : op->buffer_map) { auto oldbuf = itr.second; if (oldbuf->dtype.is_bfloat16()) { - auto newbuf = - BufferNode::make(oldbuf->data, DataType::UInt(16, oldbuf->dtype.lanes()), oldbuf->shape, + auto newbuf = Buffer(oldbuf->data, DataType::UInt(16, oldbuf->dtype.lanes()), oldbuf->shape, oldbuf->strides, oldbuf->elem_offset, oldbuf->name, oldbuf->scope, oldbuf->data_alignment, oldbuf->offset_factor, oldbuf->buffer_type); buffer_remap[oldbuf.operator->()] = newbuf; From 318ddc9524ea0ce3f1712ad6a8847c3cd78b9673 Mon Sep 17 00:00:00 2001 From: Menooker Date: Wed, 17 Jun 2020 10:45:08 +0800 Subject: [PATCH 43/43] type code changes etc. --- include/tvm/runtime/data_type.h | 2 +- python/tvm/_ffi/_cython/base.pxi | 1 - python/tvm/_ffi/runtime_ctypes.py | 2 +- 3 files changed, 2 insertions(+), 3 deletions(-) diff --git a/include/tvm/runtime/data_type.h b/include/tvm/runtime/data_type.h index 210fa8e86e01..cb817a89ab81 100644 --- a/include/tvm/runtime/data_type.h +++ b/include/tvm/runtime/data_type.h @@ -358,7 +358,7 @@ inline DLDataType String2DLDataType(std::string s) { t.lanes = 1; return t; } else if (s.substr(0, 6) == "bfloat") { - t.code = kDLBfloat; + t.code = DataType::kBFloat; scan = s.c_str() + 6; } else if (s.substr(0, 6) == "custom") { t.code = ParseCustomDatatype(s, &scan); diff --git a/python/tvm/_ffi/_cython/base.pxi b/python/tvm/_ffi/_cython/base.pxi index 86fe24191518..8c9e413813b9 100644 --- a/python/tvm/_ffi/_cython/base.pxi +++ b/python/tvm/_ffi/_cython/base.pxi @@ -28,7 +28,6 @@ cdef enum TVMArgTypeCode: kFloat = 2 kTVMOpaqueHandle = 3 kTVMNullptr = 4 - kBFloat = 4 kTVMDataType = 5 kTVMContext = 6 kTVMDLTensorHandle = 7 diff --git a/python/tvm/_ffi/runtime_ctypes.py b/python/tvm/_ffi/runtime_ctypes.py index 074a69410867..a7bfb3278784 100644 --- a/python/tvm/_ffi/runtime_ctypes.py +++ b/python/tvm/_ffi/runtime_ctypes.py @@ -99,7 +99,7 @@ def __init__(self, type_str): bits = 64 head = "" elif head.startswith("bfloat"): - self.type_code = 4 + self.type_code = DataTypeCode.BFLOAT head = head[6:] elif head.startswith("custom"): # pylint: disable=import-outside-toplevel