Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[LLVM] Avoid warnings when compiling getNumElements with LLVM12+ #6738

Merged
merged 2 commits into from
Oct 24, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 9 additions & 9 deletions src/target/llvm/codegen_llvm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -487,7 +487,7 @@ llvm::Value* CodeGenLLVM::CreateBroadcast(llvm::Value* value, int lanes) {
}

llvm::Value* CodeGenLLVM::CreateVecSlice(llvm::Value* vec, int begin, int extent) {
int num_elems = llvm::cast<llvm::VectorType>(vec->getType())->getNumElements();
int num_elems = GetVectorNumElements(vec);
if (extent == num_elems && begin == 0) return vec;
CHECK(begin >= 0 && extent <= num_elems) << "Slicing out of bound!\n";
std::vector<llvm::Constant*> indices;
Expand All @@ -503,7 +503,7 @@ llvm::Value* CodeGenLLVM::CreateVecSlice(llvm::Value* vec, int begin, int extent
}

llvm::Value* CodeGenLLVM::CreateVecFlip(llvm::Value* vec) {
int num_elems = llvm::cast<llvm::VectorType>(vec->getType())->getNumElements();
int num_elems = GetVectorNumElements(vec);
#if TVM_LLVM_VERSION >= 110
std::vector<int> indices;
#else
Expand All @@ -517,7 +517,7 @@ llvm::Value* CodeGenLLVM::CreateVecFlip(llvm::Value* vec) {

llvm::Value* CodeGenLLVM::CreateVecPad(llvm::Value* vec, int target_lanes) {
llvm::Value* mask = llvm::UndefValue::get(DTypeToLLVMType(DataType::Int(32, target_lanes)));
int num_elems = llvm::cast<llvm::VectorType>(vec->getType())->getNumElements();
int num_elems = GetVectorNumElements(vec);
if (num_elems == target_lanes) return vec;
CHECK_LT(num_elems, target_lanes);
for (int i = 0; i < num_elems; ++i) {
Expand All @@ -531,15 +531,15 @@ llvm::Value* CodeGenLLVM::CreateVecConcat(std::vector<llvm::Value*> vecs) {
int total_lanes = 0;

for (llvm::Value* v : vecs) {
total_lanes += llvm::cast<llvm::VectorType>(v->getType())->getNumElements();
total_lanes += GetVectorNumElements(v);
}
while (vecs.size() > 1) {
std::vector<llvm::Value*> new_vecs;
for (size_t i = 0; i < vecs.size() - 1; i += 2) {
llvm::Value* lhs = vecs[i];
llvm::Value* rhs = vecs[i + 1];
const size_t lhs_lanes = llvm::cast<llvm::VectorType>(lhs->getType())->getNumElements();
const size_t rhs_lanes = llvm::cast<llvm::VectorType>(rhs->getType())->getNumElements();
const size_t lhs_lanes = GetVectorNumElements(lhs);
const size_t rhs_lanes = GetVectorNumElements(rhs);
if (lhs_lanes < rhs_lanes) {
lhs = CreateVecPad(lhs, rhs_lanes);
} else if (rhs_lanes < lhs_lanes) {
Expand Down Expand Up @@ -843,16 +843,16 @@ llvm::Value* CodeGenLLVM::CreateIntrinsic(const CallNode* op) {
return builder_->CreateFCmpUNO(a, a);
} else if (op->op.same_as(builtin::vectorlow())) {
llvm::Value* v = MakeValue(op->args[0]);
int l = llvm::cast<llvm::VectorType>(v->getType())->getNumElements();
int l = GetVectorNumElements(v);
return CreateVecSlice(v, 0, l / 2);
} else if (op->op.same_as(builtin::vectorhigh())) {
llvm::Value* v = MakeValue(op->args[0]);
int l = llvm::cast<llvm::VectorType>(v->getType())->getNumElements();
int l = GetVectorNumElements(v);
return CreateVecSlice(v, l / 2, l / 2);
} else if (op->op.same_as(builtin::vectorcombine())) {
llvm::Value* v0 = MakeValue(op->args[0]);
llvm::Value* v1 = MakeValue(op->args[1]);
int num_elems = llvm::cast<llvm::VectorType>(v0->getType())->getNumElements() * 2;
int num_elems = GetVectorNumElements(v0) * 2;
#if TVM_LLVM_VERSION >= 110
std::vector<int> indices;
#else
Expand Down
14 changes: 14 additions & 0 deletions src/target/llvm/codegen_llvm.h
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,11 @@ class CodeGenLLVM : public ExprFunctor<llvm::Value*(const PrimExpr&)>,
*/
llvm::Function* GetIntrinsicDecl(llvm::Intrinsic::ID id, llvm::Type* ret_type,
llvm::ArrayRef<llvm::Type*> arg_types);
/*!
* \brief Get the number of elements in the given vector value.
* \param vec The value, must be of a vector type.
*/
inline int GetVectorNumElements(llvm::Value* vec);
// initialize the function state.
void InitFuncState();
// Get alignment given index.
Expand Down Expand Up @@ -348,6 +353,15 @@ class CodeGenLLVM : public ExprFunctor<llvm::Value*(const PrimExpr&)>,
*/
static std::unique_ptr<DebugInfo> CreateDebugInfo(llvm::Module* module);
};

inline int CodeGenLLVM::GetVectorNumElements(llvm::Value* vec) {
#if TVM_LLVM_VERSION >= 120
return llvm::cast<llvm::FixedVectorType>(vec->getType())->getNumElements();
#else
return llvm::cast<llvm::VectorType>(vec->getType())->getNumElements();
#endif
}

} // namespace codegen
} // namespace tvm
#endif // LLVM_VERSION
Expand Down
6 changes: 5 additions & 1 deletion src/target/llvm/codegen_x86_64.cc
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,11 @@ llvm::Value* CodeGenX86_64::CallVectorIntrin(llvm::Intrinsic::ID id, size_t intr
llvm::Type* result_ty,
const std::vector<llvm::Value*>& args) {
llvm::Function* f = llvm::Intrinsic::getDeclaration(module_.get(), id, {});
#if TVM_LLVM_VERSION >= 120
size_t num_elems = llvm::cast<llvm::FixedVectorType>(result_ty)->getNumElements();
#else
size_t num_elems = llvm::cast<llvm::VectorType>(result_ty)->getNumElements();
#endif
if (intrin_lanes == num_elems) {
return builder_->CreateCall(f, args);
}
Expand All @@ -130,7 +134,7 @@ llvm::Value* CodeGenX86_64::CallVectorIntrin(llvm::Intrinsic::ID id, size_t intr
std::vector<llvm::Value*> split_args;
for (const auto& v : args) {
if (v->getType()->isVectorTy()) {
CHECK_EQ(llvm::cast<llvm::VectorType>(v->getType())->getNumElements(), num_elems);
CHECK_EQ(GetVectorNumElements(v), num_elems);
split_args.push_back(CreateVecSlice(v, i, intrin_lanes));
} else {
split_args.push_back(v);
Expand Down