Skip to content

Commit

Permalink
Fix fragment shader subgroup builtin io test (#4024)
Browse files Browse the repository at this point in the history
This PR fix the expectation of fragment shader subgroup_invocation_id.
Any invocation id (including 0) within subgroup can be assigned to
inactivate invocations, and ids of active invocations can go larger than
active invocations number but still smaller than subgroup size. This PR
also fix the draw call for fragment subgroup tests.
  • Loading branch information
jzm-intel authored Nov 4, 2024
1 parent 8328265 commit f2e2ada
Showing 1 changed file with 74 additions and 47 deletions.
121 changes: 74 additions & 47 deletions src/webgpu/shader/execution/shader_io/fragment_builtins.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1578,23 +1578,6 @@ fn vsMain(@builtin(vertex_index) index : u32) -> @builtin(position) vec4f {
const byteLength = bytesPerRow * blocksPerColumn;
const uintLength = byteLength / 4;

const buffer = t.makeBufferWithContents(
new Uint32Array([1]),
GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_DST
);

const bg = t.device.createBindGroup({
layout: pipeline.getBindGroupLayout(0),
entries: [
{
binding: 0,
resource: {
buffer,
},
},
],
});

for (let i = 0; i < 2; i++) {
const framebuffer = t.createTextureTracked({
size: [width, height],
Expand All @@ -1617,8 +1600,8 @@ fn vsMain(@builtin(vertex_index) index : u32) -> @builtin(position) vec4f {
],
});
pass.setPipeline(pipeline);
pass.setBindGroup(0, bg);
pass.draw(3, 1, i);
// Draw the uperr-left triangle (vertices 0-2) or the lower-right triangle (vertices 3-5)
pass.draw(3, 1, i * 3);
pass.end();
t.queue.submit([encoder.finish()]);

Expand Down Expand Up @@ -1659,15 +1642,11 @@ enable subgroups;
const width = ${t.params.size[0]};
const height = ${t.params.size[1]};
@group(0) @binding(0) var<storage, read_write> for_layout : u32;
@fragment
fn fsMain(
@builtin(position) pos : vec4f,
@builtin(subgroup_size) sg_size : u32,
) -> @location(0) vec4u {
_ = for_layout;
let ballot = countOneBits(subgroupBallot(true));
let ballotSize = ballot.x + ballot.y + ballot.z + ballot.w;
Expand Down Expand Up @@ -1699,17 +1678,23 @@ fn fsMain(
);
});

// A non-zero magic number indicating no expectation error, in order to prevent the false no-error
// result from zero-initialization.
const kSubgroupInvocationIdNoError = 17;

/**
* Checks subgroup_invocation_id value consistency
*
* Very little uniformity is expected for subgroup_invocation_id.
* This function checks that all ids are less than the subgroup size
* and no id is repeated.
* (not the ballot size, since the subgroup id can be allocated to
* inactivate invocations between active ones) and no id is repeated.
* @param data An array of vec4u that contains (per texel):
* * subgroup_invocation_id
* * ballot size
* * non-zero ID unique to each subgroup
* * 0
* * subgroup size
* * ballot active invocation number
* * error flag, should be equal to kSubgroupInvocationIdNoError or shader found
* expection failed otherwise.
* @param format The texture format of data
* @param width The width of the framebuffer
* @param height The height of the framebuffer
Expand All @@ -1726,31 +1711,44 @@ function checkSubgroupInvocationIdConsistency(
const uintsPerRow = bytesPerRow / 4;
const uintsPerTexel = (bytesPerBlock ?? 1) / blockWidth / blockHeight / 4;

const mappings = new Map<number, bigint>();
for (let row = 0; row < height; row++) {
for (let col = 0; col < width; col++) {
const offset = uintsPerRow * row + col * uintsPerTexel;
const id = data[offset];
const size = data[offset + 1];
const repId = data[offset + 2];

if (repId === 0) {
const sgSize = data[offset + 1];
const ballotSize = data[offset + 2];
const error = data[offset + 3];

if (error === 0) {
// Inactive fragment get error `0` instead of noError. Check all output being zero.
if (id !== 0 || sgSize !== 0 || ballotSize !== 0) {
return new Error(
`Unexpected zero error with non-zero outputs for (${row}, ${col}): got output [${id}, ${sgSize}, ${ballotSize}, ${error}]`
);
}
continue;
}

if (size < id) {
if (sgSize < id) {
return new Error(
`Invocation id '${id}' is greater than subgroup size '${size}' for (${row}, ${col})`
`Invocation id '${id}' is greater than subgroup size '${sgSize}' for (${row}, ${col})`
);
}

let v = mappings.get(repId) ?? 0n;
const mask = 1n << BigInt(id);
if ((mask & v) !== 0n) {
return new Error(`Multiple invocations with id '${id}' in subgroup '${repId}'`);
if (sgSize < ballotSize) {
return new Error(
`Ballot size '${ballotSize}' is greater than subgroup size '${sgSize}' for (${row}, ${col})`
);
}

if (error !== kSubgroupInvocationIdNoError) {
return new Error(
`Unexpected error value
- icoord: (${row}, ${col})
- expected: noError (${kSubgroupInvocationIdNoError})
- got: ${error}`
);
}
v |= mask;
mappings.set(repId, v);
}
}

Expand All @@ -1775,22 +1773,51 @@ enable subgroups;
const width = ${t.params.size[0]};
const height = ${t.params.size[1]};
@group(0) @binding(0) var<storage, read_write> counter : atomic<u32>;
const maxSubgroupSize = 128u;
// A non-zero magic number indicating no expectation error, in order to prevent the
// false no-error result from zero-initialization.
const noError = ${kSubgroupInvocationIdNoError}u;
@fragment
fn fsMain(
@builtin(position) pos : vec4f,
@builtin(subgroup_invocation_id) id : u32,
@builtin(subgroup_size) sg_size : u32,
) -> @location(0) vec4u {
let ballot = countOneBits(subgroupBallot(true));
let ballotSize = ballot.x + ballot.y + ballot.z + ballot.w;
// Generate representative id for this subgroup.
var repId = atomicAdd(&counter, 1);
repId = subgroupBroadcast(repId, 0);
var error: u32 = noError;
// Validate that reported subgroup size is no larger than maxSubgroupSize
if (sg_size > maxSubgroupSize) {
error++;
}
// Validate that reported subgroup invocation id is smaller than subgroup size
if (id >= sg_size) {
error++;
}
// Validate that each subgroup id is assigned to at most one active invocation
// in the subgroup
var countAssignedId: u32 = 0u;
for (var i: u32 = 0; i < maxSubgroupSize; i++) {
let ballotIdEqualsI = countOneBits(subgroupBallot(id == i));
let countInvocationIdEqualsI = ballotIdEqualsI.x + ballotIdEqualsI.y + ballotIdEqualsI.z + ballotIdEqualsI.w;
// Validate an id assigned at most once
error += select(1u, 0u, countInvocationIdEqualsI <= 1);
// Validate id larger than subgroup size will not get balloted
error += select(1u, 0u, (id < sg_size) || (countInvocationIdEqualsI == 0));
// Sum up the assigned invocation number of each id
countAssignedId += countInvocationIdEqualsI;
}
// Validate that all active invocation get counted during the above loop
let ballotActive = countOneBits(subgroupBallot(true));
let activeInvocations = ballotActive.x + ballotActive.y + ballotActive.z + ballotActive.w;
if (activeInvocations != countAssignedId) {
error++;
}
return vec4u(id, ballotSize, repId, 0);
return vec4u(id, sg_size, activeInvocations, error);
}`;

await runSubgroupTest(
Expand Down

0 comments on commit f2e2ada

Please sign in to comment.