From 5571047ded965efddbd242b1f0c5ac357868d31a Mon Sep 17 00:00:00 2001 From: Marcel Koester Date: Thu, 7 Sep 2023 16:53:18 +0200 Subject: [PATCH] Fixed IfConversion. (#1081) * Extended PhiBindings analysis to check for individual branch targets. * Fixed IfConversion transformation phasing possibilities of endless loop branches. * Fixed PTXCodeGenerator to emit individual target bindings. * Added new test case for improved IfConversion transformation. --- Src/ILGPU.Tests/BasicJumps.cs | 57 +++++++-- .../PTX/PTXCodeGenerator.Terminators.cs | 89 +++++++++++++- Src/ILGPU/Backends/PTX/PTXCodeGenerator.cs | 115 ++++++++++-------- Src/ILGPU/Backends/PhiBindings.cs | 18 +++ Src/ILGPU/IR/Transformations/IfConversion.cs | 50 +++++--- 5 files changed, 250 insertions(+), 79 deletions(-) diff --git a/Src/ILGPU.Tests/BasicJumps.cs b/Src/ILGPU.Tests/BasicJumps.cs index bbbd6da86b..1facdf0ba9 100644 --- a/Src/ILGPU.Tests/BasicJumps.cs +++ b/Src/ILGPU.Tests/BasicJumps.cs @@ -1,6 +1,6 @@ // --------------------------------------------------------------------------------------- // ILGPU -// Copyright (c) 2021 ILGPU Project +// Copyright (c) 2021-2023 ILGPU Project // www.ilgpu.net // // File: BasicJumps.cs @@ -16,6 +16,7 @@ using Xunit.Abstractions; #pragma warning disable CS0162 +#pragma warning disable CS0164 namespace ILGPU.Tests { @@ -164,25 +165,63 @@ internal static void BasicNestedLoopJumpKernel( } [Theory] - [InlineData(32, 0, 67)] - [InlineData(32, 2, 25)] - [InlineData(1024, 0, 67)] - [InlineData(1024, 2, 25)] + [InlineData(0, 67)] + [InlineData(2, 25)] [KernelMethod(nameof(BasicNestedLoopJumpKernel))] - public void BasicNestedLoopJump(int length, int c, int res) + public void BasicNestedLoopJump(int c, int res) { - using var buffer = Accelerator.Allocate1D(length); - using var source = Accelerator.Allocate1D(64); + const int Length = 64; + using var buffer = Accelerator.Allocate1D(Length); + using var source = Accelerator.Allocate1D(Length); var sourceData = Enumerable.Range(0, (int)source.Length).ToArray(); sourceData[57] = 23; source.CopyFromCPU(Accelerator.DefaultStream, sourceData); Execute(buffer.Length, buffer.View, source.View, c); - var expected = Enumerable.Repeat(res, length).ToArray(); + var expected = Enumerable.Repeat(res, Length).ToArray(); + Verify(buffer.View, expected); + } + + private static void BasicNestedLoopJumpKernel2( + Index1D index, + ArrayView1D target, + ArrayView1D source) + { + int k = 0; + entry: + for (int i = 0; i < source.Length; ++i) + { + goto exit; + } + + target[index] = 42; + return; + + nested: + k = 43; + + exit: + if (k++ < 1) + goto entry; + target[index] = 23 + k; + } + + [Fact] + [KernelMethod(nameof(BasicNestedLoopJumpKernel2))] + public void BasicNestedLoopJump2() + { + const int Length = 32; + using var buffer = Accelerator.Allocate1D(Length); + using var source = Accelerator.Allocate1D(Length); + + Execute(buffer.Length, buffer.View, source.View); + + var expected = Enumerable.Repeat(25, Length).ToArray(); Verify(buffer.View, expected); } } } +#pragma warning restore CS0164 #pragma warning restore CS0162 diff --git a/Src/ILGPU/Backends/PTX/PTXCodeGenerator.Terminators.cs b/Src/ILGPU/Backends/PTX/PTXCodeGenerator.Terminators.cs index 090a94b57a..91a520210b 100644 --- a/Src/ILGPU/Backends/PTX/PTXCodeGenerator.Terminators.cs +++ b/Src/ILGPU/Backends/PTX/PTXCodeGenerator.Terminators.cs @@ -1,6 +1,6 @@ // --------------------------------------------------------------------------------------- // ILGPU -// Copyright (c) 2018-2022 ILGPU Project +// Copyright (c) 2018-2023 ILGPU Project // www.ilgpu.net // // File: PTXCodeGenerator.Terminators.cs @@ -9,8 +9,10 @@ // Source License. See LICENSE.txt for details. // --------------------------------------------------------------------------------------- +using ILGPU.IR; using ILGPU.IR.Values; using ILGPU.Util; +using System.Diagnostics.CodeAnalysis; namespace ILGPU.Backends.PTX { @@ -24,15 +26,50 @@ public void GenerateCode(ReturnTerminator returnTerminator) var resultRegister = Load(returnTerminator.ReturnValue); EmitStoreParam(ReturnParamName, resultRegister); } + Command( Uniforms.IsUniform(returnTerminator) ? PTXInstructions.UniformReturnOperation : PTXInstructions.ReturnOperation); } + private bool NeedSeparatePhiBindings( + BasicBlock basicBlock, + BasicBlock target, + [NotNullWhen(true)] + out PhiBindings.PhiBindingCollection bindings) + { + if (!phiBindings.TryGetBindings(target, out bindings)) + return false; + + // Check whether there are bindings pointing to different blocks + foreach (var (phiValue, _) in bindings) + { + if (phiValue.BasicBlock != target) + return true; + } + + // We were not able to find misleading data + return false; + } + + /// + /// Generates phi bindings for jumping to a specific target block. + /// + /// The current block. + private void GeneratePhiBindings(BasicBlock current) + { + if (!phiBindings.TryGetBindings(current, out var bindings)) + return; + BindPhis(bindings, target: null); + } + /// public void GenerateCode(UnconditionalBranch branch) { + // Bind phis + GeneratePhiBindings(branch.BasicBlock); + if (Schedule.IsImplicitSuccessor(branch.BasicBlock, branch.Target)) return; @@ -60,6 +97,53 @@ public void GenerateCode(IfBranch branch) ? PTXInstructions.UniformBranchOperation : PTXInstructions.BranchOperation; + // Gather phi bindings and test both, true and false targets + if (phiBindings.TryGetBindings(branch.BasicBlock, out var bindings) && + (bindings.NeedSeparateBindingsFor(trueTarget) || + bindings.NeedSeparateBindingsFor(falseTarget))) + { + // We need to emit different bindings in each branch + if (branch.IsInverted) + Utilities.Swap(ref trueTarget, ref falseTarget); + + // Declare a temporary jump target to skip true branches + var tempLabel = DeclareLabel(); + using (var command = BeginCommand( + branchOperation, + new PredicateConfiguration(condition, isTrue: false))) + { + command.AppendLabel(tempLabel); + } + + // Bind all true phis + BindPhis(bindings, trueTarget); + + // Jump to true target in the current case + using (var command = BeginCommand(branchOperation)) + { + var targetLabel = blockLookup[trueTarget]; + command.AppendLabel(targetLabel); + } + + // Mark the false case label and bind all values + MarkLabel(tempLabel); + BindPhis(bindings, falseTarget); + + if (!Schedule.IsImplicitSuccessor(branch.BasicBlock, falseTarget)) + { + // Jump to false target in the else case + using var command = BeginCommand(branchOperation); + var targetLabel = blockLookup[falseTarget]; + command.AppendLabel(targetLabel); + } + + // Skip further bindings an branches + return; + } + + // Generate phi bindings for all blocks + BindPhis(bindings, target: null); + // The current schedule has inverted all if conditions with implicit branch // targets to simplify the work of the PTX assembler if (Schedule.IsImplicitSuccessor(branch.BasicBlock, trueTarget)) @@ -151,6 +235,9 @@ public void GenerateCode(SwitchBranch branch) } Builder.AppendLine(";"); + // Generate all phi bindings for all cases + GeneratePhiBindings(branch.BasicBlock); + using (var command = BeginCommand( isUniform ? PTXInstructions.UniformBranchIndexOperation diff --git a/Src/ILGPU/Backends/PTX/PTXCodeGenerator.cs b/Src/ILGPU/Backends/PTX/PTXCodeGenerator.cs index 618a76a096..1c304b348c 100644 --- a/Src/ILGPU/Backends/PTX/PTXCodeGenerator.cs +++ b/Src/ILGPU/Backends/PTX/PTXCodeGenerator.cs @@ -286,10 +286,11 @@ protected static string GetParameterName(Parameter parameter) => #region Instance private int labelCounter; - private readonly Dictionary blockLookup = - new Dictionary(); - private readonly Dictionary<(Encoding, string), string> stringConstants = - new Dictionary<(Encoding, string), string>(); + private readonly Dictionary blockLookup = new(); + + private readonly Dictionary<(Encoding, string), string> stringConstants = new(); + private readonly PhiBindings phiBindings; + private readonly Dictionary intermediatePhiRegisters; private readonly string labelPrefix; /// @@ -321,6 +322,12 @@ internal PTXCodeGenerator(in GeneratorArgs args, Method method, Allocas allocas) args.Properties.GetPTXBackendMode() == PTXBackendMode.Enhanced ? Method.Blocks.CreateOptimizedPTXSchedule() : Method.Blocks.CreateDefaultPTXSchedule(); + + // Create phi bindings and initialize temporary phi registers + phiBindings = Schedule.ComputePhiBindings( + (_, phiValue) => Allocate(phiValue)); + intermediatePhiRegisters = new Dictionary( + phiBindings.MaxNumIntermediatePhis); } #endregion @@ -485,11 +492,6 @@ protected void GenerateCodeInternal(int registerOffset) blockLookup.Add(block, DeclareLabel()); } - // Find all phi nodes, allocate target registers and setup internal mapping - var phiBindings = Schedule.ComputePhiBindings( - (_, phiValue) => Allocate(phiValue)); - var intermediatePhiRegisters = new Dictionary( - phiBindings.MaxNumIntermediatePhis); Builder.AppendLine(); // Generate code @@ -524,52 +526,14 @@ protected void GenerateCodeInternal(int registerOffset) DebugInfoGenerator.ResetLocation(); - // Wire phi nodes - if (phiBindings.TryGetBindings(block, out var bindings)) - { - // Assign all phi values - foreach (var (phiValue, value) in bindings) - { - // Load the current phi target register - var phiTargetRegister = Load(phiValue); - - // Check for an intermediate phi value - if (bindings.IsIntermediate(phiValue)) - { - var intermediateRegister = AllocateType(phiValue.Type); - intermediatePhiRegisters.Add(phiValue, intermediateRegister); - - // Move this phi value into a temporary register for reuse - EmitComplexCommand( - PTXInstructions.MoveOperation, - new PhiMoveEmitter(), - intermediateRegister, - phiTargetRegister); - } - - // Determine the source value from which we need to copy from - var sourceRegister = intermediatePhiRegisters - .TryGetValue(value, out var tempRegister) - ? tempRegister - : Load(value); - - // Move contents - EmitComplexCommand( - PTXInstructions.MoveOperation, - new PhiMoveEmitter(), - phiTargetRegister, - sourceRegister); - } - - // Free temporary registers - foreach (var register in intermediatePhiRegisters.Values) - Free(register); - intermediatePhiRegisters.Clear(); - } - // Build terminator this.GenerateCodeFor(block.Terminator.AsNotNull()); Builder.AppendLine(); + + // Free temporary registers + foreach (var register in intermediatePhiRegisters.Values) + Free(register); + intermediatePhiRegisters.Clear(); } // Finish function and append register information @@ -577,6 +541,53 @@ protected void GenerateCodeInternal(int registerOffset) Builder.Insert(registerOffset, GenerateRegisterInformation("\t")); } + /// + /// Binds all phi values of the current block flowing through an edge to the + /// target block. + /// + private void BindPhis( + PhiBindings.PhiBindingCollection bindings, + BasicBlock? target) + { + // Assign all phi values + foreach (var (phiValue, value) in bindings) + { + // Reject phis not flowing to the target edge + if (target is not null && phiValue.BasicBlock != target) + continue; + + // Load the current phi target register + var phiTargetRegister = Load(phiValue); + + // Check for an intermediate phi value + if (bindings.IsIntermediate(phiValue)) + { + var intermediateRegister = AllocateType(phiValue.Type); + intermediatePhiRegisters.Add(phiValue, intermediateRegister); + + // Move this phi value into a temporary register for reuse + EmitComplexCommand( + PTXInstructions.MoveOperation, + new PhiMoveEmitter(), + intermediateRegister, + phiTargetRegister); + } + + // Determine the source value from which we need to copy from + var sourceRegister = intermediatePhiRegisters + .TryGetValue(value, out var tempRegister) + ? tempRegister + : Load(value); + + // Move contents + EmitComplexCommand( + PTXInstructions.MoveOperation, + new PhiMoveEmitter(), + phiTargetRegister, + sourceRegister); + } + } + /// /// Setups local or shared allocations. /// diff --git a/Src/ILGPU/Backends/PhiBindings.cs b/Src/ILGPU/Backends/PhiBindings.cs index 1fb764d677..3d04ff5da3 100644 --- a/Src/ILGPU/Backends/PhiBindings.cs +++ b/Src/ILGPU/Backends/PhiBindings.cs @@ -167,6 +167,24 @@ internal PhiBindingCollection(in BlockInfo info) #region Methods + /// + /// Returns true if the current binding configuration needs separate bindings + /// for individual target blocks. + /// + public bool NeedSeparateBindingsFor(BasicBlock target) + { + foreach (var (phiValue, _) in this) + { + if (phiValue.BasicBlock != target && + !phiValue.Sources.Contains(target, new BasicBlock.Comparer())) + { + return true; + } + } + + return false; + } + /// /// Returns true if the given phi is an intermediate phi value that requires /// a temporary intermediate variable to be assigned to. diff --git a/Src/ILGPU/IR/Transformations/IfConversion.cs b/Src/ILGPU/IR/Transformations/IfConversion.cs index 3515eff6af..5f6bb5e2e4 100644 --- a/Src/ILGPU/IR/Transformations/IfConversion.cs +++ b/Src/ILGPU/IR/Transformations/IfConversion.cs @@ -729,12 +729,20 @@ private readonly struct CaseBlocks /// The set of all block kinds. /// The current block. /// The determined true block. - private static BasicBlock GetTrueExit( + private static BasicBlock? TryGetTrueExit( in BasicBlockMap kinds, - BasicBlock current) => - kinds[current] == BlockKind.Exit - ? current - : GetTrueExit(kinds, GetIfBranch(current).TrueTarget); + BasicBlock current) + { + var next = current; + do + { + if (kinds[next] == BlockKind.Exit) + return next; + next = GetIfBranch(next).TrueTarget; + } + while (next != current); + return null; + } /// /// Gets the primary false leaf that is used to created the merged branch. @@ -767,8 +775,9 @@ private static BasicBlock GetFalseExit( /// The current root block to start the search. public CaseBlocks(in BasicBlockMap kinds, BasicBlock current) { - TrueBlock = GetTrueExit(kinds, current); - FalseBlock = GetFalseExit(kinds, TrueBlock); + TrueBlock = TryGetTrueExit(kinds, current); + if (TrueBlock is not null) + FalseBlock = GetFalseExit(kinds, TrueBlock); } #endregion @@ -778,17 +787,22 @@ public CaseBlocks(in BasicBlockMap kinds, BasicBlock current) /// /// Returns the true block. /// - public BasicBlock TrueBlock { get; } + public BasicBlock? TrueBlock { get; } /// /// Returns the false block. /// - public BasicBlock FalseBlock { get; } + public BasicBlock? FalseBlock { get; } #endregion #region Methods + /// + /// Returns true if this conversion phase is able to convert the pair. + /// + public bool IsValid => TrueBlock is not null && FalseBlock is not null; + /// /// Returns true if the given block is the . /// @@ -796,7 +810,7 @@ public CaseBlocks(in BasicBlockMap kinds, BasicBlock current) /// /// True, if the given block is the . /// - public readonly bool IsTrueBlock(BasicBlock block) + public bool IsTrueBlock(BasicBlock block) { bool result = block == TrueBlock; block.Assert(result || block == FalseBlock); @@ -812,7 +826,7 @@ public readonly bool IsTrueBlock(BasicBlock block) /// True, if the given block is either the or the /// . /// - public readonly bool Contains(BasicBlock block) => + public bool Contains(BasicBlock block) => block == TrueBlock || block == FalseBlock; /// @@ -820,7 +834,7 @@ public readonly bool Contains(BasicBlock block) => /// or the . /// /// The value to test. - public readonly void AssertInBlocks(Value value) => + public void AssertInBlocks(Value value) => value.Assert(Contains(value.BasicBlock)); #endregion @@ -1098,6 +1112,8 @@ public bool CanConvert( // Get the true and false-branch leaf nodes that are used to build the // conditional branch in the end var caseBlocks = new CaseBlocks(kinds, current); + if (!caseBlocks.IsValid) + return false; // Check all phi-value references and determine all phis that need // to be adjusted after folding all blocks @@ -1244,7 +1260,7 @@ private readonly bool GatherPhiValues( /// A conditional converter to perform the actual if/switch conversion into /// conditional value predicates. /// - private ref struct ConditionalConverter + private readonly ref struct ConditionalConverter { #region Instance @@ -1320,13 +1336,13 @@ internal ConditionalConverter( /// /// Returns true if the given block should be maintained. /// - private readonly bool IsBlockToKeep(BasicBlock block) => + private bool IsBlockToKeep(BasicBlock block) => Bitwise.Or(block == EntryBlock, CaseBlocks.Contains(block)); /// /// Returns true if the given block is an exit block. /// - private readonly bool IsExit(BasicBlock block) => + private bool IsExit(BasicBlock block) => Kinds[block] == BlockKind.Exit; /// @@ -1350,8 +1366,8 @@ public void Convert() BlockBuilder.CreateIfBranch( terminator.Location, condition.AsNotNull(), - CaseBlocks.TrueBlock, - CaseBlocks.FalseBlock); + CaseBlocks.TrueBlock.AsNotNull(), + CaseBlocks.FalseBlock.AsNotNull()); // Adapt all phis AdaptPhis();