diff --git a/spring-graphql/src/main/java/org/springframework/graphql/observation/GraphQlObservationInstrumentation.java b/spring-graphql/src/main/java/org/springframework/graphql/observation/GraphQlObservationInstrumentation.java index ac33db1dd..475675534 100644 --- a/spring-graphql/src/main/java/org/springframework/graphql/observation/GraphQlObservationInstrumentation.java +++ b/spring-graphql/src/main/java/org/springframework/graphql/observation/GraphQlObservationInstrumentation.java @@ -155,13 +155,13 @@ public DataFetcher instrumentDataFetcher(DataFetcher dataFetcher, throw new CompletionException(error); } dataFetcherObservation.stop(); - return wrapAsDataFetcherResult(result, dataFetcherObservation); + return wrapAsDataFetcherResult(result, dataFetcherObservation, environment.getLocalContext()); }); } else { observationContext.setValue(value); dataFetcherObservation.stop(); - return wrapAsDataFetcherResult(value, dataFetcherObservation); + return wrapAsDataFetcherResult(value, dataFetcherObservation, environment.getLocalContext()); } } catch (Throwable throwable) { @@ -187,7 +187,8 @@ private static Observation getCurrentObservation(DataFetchingEnvironment environ return currentObservation; } - private static DataFetcherResult wrapAsDataFetcherResult(Object value, Observation dataFetcherObservation) { + private static DataFetcherResult wrapAsDataFetcherResult(Object value, Observation dataFetcherObservation, + @Nullable GraphQLContext dataFetcherLocalContext) { if (value instanceof DataFetcherResult result) { if (result.getLocalContext() == null) { return result.transform(builder -> builder.localContext(GraphQLContext.newContext().of(ObservationThreadLocalAccessor.KEY, dataFetcherObservation).build())); @@ -201,9 +202,11 @@ else if (result.getLocalContext() instanceof GraphQLContext) { return result; } else { + GraphQLContext localContext = dataFetcherLocalContext == null ? + GraphQLContext.getDefault() : dataFetcherLocalContext; return DataFetcherResult.newResult() .data(value) - .localContext(GraphQLContext.newContext().of(ObservationThreadLocalAccessor.KEY, dataFetcherObservation).build()) + .localContext(localContext.put(ObservationThreadLocalAccessor.KEY, dataFetcherObservation)) .build(); }