Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: allow named types in unions #469

Merged
merged 24 commits into from
Sep 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
136 changes: 83 additions & 53 deletions lib/types.js
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,8 @@ class Type {
wrapUnions = 'auto';
} else if (typeof wrapUnions == 'string') {
wrapUnions = wrapUnions.toLowerCase();
} else if (typeof wrapUnions === 'function') {
wrapUnions = 'auto';
}
switch (wrapUnions) {
case 'always':
Expand Down Expand Up @@ -196,11 +198,20 @@ class Type {
let types = schema.map((obj) => {
return Type.forSchema(obj, opts);
});
let projectionFn;
if (!UnionType) {
UnionType = isAmbiguous(types) ? WrappedUnionType : UnwrappedUnionType;
if (typeof opts.wrapUnions === 'function') {
// we have a projection function
joscha marked this conversation as resolved.
Show resolved Hide resolved
projectionFn = opts.wrapUnions(types);
UnionType = typeof projectionFn !== 'undefined'
? UnwrappedUnionType
: WrappedUnionType;
} else {
UnionType = isAmbiguous(types) ? WrappedUnionType : UnwrappedUnionType;
}
}
LOGICAL_TYPE = logicalType;
type = new UnionType(types, opts);
type = new UnionType(types, opts, projectionFn);
} else { // New type definition.
type = (function (typeName) {
let Type = TYPES[typeName];
Expand Down Expand Up @@ -341,10 +352,10 @@ class Type {
return branchTypes[name];
}), opts);
} catch (err) {
opts.wrapUnions = wrapUnions;
throw err;
} finally {
opts.wrapUnions = wrapUnions;
}
opts.wrapUnions = wrapUnions;
return unionType;
}

Expand Down Expand Up @@ -1226,6 +1237,60 @@ UnionType.prototype._branchConstructor = function () {
throw new Error('unions cannot be directly wrapped');
};


function generateProjectionIndexer(projectionFn) {
return (val) => {
const index = projectionFn(val);
if (typeof index !== 'number') {
throw new Error(`Projected index '${index}' is not valid`);
}
return index;
};
}

function generateDefaultIndexer(types) {
const dynamicBranches = [];
const bucketIndices = {};

const getBranchIndex = (any, index) => {
let logicalBranches = dynamicBranches;
for (let i = 0, l = logicalBranches.length; i < l; i++) {
let branch = logicalBranches[i];
if (branch.type._check(any)) {
if (index === undefined) {
index = branch.index;
} else {
// More than one branch matches the value so we aren't guaranteed to
// infer the correct type. We throw rather than corrupt data. This can
// be fixed by "tightening" the logical types.
throw new Error('ambiguous conversion');
}
}
}
return index;
}

types.forEach(function (type, index) {
if (Type.isType(type, 'abstract', 'logical')) {
dynamicBranches.push({index, type});
} else {
let bucket = getTypeBucket(type);
if (bucketIndices[bucket] !== undefined) {
throw new Error(`ambiguous unwrapped union: ${j(this)}`);
}
bucketIndices[bucket] = index;
}
});
return (val) => {
let index = bucketIndices[getValueBucket(val)];
if (dynamicBranches.length) {
// Slower path, we must run the value through all branches.
index = getBranchIndex(val, index);
}
return index;
};
}

