Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: allow named types in unions #469

Merged
merged 24 commits into from
Sep 29, 2024
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 38 additions & 5 deletions lib/types.js
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,8 @@ class Type {
wrapUnions = 'auto';
} else if (typeof wrapUnions == 'string') {
wrapUnions = wrapUnions.toLowerCase();
} else if (typeof wrapUnions === 'function') {
wrapUnions = 'auto';
}
switch (wrapUnions) {
case 'always':
Expand Down Expand Up @@ -197,7 +199,20 @@ class Type {
return Type.forSchema(obj, opts);
});
if (!UnionType) {
UnionType = isAmbiguous(types) ? WrappedUnionType : UnwrappedUnionType;
// either automatic detection or we have a projection function
if (typeof opts.wrapUnions === 'function') {
// we have a projection function
joscha marked this conversation as resolved.
Show resolved Hide resolved
try {
UnionType = typeof opts.wrapUnions(types) !== 'undefined'
joscha marked this conversation as resolved.
Show resolved Hide resolved
// projection function yields a function, we can use an Unwrapped type
? UnwrappedUnionType
: WrappedUnionType;
} catch(e) {
throw new Error(`Union projection function errored: ${e}`);
}
} else {
UnionType = isAmbiguous(types) ? WrappedUnionType : UnwrappedUnionType;
}
}
LOGICAL_TYPE = logicalType;
type = new UnionType(types, opts);
Expand Down Expand Up @@ -341,10 +356,10 @@ class Type {
return branchTypes[name];
}), opts);
} catch (err) {
opts.wrapUnions = wrapUnions;
throw err;
} finally {
opts.wrapUnions = wrapUnions;
}
opts.wrapUnions = wrapUnions;
return unionType;
}

Expand Down Expand Up @@ -1251,13 +1266,31 @@ class UnwrappedUnionType extends UnionType {

this._dynamicBranches = null;
this._bucketIndices = {};

this.projectionFunction = (val) => this._bucketIndices[getValueBucket(val)];
joscha marked this conversation as resolved.
Show resolved Hide resolved
let hasWrapUnionsFn = opts && typeof opts.wrapUnions === 'function';
if (hasWrapUnionsFn) {
const projectionFunction = opts.wrapUnions(this.types);
if (typeof projectionFunction === 'undefined') {
hasWrapUnionsFn = false;
} else {
this.projectionFunction = (val) => {
const index = projectionFunction(val);
if (typeof index !== 'number' || index >= this._bucketIndices.length) {
throw new Error(`Projected index ${index} is not valid`);
}
return index;
}
}
}

this.types.forEach(function (type, index) {
if (Type.isType(type, 'abstract', 'logical')) {
if (!this._dynamicBranches) {
this._dynamicBranches = [];
}
this._dynamicBranches.push({index, type});
} else {
} else if (!hasWrapUnionsFn) {
let bucket = getTypeBucket(type);
if (this._bucketIndices[bucket] !== undefined) {
throw new Error(`ambiguous unwrapped union: ${j(this)}`);
Expand All @@ -1270,7 +1303,7 @@ class UnwrappedUnionType extends UnionType {
}

_getIndex (val) {
let index = this._bucketIndices[getValueBucket(val)];
let index = this.projectionFunction(val);
if (this._dynamicBranches) {
// Slower path, we must run the value through all branches.
index = this._getBranchIndex(val, index);
Expand Down
48 changes: 48 additions & 0 deletions test/test_types.js
Original file line number Diff line number Diff line change
Expand Up @@ -3492,6 +3492,54 @@ suite('types', () => {
assert(Type.isType(t.field('unwrapped').type, 'union:unwrapped'));
});

test('union projection', () => {
joscha marked this conversation as resolved.
Show resolved Hide resolved
const Dog = {
type: 'record',
name: 'Dog',
fields: [
{ type: 'string', name: 'bark' }
],
};
const Cat = {
type: 'record',
name: 'Cat',
fields: [
{ type: 'string', name: 'meow' }
],
};
const animalTypes = [Dog, Cat];

const wrapUnions = (types) => {
assert.deepEqual(types.map(t => t.name), ['Dog', 'Cat']);
return (animal) => {
const animalType = ((animal) => {
if ('bark' in animal) {
return 'Dog';
} else if ('meow' in animal) {
return 'Cat';
}
throw new Error('Unknown animal');
})(animal);
return types.indexOf(types.find(type => type.name === animalType));
joscha marked this conversation as resolved.
Show resolved Hide resolved
}
};

// TODO: replace this with a mock when available
// currently we're on mocha without sinon
function mockWrapUnions() {
mockWrapUnions.calls = typeof mockWrapUnions.calls === 'undefined'
? 1
: ++mockWrapUnions.calls;
return wrapUnions.apply(null, arguments);
}
joscha marked this conversation as resolved.
Show resolved Hide resolved

// Ambiguous, but we have a projection function
const Animal = Type.forSchema(animalTypes, { wrapUnions: mockWrapUnions });
Animal.toBuffer({ meow: '🐈' });
assert.equal(mockWrapUnions.calls, 2);
assert.throws(() => Animal.toBuffer({ snap: '🐊' }), /Unknown animal/)
});

test('invalid wrap unions option', () => {
assert.throws(() => {
Type.forSchema('string', {wrapUnions: 'FOO'});
Expand Down
11 changes: 10 additions & 1 deletion types/index.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,15 @@ interface EncoderOptions {
syncMarker: Buffer;
}

/**
* A projection function that is used when unwrapping unions.
* This function is called at schema parsing time on each union with its branches' types.
* If it returns a non-null (function) value, that function will be called each time a value's branch needs to be inferred and should return the branch's index.
* The index muss be a number between 0 and length-1 of the passed types.
* In this case (a branch index) the union will use an unwrapped representation. Otherwise (undefined), the union will be wrapped.
joscha marked this conversation as resolved.
Show resolved Hide resolved
*/
type ProjectionFn = (types: ReadonlyArray<Type>) => ((val: unknown) => number) | undefined;
joscha marked this conversation as resolved.
Show resolved Hide resolved

interface ForSchemaOptions {
assertLogicalTypes: boolean;
logicalTypes: { [type: string]: new (schema: Schema, opts?: any) => types.LogicalType; };
Expand All @@ -104,7 +113,7 @@ interface ForSchemaOptions {
omitRecordMethods: boolean;
registry: { [name: string]: Type };
typeHook: (schema: Schema | string, opts: ForSchemaOptions) => Type | undefined;
wrapUnions: boolean | 'auto' | 'always' | 'never';
wrapUnions: ProjectionFn | boolean | 'auto' | 'always' | 'never';
}

interface TypeOptions extends ForSchemaOptions {
Expand Down
Loading