-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[DataType] Add bfloat16 #5601
[DataType] Add bfloat16 #5601
Conversation
cc @gussmith23 might be related to BYOD |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @Menooker for the great work! The proposed changes mostly looks good. I left a few comments.
src/target/llvm/codegen_llvm.cc
Outdated
@@ -309,6 +309,9 @@ llvm::Type* CodeGenLLVM::DTypeToLLVMType(const DataType& dtype) const { | |||
default: | |||
LOG(FATAL) << "do not support " << dtype; | |||
} | |||
} else if (dtype.is_bfloat()) { | |||
CHECK_EQ(dtype.bits(), 16); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since bfloat is assumed to be 16bit, can we keep the terminology more consistent? Since the data type is termed as bf
, bf16
, bfloat16
, bfloat
in the proposed change. Or are we going to support more data types like bfloat18 and bfloat20 in the future?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry for the inclarity. I think in bfloat[X], only X=16 makes sense. But TVM's type system allows specifying the bits
of a type. So here is the checking to make sure it is bf16.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A good question. Will we treat TensorFloat-32 as bfloat20? If so, then bits
is useful to distinguish those.
if __name__ == "__main__": | ||
test_promote() | ||
test_eliminate() | ||
test_legalize() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please leave a new line at EOF, even this is test script :)
def np_bf162np_float(arr): | ||
''' Convert a numpy array of bf16 (uint16) to a numpy array | ||
of float''' | ||
u32 = np.left_shift(arr.astype('uint32'), 16) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are we going to produce a potential endianness problem here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In my understanding, fp32=>bf16 casting preserves the higher-ordered bits (bits 31-16). We don't need to know whether the higher-ordered bits are stored in a larger address or a smaller address (which is the endianness), we just need to get the bits by shifting, which is well-defined - just using shifting is enough.
Reference: wiki for fp32 bit order
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not 100% sure about this. I have tested the code on x86, not (yet) on other arch.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we reused the following code snippet, which preserves endianness checks?
And it has wrapper functions below.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If my understanding is correct, we don't need to care about endianness. BF16 conversions only involves getting higher-ordered bits. And the operation to get higher-ordered bits in C++/Numpy is well-defined.
src/target/llvm/codegen_llvm.cc
Outdated
@@ -906,7 +954,7 @@ DEFINE_CODEGEN_BINARY_OP(Mul); | |||
llvm::Value* CodeGenLLVM::Create##Op(DataType t, llvm::Value* a, llvm::Value* b) { \ | |||
if (t.is_int()) { \ | |||
return builder_->CreateICmpS##Op(a, b); \ | |||
} else if (t.is_uint()) { \ | |||
} else if (t.is_uint() || t.is_bfloat()) { \ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Isn't comparing bfloat16 this way risky?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FP32/FP64 comparasion are also bit-wise in my understanding.
src/target/llvm/codegen_llvm.cc
Outdated
? static_cast<llvm::Type*>(builder_->getInt32Ty()) | ||
: llvm::VectorType::get(builder_->getInt32Ty(), from.lanes()); | ||
auto v = builder_->CreateZExt(value, extended_type); | ||
v = builder_->CreateShl(v, 16); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Potential endianness problem here?
include/tvm/runtime/c_runtime_api.h
Outdated
@@ -114,6 +114,7 @@ typedef enum { | |||
kTVMNNVMLast = 20U, | |||
// The following section of code is used for non-reserved types. | |||
kTVMExtReserveEnd = 64U, | |||
kTVMBFloat = 65U, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We do not want BFloat to be passed as PackedFunc argument, most packedfunc argument should always be passed as double
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I suppose TVM should support kernel generation, e.g. generating a fused "conv+bn+relu", rather than generating end-to-end model, which is the usual case. In this case, we might select some intermediate layers of the model and let TVM generate the selected layers. The layers may require bf16 as the dtype, as they are in the middle of the model.
What I want to say is that we sometimes need bf16 as the input dtype. In our usecase in Intel, we need to generate a bf16 kernel (e.g. conv+bn+relu).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Such dtype is covered by allocating a DLTensor with type_code equals kBFloat, and does not need patch to the code here(needed for parameter argument passing PackedFunc).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The particular code is used when we directly pass a constant into PackedFunc, e.g. f(1.0, some_float_value)
. in these cases double can be used.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we remove this type from TVM runtime, we cannot pass a bf16 array to TVM via Python and users can only pass bf16 buffers via C runtime (or in some awkward way to construct a bf16 DLTensor via Python). Currently, with kTVMBFloat defined, we can:
A = te.placeholder((32, ), dtype='bfloat16')
B = te.placeholder((32, ), dtype='bfloat16')
d = te.compute((32, ), lambda x: A[x] + B[x])
sch = te.create_schedule(d.op)
module = tvm.build(sch, [A, B, d])
npa = np.random.rand(32).astype('float32')
npb = np.random.rand(32).astype('float32')
a_ = np_float2tvm_bf16(npa)
b_ = np_float2tvm_bf16(npb)
c_ = tvm.nd.empty((32,), 'bfloat16')
module(a_, b_, c_)
Which is useful for testing and prototyping.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think you will kTVMBFloat
to support this feature. The DataType::kDLBFloat flag in the runtime::DataType should be sufficient for NDArray contents(because the runtime::DataType's type code in the NDArray contents diverges from the TVM type code above the OpaqueHandle).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok I understand. will change that
include/tvm/runtime/data_type.h
Outdated
@@ -81,6 +82,10 @@ class DataType { | |||
bool is_float() const { return code() == DataType::kFloat; } | |||
/*! \return whether type is a float16 type. */ | |||
bool is_float16() const { return is_float() && bits() == 16; } | |||
/*! \return whether type is a bfloat type. */ | |||
bool is_bfloat() const { return code() == DataType::kBFloat; } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
given that only bfloat16 is defined, is_bf16 is a good enough function
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok, changed
include/tvm/runtime/data_type.h
Outdated
@@ -297,6 +302,8 @@ inline const char* TypeCode2Str(int type_code) { | |||
return "Object"; | |||
case kTVMObjectRValueRefArg: | |||
return "ObjectRValueRefArg"; | |||
case kTVMBFloat: | |||
return "bf"; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
bfloat
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok, changed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok, changed
src/target/llvm/codegen_llvm.cc
Outdated
// cast operatpr | ||
llvm::Value* CodeGenLLVM::CreateCast(DataType from, DataType to, llvm::Value* value) { | ||
llvm::Type* target = DTypeToLLVMType(to); | ||
if (value->getType() == target) return value; | ||
if (to.is_handle()) { | ||
return builder_->CreateBitCast(value, target); | ||
} else if (to.is_float() && from.is_bfloat()) { | ||
CHECK_EQ(from.bits(), 16); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If LLVM does not support bfloat, then perhaps we should do the legalization as a TIR=>TIR pass as opposed to do it in LLVM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We are actually doing TIR=>TIR legalization pass in TVM. See src/tir/transforms/bf16_legalize.cc
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Then we should directly change the type to be i16 during legalization and remove special handling code for bfloat16
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Then we cannot tell whether it is a float32 => i16 or float32 => bfloat16 casting
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There're 2 kinds of legalization:
- TIR->TIR. TIR has full ability to describe any bfloat16 operation after this PR. This legalization is introduced just because of hardware limitation that current hardware only provide few bfloat16 operations. One day when hardware has full instructions support with bfloat16, ideally this legalization can be skipped. So this legalization is a target dependent pass.
- TIR->LLVM IR. I guess this is the legalization that @tqchen mentions. Because LLVM IR doesn't natively support bfloat16 , i16 will be used to replace bfloat16. In this PR, I guess this is done within codegen_llvm, not by a particular pass.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Then we should legalize the cast as well in the TIR to introduce the actual impl of the cast funtions in TIR, please also refer to https://tvm.apache.org/2020/05/20/bring-your-own-datatypes for releated implemenetation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just 2 small questions.
Did you mean totally eliminating bf16 dtype in legalization pass? This will bring much more complexity in the BF16Legalize pass, because we need to check every TIR node to replace bf16 with int16. In contrast, current impl only changes computation TIR nodes. And in the codegen, the bf16 generation is quite simple, just adding another ‘else if’ in casting node and tvm dtype to llvm type converter
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
and I think the way processing “custom data type” that you mentioned does not fit this pr well. Actually I have already notice this feature before I wrote this bf16 feature. But it needs function calls to do lowering, which is not friendly to the codegen backend to do auto vectorization and so on. Of course you can say we can implement this cast function as an intrinsic. Yes, but more complexity is brought.
I think letting bf16 dtype live until codegen is a good idea, it makes legalization, impl of casting easier
Given that this is a new feature that will affect quite some people, please open a new RFC thread in the discuss forum to describe the motivation and the high level design. Thank you! |
@tqchen Thanks for the clarification. I have changed kTVMNullPtr back to 4. |
@vinx13 @ZihengJiang @liangfu it would be great if you cam take another look. Thanks @Menooker for keep improving the PR |
src/tir/transforms/bf16_legalize.cc
Outdated
|
||
// implementation from | ||
// https://github.com/pytorch/pytorch/blob/master/c10/util/BFloat16.h | ||
inline uint16_t round_to_nearest_even(float src) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Google C Style, we cannot directly copy code from another codebase into the mainline, we would need to either put it in 3rdparty, or implement it independently
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
changed
@junrushao1994 can you also take a quick look at this PR. thank you! |
def orig1(a,b): | ||
return lambda i: a[i]+b[i]+a[99-i]+b[99-i] | ||
def after1(a,b): | ||
return lambda i: to16(to32(a[i])+to32(b[i])+to32(a[99-i])+to32(b[99-i])) | ||
def orig2(a,b): | ||
return lambda i: a[i]*b[i]+a[99-i]*b[99-i]+a[i] | ||
def after2(a,b): | ||
return lambda i: to16(to32(a[i])*to32(b[i])+to32(a[99-i])*to32(b[99-i])+to32(a[i])) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i am not so sure why the coding style here can pass pylint...Mind sending a simple fix?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry for that. Now I have formatted and manually run pylint on this test file. BTW, the test python files are never checked in TVM CI's pylint :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ooops I see. That makes sense then :-)
@vinx13 @ZihengJiang @liangfu it would be great if you cam take another look and https://tvm.apache.org/docs/contribute/code_review.html#approve-and-request-changes-explicitly |
@@ -72,6 +73,9 @@ class DataType { | |||
data_.code = static_cast<uint8_t>(code); | |||
data_.bits = static_cast<uint8_t>(bits); | |||
data_.lanes = static_cast<uint16_t>(lanes); | |||
if (code == kBFloat) { | |||
CHECK_EQ(bits, 16); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is understandable that right now we only support bf16, but my concern is that "should we put the check here"?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I understand your concern. Any suggestions for the location where we put this check? Thanks.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@tqchen This is just a nitpick. What do you think?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let us leave it as it is for now, we can come back to it later
include/tvm/runtime/data_type.h
Outdated
@@ -372,7 +372,7 @@ inline DLDataType String2DLDataType(std::string s) { | |||
t.lanes = 1; | |||
return t; | |||
} else if (s.substr(0, 6) == "bfloat") { | |||
t.code = kTVMBFloat; | |||
t.code = kDLBfloat; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree with tq
python/tvm/_ffi/_cython/base.pxi
Outdated
@@ -27,7 +27,7 @@ cdef enum TVMTypeCode: | |||
kUInt = 1 | |||
kFloat = 2 | |||
kTVMOpaqueHandle = 3 | |||
kTVMNullptr = 4 | |||
kBFloat = 4 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
shall we remove this?
python/tvm/_ffi/runtime_ctypes.py
Outdated
@@ -96,6 +98,9 @@ def __init__(self, type_str): | |||
self.type_code = DataTypeCode.HANDLE | |||
bits = 64 | |||
head = "" | |||
elif head.startswith("bfloat"): | |||
self.type_code = 4 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not sure if it is good to hard code here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not sure if it is good to hard code here
Change to DataTypeCode. TVM refactors a lot (which is good). And when this PR was raised, all the type code here used hard codes.
The other two issues you raised were also changed as required.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM :-)
@@ -72,6 +73,9 @@ class DataType { | |||
data_.code = static_cast<uint8_t>(code); | |||
data_.bits = static_cast<uint8_t>(bits); | |||
data_.lanes = static_cast<uint16_t>(lanes); | |||
if (code == kBFloat) { | |||
CHECK_EQ(bits, 16); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let us leave it as it is for now, we can come back to it later
Thanks @Menooker for being patient and keep improving the PR to maintain a high quality standard! Thanks @ZhennanQin @junrushao1994 @liangfu for helpful reviews! |
We add bfloat16 as a new type named "bf16" in the frontend. Completed LLVM backend for generating bf16.
Details on legalization
Since most of the HW has no native support for computation on bf16, we added a pass
BF16Legalization
to use fp32 computing bf16 data. It addscast_to_fp32()
before each Op involing bf16 operands, and use Ops of fp32 to compute. Finally, it adds a 'cast_to_bf16()' after each Op that is altered. e.g.add(a,b)
=>cast16(add(cast32(a), cast32(b)))
We call this phase as "BF16Promotion". It is a sub-pass of
BF16Legalization
pass.We note that this will add redundant casting. e.g.
add(a, neg(b))
=>cast16(add(cast32(a), cast32(cast16(neg(cast32(b)))))
The pattern
cast32(cast16(some_fp32_value))
can be simplified tosome_fp32_value
.Thus, we add an optimization pass after "BF16Promotion" in
BF16Legalization
pass, which eliminates redundant casts.After
BF16Legalization
pass, there will be no bf16 related computation in the AST, except casting between fp32 and bf16, bf16 value comparasion and assignment.Casting between fp32 and bf16
We follow PyTorch's bf16 casting implementation.