Skip to content

Commit

Permalink
Add a new test suite
Browse files Browse the repository at this point in the history
* Added a new test suite 'uniform_maximal' that tests that ballots all
  work as expected when no divergent branches exist in the code
* The generator has a mode to only generate uniform conditions
  * removes several if, loop, and switch styles
  * restricts types of breaks and continues that generated
  * removes the generation of the election based noise operation
* Adds a predefined test to cover some operations
  • Loading branch information
alan-baker committed Sep 11, 2023
1 parent d7ea45d commit 8072695
Show file tree
Hide file tree
Showing 2 changed files with 212 additions and 12 deletions.
28 changes: 25 additions & 3 deletions src/webgpu/shader/execution/reconvergence/reconvergence.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ function checkIds(data: Uint32Array, subgroupSize: number): Error | undefined {
*
* So setting kDebugLevel to 0x5 would dump WGSL and end the test.
*/
const kDebugLevel = 0x0;
const kDebugLevel = 0x00;

async function testProgram(t: GPUTest, program: Program) {
const wgsl = program.genCode();
Expand Down Expand Up @@ -159,7 +159,7 @@ async function testProgram(t: GPUTest, program: Program) {

// Inputs have a value equal to their index.
const inputBuffer = t.makeBufferWithContents(
new Uint32Array([...iterRange(128, x => x)]),
new Uint32Array([...iterRange(129, x => x)]),
GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_DST
);
t.trackForCleanup(inputBuffer);
Expand Down Expand Up @@ -368,6 +368,10 @@ async function predefinedTest(t: GPUTest, style: Style, test: number) {
program.predefinedProgramWGSLv1();
break;
}
case 15: {
program.predefinedProgramAllUniform();
break;
}
default: {
unreachable('Unhandled testcase');
}
Expand All @@ -376,7 +380,7 @@ async function predefinedTest(t: GPUTest, style: Style, test: number) {
await testProgram(t, program);
}

const kPredefinedTestCases = [...iterRange(15, x => x)];
const kPredefinedTestCases = [...iterRange(16, x => x)];

g.test('predefined_workgroup')
.desc(`Test workgroup reconvergence using some predefined programs`)
Expand Down Expand Up @@ -495,3 +499,21 @@ g.test('random_wgslv1')

await testProgram(t, program);
});

g.test('uniform_maximal')
.desc(`Test workgroup reconvergence with only uniform branches`)
.params(u => u.combine('seed', generateSeeds(500)).beginSubcases())
.beforeAllSubcases(t => {
t.selectDeviceOrSkipTestCase({
requiredFeatures: ['chromium-experimental-subgroups' as GPUFeatureName],
});
})
.fn(async t => {
const invocations = kNumInvocations; // t.device.limits.maxSubgroupSize;

const onlyUniform: boolean = true;
const program: Program = new Program(Style.Maximal, t.params.seed, invocations, onlyUniform);
program.generate();

await testProgram(t, program);
});
196 changes: 187 additions & 9 deletions src/webgpu/shader/execution/reconvergence/util.ts
Original file line number Diff line number Diff line change
Expand Up @@ -330,14 +330,21 @@ export class Program {
// Indicates if the program satisfies uniform control flow for |style|
// This depends on simulating a particular subgroup size
public ucf: boolean;
// Indicates that only uniform branches should be generated.
private onlyUniform: boolean;

/**
* constructor
*
* @param style Enum indicating the type of reconvergence being tested
* @param seed Value used to seed the PRNG
*/
constructor(style: Style = Style.Workgroup, seed: number = 1, invocations: number) {
constructor(
style: Style = Style.Workgroup,
seed: number = 1,
invocations: number,
onlyUniform: boolean = false
) {
this.invocations = invocations;
assert(invocations <= 128);
this.prng = new PRNG(seed);
Expand Down Expand Up @@ -378,6 +385,7 @@ export class Program {
this.maxProgramNesting = 10; // default stack allocation
this.maxLocations = 130000; // keep the buffer under 256MiB
this.ucf = false;
this.onlyUniform = onlyUniform;
}

/** @returns A random float between 0 and 1 */
Expand All @@ -390,13 +398,31 @@ export class Program {
return this.prng.randomU32() % max;
}

/**
* Pick |count| random instructions
*
* @param count The number of instructions
*
* If |this.onlyUniform| is true then only uniform instructions will be
* selected.
*
*/
private pickOp(count: number) {
if (this.onlyUniform) {
this.pickUniformOp(count);
} else {
this.pickAnyOp(count);
}
}

/**
* Pick |count| random instructions generators
*
* @param count the number of instructions
*
* These instructions could be uniform or non-uniform.
*/
private pickOp(count: number) {
private pickAnyOp(count: number) {
for (let i = 0; i < count; i++) {
if (this.ops.length >= this.maxCount) {
return;
Expand Down Expand Up @@ -527,6 +553,97 @@ export class Program {
}
}

/**
* Pick |count| random uniform instructions generators
*
* @param count the number of instructions
*
*/
private pickUniformOp(count: number) {
for (let i = 0; i < count; i++) {
if (this.ops.length >= this.maxCount) {
return;
}

this.genBallot();
if (this.nesting < this.maxNesting) {
const r = this.getRandomUint(10);
switch (r) {
case 0:
case 1: {
this.genIf(IfType.Lid);
break;
}
case 2:
case 3: {
this.genIf(IfType.Uniform);
break;
}
case 4: {
// Avoid very deep loop nests to limit memory and runtime.
if (this.loopNesting < this.maxLoopNesting) {
this.genForUniform();
}
break;
}
case 5: {
this.genBreak();
break;
}
case 6: {
this.genContinue();
break;
}
case 7: {
// Calls and returns.
if (
this.getRandomFloat() < 0.2 &&
this.callNesting === 0 &&
this.nesting < this.maxNesting - 1
) {
this.genCall();
} else {
this.genReturn();
}
break;
}
case 8: {
if (this.loopNesting < this.maxLoopNesting) {
this.genLoopUniform();
}
break;
}
case 9: {
// crbug.com/tint/2039
// Tint generates invalid code for switch inside loops.
if (this.loopNestingThisFunction > 0) {
break;
}
const r2 = this.getRandomUint(2);
switch (r2) {
case 1: {
if (this.loopNesting > 0) {
this.genSwitchLoopCount();
break;
}
// fallthrough
}
default: {
this.genSwitchUniform();
break;
}
}
break;
}
default: {
break;
}
}
}
this.genBallot();
}
}

/**
* Ballot generation
*
Expand Down Expand Up @@ -572,7 +689,7 @@ export class Program {
}

const r = this.getRandomUint(10000);
if (r < 3) {
if (r < 3 && !this.onlyUniform) {
this.ops.push(new Op(OpType.Noise, 0));
} else if (r < 10) {
this.ops.push(new Op(OpType.Noise, 1));
Expand All @@ -590,7 +707,7 @@ export class Program {
let maskIdx = this.getRandomUint(this.numMasks);
if (type === IfType.Uniform) maskIdx = 0;

const lid = this.getRandomUint(this.invocations);
const lid = this.onlyUniform ? this.invocations : this.getRandomUint(this.invocations);
if (type === IfType.Lid) {
this.ops.push(new Op(OpType.IfId, lid));
} else if (type === IfType.LoopCount) {
Expand Down Expand Up @@ -820,13 +937,14 @@ export class Program {
* Generate a break if in a loop.
*
* Only generates a break within a loop, but may break out of a switch and
* not just a loop. Sometimes the break uses a non-uniform if/else to break.
* not just a loop. Sometimes the break uses a non-uniform if/else to break
* (unless only uniform branches are specified).
*
*/
private genBreak() {
if (this.loopNestingThisFunction > 0) {
// Sometimes put the break in a divergent if
if (this.getRandomFloat() < 0.1) {
if (this.getRandomFloat() < 0.1 && !this.onlyUniform) {
const r = this.getRandomUint(this.numMasks - 1) + 1;
this.ops.push(new Op(OpType.IfMask, r));
this.ops.push(new Op(OpType.Break, 0));
Expand All @@ -843,12 +961,13 @@ export class Program {
/**
* Generate a continue if in a loop
*
* Sometimes uses a non-uniform if/else to continue.
* Sometimes uses a non-uniform if/else to continue (unless only uniform
* branches are specified).
*/
private genContinue() {
if (this.loopNestingThisFunction > 0 && !this.isLoopInf.get(this.loopNesting)) {
// Sometimes put the continue in a divergent if
if (this.getRandomFloat() < 0.1) {
if (this.getRandomFloat() < 0.1 && !this.onlyUniform) {
const r = this.getRandomUint(this.numMasks - 1) + 1;
this.ops.push(new Op(OpType.IfMask, r));
this.ops.push(new Op(OpType.Continue, 0));
Expand Down Expand Up @@ -896,7 +1015,7 @@ export class Program {
(this.callNesting > 0 && this.loopNestingThisFunction > 1 && r < 0.5))
) {
this.genBallot();
if (this.getRandomFloat() < 0.1) {
if (this.getRandomFloat() < 0.1 && !this.onlyUniform) {
this.ops.push(new Op(OpType.IfMask, 0));
this.ops.push(new Op(OpType.Return, this.callNesting));
this.ops.push(new Op(OpType.ElseMask, 0));
Expand Down Expand Up @@ -2760,6 +2879,65 @@ ${this.functions[i]}`;
this.ops.push(new Op(OpType.Store, this.storeBase + this.ops.length));
this.ops.push(new Op(OpType.Ballot, 0));
}

/**
* Equivalent to:
*
* for (var i0 = 0u; i0 < inputs[3]; i0++)
* ballot();
* if subgroup_invocation_id < inputs[128]
* ballot();
* if subgroup_invocation_id < inputs[128]
* ballot();
* if subgroup_invocation_id < inputs[128]
* for (var i1 = 0u; i1 < inputs[3]; i1++)
* if subgroup_invocation_id < inputs[128]
* ballot();
* break;
* if inputs[3] == 3
* ballot();
* ballot();
*
*/
public predefinedProgramAllUniform() {
this.ops.push(new Op(OpType.ForUniform, 3)); // for 0
this.ops.push(new Op(OpType.Store, this.storeBase + this.ops.length));
this.ops.push(new Op(OpType.Ballot, 0));

this.ops.push(new Op(OpType.IfId, 128)); // if 0
this.ops.push(new Op(OpType.Store, this.storeBase + this.ops.length));
this.ops.push(new Op(OpType.Ballot, 0));

this.ops.push(new Op(OpType.IfId, 128)); // if 1
this.ops.push(new Op(OpType.Store, this.storeBase + this.ops.length));
this.ops.push(new Op(OpType.Ballot, 0));

this.ops.push(new Op(OpType.IfId, 128)); // if 2
this.ops.push(new Op(OpType.ForUniform, 3)); // for 1
this.ops.push(new Op(OpType.IfId, 128)); // if 3
this.ops.push(new Op(OpType.Store, this.storeBase + this.ops.length));
this.ops.push(new Op(OpType.Ballot, 0));
this.ops.push(new Op(OpType.Break, 0));
this.ops.push(new Op(OpType.EndIf, 0)); // end if 3

this.ops.push(new Op(OpType.IfMask, 0)); // if 4
this.ops.push(new Op(OpType.Store, this.storeBase + this.ops.length));
this.ops.push(new Op(OpType.Ballot, 0));
this.ops.push(new Op(OpType.EndIf, 0)); // end if 4

this.ops.push(new Op(OpType.EndForUniform, 0)); // end for 1

this.ops.push(new Op(OpType.ElseId, 128)); // else if 2
this.ops.push(new Op(OpType.EndIf, 0)); // end if 2
this.ops.push(new Op(OpType.Store, this.storeBase + this.ops.length));
this.ops.push(new Op(OpType.Ballot, 0));

this.ops.push(new Op(OpType.EndIf, 0)); // end if 1

this.ops.push(new Op(OpType.EndIf, 0)); // end if 0

this.ops.push(new Op(OpType.EndForUniform, 0)); // end for 0
}
}

export function generateSeeds(numCases: number): number[] {
Expand Down

0 comments on commit 8072695

Please sign in to comment.