Skip to content

Commit

Permalink
Use worklist and add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
cassiebeckley committed Oct 1, 2024
1 parent b5fbf00 commit 72f4038
Show file tree
Hide file tree
Showing 3 changed files with 185 additions and 54 deletions.
103 changes: 53 additions & 50 deletions source/opt/copy_prop_arrays.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,37 +75,40 @@ Pass::Status CopyPropagateArrays::Process() {

BasicBlock* entry_bb = &*function.begin();

bool load_updated;
do {
load_updated = false;
for (auto var_inst = entry_bb->begin();
var_inst->opcode() == spv::Op::OpVariable; ++var_inst) {
// Find the only store to the entire memory location, if it exists.
Instruction* store_inst = FindStoreInstruction(&*var_inst);

if (!store_inst) {
continue;
}
for (auto var_inst = entry_bb->begin();
var_inst->opcode() == spv::Op::OpVariable; ++var_inst) {
worklist_.push(&*var_inst);
}
}

std::unique_ptr<MemoryObject> source_object =
FindSourceObjectIfPossible(&*var_inst, store_inst);
while (!worklist_.empty()) {
Instruction* var_inst = worklist_.front();
worklist_.pop();

if (source_object != nullptr) {
if (!IsPointerToArrayType(var_inst->type_id()) &&
source_object->GetStorageClass() != spv::StorageClass::Input) {
continue;
}
// Find the only store to the entire memory location, if it exists.
Instruction* store_inst = FindStoreInstruction(&*var_inst);

if (CanUpdateUses(&*var_inst,
source_object->GetPointerTypeId(this))) {
modified = true;
load_updated |=
PropagateObject(&*var_inst, source_object.get(), store_inst);
}
}
if (!store_inst) {
continue;
}

std::unique_ptr<MemoryObject> source_object =
FindSourceObjectIfPossible(&*var_inst, store_inst);

if (source_object != nullptr) {
if (!IsPointerToArrayType(var_inst->type_id()) &&
source_object->GetStorageClass() != spv::StorageClass::Input) {
continue;
}

if (CanUpdateUses(&*var_inst, source_object->GetPointerTypeId(this))) {
modified = true;

PropagateObject(&*var_inst, source_object.get(), store_inst);
}
} while (load_updated);
}
}

return (modified ? Status::SuccessWithChange : Status::SuccessWithoutChange);
}

Expand Down Expand Up @@ -165,15 +168,15 @@ Instruction* CopyPropagateArrays::FindStoreInstruction(
return store_inst;
}

bool CopyPropagateArrays::PropagateObject(Instruction* var_inst,
void CopyPropagateArrays::PropagateObject(Instruction* var_inst,
MemoryObject* source,
Instruction* insertion_point) {
assert(var_inst->opcode() == spv::Op::OpVariable &&
"This function propagates variables.");

Instruction* new_access_chain = BuildNewAccessChain(insertion_point, source);
context()->KillNamesAndDecorates(var_inst);
return UpdateUses(var_inst, new_access_chain);
UpdateUses(var_inst, new_access_chain);
}

Instruction* CopyPropagateArrays::BuildNewAccessChain(
Expand Down Expand Up @@ -633,7 +636,7 @@ bool CopyPropagateArrays::CanUpdateUses(Instruction* original_ptr_inst,
});
}

