diff --git a/src/target/spirv/build_vulkan.cc b/src/target/spirv/build_vulkan.cc index 9f9718bef18e4..0e2cc04347cea 100644 --- a/src/target/spirv/build_vulkan.cc +++ b/src/target/spirv/build_vulkan.cc @@ -75,7 +75,7 @@ runtime::Module BuildSPIRV(IRModule mod, Target target, bool webgpu_restriction) mod = tir::transform::PointerValueTypeRewrite()(std::move(mod)); - CodeGenSPIRV cg; + CodeGenSPIRV cg(target); for (auto kv : mod->functions) { ICHECK(kv.second->IsInstance()) << "CodeGenSPIRV: Can only take PrimFunc"; diff --git a/src/target/spirv/codegen_spirv.cc b/src/target/spirv/codegen_spirv.cc index 0c6deb28dca91..dc625b6a928d7 100644 --- a/src/target/spirv/codegen_spirv.cc +++ b/src/target/spirv/codegen_spirv.cc @@ -37,6 +37,8 @@ namespace tvm { namespace codegen { +CodeGenSPIRV::CodeGenSPIRV(Target target) : spirv_support_(target) {} + runtime::VulkanShader CodeGenSPIRV::BuildFunction(const PrimFunc& f, const std::string& name) { this->InitFuncState(); ICHECK(f->HasNonzeroAttr(tir::attr::kNoAlias)) << "SPIRV only takes restricted memory model"; @@ -44,7 +46,8 @@ runtime::VulkanShader CodeGenSPIRV::BuildFunction(const PrimFunc& f, const std:: uint32_t num_buffer = 0; // Currently, all storage and uniform buffer arguments are passed as - // a single descriptor set at index 0. + // a single descriptor set at index 0. If ever non-zero, must + // ensure it is less than maxBoundDescriptorSets. const uint32_t descriptor_set = 0; for (Var arg : f->params) { @@ -114,7 +117,7 @@ void CodeGenSPIRV::InitFuncState() { var_map_.clear(); storage_info_.clear(); analyzer_.reset(new arith::Analyzer()); - builder_.reset(new spirv::IRBuilder()); + builder_.reset(new spirv::IRBuilder(spirv_support_)); builder_->InitHeader(); } diff --git a/src/target/spirv/codegen_spirv.h b/src/target/spirv/codegen_spirv.h index e3d6c153d06fd..3868322a74e0d 100644 --- a/src/target/spirv/codegen_spirv.h +++ b/src/target/spirv/codegen_spirv.h @@ -25,6 +25,7 @@ #define TVM_TARGET_SPIRV_CODEGEN_SPIRV_H_ #include +#include #include #include #include @@ -38,6 +39,7 @@ #include "../../runtime/thread_storage_scope.h" #include "../../runtime/vulkan/vulkan_shader.h" #include "ir_builder.h" +#include "spirv_support.h" namespace tvm { namespace codegen { @@ -50,6 +52,14 @@ using namespace tir; class CodeGenSPIRV : public ExprFunctor, public StmtFunctor { public: + /*! + * \brief Initialize the codegen based on a specific target. + * + * \param target The target for which code should be generated. The + * device_type for this target must be kDLVulkan. + */ + CodeGenSPIRV(Target target); + /*! * \brief Compile and add function f to the current module. * \param f The function to be added. @@ -131,6 +141,8 @@ class CodeGenSPIRV : public ExprFunctor, spirv::Value GetThreadIndex(const IterVar& iv, const PrimExpr& extent); spirv::Value CreateStorageSync(const CallNode* op); void Scalarize(const PrimExpr& e, std::function f); + // SPIRV-related capabilities of the target + SPIRVSupport spirv_support_; // The builder std::unique_ptr builder_; // Work group size of three diff --git a/src/target/spirv/ir_builder.cc b/src/target/spirv/ir_builder.cc index ce2b4bc152118..d204ccc55038d 100644 --- a/src/target/spirv/ir_builder.cc +++ b/src/target/spirv/ir_builder.cc @@ -23,22 +23,23 @@ */ #include "ir_builder.h" +#include + namespace tvm { namespace codegen { namespace spirv { // implementations +IRBuilder::IRBuilder(const SPIRVSupport& support) : spirv_support_(support) {} + void IRBuilder::InitHeader() { ICHECK_EQ(header_.size(), 0U); header_.push_back(spv::MagicNumber); - // Use the spirv version as indicated in the SDK. -#if SPV_VERSION >= 0x10300 - header_.push_back(0x10300); -#else + // Target SPIR-V version 1.0. Additional functionality will be + // enabled through extensions. header_.push_back(0x10000); -#endif // generator: set to 0, unknown header_.push_back(0U); @@ -46,10 +47,11 @@ void IRBuilder::InitHeader() { header_.push_back(0U); // Schema: reserved header_.push_back(0U); - // shader - ib_.Begin(spv::OpCapability).Add(spv::CapabilityShader).Commit(&header_); - // Declare int64 capability by default - ib_.Begin(spv::OpCapability).Add(spv::CapabilityInt64).Commit(&header_); + + // Declare CapabilityShader by default. All other capabilities are + // determined by the types declared. + capabilities_used_.insert(spv::CapabilityShader); + // memory model ib_.Begin(spv::OpMemoryModel) .AddSeq(spv::AddressingModelLogical, spv::MemoryModelGLSL450) @@ -71,6 +73,30 @@ void IRBuilder::InitPreDefs() { ib_.Begin(spv::OpTypeFunction).AddSeq(t_void_func_, t_void_).Commit(&global_); } +std::vector IRBuilder::Finalize() { + std::vector data; + // Index for upper bound of id numbers. + const int kBoundLoc = 3; + header_[kBoundLoc] = id_counter_; + data.insert(data.end(), header_.begin(), header_.end()); + for (const auto& capability : capabilities_used_) { + ib_.Begin(spv::OpCapability).Add(capability).Commit(&data); + } + for (const auto& ext_name : extensions_used_) { + ib_.Begin(spv::OpExtension).Add(ext_name).Commit(&data); + } + data.insert(data.end(), extended_instruction_section_.begin(), + extended_instruction_section_.end()); + data.insert(data.end(), entry_.begin(), entry_.end()); + data.insert(data.end(), exec_mode_.begin(), exec_mode_.end()); + data.insert(data.end(), debug_.begin(), debug_.end()); + data.insert(data.end(), decorate_.begin(), decorate_.end()); + data.insert(data.end(), global_.begin(), global_.end()); + data.insert(data.end(), func_header_.begin(), func_header_.end()); + data.insert(data.end(), function_.begin(), function_.end()); + return data; +} + SType IRBuilder::GetSType(const DataType& dtype) { if (dtype == DataType::Int(32)) { return t_int32_; @@ -145,16 +171,19 @@ SType IRBuilder::GetStructArrayType(const SType& value_type, uint32_t num_elems) .AddSeq(struct_type, 0, spv::DecorationOffset, 0) .Commit(&decorate_); -#if SPV_VERSION < 0x10300 - // NOTE: BufferBlock was deprecated in SPIRV 1.3 - // use StorageClassStorageBuffer instead. - // runtime array are always decorated as BufferBlock(shader storage buffer) - if (num_elems == 0) { - this->Decorate(spv::OpDecorate, struct_type, spv::DecorationBufferBlock); + // Runtime array are always decorated as Block or BufferBlock + // (shader storage buffer) + if (spirv_support_.supports_StorageBufferStorageClass) { + // If SPIRV 1.3+, or with extension + // SPV_KHR_storage_buffer_storage_class, BufferBlock is + // deprecated. + extensions_used_.insert("SPV_KHR_storage_buffer_storage_class"); + this->Decorate(spv::OpDecorate, struct_type, spv::DecorationBlock); + } else { + if (num_elems == 0) { + this->Decorate(spv::OpDecorate, struct_type, spv::DecorationBufferBlock); + } } -#else - this->Decorate(spv::OpDecorate, struct_type, spv::DecorationBlock); -#endif struct_array_type_tbl_[key] = struct_type; return struct_type; } @@ -186,13 +215,14 @@ Value IRBuilder::FloatImm(const SType& dtype, double value) { Value IRBuilder::BufferArgument(const SType& value_type, uint32_t descriptor_set, uint32_t binding) { - // NOTE: BufferBlock was deprecated in SPIRV 1.3 - // use StorageClassStorageBuffer instead. -#if SPV_VERSION >= 0x10300 - spv::StorageClass storage_class = spv::StorageClassStorageBuffer; -#else - spv::StorageClass storage_class = spv::StorageClassUniform; -#endif + // If SPIRV 1.3+, or with extension SPV_KHR_storage_buffer_storage_class, BufferBlock is + // deprecated. + spv::StorageClass storage_class; + if (spirv_support_.supports_StorageBufferStorageClass) { + storage_class = spv::StorageClassStorageBuffer; + } else { + storage_class = spv::StorageClassUniform; + } SType sarr_type = GetStructArrayType(value_type, 0); SType ptr_type = GetPointerType(sarr_type, storage_class); @@ -383,6 +413,8 @@ Value IRBuilder::GetConst_(const SType& dtype, const uint64_t* pvalue) { } SType IRBuilder::DeclareType(const DataType& dtype) { + AddCapabilityFor(dtype); + if (dtype.lanes() == 1) { SType t; t.id = id_counter_++; @@ -410,6 +442,60 @@ SType IRBuilder::DeclareType(const DataType& dtype) { } } +void IRBuilder::AddCapabilityFor(const DataType& dtype) { + // Declare appropriate capabilities for int/float types + if (dtype.is_int() || dtype.is_uint()) { + if (dtype.bits() == 8) { + ICHECK(spirv_support_.supports_Int8) << "Vulkan target does not support Int8 capability"; + capabilities_used_.insert(spv::CapabilityInt8); + } else if (dtype.bits() == 16) { + ICHECK(spirv_support_.supports_Int16) << "Vulkan target does not support Int16 capability"; + capabilities_used_.insert(spv::CapabilityInt16); + } else if (dtype.bits() == 64) { + ICHECK(spirv_support_.supports_Int64) << "Vulkan target does not support Int64 capability"; + capabilities_used_.insert(spv::CapabilityInt64); + } + + } else if (dtype.is_float()) { + if (dtype.bits() == 16) { + ICHECK(spirv_support_.supports_Float16) + << "Vulkan target does not support Float16 capability"; + capabilities_used_.insert(spv::CapabilityFloat16); + } else if (dtype.bits() == 64) { + ICHECK(spirv_support_.supports_Float64) + << "Vulkan target does not support Float64 capability"; + capabilities_used_.insert(spv::CapabilityFloat64); + } + } + + // Declare ability to read type to/from storage buffers. Doing so + // here is a little bit overzealous, should be relaxed in the + // future. Requiring StorageBuffer8BitAccess in order to declare an + // Int8 prevents use of an 8-bit loop iterator on a device that + // supports Int8 but doesn't support 8-bit buffer access. + if (dtype.bits() == 8) { + ICHECK(spirv_support_.supports_StorageBuffer8BitAccess) + << "Vulkan target does not support StorageBuffer8BitAccess"; + capabilities_used_.insert(spv::CapabilityStorageBuffer8BitAccess); + extensions_used_.insert("SPV_KHR_8bit_storage"); + + ICHECK(spirv_support_.supports_StorageBufferStorageClass) + << "Illegal Vulkan target description. " + << "Vulkan spec requires extension VK_KHR_storage_buffer_storage_class " + << "if VK_KHR_8bit_storage is supported"; + } else if (dtype.bits() == 16) { + ICHECK(spirv_support_.supports_StorageBuffer8BitAccess) + << "Vulkan target does not support StorageBuffer16BitAccess"; + + extensions_used_.insert("SPV_KHR_16bit_storage"); + if (spirv_support_.supports_StorageBufferStorageClass) { + capabilities_used_.insert(spv::CapabilityStorageBuffer16BitAccess); + } else { + capabilities_used_.insert(spv::CapabilityStorageUniformBufferBlock16); + } + } +} + PhiValue IRBuilder::MakePhi(const SType& out_type, uint32_t num_incoming) { Value val = NewValue(out_type, kNormal); ib_.Begin(spv::OpPhi).AddSeq(out_type, val); diff --git a/src/target/spirv/ir_builder.h b/src/target/spirv/ir_builder.h index 250d67067a814..afd9be92fa5a9 100644 --- a/src/target/spirv/ir_builder.h +++ b/src/target/spirv/ir_builder.h @@ -30,6 +30,7 @@ // clang-format off #include #include +#include #include #include #include @@ -37,6 +38,8 @@ #include // clang-format on +#include "spirv_support.h" + namespace tvm { namespace codegen { namespace spirv { @@ -268,6 +271,14 @@ class InstrBuilder { */ class IRBuilder { public: + /*! + * \brief Initialize the codegen based on a specific feature set. + * + * \param support The features in SPIRV that are supported by the + * target device. + */ + explicit IRBuilder(const SPIRVSupport& support); + /*! \brief Initialize header */ void InitHeader(); /*! \brief Initialize the predefined contents */ @@ -278,29 +289,21 @@ class IRBuilder { * \return The finalized binary instruction. */ Value ExtInstImport(const std::string& name) { + auto it = ext_inst_tbl_.find(name); + if (it != ext_inst_tbl_.end()) { + return it->second; + } Value val = NewValue(SType(), kExtInst); - ib_.Begin(spv::OpExtInstImport).AddSeq(val, name).Commit(&header_); + ib_.Begin(spv::OpExtInstImport).AddSeq(val, name).Commit(&extended_instruction_section_); + ext_inst_tbl_[name] = val; return val; } /*! * \brief Get the final binary built from the builder * \return The finalized binary instruction. */ - std::vector Finalize() { - std::vector data; - // set bound - const int kBoundLoc = 3; - header_[kBoundLoc] = id_counter_; - data.insert(data.end(), header_.begin(), header_.end()); - data.insert(data.end(), entry_.begin(), entry_.end()); - data.insert(data.end(), exec_mode_.begin(), exec_mode_.end()); - data.insert(data.end(), debug_.begin(), debug_.end()); - data.insert(data.end(), decorate_.begin(), decorate_.end()); - data.insert(data.end(), global_.begin(), global_.end()); - data.insert(data.end(), func_header_.begin(), func_header_.end()); - data.insert(data.end(), function_.begin(), function_.end()); - return data; - } + std::vector Finalize(); + /*! * \brief Create new label * \return The created new label @@ -599,6 +602,19 @@ class IRBuilder { Value GetConst_(const SType& dtype, const uint64_t* pvalue); // declare type SType DeclareType(const DataType& dtype); + + // Declare the appropriate SPIR-V capabilities and extensions to use + // this data type. + void AddCapabilityFor(const DataType& dtype); + + /*! \brief SPIRV-related capabilities of the target + * + * This SPIRVSupport object is owned by the same CodeGenSPIRV + * object that owns the IRBuilder. Therefore, safe to use a + * reference as the CodeGenSPIRV will live longer. + */ + const SPIRVSupport& spirv_support_; + /*! \brief internal instruction builder */ InstrBuilder ib_; /*! \brief Current label */ @@ -623,9 +639,22 @@ class IRBuilder { std::map, SType> pointer_type_tbl_; /*! \brief map from constant int to its value */ std::map, Value> const_tbl_; - /*! \brief Header segment, include import */ + /*! \brief map from name of a ExtInstImport to its value */ + std::map ext_inst_tbl_; + + /*! \brief Header segment + * + * 5 words long, described in "First Words of Physical Layout" + * section of SPIR-V documentation. + */ std::vector header_; - /*! \brief engtry point segment */ + /*! \brief SPIR-V capabilities used by this module. */ + std::set capabilities_used_; + /*! \brief SPIR-V extensions used by this module. */ + std::set extensions_used_; + /*! \brief entry point segment */ + std::vector extended_instruction_section_; + /*! \brief entry point segment */ std::vector entry_; /*! \brief Header segment */ std::vector exec_mode_; diff --git a/src/target/spirv/spirv_support.cc b/src/target/spirv/spirv_support.cc new file mode 100644 index 0000000000000..6ba5da3e3ce05 --- /dev/null +++ b/src/target/spirv/spirv_support.cc @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file spirv_support + * + * \brief Utility for determining which spirv capabilities a TVM + * target supports. + */ + +#include "spirv_support.h" + +#include + +namespace tvm { +namespace codegen { + +SPIRVSupport::SPIRVSupport(tvm::Target target) { + ICHECK_EQ(target->kind->device_type, kDLVulkan) + << "SPIRVSupport can only be checked for vulkan device type"; + + // Currently, this codifies the assumptions that were present and + // implicit in previous implementations. In the future, this will + // pull information from the specified `Target`. + + supports_StorageBufferStorageClass = (SPV_VERSION >= 0x10300); + supports_StorageBuffer8BitAccess = true; + supports_StorageBuffer16BitAccess = true; + supports_Float16 = true; + supports_Int8 = true; + supports_Int16 = true; + supports_Int64 = true; +} + +} // namespace codegen +} // namespace tvm diff --git a/src/target/spirv/spirv_support.h b/src/target/spirv/spirv_support.h new file mode 100644 index 0000000000000..ecdc4178006b6 --- /dev/null +++ b/src/target/spirv/spirv_support.h @@ -0,0 +1,220 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file spirv_support + * + * \brief Utility for determining which spirv capabilities a TVM + * target supports. + */ +#ifndef TVM_TARGET_SPIRV_SPIRV_SUPPORT_H_ +#define TVM_TARGET_SPIRV_SPIRV_SUPPORT_H_ + +#include + +namespace tvm { +namespace codegen { + +/*! \brief Represents which support a Vulkan driver has that are relevant to codegen */ +struct SPIRVSupport { + /*! \brief Determine spirv capabilities from a vulkan target. + */ + explicit SPIRVSupport(Target target); + + /*! + * \brief The supported operations + * + * Vulkan extension: VK_KHR_driver_properties + * Minimum vulkan version: 1.1 + * + * The supportedOperations bitflags from + * VkPhysicalDeviceSubgroupProperties. + * + * Requires vulkan 1.1 or higher to use. If the + * VK_KHR_driver_properties extension is not present in order to + * query this value, or if the driver does not support vulkan 1.0, + * then this value will be set to 0. + * + */ + uint32_t supportedSubgroupOperations{0}; + + /*! + * \brief The maximum size (bytes) of push constants + * + * Default value is from Vulkan spec, "Required Limits" table. + * Implementations may have a larger limit. + */ + uint32_t maxPushConstantsSize{128}; + + /*! + * \brief The maximum size (bytes) of a uniform buffer. + * + * Default value is from Vulkan spec, "Required Limits" table. + * Implementations may have a larger limit. + */ + uint32_t maxUniformBufferRange{16384}; + + /*! + * \brief The maximum size (bytes) of a storage buffer. + * + * Default value is from Vulkan spec, "Required Limits" table. + * Implementations may have a larger limit. + */ + uint32_t maxStorageBufferRange{1 << 27}; + + /*! + * \brief The maximum number of storage buffers accessible by a single shader. + * + * Default value is from Vulkan spec, "Required Limits" table. + * Implementations may have a larger limit, frequently much larger. + * (e.g. GTX 1080 has max of 2^20) + */ + uint32_t maxPerStageDescriptorStorageBuffers{4}; + + /*! + * \brief Whether the driver supports StorageClassStorageBuffer + * + * Vulkan extension: VK_KHR_storage_buffer_storage_class + * Device property: N/A + * SPV Extension: SPV_KHR_storage_buffer_storage_class + * SPV Capability: N/A + * + * If support is present, access push constants and UBO as + * block-decorated StorageClassStorageBuffer. Otherwise, access as + * buffer-block-decorated StorageClassUniform. SPIRV 1.3 deprecated + * BufferBlock, so this should always be true drivers that support + * SPIRV 1.3. + * + */ + bool supports_StorageBufferStorageClass{false}; + + /*! + * \brief Whether the driver supports reading/writing to 16-bit values + * + * Vulkan extension: VK_KHR_8bit_storage + * Device property: storageBuffer8BitAccess + * SPV extension: SPV_KHR_8bit_storage + * SPV Capability: StorageBuffer8BitAccess + * + * If support is present, can read/write 8-bit values, but doesn't + * necessarily provide 8-bit operations. + * + * If support is present, will declare StorageBuffer8BitAccess as + * needed. If support is not present, will throw error if a + * PrimFunc calls for this functionality. Unlike + * StorageUniform16BitAccess, no fallback to + * "StorageUniformBufferBlock8" is needed, as VK_KHR_8bit_storage + * requires VK_KHR_storage_buffer_storage_class to also be present. + * + */ + bool supports_StorageBuffer8BitAccess{false}; + + /*! + * \brief Whether the driver supports reading/writing to 16-bit values + * + * Vulkan extension: VK_KHR_16bit_storage + * Device property: storageBuffer16BitAccess + * SPV extension: SPV_KHR_16bit_storage + * SPV Capability: StorageBuffer16BitAccess, StorageUniformBufferBlock16 + * + * If support is present, can read/write 16-bit values, but doesn't + * necessarily provide 16-bit operations. + * + * If support is present, will declare either + * StorageBuffer16BitAccess or StorageUniformBufferBlock16 as + * needed, selecting based on the value of + * supports_StorageBufferStorageClass. If support is not present, + * will throw error if a PrimFunc calls for this functionality. + */ + bool supports_StorageBuffer16BitAccess{false}; + + /*! + * \brief Whether the driver supports operations involving 16-bit floats + * + * Vulkan extension: VK_KHR_shader_float16_int8 + * Device Property: shaderFloat16 + * SPV Extension name: N/A + * SPV Capability: Float16, Float16Buffer + * + * If support is present, can perform 16-bit float operations. If + * support is not present, codegen will throw exception on + * attempting to create a 16-bit float. + */ + bool supports_Float16{false}; + + /*! + * \brief Whether the driver supports operations involving 16-bit floats + * + * Vulkan extension: N/A + * Device Property: shaderFloat64 + * SPV Extension name: N/A + * SPV Capability: Float64 + * + * If support is present, can perform 64-bit float operations. If + * support is not present, codegen will throw exception on + * attempting to create a 64-bit float. + */ + bool supports_Float64{false}; + + /*! + * \brief Whether the driver supports operations involving 8-bit ints + * + * Vulkan extension: VK_KHR_shader_float16_int8 + * Device Property: shaderInt8 + * SPV Extension name: N/A + * SPV Capability: Int8 + * + * If support is present, can perform 8-bit int operations. If + * support is not present, codegen will throw exception on + * attempting to create a 8-bit int. + */ + bool supports_Int8{false}; + + /*! + * \brief Whether the driver supports operations involving 8-bit ints + * + * Vulkan extension: N/A + * Device Property: shaderInt16 + * SPV Extension name: N/A + * SPV Capability: Int16 + * + * If support is present, can perform 16-bit int operations. If + * support is not present, codegen will throw exception on + * attempting to create a 16-bit int. + */ + bool supports_Int16{false}; + + /*! + * \brief Whether the driver supports operations involving 64-bit ints + * + * Vulkan extension: N/A + * Device Property: shaderInt64 + * SPV Extension name: N/A + * SPV Capability: Int64 + * + * If support is present, can perform 64-bit int operations. If + * support is not present, codegen will throw exception on + * attempting to create a 64-bit int. + */ + bool supports_Int64{false}; +}; + +} // namespace codegen +} // namespace tvm +#endif // TVM_TARGET_SPIRV_SPIRV_SUPPORT_H_