Skip to content

Commit

Permalink
[Vulkan] Broke out implicit device requirements into SPIRVSupport
Browse files Browse the repository at this point in the history
Codifies the current requirements that are implicit in the shaders
built by CodeGenSPIRV (e.g. can read from 8-bit buffers).  The next
steps for this development are (1) to query driver/device support
information from the device, (2) to pass these query parameters
through the Target, and (3) to ensure correct shader generation even
when features are not supported.

Step (3) will require exposing the target properties to relay
optimization passes.
  • Loading branch information
Lunderberg committed May 18, 2021
1 parent c510c2b commit 43db690
Show file tree
Hide file tree
Showing 7 changed files with 467 additions and 46 deletions.
2 changes: 1 addition & 1 deletion src/target/spirv/build_vulkan.cc
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ runtime::Module BuildSPIRV(IRModule mod, Target target, bool webgpu_restriction)

mod = tir::transform::PointerValueTypeRewrite()(std::move(mod));

CodeGenSPIRV cg;
CodeGenSPIRV cg(target);

for (auto kv : mod->functions) {
ICHECK(kv.second->IsInstance<PrimFuncNode>()) << "CodeGenSPIRV: Can only take PrimFunc";
Expand Down
7 changes: 5 additions & 2 deletions src/target/spirv/codegen_spirv.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,17 @@
namespace tvm {
namespace codegen {

CodeGenSPIRV::CodeGenSPIRV(Target target) : spirv_support_(target) {}

runtime::VulkanShader CodeGenSPIRV::BuildFunction(const PrimFunc& f, const std::string& name) {
this->InitFuncState();
ICHECK(f->HasNonzeroAttr(tir::attr::kNoAlias)) << "SPIRV only takes restricted memory model";
std::vector<Var> pod_args;
uint32_t num_buffer = 0;

// Currently, all storage and uniform buffer arguments are passed as
// a single descriptor set at index 0.
// a single descriptor set at index 0. If ever non-zero, must
// ensure it is less than maxBoundDescriptorSets.
const uint32_t descriptor_set = 0;

for (Var arg : f->params) {
Expand Down Expand Up @@ -114,7 +117,7 @@ void CodeGenSPIRV::InitFuncState() {
var_map_.clear();
storage_info_.clear();
analyzer_.reset(new arith::Analyzer());
builder_.reset(new spirv::IRBuilder());
builder_.reset(new spirv::IRBuilder(spirv_support_));
builder_->InitHeader();
}

Expand Down
12 changes: 12 additions & 0 deletions src/target/spirv/codegen_spirv.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#define TVM_TARGET_SPIRV_CODEGEN_SPIRV_H_

#include <tvm/arith/analyzer.h>
#include <tvm/target/target.h>
#include <tvm/tir/analysis.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/function.h>
Expand All @@ -38,6 +39,7 @@
#include "../../runtime/thread_storage_scope.h"
#include "../../runtime/vulkan/vulkan_shader.h"
#include "ir_builder.h"
#include "spirv_support.h"

namespace tvm {
namespace codegen {
Expand All @@ -50,6 +52,14 @@ using namespace tir;
class CodeGenSPIRV : public ExprFunctor<spirv::Value(const PrimExpr&)>,
public StmtFunctor<void(const Stmt&)> {
public:
/*!
* \brief Initialize the codegen based on a specific target.
*
* \param target The target for which code should be generated. The
* device_type for this target must be kDLVulkan.
*/
CodeGenSPIRV(Target target);

/*!
* \brief Compile and add function f to the current module.
* \param f The function to be added.
Expand Down Expand Up @@ -131,6 +141,8 @@ class CodeGenSPIRV : public ExprFunctor<spirv::Value(const PrimExpr&)>,
spirv::Value GetThreadIndex(const IterVar& iv, const PrimExpr& extent);
spirv::Value CreateStorageSync(const CallNode* op);
void Scalarize(const PrimExpr& e, std::function<void(int i, spirv::Value v)> f);
// SPIRV-related capabilities of the target
SPIRVSupport spirv_support_;
// The builder
std::unique_ptr<spirv::IRBuilder> builder_;
// Work group size of three
Expand Down
136 changes: 111 additions & 25 deletions src/target/spirv/ir_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,33 +23,35 @@
*/
#include "ir_builder.h"

#include <spirv.hpp>

namespace tvm {
namespace codegen {
namespace spirv {

// implementations

IRBuilder::IRBuilder(const SPIRVSupport& support) : spirv_support_(support) {}

void IRBuilder::InitHeader() {
ICHECK_EQ(header_.size(), 0U);
header_.push_back(spv::MagicNumber);

// Use the spirv version as indicated in the SDK.
#if SPV_VERSION >= 0x10300
header_.push_back(0x10300);
#else
// Target SPIR-V version 1.0. Additional functionality will be
// enabled through extensions.
header_.push_back(0x10000);
#endif

// generator: set to 0, unknown
header_.push_back(0U);
// Bound: set during Finalize
header_.push_back(0U);
// Schema: reserved
header_.push_back(0U);
// shader
ib_.Begin(spv::OpCapability).Add(spv::CapabilityShader).Commit(&header_);
// Declare int64 capability by default
ib_.Begin(spv::OpCapability).Add(spv::CapabilityInt64).Commit(&header_);

// Declare CapabilityShader by default. All other capabilities are
// determined by the types declared.
capabilities_used_.insert(spv::CapabilityShader);

// memory model
ib_.Begin(spv::OpMemoryModel)
.AddSeq(spv::AddressingModelLogical, spv::MemoryModelGLSL450)
Expand All @@ -71,6 +73,30 @@ void IRBuilder::InitPreDefs() {
ib_.Begin(spv::OpTypeFunction).AddSeq(t_void_func_, t_void_).Commit(&global_);
}

std::vector<uint32_t> IRBuilder::Finalize() {
std::vector<uint32_t> data;
// Index for upper bound of id numbers.
const int kBoundLoc = 3;
header_[kBoundLoc] = id_counter_;
data.insert(data.end(), header_.begin(), header_.end());
for (const auto& capability : capabilities_used_) {
ib_.Begin(spv::OpCapability).Add(capability).Commit(&data);
}
for (const auto& ext_name : extensions_used_) {
ib_.Begin(spv::OpExtension).Add(ext_name).Commit(&data);
}
data.insert(data.end(), extended_instruction_section_.begin(),
extended_instruction_section_.end());
data.insert(data.end(), entry_.begin(), entry_.end());
data.insert(data.end(), exec_mode_.begin(), exec_mode_.end());
data.insert(data.end(), debug_.begin(), debug_.end());
data.insert(data.end(), decorate_.begin(), decorate_.end());
data.insert(data.end(), global_.begin(), global_.end());
data.insert(data.end(), func_header_.begin(), func_header_.end());
data.insert(data.end(), function_.begin(), function_.end());
return data;
}

SType IRBuilder::GetSType(const DataType& dtype) {
if (dtype == DataType::Int(32)) {
return t_int32_;
Expand Down Expand Up @@ -145,16 +171,19 @@ SType IRBuilder::GetStructArrayType(const SType& value_type, uint32_t num_elems)
.AddSeq(struct_type, 0, spv::DecorationOffset, 0)
.Commit(&decorate_);

#if SPV_VERSION < 0x10300
// NOTE: BufferBlock was deprecated in SPIRV 1.3
// use StorageClassStorageBuffer instead.
// runtime array are always decorated as BufferBlock(shader storage buffer)
if (num_elems == 0) {
this->Decorate(spv::OpDecorate, struct_type, spv::DecorationBufferBlock);
// Runtime array are always decorated as Block or BufferBlock
// (shader storage buffer)
if (spirv_support_.supports_storage_buffer_storage_class) {
// If SPIRV 1.3+, or with extension
// SPV_KHR_storage_buffer_storage_class, BufferBlock is
// deprecated.
extensions_used_.insert("SPV_KHR_storage_buffer_storage_class");
this->Decorate(spv::OpDecorate, struct_type, spv::DecorationBlock);
} else {
if (num_elems == 0) {
this->Decorate(spv::OpDecorate, struct_type, spv::DecorationBufferBlock);
}
}
#else
this->Decorate(spv::OpDecorate, struct_type, spv::DecorationBlock);
#endif
struct_array_type_tbl_[key] = struct_type;
return struct_type;
}
Expand Down Expand Up @@ -186,13 +215,14 @@ Value IRBuilder::FloatImm(const SType& dtype, double value) {

Value IRBuilder::BufferArgument(const SType& value_type, uint32_t descriptor_set,
uint32_t binding) {
// NOTE: BufferBlock was deprecated in SPIRV 1.3
// use StorageClassStorageBuffer instead.
#if SPV_VERSION >= 0x10300
spv::StorageClass storage_class = spv::StorageClassStorageBuffer;
#else
spv::StorageClass storage_class = spv::StorageClassUniform;
#endif
// If SPIRV 1.3+, or with extension SPV_KHR_storage_buffer_storage_class, BufferBlock is
// deprecated.
spv::StorageClass storage_class;
if (spirv_support_.supports_storage_buffer_storage_class) {
storage_class = spv::StorageClassStorageBuffer;
} else {
storage_class = spv::StorageClassUniform;
}

SType sarr_type = GetStructArrayType(value_type, 0);
SType ptr_type = GetPointerType(sarr_type, storage_class);
Expand Down Expand Up @@ -383,6 +413,8 @@ Value IRBuilder::GetConst_(const SType& dtype, const uint64_t* pvalue) {
}

SType IRBuilder::DeclareType(const DataType& dtype) {
AddCapabilityFor(dtype);

if (dtype.lanes() == 1) {
SType t;
t.id = id_counter_++;
Expand Down Expand Up @@ -410,6 +442,60 @@ SType IRBuilder::DeclareType(const DataType& dtype) {
}
}

void IRBuilder::AddCapabilityFor(const DataType& dtype) {
// Declare appropriate capabilities for int/float types
if (dtype.is_int() || dtype.is_uint()) {
if (dtype.bits() == 8) {
ICHECK(spirv_support_.supports_int8) << "Vulkan target does not support Int8 capability";
capabilities_used_.insert(spv::CapabilityInt8);
} else if (dtype.bits() == 16) {
ICHECK(spirv_support_.supports_int16) << "Vulkan target does not support Int16 capability";
capabilities_used_.insert(spv::CapabilityInt16);
} else if (dtype.bits() == 64) {
ICHECK(spirv_support_.supports_int64) << "Vulkan target does not support Int64 capability";
capabilities_used_.insert(spv::CapabilityInt64);
}

} else if (dtype.is_float()) {
if (dtype.bits() == 16) {
ICHECK(spirv_support_.supports_float16)
<< "Vulkan target does not support Float16 capability";
capabilities_used_.insert(spv::CapabilityFloat16);
} else if (dtype.bits() == 64) {
ICHECK(spirv_support_.supports_float64)
<< "Vulkan target does not support Float64 capability";
capabilities_used_.insert(spv::CapabilityFloat64);
}
}

// Declare ability to read type to/from storage buffers. Doing so
// here is a little bit overzealous, should be relaxed in the
// future. Requiring StorageBuffer8BitAccess in order to declare an
// Int8 prevents use of an 8-bit loop iterator on a device that
// supports Int8 but doesn't support 8-bit buffer access.
if (dtype.bits() == 8) {
ICHECK(spirv_support_.supports_storage_buffer_8bit_access)
<< "Vulkan target does not support StorageBuffer8BitAccess";
capabilities_used_.insert(spv::CapabilityStorageBuffer8BitAccess);
extensions_used_.insert("SPV_KHR_8bit_storage");

ICHECK(spirv_support_.supports_storage_buffer_storage_class)
<< "Illegal Vulkan target description. "
<< "Vulkan spec requires extension VK_KHR_storage_buffer_storage_class "
<< "if VK_KHR_8bit_storage is supported";
} else if (dtype.bits() == 16) {
ICHECK(spirv_support_.supports_storage_buffer_8bit_access)
<< "Vulkan target does not support StorageBuffer16BitAccess";

extensions_used_.insert("SPV_KHR_16bit_storage");
if (spirv_support_.supports_storage_buffer_storage_class) {
capabilities_used_.insert(spv::CapabilityStorageBuffer16BitAccess);
} else {
capabilities_used_.insert(spv::CapabilityStorageUniformBufferBlock16);
}
}
}

PhiValue IRBuilder::MakePhi(const SType& out_type, uint32_t num_incoming) {
Value val = NewValue(out_type, kNormal);
ib_.Begin(spv::OpPhi).AddSeq(out_type, val);
Expand Down
65 changes: 47 additions & 18 deletions src/target/spirv/ir_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,16 @@
// clang-format off
#include <algorithm>
#include <map>
#include <set>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>
#include <spirv.hpp>
// clang-format on

#include "spirv_support.h"

namespace tvm {
namespace codegen {
namespace spirv {
Expand Down Expand Up @@ -268,6 +271,14 @@ class InstrBuilder {
*/
class IRBuilder {
public:
/*!
* \brief Initialize the codegen based on a specific feature set.
*
* \param support The features in SPIRV that are supported by the
* target device.
*/
explicit IRBuilder(const SPIRVSupport& support);

/*! \brief Initialize header */
void InitHeader();
/*! \brief Initialize the predefined contents */
Expand All @@ -278,29 +289,21 @@ class IRBuilder {
* \return The finalized binary instruction.
*/
Value ExtInstImport(const std::string& name) {
auto it = ext_inst_tbl_.find(name);
if (it != ext_inst_tbl_.end()) {
return it->second;
}
Value val = NewValue(SType(), kExtInst);
ib_.Begin(spv::OpExtInstImport).AddSeq(val, name).Commit(&header_);
ib_.Begin(spv::OpExtInstImport).AddSeq(val, name).Commit(&extended_instruction_section_);
ext_inst_tbl_[name] = val;
return val;
}
/*!
* \brief Get the final binary built from the builder
* \return The finalized binary instruction.
*/
std::vector<uint32_t> Finalize() {
std::vector<uint32_t> data;
// set bound
const int kBoundLoc = 3;
header_[kBoundLoc] = id_counter_;
data.insert(data.end(), header_.begin(), header_.end());
data.insert(data.end(), entry_.begin(), entry_.end());
data.insert(data.end(), exec_mode_.begin(), exec_mode_.end());
data.insert(data.end(), debug_.begin(), debug_.end());
data.insert(data.end(), decorate_.begin(), decorate_.end());
data.insert(data.end(), global_.begin(), global_.end());
data.insert(data.end(), func_header_.begin(), func_header_.end());
data.insert(data.end(), function_.begin(), function_.end());
return data;
}
std::vector<uint32_t> Finalize();

/*!
* \brief Create new label
* \return The created new label
Expand Down Expand Up @@ -599,6 +602,19 @@ class IRBuilder {
Value GetConst_(const SType& dtype, const uint64_t* pvalue);
// declare type
SType DeclareType(const DataType& dtype);

// Declare the appropriate SPIR-V capabilities and extensions to use
// this data type.
void AddCapabilityFor(const DataType& dtype);

/*! \brief SPIRV-related capabilities of the target
*
* This SPIRVSupport object is owned by the same CodeGenSPIRV
* object that owns the IRBuilder. Therefore, safe to use a
* reference as the CodeGenSPIRV will live longer.
*/
const SPIRVSupport& spirv_support_;

/*! \brief internal instruction builder */
InstrBuilder ib_;
/*! \brief Current label */
Expand All @@ -623,9 +639,22 @@ class IRBuilder {
std::map<std::pair<uint32_t, spv::StorageClass>, SType> pointer_type_tbl_;
/*! \brief map from constant int to its value */
std::map<std::pair<uint32_t, uint64_t>, Value> const_tbl_;
/*! \brief Header segment, include import */
/*! \brief map from name of a ExtInstImport to its value */
std::map<std::string, Value> ext_inst_tbl_;

/*! \brief Header segment
*
* 5 words long, described in "First Words of Physical Layout"
* section of SPIR-V documentation.
*/
std::vector<uint32_t> header_;
/*! \brief engtry point segment */
/*! \brief SPIR-V capabilities used by this module. */
std::set<spv::Capability> capabilities_used_;
/*! \brief SPIR-V extensions used by this module. */
std::set<std::string> extensions_used_;
/*! \brief entry point segment */
std::vector<uint32_t> extended_instruction_section_;
/*! \brief entry point segment */
std::vector<uint32_t> entry_;
/*! \brief Header segment */
std::vector<uint32_t> exec_mode_;
Expand Down
Loading

0 comments on commit 43db690

Please sign in to comment.