bool CopyPropagateArrays::UpdateUses(Instruction* original_ptr_inst,
void CopyPropagateArrays::UpdateUses(Instruction* original_ptr_inst,
Instruction* new_ptr_inst) {
analysis::TypeManager* type_mgr = context()->get_type_mgr();
analysis::ConstantManager* const_mgr = context()->get_constant_mgr();
Expand All @@ -645,8 +648,6 @@ bool CopyPropagateArrays::UpdateUses(Instruction* original_ptr_inst,
uses.push_back({use, index});
});

bool updated_load = false;

for (auto pair : uses) {
Instruction* use = pair.first;
uint32_t index = pair.second;
Expand Down Expand Up @@ -714,25 +715,16 @@ bool CopyPropagateArrays::UpdateUses(Instruction* original_ptr_inst,
context()->AnalyzeUses(use);
}

updated_load = true;
AddUsesToWorklist(use);
} break;
case spv::Op::OpExtInst: {
if (use->GetSingleWordInOperand(kExtInstSetInIdx) ==
context()->get_feature_mgr()->GetExtInstImportId_GLSLstd450()) {
uint32_t ext_inst = use->GetSingleWordInOperand(kExtInstOpInIdx);
switch (ext_inst) {
case GLSLstd450InterpolateAtCentroid:
case GLSLstd450InterpolateAtOffset:
case GLSLstd450InterpolateAtSample:
// Replace the actual use.
context()->ForgetUses(use);
use->SetOperand(index, {new_ptr_inst->result_id()});
context()->AnalyzeUses(use);
break;
default:
assert(false && "Don't know how to rewrite instruction");
break;
}
if (IsInterpolationInstruction(use)) {
// Replace the actual use.
context()->ForgetUses(use);
use->SetOperand(index, {new_ptr_inst->result_id()});
context()->AnalyzeUses(use);
} else {
assert(false && "Don't know how to rewrite instruction");
}
} break;
case spv::Op::OpAccessChain: {
Expand Down Expand Up @@ -838,8 +830,6 @@ bool CopyPropagateArrays::UpdateUses(Instruction* original_ptr_inst,
break;
}
}

return updated_load;
}

uint32_t CopyPropagateArrays::GetMemberTypeId(
Expand All @@ -865,6 +855,19 @@ uint32_t CopyPropagateArrays::GetMemberTypeId(
return id;
}

void CopyPropagateArrays::AddUsesToWorklist(Instruction* inst) {
analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr();

def_use_mgr->ForEachUse(
inst, [this, def_use_mgr](Instruction* use, uint32_t) {
if (use->opcode() == spv::Op::OpStore) {
Instruction* target_pointer = def_use_mgr->GetDef(
use->GetSingleWordInOperand(kStorePointerInOperand));
worklist_.push(target_pointer);
}
});
}

void CopyPropagateArrays::MemoryObject::PushIndirection(
const std::vector<AccessChainEntry>& access_chain) {
access_chain_.insert(access_chain_.end(), access_chain.begin(),
Expand Down
15 changes: 11 additions & 4 deletions source/opt/copy_prop_arrays.h
Original file line number Diff line number Diff line change
Expand Up @@ -175,8 +175,7 @@ class CopyPropagateArrays : public MemPass {
// Replaces all loads of |var_inst| with a load from |source| instead.
// |insertion_pos| is a position where it is possible to construct the
// address of |source| and also dominates all of the loads of |var_inst|.
// Returns true if a load has been updated and needs to be propagated.
bool PropagateObject(Instruction* var_inst, MemoryObject* source,
void PropagateObject(Instruction* var_inst, MemoryObject* source,
Instruction* insertion_pos);

// Returns true if all of the references to |ptr_inst| can be rewritten and
Expand Down Expand Up @@ -241,8 +240,8 @@ class CopyPropagateArrays : public MemPass {
// Rewrites all uses of |original_ptr| to use |new_pointer_inst| updating
// types of other instructions as needed. This function should not be called
// if |CanUpdateUses(original_ptr_inst, new_pointer_inst->type_id())| returns
// false. Returns true if a load has been updated and needs to be propagated.
bool UpdateUses(Instruction* original_ptr_inst,
// false.
void UpdateUses(Instruction* original_ptr_inst,
Instruction* new_pointer_inst);

// Return true if |UpdateUses| is able to change all of the uses of
Expand All @@ -259,6 +258,14 @@ class CopyPropagateArrays : public MemPass {
// same way the indexes are used in an |OpCompositeExtract| instruction.
uint32_t GetMemberTypeId(uint32_t id,
const std::vector<uint32_t>& access_chain) const;

// If the result of inst is stored to a variable, add that variable to the
// worklist.
void AddUsesToWorklist(Instruction* inst);

// OpVariable worklist. An instruction is added to this list if we would like
// to run copy propagation on it.
std::queue<Instruction*> worklist_;
};

} // namespace opt
Expand Down
121 changes: 121 additions & 0 deletions test/opt/copy_prop_array_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1977,6 +1977,127 @@ OpFunctionEnd

SinglePassRunAndCheck<CopyPropagateArrays>(text, text, false);
}

TEST_F(CopyPropArrayPassTest, InterpolateFunctions) {
const std::string before = R"(OpCapability InterpolationFunction
OpCapability Shader
%1 = OpExtInstImport "GLSL.std.450"
OpMemoryModel Logical GLSL450
OpEntryPoint Fragment %main "main" %in_var_COLOR
OpExecutionMode %main OriginUpperLeft
OpSource HLSL 680
OpName %in_var_COLOR "in.var.COLOR"
OpName %main "main"
OpName %offset "offset"
OpDecorate %in_var_COLOR Location 0
%int = OpTypeInt 32 1
%int_0 = OpConstant %int 0
%float = OpTypeFloat 32
%float_0 = OpConstant %float 0
%v2float = OpTypeVector %float 2
%v4float = OpTypeVector %float 4
%_ptr_Input_v4float = OpTypePointer Input %v4float
%void = OpTypeVoid
%19 = OpTypeFunction %void
%_ptr_Function_v4float = OpTypePointer Function %v4float
%in_var_COLOR = OpVariable %_ptr_Input_v4float Input
%main = OpFunction %void None %19
%20 = OpLabel
%45 = OpVariable %_ptr_Function_v4float Function
%25 = OpLoad %v4float %in_var_COLOR
OpStore %45 %25
; CHECK: OpExtInst %v4float %1 InterpolateAtCentroid %in_var_COLOR
%52 = OpExtInst %v4float %1 InterpolateAtCentroid %45
; CHECK: OpExtInst %v4float %1 InterpolateAtSample %in_var_COLOR %int_0
%54 = OpExtInst %v4float %1 InterpolateAtSample %45 %int_0
%offset = OpCompositeConstruct %v2float %float_0 %float_0
; CHECK: OpExtInst %v4float %1 InterpolateAtOffset %in_var_COLOR %offset
%56 = OpExtInst %v4float %1 InterpolateAtOffset %45 %offset
OpReturn
OpFunctionEnd
)";

SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER |
SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES);
SinglePassRunAndMatch<CopyPropagateArrays>(before, false);
}

TEST_F(CopyPropArrayPassTest, InterpolateMultiPropagation) {
const std::string before = R"(OpCapability InterpolationFunction
OpCapability Shader
%1 = OpExtInstImport "GLSL.std.450"
OpMemoryModel Logical GLSL450
OpEntryPoint Fragment %main "main" %in_var_COLOR
OpExecutionMode %main OriginUpperLeft
OpSource HLSL 680
OpName %in_var_COLOR "in.var.COLOR"
OpName %main "main"
OpName %param_var_color "param.var.color"
OpDecorate %in_var_COLOR Location 0
%float = OpTypeFloat 32
%v4float = OpTypeVector %float 4
%_ptr_Input_v4float = OpTypePointer Input %v4float
%void = OpTypeVoid
%19 = OpTypeFunction %void
%_ptr_Function_v4float = OpTypePointer Function %v4float
%in_var_COLOR = OpVariable %_ptr_Input_v4float Input
%main = OpFunction %void None %19
%20 = OpLabel
%45 = OpVariable %_ptr_Function_v4float Function
%param_var_color = OpVariable %_ptr_Function_v4float Function
%25 = OpLoad %v4float %in_var_COLOR
OpStore %param_var_color %25
; CHECK: OpExtInst %v4float %1 InterpolateAtCentroid %in_var_COLOR
%52 = OpExtInst %v4float %1 InterpolateAtCentroid %param_var_color
%49 = OpLoad %v4float %param_var_color
OpStore %45 %49
; CHECK: OpExtInst %v4float %1 InterpolateAtCentroid %in_var_COLOR
%54 = OpExtInst %v4float %1 InterpolateAtCentroid %45
OpReturn
OpFunctionEnd
)";

SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER |
SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES);
SinglePassRunAndMatch<CopyPropagateArrays>(before, false);
}

TEST_F(CopyPropArrayPassTest, PropagateScalar) {
const std::string before = R"(OpCapability InterpolationFunction
OpCapability Shader
%1 = OpExtInstImport "GLSL.std.450"
OpMemoryModel Logical GLSL450
OpEntryPoint Fragment %main "main" %in_var_SV_InstanceID
OpExecutionMode %main OriginUpperLeft
OpSource HLSL 680
OpName %in_var_SV_InstanceID "in.var.SV_InstanceID"
OpName %main "main"
OpDecorate %in_var_SV_InstanceID Location 0
%float = OpTypeFloat 32
%v4float = OpTypeVector %float 4
%_ptr_Input_float = OpTypePointer Input %float
%void = OpTypeVoid
%19 = OpTypeFunction %void
%_ptr_Function_float = OpTypePointer Function %float
%in_var_SV_InstanceID = OpVariable %_ptr_Input_float Input
%main = OpFunction %void None %19
%20 = OpLabel
%45 = OpVariable %_ptr_Function_float Function
%25 = OpLoad %v4float %in_var_SV_InstanceID
OpStore %45 %25
; CHECK: OpExtInst %v4float %1 InterpolateAtCentroid %in_var_SV_InstanceID
%52 = OpExtInst %v4float %1 InterpolateAtCentroid %45
OpReturn
OpFunctionEnd
)";

SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER |
SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES);
SinglePassRunAndMatch<CopyPropagateArrays>(before, false);
}
} // namespace
} // namespace opt
} // namespace spvtools

0 comments on commit 72f4038

Please sign in to comment.