From e1fe9c5f7a8e32c65cf1edb2483f3c3f1f44b3e1 Mon Sep 17 00:00:00 2001 From: Yaacov Rydzinski Date: Fri, 11 Aug 2023 20:44:27 +0300 Subject: [PATCH] add Abort signals --- src/execution/__tests__/executor-test.ts | 293 +++++++++++++++++++++++ src/execution/__tests__/stream-test.ts | 80 +++++++ src/execution/execute.ts | 263 +++++++++++++++----- src/jsutils/addAbortListener.ts | 64 +++++ src/type/definition.ts | 1 + 5 files changed, 646 insertions(+), 55 deletions(-) create mode 100644 src/jsutils/addAbortListener.ts diff --git a/src/execution/__tests__/executor-test.ts b/src/execution/__tests__/executor-test.ts index c29b4ae60dd..1eba3a86693 100644 --- a/src/execution/__tests__/executor-test.ts +++ b/src/execution/__tests__/executor-test.ts @@ -635,6 +635,299 @@ describe('Execute: Handles basic execution tasks', () => { expect(isAsyncResolverFinished).to.equal(true); }); + it('exits early on early abort', () => { + let isExecuted = false; + + const schema = new GraphQLSchema({ + query: new GraphQLObjectType({ + name: 'Query', + fields: { + field: { + type: GraphQLString, + /* c8 ignore next 3 */ + resolve() { + isExecuted = true; + }, + }, + }, + }), + }); + + const document = parse(` + { + field + } + `); + + const abortController = new AbortController(); + abortController.abort(); + + const result = execute({ + schema, + document, + abortSignal: abortController.signal, + }); + + expect(isExecuted).to.equal(false); + expectJSON(result).toDeepEqual({ + data: { field: null }, + errors: [ + { + message: 'This operation was aborted', + locations: [{ line: 3, column: 9 }], + path: ['field'], + }, + ], + }); + }); + + it('exits early on abort mid-execution', async () => { + let isExecuted = false; + + const asyncObjectType = new GraphQLObjectType({ + name: 'AsyncObject', + fields: { + field: { + type: GraphQLString, + /* c8 ignore next 3 */ + resolve() { + isExecuted = true; + }, + }, + }, + }); + + const schema = new GraphQLSchema({ + query: new GraphQLObjectType({ + name: 'Query', + fields: { + asyncObject: { + type: asyncObjectType, + async resolve() { + await resolveOnNextTick(); + return {}; + }, + }, + }, + }), + }); + + const document = parse(` + { + asyncObject { + field + } + } + `); + + const abortController = new AbortController(); + + const result = execute({ + schema, + document, + abortSignal: abortController.signal, + }); + + abortController.abort(); + + expect(isExecuted).to.equal(false); + expectJSON(await result).toDeepEqual({ + data: { asyncObject: { field: null } }, + errors: [ + { + message: 'This operation was aborted', + locations: [{ line: 4, column: 11 }], + path: ['asyncObject', 'field'], + }, + ], + }); + expect(isExecuted).to.equal(false); + }); + + it('exits early on abort mid-resolver', async () => { + const schema = new GraphQLSchema({ + query: new GraphQLObjectType({ + name: 'Query', + fields: { + asyncField: { + type: GraphQLString, + async resolve(_parent, _args, _context, _info, abortSignal) { + await resolveOnNextTick(); + abortSignal?.throwIfAborted(); + }, + }, + }, + }), + }); + + const document = parse(` + { + asyncField + } + `); + + const abortController = new AbortController(); + + const result = execute({ + schema, + document, + abortSignal: abortController.signal, + }); + + abortController.abort(); + + expectJSON(await result).toDeepEqual({ + data: { asyncField: null }, + errors: [ + { + message: 'This operation was aborted', + locations: [{ line: 3, column: 9 }], + path: ['asyncField'], + }, + ], + }); + }); + + it('exits early on abort mid-nested resolver', async () => { + const syncObjectType = new GraphQLObjectType({ + name: 'SyncObject', + fields: { + asyncField: { + type: GraphQLString, + async resolve(_parent, _args, _context, _info, abortSignal) { + await resolveOnNextTick(); + abortSignal?.throwIfAborted(); + }, + }, + }, + }); + + const schema = new GraphQLSchema({ + query: new GraphQLObjectType({ + name: 'Query', + fields: { + syncObject: { + type: syncObjectType, + resolve() { + return {}; + }, + }, + }, + }), + }); + + const document = parse(` + { + syncObject { + asyncField + } + } + `); + + const abortController = new AbortController(); + + const result = execute({ + schema, + document, + abortSignal: abortController.signal, + }); + + abortController.abort(); + + expectJSON(await result).toDeepEqual({ + data: { syncObject: { asyncField: null } }, + errors: [ + { + message: 'This operation was aborted', + locations: [{ line: 4, column: 11 }], + path: ['syncObject', 'asyncField'], + }, + ], + }); + }); + + it('exits early on error', async () => { + const objectType = new GraphQLObjectType({ + name: 'Object', + fields: { + nonNullNestedAsyncField: { + type: new GraphQLNonNull(GraphQLString), + async resolve() { + await resolveOnNextTick(); + throw new Error('Oops'); + }, + }, + nestedAsyncField: { + type: GraphQLString, + async resolve(_parent, _args, _context, _info, abortSignal) { + await resolveOnNextTick(); + abortSignal?.throwIfAborted(); + }, + }, + }, + }); + + const schema = new GraphQLSchema({ + query: new GraphQLObjectType({ + name: 'Query', + fields: { + object: { + type: objectType, + resolve() { + return {}; + }, + }, + asyncField: { + type: GraphQLString, + async resolve() { + await resolveOnNextTick(); + return 'asyncValue'; + }, + }, + }, + }), + }); + + const document = parse(` + { + object { + nonNullNestedAsyncField + nestedAsyncField + } + asyncField + } + `); + + const abortController = new AbortController(); + + const result = execute({ + schema, + document, + abortSignal: abortController.signal, + }); + + abortController.abort(); + + expectJSON(await result).toDeepEqual({ + data: { + object: null, + asyncField: 'asyncValue', + }, + errors: [ + { + message: 'This operation was aborted', + locations: [{ line: 5, column: 11 }], + path: ['object', 'nestedAsyncField'], + }, + { + message: 'Oops', + locations: [{ line: 4, column: 11 }], + path: ['object', 'nonNullNestedAsyncField'], + }, + ], + }); + }); + it('Full response path is included for non-nullable fields', () => { const A: GraphQLObjectType = new GraphQLObjectType({ name: 'A', diff --git a/src/execution/__tests__/stream-test.ts b/src/execution/__tests__/stream-test.ts index 194ed0d84b4..f4756af8bcc 100644 --- a/src/execution/__tests__/stream-test.ts +++ b/src/execution/__tests__/stream-test.ts @@ -1160,6 +1160,45 @@ describe('Execute: stream directive', () => { }, ]); }); + it('Handles nested errors thrown by completeValue after initialCount is reached for a non-nullable list', async () => { + const document = parse(` + query { + nonNullFriendList @stream(initialCount: 1) { + nonNullName + } + } + `); + const result = await complete(document, { + nonNullFriendList: () => [ + { nonNullName: friends[0].name }, + { nonNullName: new Error('Oops') }, + ], + }); + expectJSON(result).toDeepEqual([ + { + data: { + nonNullFriendList: [{ nonNullName: 'Luke' }], + }, + hasNext: true, + }, + { + incremental: [ + { + items: null, + path: ['nonNullFriendList', 1], + errors: [ + { + message: 'Oops', + locations: [{ line: 4, column: 11 }], + path: ['nonNullFriendList', 1, 'nonNullName'], + }, + ], + }, + ], + hasNext: false, + }, + ]); + }); it('Handles nested errors thrown by completeValue after initialCount is reached from async iterable', async () => { const document = parse(` query { @@ -1214,6 +1253,47 @@ describe('Execute: stream directive', () => { }, ]); }); + it('Handles nested errors thrown by completeValue after initialCount is reached from async iterable for a non-nullable list', async () => { + const document = parse(` + query { + nonNullFriendList @stream(initialCount: 1) { + nonNullName + } + } + `); + const result = await complete(document, { + async *nonNullFriendList() { + yield await Promise.resolve({ nonNullName: friends[0].name }); + yield await Promise.resolve({ + nonNullName: () => new Error('Oops'), + }); /* c8 ignore start */ + } /* c8 ignore stop */, + }); + expectJSON(result).toDeepEqual([ + { + data: { + nonNullFriendList: [{ nonNullName: 'Luke' }], + }, + hasNext: true, + }, + { + incremental: [ + { + items: null, + path: ['nonNullFriendList', 1], + errors: [ + { + message: 'Oops', + locations: [{ line: 4, column: 11 }], + path: ['nonNullFriendList', 1, 'nonNullName'], + }, + ], + }, + ], + hasNext: false, + }, + ]); + }); it('Handles nested async errors thrown by completeValue after initialCount is reached', async () => { const document = parse(` query { diff --git a/src/execution/execute.ts b/src/execution/execute.ts index 7a6f2ce27d0..ad500e110be 100644 --- a/src/execution/execute.ts +++ b/src/execution/execute.ts @@ -1,3 +1,4 @@ +import { addAbortListener } from '../jsutils/addAbortListener.js'; import { inspect } from '../jsutils/inspect.js'; import { invariant } from '../jsutils/invariant.js'; import { isAsyncIterable } from '../jsutils/isAsyncIterable.js'; @@ -132,6 +133,7 @@ export interface ExecutionContext { typeResolver: GraphQLTypeResolver; subscribeFieldResolver: GraphQLFieldResolver; incrementalPublisher: IncrementalPublisher; + abortSignal: AbortSignal | undefined; } /** @@ -201,6 +203,7 @@ export interface ExecutionArgs { fieldResolver?: Maybe>; typeResolver?: Maybe>; subscribeFieldResolver?: Maybe>; + abortSignal?: AbortSignal; } const UNEXPECTED_EXPERIMENTAL_DIRECTIVES = @@ -389,6 +392,7 @@ export function buildExecutionContext( fieldResolver, typeResolver, subscribeFieldResolver, + abortSignal, } = args; // If the schema used for execution is invalid, throw an error. @@ -453,6 +457,7 @@ export function buildExecutionContext( typeResolver: typeResolver ?? defaultTypeResolver, subscribeFieldResolver: subscribeFieldResolver ?? defaultFieldResolver, incrementalPublisher: new IncrementalPublisher(), + abortSignal, }; } @@ -473,8 +478,14 @@ function executeOperation( exeContext: ExecutionContext, initialResultRecord: InitialResultRecord, ): PromiseOrValue> { - const { operation, schema, fragments, variableValues, rootValue } = - exeContext; + const { + operation, + schema, + fragments, + variableValues, + rootValue, + abortSignal, + } = exeContext; const rootType = schema.getRootType(operation.operation); if (rootType == null) { throw new GraphQLError( @@ -502,6 +513,7 @@ function executeOperation( path, groupedFieldSet, initialResultRecord, + abortSignal, ); break; case OperationTypeNode.MUTATION: @@ -512,6 +524,7 @@ function executeOperation( path, groupedFieldSet, initialResultRecord, + abortSignal, ); break; case OperationTypeNode.SUBSCRIPTION: @@ -524,6 +537,7 @@ function executeOperation( path, groupedFieldSet, initialResultRecord, + abortSignal, ); } @@ -535,6 +549,7 @@ function executeOperation( rootValue, patchGroupedFieldSet, initialResultRecord, + abortSignal, label, path, ); @@ -554,6 +569,7 @@ function executeFieldsSerially( path: Path | undefined, groupedFieldSet: GroupedFieldSet, incrementalDataRecord: InitialResultRecord, + abortSignal: AbortSignal | undefined, ): PromiseOrValue> { return promiseReduce( groupedFieldSet, @@ -566,6 +582,7 @@ function executeFieldsSerially( fieldGroup, fieldPath, incrementalDataRecord, + abortSignal, ); if (result === undefined) { return results; @@ -594,6 +611,7 @@ function executeFields( path: Path | undefined, groupedFieldSet: GroupedFieldSet, incrementalDataRecord: IncrementalDataRecord, + abortSignal: AbortSignal | undefined, ): PromiseOrValue> { const results = Object.create(null); let containsPromise = false; @@ -608,6 +626,7 @@ function executeFields( fieldGroup, fieldPath, incrementalDataRecord, + abortSignal, ); if (result !== undefined) { @@ -652,6 +671,7 @@ function executeField( fieldGroup: FieldGroup, path: Path, incrementalDataRecord: IncrementalDataRecord, + abortSignal: AbortSignal | undefined, ): PromiseOrValue { const fieldName = fieldGroup[0].name.value; const fieldDef = exeContext.schema.getField(parentType, fieldName); @@ -688,7 +708,11 @@ function executeField( // used to represent an authenticated user, or request-specific caches. const contextValue = exeContext.contextValue; - result = resolveFn(source, args, contextValue, info); + if (abortSignal?.aborted) { + abortSignal.throwIfAborted(); + } + + result = resolveFn(source, args, contextValue, info, abortSignal); if (isPromise(result)) { return completePromisedValue( @@ -699,6 +723,7 @@ function executeField( path, result, incrementalDataRecord, + abortSignal, ); } @@ -736,6 +761,13 @@ function executeField( return null; } + const abortController = new AbortController(); + let removeAbortListener: (() => void) | undefined; + if (abortSignal !== undefined) { + removeAbortListener = addAbortListener(abortSignal, () => + abortController.abort(), + ); + } let completed; try { completed = completeNonLeafValue( @@ -746,8 +778,11 @@ function executeField( path, result, incrementalDataRecord, + abortController.signal, ); } catch (rawError) { + removeAbortListener?.(); + abortController.abort(); handleFieldError( rawError, exeContext, @@ -763,19 +798,29 @@ function executeField( if (isPromise(completed)) { // Note: we don't rely on a `catch` method, but we do expect "thenable" // to take a second callback for the error case. - return completed.then(undefined, (rawError) => { - handleFieldError( - rawError, - exeContext, - returnType, - fieldGroup, - path, - incrementalDataRecord, - ); - exeContext.incrementalPublisher.filter(path, incrementalDataRecord); - return null; - }); + return completed.then( + (resolved) => { + removeAbortListener?.(); + return resolved; + }, + (rawError) => { + removeAbortListener?.(); + abortController.abort(); + handleFieldError( + rawError, + exeContext, + returnType, + fieldGroup, + path, + incrementalDataRecord, + ); + exeContext.incrementalPublisher.filter(path, incrementalDataRecord); + return null; + }, + ); } + + removeAbortListener?.(); return completed; } @@ -848,6 +893,7 @@ function completeNonLeafValue( path: Path, result: unknown, incrementalDataRecord: IncrementalDataRecord, + abortSignal: AbortSignal, ): PromiseOrValue { // If field type is List, complete each item in the list with the inner type if (isListType(nullableType)) { @@ -859,6 +905,7 @@ function completeNonLeafValue( path, result, incrementalDataRecord, + abortSignal, ); } @@ -873,6 +920,7 @@ function completeNonLeafValue( path, result, incrementalDataRecord, + abortSignal, ); } @@ -886,6 +934,7 @@ function completeNonLeafValue( path, result, incrementalDataRecord, + abortSignal, ); } /* c8 ignore next 6 */ @@ -904,6 +953,7 @@ async function completePromisedValue( path: Path, result: Promise, incrementalDataRecord: IncrementalDataRecord, + abortSignal: AbortSignal | undefined, ): Promise { let resolved; let nullableType: GraphQLNullableOutputType; @@ -944,6 +994,13 @@ async function completePromisedValue( return null; } + const abortController = new AbortController(); + let removeAbortListener: (() => void) | undefined; + if (abortSignal !== undefined) { + removeAbortListener = addAbortListener(abortSignal, () => + abortController.abort(), + ); + } try { let completed = completeNonLeafValue( exeContext, @@ -953,12 +1010,16 @@ async function completePromisedValue( path, resolved, incrementalDataRecord, + abortController.signal, ); if (isPromise(completed)) { completed = await completed; } + removeAbortListener?.(); return completed; } catch (rawError) { + removeAbortListener?.(); + abortController.abort(); handleFieldError( rawError, exeContext, @@ -1041,6 +1102,7 @@ async function completeAsyncIteratorValue( path: Path, asyncIterator: AsyncIterator, incrementalDataRecord: IncrementalDataRecord, + abortSignal: AbortSignal, ): Promise> { const stream = getStreamValues(exeContext, fieldGroup, path); let containsPromise = false; @@ -1063,6 +1125,7 @@ async function completeAsyncIteratorValue( itemType, path, incrementalDataRecord, + abortSignal, stream.label, ); break; @@ -1090,6 +1153,7 @@ async function completeAsyncIteratorValue( info, itemPath, incrementalDataRecord, + abortSignal, ) ) { containsPromise = true; @@ -1111,6 +1175,7 @@ function completeListValue( path: Path, result: unknown, incrementalDataRecord: IncrementalDataRecord, + abortSignal: AbortSignal, ): PromiseOrValue> { const itemType = returnType.ofType; @@ -1125,6 +1190,7 @@ function completeListValue( path, asyncIterator, incrementalDataRecord, + abortSignal, ); } @@ -1161,6 +1227,7 @@ function completeListValue( info, itemType, previousIncrementalDataRecord, + abortSignal, stream.label, ); index++; @@ -1177,6 +1244,7 @@ function completeListValue( info, itemPath, incrementalDataRecord, + abortSignal, ) ) { containsPromise = true; @@ -1202,6 +1270,7 @@ function completeListItemValue( info: GraphQLResolveInfo, itemPath: Path, incrementalDataRecord: IncrementalDataRecord, + abortSignal: AbortSignal, ): boolean { if (isPromise(item)) { completedResults.push( @@ -1213,6 +1282,7 @@ function completeListItemValue( itemPath, item, incrementalDataRecord, + abortSignal, ), ); @@ -1258,6 +1328,10 @@ function completeListItemValue( return false; } + const abortController = new AbortController(); + const removeAbortListener = addAbortListener(abortSignal, () => + abortController.abort(), + ); let completedItem; try { completedItem = completeNonLeafValue( @@ -1268,8 +1342,11 @@ function completeListItemValue( itemPath, item, incrementalDataRecord, + abortController.signal, ); } catch (rawError) { + removeAbortListener(); + abortController.abort(); handleFieldError( rawError, exeContext, @@ -1287,23 +1364,35 @@ function completeListItemValue( // Note: we don't rely on a `catch` method, but we do expect "thenable" // to take a second callback for the error case. completedResults.push( - completedItem.then(undefined, (rawError) => { - handleFieldError( - rawError, - exeContext, - itemType, - fieldGroup, - itemPath, - incrementalDataRecord, - ); - exeContext.incrementalPublisher.filter(itemPath, incrementalDataRecord); - return null; - }), + completedItem.then( + (resolved) => { + removeAbortListener(); + return resolved; + }, + (rawError) => { + removeAbortListener(); + abortController.abort(); + handleFieldError( + rawError, + exeContext, + itemType, + fieldGroup, + itemPath, + incrementalDataRecord, + ); + exeContext.incrementalPublisher.filter( + itemPath, + incrementalDataRecord, + ); + return null; + }, + ), ); return true; } + removeAbortListener(); completedResults.push(completedItem); return false; @@ -1339,6 +1428,7 @@ function completeAbstractValue( path: Path, result: unknown, incrementalDataRecord: IncrementalDataRecord, + abortSignal: AbortSignal, ): PromiseOrValue> { const resolveTypeFn = returnType.resolveType ?? exeContext.typeResolver; const contextValue = exeContext.contextValue; @@ -1361,6 +1451,7 @@ function completeAbstractValue( path, result, incrementalDataRecord, + abortSignal, ), ); } @@ -1380,6 +1471,7 @@ function completeAbstractValue( path, result, incrementalDataRecord, + abortSignal, ); } @@ -1449,6 +1541,7 @@ function completeObjectValue( path: Path, result: unknown, incrementalDataRecord: IncrementalDataRecord, + abortSignal: AbortSignal, ): PromiseOrValue> { // If there is an isTypeOf predicate function, call it with the // current result. If isTypeOf returns false, then raise an error rather @@ -1468,6 +1561,7 @@ function completeObjectValue( path, result, incrementalDataRecord, + abortSignal, ); }); } @@ -1484,6 +1578,7 @@ function completeObjectValue( path, result, incrementalDataRecord, + abortSignal, ); } @@ -1505,6 +1600,7 @@ function collectAndExecuteSubfields( path: Path, result: unknown, incrementalDataRecord: IncrementalDataRecord, + abortSignal: AbortSignal, ): PromiseOrValue> { // Collect sub-fields to execute to complete this value. const { groupedFieldSet: subGroupedFieldSet, patches: subPatches } = @@ -1517,6 +1613,7 @@ function collectAndExecuteSubfields( path, subGroupedFieldSet, incrementalDataRecord, + abortSignal, ); for (const subPatch of subPatches) { @@ -1527,6 +1624,7 @@ function collectAndExecuteSubfields( result, subPatchGroupedFieldSet, incrementalDataRecord, + abortSignal, label, path, ); @@ -1737,8 +1835,14 @@ function createSourceEventStreamImpl( function executeSubscription( exeContext: ExecutionContext, ): PromiseOrValue> { - const { schema, fragments, operation, variableValues, rootValue } = - exeContext; + const { + schema, + fragments, + operation, + variableValues, + rootValue, + abortSignal, + } = exeContext; const rootType = schema.getSubscriptionType(); if (rootType == null) { @@ -1793,7 +1897,7 @@ function executeSubscription( // Call the `subscribe()` resolver or the default resolver to produce an // AsyncIterable yielding raw payloads. const resolveFn = fieldDef.subscribe ?? exeContext.subscribeFieldResolver; - const result = resolveFn(rootValue, args, contextValue, info); + const result = resolveFn(rootValue, args, contextValue, info, abortSignal); if (isPromise(result)) { return result.then(assertEventStream).then(undefined, (error) => { @@ -1829,6 +1933,7 @@ function executeDeferredFragment( sourceValue: unknown, fields: GroupedFieldSet, parentContext: IncrementalDataRecord, + abortSignal: AbortSignal | undefined, label?: string, path?: Path, ): void { @@ -1849,6 +1954,7 @@ function executeDeferredFragment( path, fields, incrementalDataRecord, + abortSignal, ); if (isPromise(promiseOrData)) { @@ -1890,6 +1996,7 @@ function executeStreamField( info: GraphQLResolveInfo, itemType: GraphQLOutputType, parentContext: IncrementalDataRecord, + abortSignal: AbortSignal, label?: string, ): SubsequentDataRecord { const incrementalPublisher = exeContext.incrementalPublisher; @@ -1909,6 +2016,7 @@ function executeStreamField( itemPath, item, incrementalDataRecord, + abortSignal, ).then( (value) => incrementalPublisher.completeStreamItemsRecord(incrementalDataRecord, [ @@ -1929,8 +2037,8 @@ function executeStreamField( } let completedItem: PromiseOrValue; + let nullableType: GraphQLNullableOutputType; try { - let nullableType: GraphQLNullableOutputType; try { if (item instanceof Error) { throw item; @@ -1975,7 +2083,18 @@ function executeStreamField( ]); return incrementalDataRecord; } + } catch (error) { + incrementalPublisher.addFieldError(incrementalDataRecord, error); + incrementalPublisher.filter(path, incrementalDataRecord); + incrementalPublisher.completeStreamItemsRecord(incrementalDataRecord, null); + return incrementalDataRecord; + } + const abortController = new AbortController(); + const removeAbortListener = addAbortListener(abortSignal, () => + abortController.abort(), + ); + try { try { completedItem = completeNonLeafValue( exeContext, @@ -1985,8 +2104,11 @@ function executeStreamField( itemPath, item, incrementalDataRecord, + abortController.signal, ); } catch (rawError) { + removeAbortListener(); + abortController.abort(); handleFieldError( rawError, exeContext, @@ -2010,18 +2132,29 @@ function executeStreamField( if (isPromise(completedItem)) { completedItem - .then(undefined, (rawError) => { - handleFieldError( - rawError, - exeContext, - itemType, - fieldGroup, - itemPath, - incrementalDataRecord, - ); - exeContext.incrementalPublisher.filter(itemPath, incrementalDataRecord); - return null; - }) + .then( + (resolvedItem) => { + removeAbortListener(); + return resolvedItem; + }, + (rawError) => { + removeAbortListener(); + abortController.abort(); + handleFieldError( + rawError, + exeContext, + itemType, + fieldGroup, + itemPath, + incrementalDataRecord, + ); + exeContext.incrementalPublisher.filter( + itemPath, + incrementalDataRecord, + ); + return null; + }, + ) .then( (value) => incrementalPublisher.completeStreamItemsRecord( @@ -2041,6 +2174,7 @@ function executeStreamField( return incrementalDataRecord; } + removeAbortListener(); incrementalPublisher.completeStreamItemsRecord(incrementalDataRecord, [ completedItem, ]); @@ -2056,6 +2190,7 @@ async function executeStreamAsyncIteratorItem( incrementalDataRecord: StreamItemsRecord, path: Path, itemPath: Path, + abortSignal: AbortSignal, ): Promise> { let item; try { @@ -2108,6 +2243,10 @@ async function executeStreamAsyncIteratorItem( return { done: false, value: null }; } + const abortController = new AbortController(); + const removeAbortListener = addAbortListener(abortSignal, () => + abortController.abort(), + ); let completedItem; try { completedItem = completeNonLeafValue( @@ -2118,8 +2257,11 @@ async function executeStreamAsyncIteratorItem( itemPath, item, incrementalDataRecord, + abortController.signal, ); } catch (rawError) { + removeAbortListener(); + abortController.abort(); handleFieldError( rawError, exeContext, @@ -2133,19 +2275,28 @@ async function executeStreamAsyncIteratorItem( } if (isPromise(completedItem)) { - completedItem = completedItem.then(undefined, (rawError) => { - handleFieldError( - rawError, - exeContext, - itemType, - fieldGroup, - itemPath, - incrementalDataRecord, - ); - exeContext.incrementalPublisher.filter(itemPath, incrementalDataRecord); - return null; - }); + completedItem = completedItem.then( + (resolvedItem) => { + removeAbortListener(); + return resolvedItem; + }, + (rawError) => { + removeAbortListener(); + abortController.abort(); + handleFieldError( + rawError, + exeContext, + itemType, + fieldGroup, + itemPath, + incrementalDataRecord, + ); + exeContext.incrementalPublisher.filter(itemPath, incrementalDataRecord); + return null; + }, + ); } + removeAbortListener(); return { done: false, value: completedItem }; } @@ -2158,6 +2309,7 @@ async function executeStreamAsyncIterator( itemType: GraphQLOutputType, path: Path, parentContext: IncrementalDataRecord, + abortSignal: AbortSignal, label?: string, ): Promise { const incrementalPublisher = exeContext.incrementalPublisher; @@ -2186,6 +2338,7 @@ async function executeStreamAsyncIterator( incrementalDataRecord, path, itemPath, + abortSignal, ); } catch (error) { incrementalPublisher.addFieldError(incrementalDataRecord, error); diff --git a/src/jsutils/addAbortListener.ts b/src/jsutils/addAbortListener.ts new file mode 100644 index 00000000000..cdada905311 --- /dev/null +++ b/src/jsutils/addAbortListener.ts @@ -0,0 +1,64 @@ +type Callback = () => void; +interface AbortInfo { + listeners: Set; + dispose: Callback; +} +type Cache = WeakMap; + +let maybeCache: Cache | undefined; + +/** + * Helper function to add a callback to be triggered when the abort signal fires. + * Returns a function that will remove the callback when called. + * + * This helper function also avoids hitting the max listener limit on AbortSignals, + * which could be a common occurrence when setting up multiple contingent + * abort signals. + */ +export function addAbortListener( + abortSignal: AbortSignal, + callback: Callback, +): Callback { + if (abortSignal.aborted) { + callback(); + return () => { + /* noop */ + }; + } + + const cache = (maybeCache ??= new WeakMap()); + + const abortInfo = cache.get(abortSignal); + + if (abortInfo !== undefined) { + abortInfo.listeners.add(callback); + return () => removeAbortListener(abortInfo, callback); + } + + const listeners = new Set([callback]); + const onAbort = () => triggerCallbacks(listeners); + const dispose = () => { + abortSignal.removeEventListener('abort', onAbort); + }; + const newAbortInfo = { listeners, dispose }; + cache.set(abortSignal, newAbortInfo); + abortSignal.addEventListener('abort', onAbort); + + return () => removeAbortListener(newAbortInfo, callback); +} + +function triggerCallbacks(listeners: Set): void { + for (const listener of listeners) { + listener(); + } +} + +function removeAbortListener(abortInfo: AbortInfo, callback: Callback): void { + const listeners = abortInfo.listeners; + + listeners.delete(callback); + + if (listeners.size === 0) { + abortInfo.dispose(); + } +} diff --git a/src/type/definition.ts b/src/type/definition.ts index 0ca4152bd2f..140d891bf37 100644 --- a/src/type/definition.ts +++ b/src/type/definition.ts @@ -883,6 +883,7 @@ export type GraphQLFieldResolver< args: TArgs, context: TContext, info: GraphQLResolveInfo, + abortSignal: AbortSignal | undefined, ) => TResult; export interface GraphQLResolveInfo {