/**
* "Natural" union type.
*
Expand All @@ -1246,54 +1311,17 @@ UnionType.prototype._branchConstructor = function () {
* + `map`, `record`
*/
class UnwrappedUnionType extends UnionType {
constructor (schema, opts) {
constructor (schema, opts, /* @private parameter */ _projectionFn) {
super(schema, opts);

this._dynamicBranches = null;
this._bucketIndices = {};
this.types.forEach(function (type, index) {
if (Type.isType(type, 'abstract', 'logical')) {
if (!this._dynamicBranches) {
this._dynamicBranches = [];
}
this._dynamicBranches.push({index, type});
} else {
let bucket = getTypeBucket(type);
if (this._bucketIndices[bucket] !== undefined) {
throw new Error(`ambiguous unwrapped union: ${j(this)}`);
}
this._bucketIndices[bucket] = index;
}
}, this);

Object.freeze(this);
}

_getIndex (val) {
let index = this._bucketIndices[getValueBucket(val)];
if (this._dynamicBranches) {
// Slower path, we must run the value through all branches.
index = this._getBranchIndex(val, index);
if (!_projectionFn && opts && typeof opts.wrapUnions === 'function') {
_projectionFn = opts.wrapUnions(this.types);
}
return index;
}
this._getIndex = _projectionFn
? generateProjectionIndexer(_projectionFn)
: generateDefaultIndexer(this.types);

_getBranchIndex (any, index) {
let logicalBranches = this._dynamicBranches;
for (let i = 0, l = logicalBranches.length; i < l; i++) {
let branch = logicalBranches[i];
if (branch.type._check(any)) {
if (index === undefined) {
index = branch.index;
} else {
// More than one branch matches the value so we aren't guaranteed to
// infer the correct type. We throw rather than corrupt data. This can
// be fixed by "tightening" the logical types.
throw new Error('ambiguous conversion');
}
}
}
return index;
Object.freeze(this);
}

_check (val, flags, hook, path) {
Expand Down Expand Up @@ -1355,16 +1383,18 @@ class UnwrappedUnionType extends UnionType {
// Using the `coerceBuffers` option can cause corruption and erroneous
// failures with unwrapped unions (in rare cases when the union also
// contains a record which matches a buffer's JSON representation).
if (isJsonBuffer(val) && this._bucketIndices.buffer !== undefined) {
index = this._bucketIndices.buffer;
} else {
index = this._getIndex(val);
if (isJsonBuffer(val)) {
let bufIndex = this.types.findIndex(t => getTypeBucket(t) === 'buffer');
if (bufIndex !== -1) {
index = bufIndex;
}
}
index ??= this._getIndex(val);
break;
case 2:
// Decoding from JSON, we must unwrap the value.
if (val === null) {
index = this._bucketIndices['null'];
index = this._getIndex(null);
} else if (typeof val === 'object') {
let keys = Object.keys(val);
if (keys.length === 1) {
Expand Down
51 changes: 51 additions & 0 deletions test/test_types.js
Original file line number Diff line number Diff line change
Expand Up @@ -3505,6 +3505,57 @@ suite('types', () => {
assert(Type.isType(t.field('unwrapped').type, 'union:unwrapped'));
});

test('union projection', () => {
joscha marked this conversation as resolved.
Show resolved Hide resolved
const Dog = {
type: 'record',
name: 'Dog',
fields: [
{ type: 'string', name: 'bark' }
],
};
const Cat = {
type: 'record',
name: 'Cat',
fields: [
{ type: 'string', name: 'meow' }
],
};
const animalTypes = [Dog, Cat];

let callsToWrapUnions = 0;
const wrapUnions = (types) => {
callsToWrapUnions++;
assert.deepEqual(types.map(t => t.name), ['Dog', 'Cat']);
return (animal) => {
const animalType = ((animal) => {
if ('bark' in animal) {
return 'Dog';
} else if ('meow' in animal) {
return 'Cat';
}
throw new Error('Unknown animal');
})(animal);
return types.indexOf(types.find(type => type.name === animalType));
joscha marked this conversation as resolved.
Show resolved Hide resolved
}
};

// Ambiguous, but we have a projection function
const Animal = Type.forSchema(animalTypes, { wrapUnions });
Animal.toBuffer({ meow: '🐈' });
assert.equal(callsToWrapUnions, 1);
assert.throws(() => Animal.toBuffer({ snap: '🐊' }), /Unknown animal/)
});

test('union projection with fallback', () => {
let t = Type.forSchema({
type: 'record',
fields: [
{name: 'wrapped', type: ['int', 'double' ]}, // Ambiguous.
]
}, {wrapUnions: () => undefined });
assert(Type.isType(t.field('wrapped').type, 'union:wrapped'));
});

test('invalid wrap unions option', () => {
assert.throws(() => {
Type.forSchema('string', {wrapUnions: 'FOO'});
Expand Down
17 changes: 16 additions & 1 deletion types/index.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,21 @@ interface EncoderOptions {
syncMarker: Buffer;
}

/**
* A projection function that is used when unwrapping unions.
* This function is called at schema parsing time on each union with its branches'
* types.
* If it returns a non-null (function) value, that function will be called each
* time a value's branch needs to be inferred and should return the branch's
* index.
* The index muss be a number between 0 and length-1 of the passed types.
* In this case (a branch index) the union will use an unwrapped representation.
* Otherwise (undefined), the union will be wrapped.
*/
type BranchProjection = (types: ReadonlyArray<Type>) =>
| ((val: unknown) => number)
| undefined;

interface ForSchemaOptions {
assertLogicalTypes: boolean;
logicalTypes: { [type: string]: new (schema: Schema, opts?: any) => types.LogicalType; };
Expand All @@ -103,7 +118,7 @@ interface ForSchemaOptions {
omitRecordMethods: boolean;
registry: { [name: string]: Type };
typeHook: (schema: Schema | string, opts: ForSchemaOptions) => Type | undefined;
wrapUnions: boolean | 'auto' | 'always' | 'never';
wrapUnions: BranchProjection | boolean | 'auto' | 'always' | 'never';
}

interface TypeOptions extends ForSchemaOptions {
Expand Down