Skip to content

Commit

Permalink
SPV_KHR_untyped_pointers - implement OpTypeUntypedPointerKHR (#2687)
Browse files Browse the repository at this point in the history
This is the first part of the extension implementation.
Introducing untyped pointer type.

Spec: https://htmlpreview.github.io/?https://github.com/KhronosGroup/SPIRV-Registry/blob/main/extensions/KHR/SPV_KHR_untyped_pointers.html
  • Loading branch information
vmaksimo committed Sep 5, 2024
1 parent 2b5f15d commit 7dacb7c
Show file tree
Hide file tree
Showing 19 changed files with 372 additions and 52 deletions.
1 change: 1 addition & 0 deletions include/LLVMSPIRVExtensions.inc
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ EXT(SPV_KHR_subgroup_rotate)
EXT(SPV_KHR_non_semantic_info)
EXT(SPV_KHR_shader_clock)
EXT(SPV_KHR_cooperative_matrix)
EXT(SPV_KHR_untyped_pointers)
EXT(SPV_INTEL_subgroups)
EXT(SPV_INTEL_media_block_io)
EXT(SPV_INTEL_device_side_avc_motion_estimation)
Expand Down
7 changes: 7 additions & 0 deletions lib/SPIRV/SPIRVReader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -355,6 +355,11 @@ Type *SPIRVToLLVM::transType(SPIRVType *T, bool UseTPT) {
return TypedPointerType::get(ElementTy, AS);
return mapType(T, PointerType::get(ElementTy, AS));
}
case OpTypeUntypedPointerKHR: {
const unsigned AS =
SPIRSPIRVAddrSpaceMap::rmap(T->getPointerStorageClass());
return mapType(T, PointerType::get(*Context, AS));
}
case OpTypeVector:
return mapType(T,
FixedVectorType::get(transType(T->getVectorComponentType()),
Expand Down Expand Up @@ -558,6 +563,8 @@ std::string SPIRVToLLVM::transTypeToOCLTypeName(SPIRVType *T, bool IsSigned) {
}
return transTypeToOCLTypeName(ET) + "*";
}
case OpTypeUntypedPointerKHR:
return "int*";
case OpTypeVector:
return transTypeToOCLTypeName(T->getVectorComponentType()) +
T->getVectorComponentCount();
Expand Down
56 changes: 42 additions & 14 deletions lib/SPIRV/SPIRVWriter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -413,8 +413,8 @@ SPIRVType *LLVMToSPIRVBase::transType(Type *T) {
// A pointer to image or pipe type in LLVM is translated to a SPIRV
// (non-pointer) image or pipe type.
if (T->isPointerTy()) {
auto *ET = Type::getInt8Ty(T->getContext());
auto AddrSpc = T->getPointerAddressSpace();
auto *ET = Type::getInt8Ty(T->getContext());
return transPointerType(ET, AddrSpc);
}

Expand Down Expand Up @@ -716,7 +716,6 @@ SPIRVType *LLVMToSPIRVBase::transPointerType(Type *ET, unsigned AddrSpc) {
transType(ET)));
}
} else {
SPIRVType *ElementType = transType(ET);
// ET, as a recursive type, may contain exactly the same pointer T, so it
// may happen that after translation of ET we already have translated T,
// added the translated pointer to the SPIR-V module and mapped T to this
Expand All @@ -725,7 +724,17 @@ SPIRVType *LLVMToSPIRVBase::transPointerType(Type *ET, unsigned AddrSpc) {
if (Loc != PointeeTypeMap.end()) {
return Loc->second;
}
SPIRVType *TranslatedTy = transPointerType(ElementType, AddrSpc);

SPIRVType *ElementType = nullptr;
SPIRVType *TranslatedTy = nullptr;
if (ET->isPointerTy() &&
BM->isAllowedToUseExtension(ExtensionID::SPV_KHR_untyped_pointers)) {
TranslatedTy = BM->addUntypedPointerKHRType(
SPIRSPIRVAddrSpaceMap::map(static_cast<SPIRAddressSpace>(AddrSpc)));
} else {
ElementType = transType(ET);
TranslatedTy = transPointerType(ElementType, AddrSpc);
}
PointeeTypeMap[TypeKey] = TranslatedTy;
return TranslatedTy;
}
Expand All @@ -740,8 +749,16 @@ SPIRVType *LLVMToSPIRVBase::transPointerType(SPIRVType *ET, unsigned AddrSpc) {
if (Loc != PointeeTypeMap.end())
return Loc->second;

SPIRVType *TranslatedTy = BM->addPointerType(
SPIRSPIRVAddrSpaceMap::map(static_cast<SPIRAddressSpace>(AddrSpc)), ET);
SPIRVType *TranslatedTy = nullptr;
if (BM->isAllowedToUseExtension(ExtensionID::SPV_KHR_untyped_pointers) &&
!(ET->isTypeArray() || ET->isTypeVector() || ET->isTypeImage() ||
ET->isTypeSampler() || ET->isTypePipe())) {
TranslatedTy = BM->addUntypedPointerKHRType(
SPIRSPIRVAddrSpaceMap::map(static_cast<SPIRAddressSpace>(AddrSpc)));
} else {
TranslatedTy = BM->addPointerType(
SPIRSPIRVAddrSpaceMap::map(static_cast<SPIRAddressSpace>(AddrSpc)), ET);
}
PointeeTypeMap[TypeKey] = TranslatedTy;
return TranslatedTy;
}
Expand Down Expand Up @@ -2176,8 +2193,13 @@ LLVMToSPIRVBase::transValueWithoutDecoration(Value *V, SPIRVBasicBlock *BB,
MemoryAccessNoAliasINTELMaskMask);
if (MemoryAccess.front() == 0)
MemoryAccess.clear();
return mapValue(V, BM->addLoadInst(transValue(LD->getPointerOperand(), BB),
MemoryAccess, BB));
return mapValue(
V,
BM->addLoadInst(
transValue(LD->getPointerOperand(), BB), MemoryAccess, BB,
BM->isAllowedToUseExtension(ExtensionID::SPV_KHR_untyped_pointers)
? transType(LD->getType())
: nullptr));
}

