Skip to content

Commit

Permalink
SPV_KHR_untyped_pointers - implement OpUntypedVariableKHR (#2709)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmaksimo authored Sep 16, 2024
1 parent dfeb22b commit 484e407
Show file tree
Hide file tree
Showing 15 changed files with 336 additions and 81 deletions.
24 changes: 16 additions & 8 deletions lib/SPIRV/SPIRVReader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1527,9 +1527,15 @@ Value *SPIRVToLLVM::transValueWithoutDecoration(SPIRVValue *BV, Function *F,
case OpUndef:
return mapValue(BV, UndefValue::get(transType(BV->getType())));

case OpVariable: {
auto *BVar = static_cast<SPIRVVariable *>(BV);
auto *PreTransTy = BVar->getType()->getPointerElementType();
case OpVariable:
case OpUntypedVariableKHR: {
auto *BVar = static_cast<SPIRVVariableBase *>(BV);
SPIRVType *PreTransTy = BVar->getType()->getPointerElementType();
if (BVar->getType()->isTypeUntypedPointerKHR()) {
auto *UntypedVar = static_cast<SPIRVUntypedVariableKHR *>(BVar);
if (SPIRVType *DT = UntypedVar->getDataType())
PreTransTy = DT;
}
auto *Ty = transType(PreTransTy);
bool IsConst = BVar->isConstant();
llvm::GlobalValue::LinkageTypes LinkageTy = transLinkageType(BVar);
Expand Down Expand Up @@ -4082,7 +4088,7 @@ bool SPIRVToLLVM::transDecoration(SPIRVValue *BV, Value *V) {
return true;
}

void SPIRVToLLVM::transGlobalCtorDtors(SPIRVVariable *BV) {
void SPIRVToLLVM::transGlobalCtorDtors(SPIRVVariableBase *BV) {
if (BV->getName() != "llvm.global_ctors" &&
BV->getName() != "llvm.global_dtors")
return;
Expand Down Expand Up @@ -4914,15 +4920,17 @@ SPIRVToLLVM::transLinkageType(const SPIRVValue *V) {
return GlobalValue::ExternalLinkage;
}
// Variable declaration
if (V->getOpCode() == OpVariable) {
if (static_cast<const SPIRVVariable *>(V)->getInitializer() == 0)
if (V->getOpCode() == OpVariable ||
V->getOpCode() == OpUntypedVariableKHR) {
if (static_cast<const SPIRVVariableBase *>(V)->getInitializer() == 0)
return GlobalValue::ExternalLinkage;
}
// Definition
return GlobalValue::AvailableExternallyLinkage;
case LinkageTypeExport:
if (V->getOpCode() == OpVariable) {
if (static_cast<const SPIRVVariable *>(V)->getInitializer() == 0)
if (V->getOpCode() == OpVariable ||
V->getOpCode() == OpUntypedVariableKHR) {
if (static_cast<const SPIRVVariableBase *>(V)->getInitializer() == 0)
// Tentative definition
return GlobalValue::CommonLinkage;
}
Expand Down
2 changes: 1 addition & 1 deletion lib/SPIRV/SPIRVReader.h
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ class SPIRVToLLVM : private BuiltinCallHelper {

void transUserSemantic(SPIRV::SPIRVFunction *Fun);
void transGlobalAnnotations();
void transGlobalCtorDtors(SPIRVVariable *BV);
void transGlobalCtorDtors(SPIRVVariableBase *BV);
void createCXXStructor(const char *ListName,
SmallVectorImpl<Function *> &Funcs);
void transIntelFPGADecorations(SPIRVValue *BV, Value *V);
Expand Down
65 changes: 41 additions & 24 deletions lib/SPIRV/SPIRVWriter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -751,8 +751,8 @@ SPIRVType *LLVMToSPIRVBase::transPointerType(SPIRVType *ET, unsigned AddrSpc) {

SPIRVType *TranslatedTy = nullptr;
if (BM->isAllowedToUseExtension(ExtensionID::SPV_KHR_untyped_pointers) &&
!(ET->isTypeArray() || ET->isTypeVector() || ET->isTypeImage() ||
ET->isTypeSampler() || ET->isTypePipe())) {
!(ET->isTypeArray() || ET->isTypeVector() || ET->isTypeStruct() ||
ET->isTypeImage() || ET->isTypeSampler() || ET->isTypePipe())) {
TranslatedTy = BM->addUntypedPointerKHRType(
SPIRSPIRVAddrSpaceMap::map(static_cast<SPIRAddressSpace>(AddrSpc)));
} else {
Expand Down Expand Up @@ -1304,7 +1304,8 @@ SPIRVValue *LLVMToSPIRVBase::transConstantUse(Constant *C,
if (Trans->getType() == ExpectedType || Trans->getType()->isTypePipeStorage())
return Trans;

assert(C->getType()->isPointerTy() &&
assert((C->getType()->isPointerTy() ||
ExpectedType->isTypeUntypedPointerKHR()) &&
"Only pointer type mismatches should be possible");
// In the common case of strings ([N x i8] GVs), see if we can emit a GEP
// instruction.
Expand Down Expand Up @@ -2047,8 +2048,12 @@ LLVMToSPIRVBase::transValueWithoutDecoration(Value *V, SPIRVBasicBlock *BB,
}
}
}
SPIRVType *TransTy = transType(Ty);
BVarInit = transConstantUse(Init, TransTy->getPointerElementType());
if (BM->isAllowedToUseExtension(ExtensionID::SPV_KHR_untyped_pointers)) {
BVarInit = transConstantUse(Init, transType(Init->getType()));
} else {
SPIRVType *TransTy = transType(Ty);
BVarInit = transConstantUse(Init, TransTy->getPointerElementType());
}
}

SPIRVStorageClassKind StorageClass;
Expand Down Expand Up @@ -2081,9 +2086,12 @@ LLVMToSPIRVBase::transValueWithoutDecoration(Value *V, SPIRVBasicBlock *BB,
}

SPIRVType *TranslatedTy = transType(Ty);
auto *BVar = static_cast<SPIRVVariable *>(
BM->addVariable(TranslatedTy, GV->isConstant(), transLinkageType(GV),
BVarInit, GV->getName().str(), StorageClass, nullptr));
auto *BVar = static_cast<SPIRVVariableBase *>(BM->addVariable(
TranslatedTy,
TranslatedTy->isTypeUntypedPointerKHR() ? transType(GV->getValueType())
: nullptr,
GV->isConstant(), transLinkageType(GV), BVarInit, GV->getName().str(),
StorageClass, nullptr));

if (IsVectorCompute) {
BVar->addDecorate(DecorationVectorComputeVariableINTEL);
Expand Down Expand Up @@ -2272,8 +2280,9 @@ LLVMToSPIRVBase::transValueWithoutDecoration(Value *V, SPIRVBasicBlock *BB,
StorageClassFunction,
BM->addArrayType(transType(Alc->getAllocatedType()), Length));
SPIRVValue *Arr = BM->addVariable(
AllocationType, false, spv::internal::LinkageTypeInternal, nullptr,
Alc->getName().str() + "_alloca", StorageClassFunction, BB);
AllocationType, nullptr, false, spv::internal::LinkageTypeInternal,
nullptr, Alc->getName().str() + "_alloca", StorageClassFunction,
BB);
// Manually set alignment. OpBitcast created below will be decorated as
// that's the SPIR-V value mapped to the original LLVM one.
transAlign(Alc, Arr);
Expand All @@ -2297,7 +2306,10 @@ LLVMToSPIRVBase::transValueWithoutDecoration(Value *V, SPIRVBasicBlock *BB,
TranslatedTy->getPointerElementType())
: TranslatedTy;
SPIRVValue *Var = BM->addVariable(
VarTy, false, spv::internal::LinkageTypeInternal, nullptr,
VarTy,
VarTy->isTypeUntypedPointerKHR() ? transType(Alc->getAllocatedType())
: nullptr,
false, spv::internal::LinkageTypeInternal, nullptr,
Alc->getName().str(), StorageClassFunction, BB);
if (V->getType()->getPointerAddressSpace() == SPIRAS_Generic) {
SPIRVValue *Cast =
Expand Down Expand Up @@ -2742,7 +2754,7 @@ void checkIsGlobalVar(SPIRVEntry *E, Decoration Dec) {
E->getErrorLog().checkError(E->isVariable(), SPIRVEC_InvalidModule, ErrStr);

auto AddrSpace = SPIRSPIRVAddrSpaceMap::rmap(
static_cast<SPIRVVariable *>(E)->getStorageClass());
static_cast<SPIRVVariableBase *>(E)->getStorageClass());
ErrStr += " in a global (module) scope";
E->getErrorLog().checkError(AddrSpace == SPIRAS_Global, SPIRVEC_InvalidModule,
ErrStr);
Expand Down Expand Up @@ -2890,10 +2902,11 @@ static void transMetadataDecorations(Metadata *MD, SPIRVValue *Target) {
case spv::internal::DecorationInitModeINTEL:
case DecorationInitModeINTEL: {
checkIsGlobalVar(Target, DecoKind);
ErrLog.checkError(static_cast<SPIRVVariable *>(Target)->getInitializer(),
SPIRVEC_InvalidLlvmModule,
"InitModeINTEL only be applied to a global (module "
"scope) variable which has an Initializer operand");
ErrLog.checkError(
static_cast<SPIRVVariableBase *>(Target)->getInitializer(),
SPIRVEC_InvalidLlvmModule,
"InitModeINTEL only be applied to a global (module "
"scope) variable which has an Initializer operand");

ErrLog.checkError(NumOperands == 2, SPIRVEC_InvalidLlvmModule,
"InitModeINTEL requires exactly 1 extra operand");
Expand Down Expand Up @@ -4135,14 +4148,18 @@ SPIRVValue *LLVMToSPIRVBase::transIntrinsicInst(IntrinsicInst *II,
SPIRVType *FTy = transType(II->getType()->getStructElementType(0));
SPIRVTypePointer *ITy = static_cast<SPIRVTypePointer *>(transPointerType(
II->getType()->getStructElementType(1), SPIRAS_Private));

unsigned BitWidth = ITy->getElementType()->getBitWidth();
BM->getErrorLog().checkError(BitWidth == 32, SPIRVEC_InvalidBitWidth,
std::to_string(BitWidth));

if (!ITy->isTypeUntypedPointerKHR()) {
unsigned BitWidth = ITy->getElementType()->getBitWidth();
BM->getErrorLog().checkError(BitWidth == 32, SPIRVEC_InvalidBitWidth,
std::to_string(BitWidth));
}
SPIRVValue *IntVal =
BM->addVariable(ITy, false, spv::internal::LinkageTypeInternal, nullptr,
"", ITy->getStorageClass(), BB);
BM->addVariable(ITy,
ITy->isTypeUntypedPointerKHR()
? transType(II->getType()->getStructElementType(1))
: nullptr,
false, spv::internal::LinkageTypeInternal, nullptr, "",
ITy->getStorageClass(), BB);

std::vector<SPIRVValue *> Ops{transValue(II->getArgOperand(0), BB), IntVal};

Expand Down Expand Up @@ -4564,7 +4581,7 @@ SPIRVValue *LLVMToSPIRVBase::transIntrinsicInst(IntrinsicInst *II,
Init = BM->addCompositeConstant(CompositeTy, Elts);
}
SPIRVType *VarTy = transPointerType(AT, SPIRV::SPIRAS_Constant);
SPIRVValue *Var = BM->addVariable(VarTy, /*isConstant*/ true,
SPIRVValue *Var = BM->addVariable(VarTy, nullptr, /*isConstant*/ true,
spv::internal::LinkageTypeInternal, Init,
"", StorageClassUniformConstant, nullptr);
SPIRVType *SourceTy =
Expand Down
2 changes: 1 addition & 1 deletion lib/SPIRV/libSPIRV/SPIRVBasicBlock.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ SPIRVInstruction *SPIRVBasicBlock::getVariableInsertionPoint() const {
isa<OpNoLine>(Inst) ||
// Note: OpVariable and OpPhi instructions do not belong to the
// same block in a valid SPIR-V module.
isa<OpPhi>(Inst));
isa<OpPhi>(Inst) || isa<OpUntypedVariableKHR>(Inst));
});
if (IP == InstVec.end())
return nullptr;
Expand Down
4 changes: 2 additions & 2 deletions lib/SPIRV/libSPIRV/SPIRVBasicBlock.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,9 +77,9 @@ class SPIRVBasicBlock : public SPIRVValue {
const SPIRVInstruction *getTerminateInstr() const {
return InstVec.empty() ? nullptr : InstVec.back();
}
// OpVariable instructions must be the first instructions in the block,
// Variables must be the first instructions in the block,
// intermixed with OpLine and OpNoLine instructions. Return first instruction
// not being an OpVariable, OpLine or OpNoLine.
// not being an OpVariable, OpUntypedVariableKHR, OpLine or OpNoLine.
SPIRVInstruction *getVariableInsertionPoint() const;

void setScope(SPIRVEntry *Scope) override;
Expand Down
3 changes: 2 additions & 1 deletion lib/SPIRV/libSPIRV/SPIRVEntry.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -559,7 +559,8 @@ SPIRVEntry::getDecorationIds(Decoration Kind) const {
}

bool SPIRVEntry::hasLinkageType() const {
return OpCode == OpFunction || OpCode == OpVariable;
return OpCode == OpFunction || OpCode == OpVariable ||
OpCode == OpUntypedVariableKHR;
}

bool SPIRVEntry::isExtInst(const SPIRVExtInstSetKind InstSet) const {
Expand Down
4 changes: 3 additions & 1 deletion lib/SPIRV/libSPIRV/SPIRVEntry.h
Original file line number Diff line number Diff line change
Expand Up @@ -339,7 +339,9 @@ class SPIRVEntry {
bool isUndef() const { return OpCode == OpUndef; }
bool isControlBarrier() const { return OpCode == OpControlBarrier; }
bool isMemoryBarrier() const { return OpCode == OpMemoryBarrier; }
bool isVariable() const { return OpCode == OpVariable; }
bool isVariable() const {
return OpCode == OpVariable || OpCode == OpUntypedVariableKHR;
}
bool isEndOfBlock() const;
virtual bool isInst() const { return false; }
virtual bool isOperandLiteral(unsigned Index) const {
Expand Down
108 changes: 90 additions & 18 deletions lib/SPIRV/libSPIRV/SPIRVInstruction.h
Original file line number Diff line number Diff line change
Expand Up @@ -461,24 +461,25 @@ class SPIRVMemoryAccess {
SPIRVId NoAliasInstID;
};

class SPIRVVariable : public SPIRVInstruction {
class SPIRVVariableBase : public SPIRVInstruction {
public:
// Complete constructor for integer constant
SPIRVVariable(SPIRVType *TheType, SPIRVId TheId, SPIRVValue *TheInitializer,
const std::string &TheName,
SPIRVStorageClassKind TheStorageClass, SPIRVBasicBlock *TheBB,
SPIRVModule *TheM)
: SPIRVInstruction(TheInitializer && !TheInitializer->isUndef() ? 5 : 4,
OpVariable, TheType, TheId, TheBB, TheM),
SPIRVVariableBase(Op OC, SPIRVType *TheType, SPIRVId TheId,
SPIRVValue *TheInitializer, const std::string &TheName,
SPIRVStorageClassKind TheStorageClass,
SPIRVBasicBlock *TheBB, SPIRVModule *TheM, SPIRVWord WC)
: SPIRVInstruction(WC, OC, TheType, TheId, TheBB, TheM),
StorageClass(TheStorageClass) {
if (TheInitializer && !TheInitializer->isUndef())
Initializer.push_back(TheInitializer->getId());
Name = TheName;
validate();
}
// Incomplete constructor
SPIRVVariable()
: SPIRVInstruction(OpVariable), StorageClass(StorageClassFunction) {}
// Incomplete constructors
SPIRVVariableBase(Op OC)
: SPIRVInstruction(OC), StorageClass(StorageClassFunction) {}
SPIRVVariableBase()
: SPIRVInstruction(OpNop), StorageClass(StorageClassFunction) {}

SPIRVStorageClassKind getStorageClass() const { return StorageClass; }
SPIRVValue *getInitializer() const {
Expand Down Expand Up @@ -530,6 +531,77 @@ class SPIRVVariable : public SPIRVInstruction {
std::vector<SPIRVId> Initializer;
};

class SPIRVVariable : public SPIRVVariableBase {
public:
// Complete constructor for integer constant
SPIRVVariable(SPIRVType *TheType, SPIRVId TheId, SPIRVValue *TheInitializer,
const std::string &TheName,
SPIRVStorageClassKind TheStorageClass, SPIRVBasicBlock *TheBB,
SPIRVModule *TheM)
: SPIRVVariableBase(OpVariable, TheType, TheId, TheInitializer, TheName,
TheStorageClass, TheBB, TheM,
TheInitializer && !TheInitializer->isUndef() ? 5
: 4) {}
// Incomplete constructor
SPIRVVariable() : SPIRVVariableBase(OpVariable) {}
};

class SPIRVUntypedVariableKHR : public SPIRVVariableBase {
public:
SPIRVUntypedVariableKHR(SPIRVType *TheType, SPIRVId TheId,
SPIRVType *TheDataType, SPIRVValue *TheInitializer,
const std::string &TheName,
SPIRVStorageClassKind TheStorageClass,
SPIRVBasicBlock *TheBB, SPIRVModule *TheM)
: SPIRVVariableBase(
OpUntypedVariableKHR, TheType, TheId, TheInitializer, TheName,
TheStorageClass, TheBB, TheM,
TheDataType && !TheDataType->isUndef()
? (TheInitializer && !TheInitializer->isUndef() ? 6 : 5)
: 4) {
if (TheDataType && !TheDataType->isUndef())
DataType.push_back(TheDataType->getId());
validate();
}
SPIRVUntypedVariableKHR() : SPIRVVariableBase(OpUntypedVariableKHR) {}
SPIRVType *getDataType() const {
if (DataType.empty())
return nullptr;
assert(DataType.size() == 1);
return get<SPIRVType>(DataType[0]);
}
std::vector<SPIRVEntry *> getNonLiteralOperands() const override {
std::vector<SPIRVEntry *> Vec;
if (SPIRVType *T = getDataType())
Vec.push_back(T);
if (SPIRVValue *V = getInitializer())
Vec.push_back(V);
return Vec;
}
std::optional<ExtensionID> getRequiredExtension() const override {
return ExtensionID::SPV_KHR_untyped_pointers;
}
SPIRVCapVec getRequiredCapability() const override {
return getVec(CapabilityUntypedPointersKHR);
}

protected:
void validate() const override {
SPIRVVariableBase::validate();
assert(DataType.size() == 1 || DataType.empty());
}
void setWordCount(SPIRVWord TheWordCount) override {
SPIRVEntry::setWordCount(TheWordCount);
if (TheWordCount > 4)
DataType.resize(1);
if (TheWordCount > 5)
Initializer.resize(1);
}
_SPIRV_DEF_ENCDEC5(Type, Id, StorageClass, DataType, Initializer)

std::vector<SPIRVId> DataType;
};

class SPIRVStore : public SPIRVInstruction, public SPIRVMemoryAccess {
public:
const static SPIRVWord FixedWords = 3;
Expand Down Expand Up @@ -581,9 +653,6 @@ class SPIRVStore : public SPIRVInstruction, public SPIRVMemoryAccess {
(getValueType(PtrId)
->getPointerElementType()
->isTypeUntypedPointerKHR() ||
// TODO: This check should be removed once we support untyped
// variables.
getValueType(ValId)->isTypeUntypedPointerKHR() ||
getValueType(PtrId)->getPointerElementType() == getValueType(ValId)) &&
"Inconsistent operand types");
}
Expand Down Expand Up @@ -638,8 +707,6 @@ class SPIRVLoad : public SPIRVInstruction, public SPIRVMemoryAccess {
getValueType(PtrId)
->getPointerElementType()
->isTypeUntypedPointerKHR() ||
// TODO: This check should be removed once we support untyped
// variables.
Type->isTypeUntypedPointerKHR() ||
Type == getValueType(PtrId)->getPointerElementType()) &&
"Inconsistent types");
Expand Down Expand Up @@ -697,9 +764,14 @@ class SPIRVBinary : public SPIRVInstTemplateBase {
} else if (isBinaryPtrOpCode(OpCode)) {
assert((Op1Ty->isTypePointer() && Op2Ty->isTypePointer()) &&
"Invalid types for PtrEqual, PtrNotEqual, or PtrDiff instruction");
assert(static_cast<SPIRVTypePointer *>(Op1Ty)->getElementType() ==
static_cast<SPIRVTypePointer *>(Op2Ty)->getElementType() &&
"Invalid types for PtrEqual, PtrNotEqual, or PtrDiff instruction");
if (!Op1Ty->isTypeUntypedPointerKHR() ||
!Op2Ty->isTypeUntypedPointerKHR())
assert(
static_cast<SPIRVTypePointer *>(Op1Ty)->getElementType() ==
static_cast<SPIRVTypePointer *>(Op2Ty)->getElementType() &&
"Invalid types for PtrEqual, PtrNotEqual, or PtrDiff instruction");
else if (OpCode == OpPtrDiff)
assert(Op1Ty == Op2Ty && "Invalid types for PtrDiff instruction");
} else {
assert(0 && "Invalid op code!");
}
Expand Down
Loading

0 comments on commit 484e407

Please sign in to comment.