From 80726955389a2bd50fae9e070ad6686a89784a32 Mon Sep 17 00:00:00 2001 From: Alan Baker Date: Mon, 11 Sep 2023 15:36:29 -0400 Subject: [PATCH] Add a new test suite * Added a new test suite 'uniform_maximal' that tests that ballots all work as expected when no divergent branches exist in the code * The generator has a mode to only generate uniform conditions * removes several if, loop, and switch styles * restricts types of breaks and continues that generated * removes the generation of the election based noise operation * Adds a predefined test to cover some operations --- .../reconvergence/reconvergence.spec.ts | 28 ++- .../shader/execution/reconvergence/util.ts | 196 +++++++++++++++++- 2 files changed, 212 insertions(+), 12 deletions(-) diff --git a/src/webgpu/shader/execution/reconvergence/reconvergence.spec.ts b/src/webgpu/shader/execution/reconvergence/reconvergence.spec.ts index 1b91fac9d00b..65dc6b52ea12 100644 --- a/src/webgpu/shader/execution/reconvergence/reconvergence.spec.ts +++ b/src/webgpu/shader/execution/reconvergence/reconvergence.spec.ts @@ -113,7 +113,7 @@ function checkIds(data: Uint32Array, subgroupSize: number): Error | undefined { * * So setting kDebugLevel to 0x5 would dump WGSL and end the test. */ -const kDebugLevel = 0x0; +const kDebugLevel = 0x00; async function testProgram(t: GPUTest, program: Program) { const wgsl = program.genCode(); @@ -159,7 +159,7 @@ async function testProgram(t: GPUTest, program: Program) { // Inputs have a value equal to their index. const inputBuffer = t.makeBufferWithContents( - new Uint32Array([...iterRange(128, x => x)]), + new Uint32Array([...iterRange(129, x => x)]), GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_DST ); t.trackForCleanup(inputBuffer); @@ -368,6 +368,10 @@ async function predefinedTest(t: GPUTest, style: Style, test: number) { program.predefinedProgramWGSLv1(); break; } + case 15: { + program.predefinedProgramAllUniform(); + break; + } default: { unreachable('Unhandled testcase'); } @@ -376,7 +380,7 @@ async function predefinedTest(t: GPUTest, style: Style, test: number) { await testProgram(t, program); } -const kPredefinedTestCases = [...iterRange(15, x => x)]; +const kPredefinedTestCases = [...iterRange(16, x => x)]; g.test('predefined_workgroup') .desc(`Test workgroup reconvergence using some predefined programs`) @@ -495,3 +499,21 @@ g.test('random_wgslv1') await testProgram(t, program); }); + +g.test('uniform_maximal') + .desc(`Test workgroup reconvergence with only uniform branches`) + .params(u => u.combine('seed', generateSeeds(500)).beginSubcases()) + .beforeAllSubcases(t => { + t.selectDeviceOrSkipTestCase({ + requiredFeatures: ['chromium-experimental-subgroups' as GPUFeatureName], + }); + }) + .fn(async t => { + const invocations = kNumInvocations; // t.device.limits.maxSubgroupSize; + + const onlyUniform: boolean = true; + const program: Program = new Program(Style.Maximal, t.params.seed, invocations, onlyUniform); + program.generate(); + + await testProgram(t, program); + }); diff --git a/src/webgpu/shader/execution/reconvergence/util.ts b/src/webgpu/shader/execution/reconvergence/util.ts index 69a4392a5721..0f4ff8a221a5 100644 --- a/src/webgpu/shader/execution/reconvergence/util.ts +++ b/src/webgpu/shader/execution/reconvergence/util.ts @@ -330,6 +330,8 @@ export class Program { // Indicates if the program satisfies uniform control flow for |style| // This depends on simulating a particular subgroup size public ucf: boolean; + // Indicates that only uniform branches should be generated. + private onlyUniform: boolean; /** * constructor @@ -337,7 +339,12 @@ export class Program { * @param style Enum indicating the type of reconvergence being tested * @param seed Value used to seed the PRNG */ - constructor(style: Style = Style.Workgroup, seed: number = 1, invocations: number) { + constructor( + style: Style = Style.Workgroup, + seed: number = 1, + invocations: number, + onlyUniform: boolean = false + ) { this.invocations = invocations; assert(invocations <= 128); this.prng = new PRNG(seed); @@ -378,6 +385,7 @@ export class Program { this.maxProgramNesting = 10; // default stack allocation this.maxLocations = 130000; // keep the buffer under 256MiB this.ucf = false; + this.onlyUniform = onlyUniform; } /** @returns A random float between 0 and 1 */ @@ -390,13 +398,31 @@ export class Program { return this.prng.randomU32() % max; } + /** + * Pick |count| random instructions + * + * @param count The number of instructions + * + * If |this.onlyUniform| is true then only uniform instructions will be + * selected. + * + */ + private pickOp(count: number) { + if (this.onlyUniform) { + this.pickUniformOp(count); + } else { + this.pickAnyOp(count); + } + } + /** * Pick |count| random instructions generators * * @param count the number of instructions * + * These instructions could be uniform or non-uniform. */ - private pickOp(count: number) { + private pickAnyOp(count: number) { for (let i = 0; i < count; i++) { if (this.ops.length >= this.maxCount) { return; @@ -527,6 +553,97 @@ export class Program { } } + /** + * Pick |count| random uniform instructions generators + * + * @param count the number of instructions + * + */ + private pickUniformOp(count: number) { + for (let i = 0; i < count; i++) { + if (this.ops.length >= this.maxCount) { + return; + } + + this.genBallot(); + if (this.nesting < this.maxNesting) { + const r = this.getRandomUint(10); + switch (r) { + case 0: + case 1: { + this.genIf(IfType.Lid); + break; + } + case 2: + case 3: { + this.genIf(IfType.Uniform); + break; + } + case 4: { + // Avoid very deep loop nests to limit memory and runtime. + if (this.loopNesting < this.maxLoopNesting) { + this.genForUniform(); + } + break; + } + case 5: { + this.genBreak(); + break; + } + case 6: { + this.genContinue(); + break; + } + case 7: { + // Calls and returns. + if ( + this.getRandomFloat() < 0.2 && + this.callNesting === 0 && + this.nesting < this.maxNesting - 1 + ) { + this.genCall(); + } else { + this.genReturn(); + } + break; + } + case 8: { + if (this.loopNesting < this.maxLoopNesting) { + this.genLoopUniform(); + } + break; + } + case 9: { + // crbug.com/tint/2039 + // Tint generates invalid code for switch inside loops. + if (this.loopNestingThisFunction > 0) { + break; + } + const r2 = this.getRandomUint(2); + switch (r2) { + case 1: { + if (this.loopNesting > 0) { + this.genSwitchLoopCount(); + break; + } + // fallthrough + } + default: { + this.genSwitchUniform(); + break; + } + } + break; + } + default: { + break; + } + } + } + this.genBallot(); + } + } + /** * Ballot generation * @@ -572,7 +689,7 @@ export class Program { } const r = this.getRandomUint(10000); - if (r < 3) { + if (r < 3 && !this.onlyUniform) { this.ops.push(new Op(OpType.Noise, 0)); } else if (r < 10) { this.ops.push(new Op(OpType.Noise, 1)); @@ -590,7 +707,7 @@ export class Program { let maskIdx = this.getRandomUint(this.numMasks); if (type === IfType.Uniform) maskIdx = 0; - const lid = this.getRandomUint(this.invocations); + const lid = this.onlyUniform ? this.invocations : this.getRandomUint(this.invocations); if (type === IfType.Lid) { this.ops.push(new Op(OpType.IfId, lid)); } else if (type === IfType.LoopCount) { @@ -820,13 +937,14 @@ export class Program { * Generate a break if in a loop. * * Only generates a break within a loop, but may break out of a switch and - * not just a loop. Sometimes the break uses a non-uniform if/else to break. + * not just a loop. Sometimes the break uses a non-uniform if/else to break + * (unless only uniform branches are specified). * */ private genBreak() { if (this.loopNestingThisFunction > 0) { // Sometimes put the break in a divergent if - if (this.getRandomFloat() < 0.1) { + if (this.getRandomFloat() < 0.1 && !this.onlyUniform) { const r = this.getRandomUint(this.numMasks - 1) + 1; this.ops.push(new Op(OpType.IfMask, r)); this.ops.push(new Op(OpType.Break, 0)); @@ -843,12 +961,13 @@ export class Program { /** * Generate a continue if in a loop * - * Sometimes uses a non-uniform if/else to continue. + * Sometimes uses a non-uniform if/else to continue (unless only uniform + * branches are specified). */ private genContinue() { if (this.loopNestingThisFunction > 0 && !this.isLoopInf.get(this.loopNesting)) { // Sometimes put the continue in a divergent if - if (this.getRandomFloat() < 0.1) { + if (this.getRandomFloat() < 0.1 && !this.onlyUniform) { const r = this.getRandomUint(this.numMasks - 1) + 1; this.ops.push(new Op(OpType.IfMask, r)); this.ops.push(new Op(OpType.Continue, 0)); @@ -896,7 +1015,7 @@ export class Program { (this.callNesting > 0 && this.loopNestingThisFunction > 1 && r < 0.5)) ) { this.genBallot(); - if (this.getRandomFloat() < 0.1) { + if (this.getRandomFloat() < 0.1 && !this.onlyUniform) { this.ops.push(new Op(OpType.IfMask, 0)); this.ops.push(new Op(OpType.Return, this.callNesting)); this.ops.push(new Op(OpType.ElseMask, 0)); @@ -2760,6 +2879,65 @@ ${this.functions[i]}`; this.ops.push(new Op(OpType.Store, this.storeBase + this.ops.length)); this.ops.push(new Op(OpType.Ballot, 0)); } + + /** + * Equivalent to: + * + * for (var i0 = 0u; i0 < inputs[3]; i0++) + * ballot(); + * if subgroup_invocation_id < inputs[128] + * ballot(); + * if subgroup_invocation_id < inputs[128] + * ballot(); + * if subgroup_invocation_id < inputs[128] + * for (var i1 = 0u; i1 < inputs[3]; i1++) + * if subgroup_invocation_id < inputs[128] + * ballot(); + * break; + * if inputs[3] == 3 + * ballot(); + * ballot(); + * + */ + public predefinedProgramAllUniform() { + this.ops.push(new Op(OpType.ForUniform, 3)); // for 0 + this.ops.push(new Op(OpType.Store, this.storeBase + this.ops.length)); + this.ops.push(new Op(OpType.Ballot, 0)); + + this.ops.push(new Op(OpType.IfId, 128)); // if 0 + this.ops.push(new Op(OpType.Store, this.storeBase + this.ops.length)); + this.ops.push(new Op(OpType.Ballot, 0)); + + this.ops.push(new Op(OpType.IfId, 128)); // if 1 + this.ops.push(new Op(OpType.Store, this.storeBase + this.ops.length)); + this.ops.push(new Op(OpType.Ballot, 0)); + + this.ops.push(new Op(OpType.IfId, 128)); // if 2 + this.ops.push(new Op(OpType.ForUniform, 3)); // for 1 + this.ops.push(new Op(OpType.IfId, 128)); // if 3 + this.ops.push(new Op(OpType.Store, this.storeBase + this.ops.length)); + this.ops.push(new Op(OpType.Ballot, 0)); + this.ops.push(new Op(OpType.Break, 0)); + this.ops.push(new Op(OpType.EndIf, 0)); // end if 3 + + this.ops.push(new Op(OpType.IfMask, 0)); // if 4 + this.ops.push(new Op(OpType.Store, this.storeBase + this.ops.length)); + this.ops.push(new Op(OpType.Ballot, 0)); + this.ops.push(new Op(OpType.EndIf, 0)); // end if 4 + + this.ops.push(new Op(OpType.EndForUniform, 0)); // end for 1 + + this.ops.push(new Op(OpType.ElseId, 128)); // else if 2 + this.ops.push(new Op(OpType.EndIf, 0)); // end if 2 + this.ops.push(new Op(OpType.Store, this.storeBase + this.ops.length)); + this.ops.push(new Op(OpType.Ballot, 0)); + + this.ops.push(new Op(OpType.EndIf, 0)); // end if 1 + + this.ops.push(new Op(OpType.EndIf, 0)); // end if 0 + + this.ops.push(new Op(OpType.EndForUniform, 0)); // end for 0 + } } export function generateSeeds(numCases: number): number[] {