if (BinaryOperator *B = dyn_cast<BinaryOperator>(V)) {
Expand Down Expand Up @@ -2387,14 +2409,17 @@ LLVMToSPIRVBase::transValueWithoutDecoration(Value *V, SPIRVBasicBlock *BB,

if (auto *Phi = dyn_cast<PHINode>(V)) {
std::vector<SPIRVValue *> IncomingPairs;
SPIRVType *Ty = transScavengedType(Phi);

for (size_t I = 0, E = Phi->getNumIncomingValues(); I != E; ++I) {
IncomingPairs.push_back(transValue(Phi->getIncomingValue(I), BB, true,
FuncTransMode::Pointer));
SPIRVValue *Val = transValue(Phi->getIncomingValue(I), BB, true,
FuncTransMode::Pointer);
if (Val->getType() != Ty)
Val = BM->addUnaryInst(OpBitcast, Ty, Val, BB);
IncomingPairs.push_back(Val);
IncomingPairs.push_back(transValue(Phi->getIncomingBlock(I), nullptr));
}
return mapValue(V,
BM->addPhiInst(transScavengedType(Phi), IncomingPairs, BB));
return mapValue(V, BM->addPhiInst(Ty, IncomingPairs, BB));
}

if (auto *Ext = dyn_cast<ExtractValueInst>(V)) {
Expand Down Expand Up @@ -6650,9 +6675,12 @@ LLVMToSPIRVBase::transBuiltinToInstWithoutDecoration(Op OC, CallInst *CI,
assert((Pointee == Args[I] || !isa<Function>(Pointee)) &&
"Illegal use of a function pointer type");
}
SPArgs.push_back(SPI->isOperandLiteral(I)
? cast<ConstantInt>(Args[I])->getZExtValue()
: transValue(Args[I], BB)->getId());
if (!SPI->isOperandLiteral(I)) {
SPIRVValue *Val = transValue(Args[I], BB);
SPArgs.push_back(Val->getId());
} else {
SPArgs.push_back(cast<ConstantInt>(Args[I])->getZExtValue());
}
}
BM->addInstTemplate(SPI, SPArgs, BB, SPRetTy);
if (!SPRetTy || !SPRetTy->isTypeStruct())
Expand Down
34 changes: 25 additions & 9 deletions lib/SPIRV/libSPIRV/SPIRVInstruction.h
Original file line number Diff line number Diff line change
Expand Up @@ -577,9 +577,15 @@ class SPIRVStore : public SPIRVInstruction, public SPIRVMemoryAccess {
SPIRVInstruction::validate();
if (getSrc()->isForward() || getDst()->isForward())
return;
assert(getValueType(PtrId)->getPointerElementType() ==
getValueType(ValId) &&
"Inconsistent operand types");
assert(
(getValueType(PtrId)
->getPointerElementType()
->isTypeUntypedPointerKHR() ||
// TODO: This check should be removed once we support untyped
// variables.
getValueType(ValId)->isTypeUntypedPointerKHR() ||
getValueType(PtrId)->getPointerElementType() == getValueType(ValId)) &&
"Inconsistent operand types");
}

private:
Expand All @@ -594,11 +600,12 @@ class SPIRVLoad : public SPIRVInstruction, public SPIRVMemoryAccess {
// Complete constructor
SPIRVLoad(SPIRVId TheId, SPIRVId PointerId,
const std::vector<SPIRVWord> &TheMemoryAccess,
SPIRVBasicBlock *TheBB)
SPIRVBasicBlock *TheBB, SPIRVType *TheType = nullptr)
: SPIRVInstruction(
FixedWords + TheMemoryAccess.size(), OpLoad,
TheBB->getValueType(PointerId)->getPointerElementType(), TheId,
TheBB),
TheType ? TheType
: TheBB->getValueType(PointerId)->getPointerElementType(),
TheId, TheBB),
SPIRVMemoryAccess(TheMemoryAccess), PtrId(PointerId),
MemoryAccess(TheMemoryAccess) {
validate();
Expand Down Expand Up @@ -628,6 +635,12 @@ class SPIRVLoad : public SPIRVInstruction, public SPIRVMemoryAccess {
void validate() const override {
SPIRVInstruction::validate();
assert((getValue(PtrId)->isForward() ||
getValueType(PtrId)
->getPointerElementType()
->isTypeUntypedPointerKHR() ||
// TODO: This check should be removed once we support untyped
// variables.
Type->isTypeUntypedPointerKHR() ||
Type == getValueType(PtrId)->getPointerElementType()) &&
"Inconsistent types");
}
Expand Down Expand Up @@ -2010,7 +2023,8 @@ class SPIRVCompositeExtractBase : public SPIRVInstTemplateBase {
(void)Composite;
assert(getValueType(Composite)->isTypeArray() ||
getValueType(Composite)->isTypeStruct() ||
getValueType(Composite)->isTypeVector());
getValueType(Composite)->isTypeVector() ||
getValueType(Composite)->isTypeUntypedPointerKHR());
}
};

Expand All @@ -2036,7 +2050,8 @@ class SPIRVCompositeInsertBase : public SPIRVInstTemplateBase {
(void)Composite;
assert(getValueType(Composite)->isTypeArray() ||
getValueType(Composite)->isTypeStruct() ||
getValueType(Composite)->isTypeVector());
getValueType(Composite)->isTypeVector() ||
getValueType(Composite)->isTypeUntypedPointerKHR());
assert(Type == getValueType(Composite));
}
};
Expand Down Expand Up @@ -2383,7 +2398,8 @@ template <Op OC> class SPIRVLifetime : public SPIRVInstruction {
// Signedness of 1, its sign bit cannot be set.
if (!(ObjType->getPointerElementType()->isTypeVoid() ||
// (void *) is i8* in LLVM IR
ObjType->getPointerElementType()->isTypeInt(8)) ||
ObjType->getPointerElementType()->isTypeInt(8) ||
ObjType->getPointerElementType()->isTypeUntypedPointerKHR()) ||
!Module->hasCapability(CapabilityAddresses))
assert(Size == 0 && "Size must be 0");
}
Expand Down
28 changes: 23 additions & 5 deletions lib/SPIRV/libSPIRV/SPIRVModule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,8 @@ class SPIRVModuleImpl : public SPIRVModule {
SPIRVTypeInt *addIntegerType(unsigned BitWidth) override;
SPIRVTypeOpaque *addOpaqueType(const std::string &) override;
SPIRVTypePointer *addPointerType(SPIRVStorageClassKind, SPIRVType *) override;
SPIRVTypeUntypedPointerKHR *
addUntypedPointerKHRType(SPIRVStorageClassKind) override;
SPIRVTypeImage *addImageType(SPIRVType *,
const SPIRVTypeImageDescriptor &) override;
SPIRVTypeImage *addImageType(SPIRVType *, const SPIRVTypeImageDescriptor &,
Expand Down Expand Up @@ -353,7 +355,7 @@ class SPIRVModuleImpl : public SPIRVModule {
SPIRVInstruction *addCmpInst(Op, SPIRVType *, SPIRVValue *, SPIRVValue *,
SPIRVBasicBlock *) override;
SPIRVInstruction *addLoadInst(SPIRVValue *, const std::vector<SPIRVWord> &,
SPIRVBasicBlock *) override;
SPIRVBasicBlock *, SPIRVType *) override;
SPIRVInstruction *addPhiInst(SPIRVType *, std::vector<SPIRVValue *>,
SPIRVBasicBlock *) override;
SPIRVInstruction *addCompositeConstructInst(SPIRVType *,
Expand Down Expand Up @@ -563,6 +565,8 @@ class SPIRVModuleImpl : public SPIRVModule {
SPIRVUnknownStructFieldMap UnknownStructFieldMap;
SPIRVTypeBool *BoolTy;
SPIRVTypeVoid *VoidTy;
SmallDenseMap<SPIRVStorageClassKind, SPIRVTypeUntypedPointerKHR *>
UntypedPtrTyMap;
SmallDenseMap<unsigned, SPIRVTypeInt *, 4> IntTypeMap;
SmallDenseMap<unsigned, SPIRVTypeFloat *, 4> FloatTypeMap;
SmallDenseMap<std::pair<unsigned, SPIRVType *>, SPIRVTypePointer *, 4>
Expand Down Expand Up @@ -1014,6 +1018,17 @@ SPIRVModuleImpl::addPointerType(SPIRVStorageClassKind StorageClass,
return addType(Ty);
}

SPIRVTypeUntypedPointerKHR *
SPIRVModuleImpl::addUntypedPointerKHRType(SPIRVStorageClassKind StorageClass) {
auto Loc = UntypedPtrTyMap.find(StorageClass);
if (Loc != UntypedPtrTyMap.end())
return Loc->second;

auto *Ty = new SPIRVTypeUntypedPointerKHR(this, getId(), StorageClass);
UntypedPtrTyMap[StorageClass] = Ty;
return addType(Ty);
}

SPIRVTypeFunction *SPIRVModuleImpl::addFunctionType(
SPIRVType *ReturnType, const std::vector<SPIRVType *> &ParameterTypes) {
return addType(
Expand Down Expand Up @@ -1430,9 +1445,10 @@ SPIRVModuleImpl::addInstruction(SPIRVInstruction *Inst, SPIRVBasicBlock *BB,
SPIRVInstruction *
SPIRVModuleImpl::addLoadInst(SPIRVValue *Source,
const std::vector<SPIRVWord> &TheMemoryAccess,
SPIRVBasicBlock *BB) {
SPIRVBasicBlock *BB, SPIRVType *TheType) {
return addInstruction(
new SPIRVLoad(getId(), Source->getId(), TheMemoryAccess, BB), BB);
new SPIRVLoad(getId(), Source->getId(), TheMemoryAccess, BB, TheType),
BB);
}

SPIRVInstruction *
Expand Down Expand Up @@ -1925,11 +1941,13 @@ class TopologicalSort {
// We've found a recursive data type, e.g. a structure having a member
// which is a pointer to the same structure.
State = Unvisited; // Forget about it
if (E->getOpCode() == OpTypePointer) {
if (E->getOpCode() == OpTypePointer ||
E->getOpCode() == OpTypeUntypedPointerKHR) {
// If we have a pointer in the recursive chain, we can break the
// cyclic dependency by inserting a forward declaration of that
// pointer.
SPIRVTypePointer *Ptr = static_cast<SPIRVTypePointer *>(E);
SPIRVTypePointerBase<> *Ptr =
static_cast<SPIRVTypePointerBase<> *>(E);
SPIRVModule *BM = E->getModule();
ForwardPointerSet.insert(BM->add(new SPIRVTypeForwardPointer(
BM, Ptr->getId(), Ptr->getPointerStorageClass())));
Expand Down
6 changes: 5 additions & 1 deletion lib/SPIRV/libSPIRV/SPIRVModule.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ class SPIRVTypeFunction;
class SPIRVTypeInt;
class SPIRVTypeOpaque;
class SPIRVTypePointer;
class SPIRVTypeUntypedPointerKHR;
class SPIRVTypeImage;
class SPIRVTypeSampler;
class SPIRVTypeSampledImage;
Expand Down Expand Up @@ -257,6 +258,8 @@ class SPIRVModule {
virtual SPIRVTypeOpaque *addOpaqueType(const std::string &) = 0;
virtual SPIRVTypePointer *addPointerType(SPIRVStorageClassKind,
SPIRVType *) = 0;
virtual SPIRVTypeUntypedPointerKHR *
addUntypedPointerKHRType(SPIRVStorageClassKind) = 0;
virtual SPIRVTypeStruct *openStructType(unsigned, const std::string &) = 0;
virtual SPIRVEntry *addTypeStructContinuedINTEL(unsigned NumMembers) = 0;
virtual void closeStructType(SPIRVTypeStruct *, bool) = 0;
Expand Down Expand Up @@ -396,7 +399,8 @@ class SPIRVModule {
SPIRVBasicBlock *BB, SPIRVType *Ty) = 0;
virtual SPIRVInstruction *addLoadInst(SPIRVValue *,
const std::vector<SPIRVWord> &,
SPIRVBasicBlock *) = 0;
SPIRVBasicBlock *,
SPIRVType *TheType = nullptr) = 0;
virtual SPIRVInstruction *addLifetimeInst(Op OC, SPIRVValue *Object,
SPIRVWord Size,
SPIRVBasicBlock *BB) = 0;
Expand Down
1 change: 1 addition & 0 deletions lib/SPIRV/libSPIRV/SPIRVNameMapEnum.h
Original file line number Diff line number Diff line change
Expand Up @@ -463,6 +463,7 @@ template <> inline void SPIRVMap<Capability, std::string>::init() {
add(CapabilityRoundingModeRTZ, "RoundingModeRTZ");
add(CapabilityRayQueryProvisionalKHR, "RayQueryProvisionalKHR");
add(CapabilityRayQueryKHR, "RayQueryKHR");
add(CapabilityUntypedPointersKHR, "UntypedPointersKHR");
add(CapabilityRayTraversalPrimitiveCullingKHR,
"RayTraversalPrimitiveCullingKHR");
add(CapabilityRayTracingKHR, "RayTracingKHR");
Expand Down
3 changes: 2 additions & 1 deletion lib/SPIRV/libSPIRV/SPIRVOpCode.h
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,8 @@ inline bool isTypeOpCode(Op OpCode) {
OC == internal::OpTypeJointMatrixINTEL ||
OC == internal::OpTypeJointMatrixINTELv2 ||
OC == OpTypeCooperativeMatrixKHR ||
OC == internal::OpTypeTaskSequenceINTEL;
OC == internal::OpTypeTaskSequenceINTEL ||
OC == OpTypeUntypedPointerKHR;
}

inline bool isSpecConstantOpCode(Op OpCode) {
Expand Down
1 change: 1 addition & 0 deletions lib/SPIRV/libSPIRV/SPIRVOpCodeEnum.h
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,7 @@ _SPIRV_OP(PtrEqual, 401)
_SPIRV_OP(PtrNotEqual, 402)
_SPIRV_OP(PtrDiff, 403)
_SPIRV_OP(CopyLogical, 400)
_SPIRV_OP(TypeUntypedPointerKHR, 4417)
_SPIRV_OP(GroupNonUniformRotateKHR, 4431)
_SPIRV_OP(SDotKHR, 4450)
_SPIRV_OP(UDotKHR, 4451)
Expand Down
16 changes: 13 additions & 3 deletions lib/SPIRV/libSPIRV/SPIRVType.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -86,12 +86,16 @@ SPIRVType *SPIRVType::getFunctionReturnType() const {
}

SPIRVType *SPIRVType::getPointerElementType() const {
assert(OpCode == OpTypePointer && "Not a pointer type");
assert((OpCode == OpTypePointer || OpCode == OpTypeUntypedPointerKHR) &&
"Not a pointer type");
if (OpCode == OpTypeUntypedPointerKHR)
return const_cast<SPIRVType *>(this);
return static_cast<const SPIRVTypePointer *>(this)->getElementType();
}

SPIRVStorageClassKind SPIRVType::getPointerStorageClass() const {
assert(OpCode == OpTypePointer && "Not a pointer type");
assert((OpCode == OpTypePointer || OpCode == OpTypeUntypedPointerKHR) &&
"Not a pointer type");
return static_cast<const SPIRVTypePointer *>(this)->getStorageClass();
}

Expand Down Expand Up @@ -183,7 +187,13 @@ bool SPIRVType::isTypeInt(unsigned Bits) const {
return isType<SPIRVTypeInt>(this, Bits);
}

bool SPIRVType::isTypePointer() const { return OpCode == OpTypePointer; }
bool SPIRVType::isTypePointer() const {
return OpCode == OpTypePointer || OpCode == OpTypeUntypedPointerKHR;
}

bool SPIRVType::isTypeUntypedPointerKHR() const {
return OpCode == OpTypeUntypedPointerKHR;
}

bool SPIRVType::isTypeOpaque() const { return OpCode == OpTypeOpaque; }

Expand Down
Loading

0 comments on commit 7dacb7c

Please sign in to comment.