Skip to content

Commit

Permalink
Initial commit for subgroupInclusive add/mul cts (#3964)
Browse files Browse the repository at this point in the history
* Initial commit for subgroupInclusive add/mul cts

* Change to using bounds for subgroupsInclusive (add/mul)

---------

Co-authored-by: Peter McNeeley <[email protected]>
  • Loading branch information
petermcneeleychromium and Peter McNeeley committed Sep 25, 2024
1 parent e46cff2 commit 77f1e4a
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 21 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
export const description = `
Execution tests for subgroupAdd and subgroupExclusiveAdd
Execution tests for subgroupAdd, subgroupExclusiveAdd, and subgroupInclusiveAdd
Note: There is a lack of portability for non-uniform execution so these tests
restrict themselves to uniform control flow.
Expand Down Expand Up @@ -38,7 +38,7 @@ const kIdentity = 0;

const kDataTypes = objectsToRecord(kConcreteNumericScalarsAndVectors);

const kOperations = ['subgroupAdd', 'subgroupExclusiveAdd'] as const;
const kOperations = ['subgroupAdd', 'subgroupExclusiveAdd', 'subgroupInclusiveAdd'] as const;

g.test('fp_accuracy')
.desc(
Expand Down Expand Up @@ -86,17 +86,20 @@ and limit the number of permutations needed to calculate the final result.`
* Expected results:
* - subgroupAdd: each invocation should have result equal to real subgroup size
* - subgroupExclusiveAdd: each invocation should have result equal to its subgroup invocation id
* - subgroupInclusiveAdd: each invocation should be equal to the result of subgroupExclusiveAdd plus the fill value
* @param metadata An array containing actual subgroup size per invocation followed by
* subgroup invocation id per invocation
* @param output An array of additions
* @param type The data type
* @param operation Type of addition
* @param expectedfillValue The original value used to fill the test array
*/
function checkAddition(
metadata: Uint32Array,
output: Uint32Array,
type: Type,
operation: 'subgroupAdd' | 'subgroupExclusiveAdd'
operation: 'subgroupAdd' | 'subgroupExclusiveAdd' | 'subgroupInclusiveAdd',
expectedfillValue: number
): undefined | Error {
let numEles = 1;
if (type instanceof VectorType) {
Expand All @@ -105,7 +108,11 @@ function checkAddition(
const scalarTy = scalarTypeOf(type);
const expectedOffset = operation === 'subgroupAdd' ? 0 : metadata.length / 2;
for (let i = 0; i < metadata.length / 2; i++) {
const expected = metadata[i + expectedOffset];
let expected = metadata[i + expectedOffset];
if (operation === 'subgroupInclusiveAdd') {
expected += expectedfillValue;
}

for (let j = 0; j < numEles; j++) {
let idx = i * numEles + j;
const isOdd = idx & 0x1;
Expand Down Expand Up @@ -217,8 +224,8 @@ fn main(
outputs[lid] = ${t.params.operation}(inputs[lid]);
}`;

let fillValue = 1;
const expectedFillValue = 1;
let fillValue = expectedFillValue;
let numUints = wgThreads * numEles;
if (scalarType === Type.f32) {
fillValue = numberToFloatBits(1, kFloat32Format);
Expand All @@ -234,7 +241,7 @@ fn main(
numUints,
new Uint32Array([...iterRange(numUints, x => fillValue)]),
(metadata: Uint32Array, output: Uint32Array) => {
return checkAddition(metadata, output, type, t.params.operation);
return checkAddition(metadata, output, type, t.params.operation, expectedFillValue);
}
);
});
Expand All @@ -255,15 +262,16 @@ g.test('fragment').unimplemented();
function checkPredicatedAddition(
metadata: Uint32Array,
output: Uint32Array,
operation: 'subgroupAdd' | 'subgroupExclusiveAdd',
operation: 'subgroupAdd' | 'subgroupExclusiveAdd' | 'subgroupInclusiveAdd',
filter: (id: number, size: number) => boolean
): Error | undefined {
for (let i = 0; i < output.length; i++) {
const size = metadata[i];
const id = metadata[output.length + i];
let expected = 0;
if (filter(id, size)) {
const bound = operation === 'subgroupAdd' ? size : id;
const bound =
operation === 'subgroupInclusiveAdd' ? id + 1 : operation === 'subgroupAdd' ? size : id;
for (let j = 0; j < bound; j++) {
if (filter(j, size)) {
expected += j;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
export const description = `
Execution tests for subgroupMul and subgroupExclusiveMul
Execution tests for subgroupMul, subgroupExclusiveMul, and subgroupInclusiveMul
Note: There is a lack of portability for non-uniform execution so these tests
restrict themselves to uniform control flow.
Expand Down Expand Up @@ -38,7 +38,7 @@ const kIdentity = 1;

const kDataTypes = objectsToRecord(kConcreteNumericScalarsAndVectors);

const kOperations = ['subgroupMul', 'subgroupExclusiveMul'] as const;
const kOperations = ['subgroupMul', 'subgroupExclusiveMul', 'subgroupInclusiveMul'] as const;

g.test('fp_accuracy')
.desc(
Expand Down Expand Up @@ -86,17 +86,20 @@ and limit the number of permutations needed to calculate the final result.`
* Expected results:
* - subgroupMul: each invocation should have result equal to 2 to the real subgroup size
* - subgroupExclusiveMul: each invocation should have result equal to 2 to its subgroup invocation id
* - subgroupInclusiveMul: each invocation should be equal to subgroupExclusiveMul result multiplied by the fill value
* @param metadata An array containing actual subgroup size per invocation followed by
* subgroup invocation id per invocation
* @param output An array of multiplications
* @param type The data type
* @param operation Type of multiplication
* @param expectedFillValue The original value used to fill the test array
*/
function checkMultiplication(
metadata: Uint32Array,
output: Uint32Array,
type: Type,
operation: 'subgroupMul' | 'subgroupExclusiveMul'
operation: 'subgroupMul' | 'subgroupExclusiveMul' | 'subgroupInclusiveMul',
expectedfillValue: number
): undefined | Error {
let numEles = 1;
if (type instanceof VectorType) {
Expand All @@ -105,7 +108,10 @@ function checkMultiplication(
const scalarTy = scalarTypeOf(type);
const expectedOffset = operation === 'subgroupMul' ? 0 : metadata.length / 2;
for (let i = 0; i < metadata.length / 2; i++) {
const expected = Math.pow(2, metadata[i + expectedOffset]);
let expected = Math.pow(2, metadata[i + expectedOffset]);
if (operation === 'subgroupInclusiveMul') {
expected *= expectedfillValue;
}
for (let j = 0; j < numEles; j++) {
let idx = i * numEles + j;
const isOdd = idx & 0x1;
Expand Down Expand Up @@ -237,7 +243,8 @@ fn main(
outputs[lid] = ${t.params.operation}(inputs[lid]);
}`;

let fillValue = 2;
const expectedfillValue = 2;
let fillValue = expectedfillValue;
let numUints = wgThreads * numEles;
if (scalarType === Type.f32) {
fillValue = numberToFloatBits(fillValue, kFloat32Format);
Expand All @@ -253,7 +260,7 @@ fn main(
numUints,
new Uint32Array([...iterRange(numUints, x => fillValue)]),
(metadata: Uint32Array, output: Uint32Array) => {
return checkMultiplication(metadata, output, type, t.params.operation);
return checkMultiplication(metadata, output, type, t.params.operation, expectedfillValue);
}
);
});
Expand All @@ -274,18 +281,23 @@ g.test('fragment').unimplemented();
function checkPredicatedMultiplication(
metadata: Uint32Array,
output: Uint32Array,
operation: 'subgroupMul' | 'subgroupExclusiveMul',
operation: 'subgroupMul' | 'subgroupExclusiveMul' | 'subgroupInclusiveMul',
filter: (id: number, size: number) => boolean
): Error | undefined {
for (let i = 0; i < output.length; i++) {
const size = metadata[i];
const id = metadata[output.length + i];
let expected = 1;
if (filter(id, size)) {
const bound = operation === 'subgroupMul' ? size : id;
// This function replicates the behavior in the shader.
const valueModFun = function (id: number) {
return (id % 4) + 1;
};
const bound =
operation === 'subgroupInclusiveMul' ? id + 1 : operation === 'subgroupMul' ? size : id;
for (let j = 0; j < bound; j++) {
if (filter(j, size)) {
expected *= (j % 4) + 1;
expected *= valueModFun(j);
}
}
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import { ShaderValidationTest } from '../../../shader_validation_test.js';

export const g = makeTestGroup(ShaderValidationTest);

const kBuiltins = ['subgroupAdd', 'subgroupExclusiveAdd'] as const;
const kBuiltins = ['subgroupAdd', 'subgroupExclusiveAdd', 'subgroupInclusiveAdd'] as const;

const kStages: Record<string, (builtin: string) => string> = {
constant: (builtin: string) => {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
export const description = `
Validation tests for subgroupMul and subgroupExclusiveMul
Validation tests for subgroupMul, subgroupExclusiveMul, and subgroupInclusiveMul
`;

import { makeTestGroup } from '../../../../../../common/framework/test_group.js';
Expand All @@ -9,7 +9,7 @@ import { ShaderValidationTest } from '../../../shader_validation_test.js';

export const g = makeTestGroup(ShaderValidationTest);

const kBuiltins = ['subgroupMul', 'subgroupExclusiveMul'] as const;
const kBuiltins = ['subgroupMul', 'subgroupExclusiveMul', 'subgroupInclusiveMul'] as const;

const kStages: Record<string, (builtin: string) => string> = {
constant: (builtin: string) => {
Expand Down

0 comments on commit 77f1e4a

Please sign in to comment.