Skip to content

Commit

Permalink
[LLVM] Avoid warnings when compiling getNumElements with LLVM12+ (#6738)
Browse files Browse the repository at this point in the history
* [LLVM] Avoid warnings when compiling getNumElements with LLVM12+

Extract the element-count code into GetVectorNumElements and make it
compile cleanly with all LLVM versions.

* Trigger another build
  • Loading branch information
Krzysztof Parzyszek committed Oct 24, 2020
1 parent fc69f68 commit c4e26b6
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 10 deletions.
18 changes: 9 additions & 9 deletions src/target/llvm/codegen_llvm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -487,7 +487,7 @@ llvm::Value* CodeGenLLVM::CreateBroadcast(llvm::Value* value, int lanes) {
}

llvm::Value* CodeGenLLVM::CreateVecSlice(llvm::Value* vec, int begin, int extent) {
int num_elems = llvm::cast<llvm::VectorType>(vec->getType())->getNumElements();
int num_elems = GetVectorNumElements(vec);
if (extent == num_elems && begin == 0) return vec;
CHECK(begin >= 0 && extent <= num_elems) << "Slicing out of bound!\n";
std::vector<llvm::Constant*> indices;
Expand All @@ -503,7 +503,7 @@ llvm::Value* CodeGenLLVM::CreateVecSlice(llvm::Value* vec, int begin, int extent
}

llvm::Value* CodeGenLLVM::CreateVecFlip(llvm::Value* vec) {
int num_elems = llvm::cast<llvm::VectorType>(vec->getType())->getNumElements();
int num_elems = GetVectorNumElements(vec);
#if TVM_LLVM_VERSION >= 110
std::vector<int> indices;
#else
Expand All @@ -517,7 +517,7 @@ llvm::Value* CodeGenLLVM::CreateVecFlip(llvm::Value* vec) {

llvm::Value* CodeGenLLVM::CreateVecPad(llvm::Value* vec, int target_lanes) {
llvm::Value* mask = llvm::UndefValue::get(DTypeToLLVMType(DataType::Int(32, target_lanes)));
int num_elems = llvm::cast<llvm::VectorType>(vec->getType())->getNumElements();
int num_elems = GetVectorNumElements(vec);
if (num_elems == target_lanes) return vec;
CHECK_LT(num_elems, target_lanes);
for (int i = 0; i < num_elems; ++i) {
Expand All @@ -531,15 +531,15 @@ llvm::Value* CodeGenLLVM::CreateVecConcat(std::vector<llvm::Value*> vecs) {
int total_lanes = 0;

for (llvm::Value* v : vecs) {
total_lanes += llvm::cast<llvm::VectorType>(v->getType())->getNumElements();
total_lanes += GetVectorNumElements(v);
}
while (vecs.size() > 1) {
std::vector<llvm::Value*> new_vecs;
for (size_t i = 0; i < vecs.size() - 1; i += 2) {
llvm::Value* lhs = vecs[i];
llvm::Value* rhs = vecs[i + 1];
const size_t lhs_lanes = llvm::cast<llvm::VectorType>(lhs->getType())->getNumElements();
const size_t rhs_lanes = llvm::cast<llvm::VectorType>(rhs->getType())->getNumElements();
const size_t lhs_lanes = GetVectorNumElements(lhs);
const size_t rhs_lanes = GetVectorNumElements(rhs);
if (lhs_lanes < rhs_lanes) {
lhs = CreateVecPad(lhs, rhs_lanes);
} else if (rhs_lanes < lhs_lanes) {
Expand Down Expand Up @@ -843,16 +843,16 @@ llvm::Value* CodeGenLLVM::CreateIntrinsic(const CallNode* op) {
return builder_->CreateFCmpUNO(a, a);
} else if (op->op.same_as(builtin::vectorlow())) {
llvm::Value* v = MakeValue(op->args[0]);
int l = llvm::cast<llvm::VectorType>(v->getType())->getNumElements();
int l = GetVectorNumElements(v);
return CreateVecSlice(v, 0, l / 2);
} else if (op->op.same_as(builtin::vectorhigh())) {
llvm::Value* v = MakeValue(op->args[0]);
int l = llvm::cast<llvm::VectorType>(v->getType())->getNumElements();
int l = GetVectorNumElements(v);
return CreateVecSlice(v, l / 2, l / 2);
} else if (op->op.same_as(builtin::vectorcombine())) {
llvm::Value* v0 = MakeValue(op->args[0]);
llvm::Value* v1 = MakeValue(op->args[1]);
int num_elems = llvm::cast<llvm::VectorType>(v0->getType())->getNumElements() * 2;
int num_elems = GetVectorNumElements(v0) * 2;
#if TVM_LLVM_VERSION >= 110
std::vector<int> indices;
#else
Expand Down
14 changes: 14 additions & 0 deletions src/target/llvm/codegen_llvm.h
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,11 @@ class CodeGenLLVM : public ExprFunctor<llvm::Value*(const PrimExpr&)>,
*/
llvm::Function* GetIntrinsicDecl(llvm::Intrinsic::ID id, llvm::Type* ret_type,
llvm::ArrayRef<llvm::Type*> arg_types);
/*!
* \brief Get the number of elements in the given vector value.
* \param vec The value, must be of a vector type.
*/
inline int GetVectorNumElements(llvm::Value* vec);
// initialize the function state.
void InitFuncState();
// Get alignment given index.
Expand Down Expand Up @@ -348,6 +353,15 @@ class CodeGenLLVM : public ExprFunctor<llvm::Value*(const PrimExpr&)>,
*/
static std::unique_ptr<DebugInfo> CreateDebugInfo(llvm::Module* module);
};

inline int CodeGenLLVM::GetVectorNumElements(llvm::Value* vec) {
#if TVM_LLVM_VERSION >= 120
return llvm::cast<llvm::FixedVectorType>(vec->getType())->getNumElements();
#else
return llvm::cast<llvm::VectorType>(vec->getType())->getNumElements();
#endif
}

} // namespace codegen
} // namespace tvm
#endif // LLVM_VERSION
Expand Down
6 changes: 5 additions & 1 deletion src/target/llvm/codegen_x86_64.cc
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,11 @@ llvm::Value* CodeGenX86_64::CallVectorIntrin(llvm::Intrinsic::ID id, size_t intr
llvm::Type* result_ty,
const std::vector<llvm::Value*>& args) {
llvm::Function* f = llvm::Intrinsic::getDeclaration(module_.get(), id, {});
#if TVM_LLVM_VERSION >= 120
size_t num_elems = llvm::cast<llvm::FixedVectorType>(result_ty)->getNumElements();
#else
size_t num_elems = llvm::cast<llvm::VectorType>(result_ty)->getNumElements();
#endif
if (intrin_lanes == num_elems) {
return builder_->CreateCall(f, args);
}
Expand All @@ -130,7 +134,7 @@ llvm::Value* CodeGenX86_64::CallVectorIntrin(llvm::Intrinsic::ID id, size_t intr
std::vector<llvm::Value*> split_args;
for (const auto& v : args) {
if (v->getType()->isVectorTy()) {
CHECK_EQ(llvm::cast<llvm::VectorType>(v->getType())->getNumElements(), num_elems);
CHECK_EQ(GetVectorNumElements(v), num_elems);
split_args.push_back(CreateVecSlice(v, i, intrin_lanes));
} else {
split_args.push_back(v);
Expand Down

0 comments on commit c4e26b6

Please sign in to comment.