diff --git a/spring-graphql/src/main/java/org/springframework/graphql/observation/GraphQlObservationInstrumentation.java b/spring-graphql/src/main/java/org/springframework/graphql/observation/GraphQlObservationInstrumentation.java index b3b115a8..1fee0bd6 100644 --- a/spring-graphql/src/main/java/org/springframework/graphql/observation/GraphQlObservationInstrumentation.java +++ b/spring-graphql/src/main/java/org/springframework/graphql/observation/GraphQlObservationInstrumentation.java @@ -190,15 +190,19 @@ private static Observation getCurrentObservation(DataFetchingEnvironment environ } private static DataFetchingEnvironment wrapDataFetchingEnvironment(DataFetchingEnvironment environment, Observation dataFetcherObservation) { - GraphQLContext.Builder localContextBuilder = GraphQLContext.newContext(); - if (environment.getLocalContext() instanceof GraphQLContext localContext) { - localContextBuilder.of(localContext); + if (environment.getLocalContext() == null || environment.getLocalContext() instanceof GraphQLContext) { + GraphQLContext.Builder localContextBuilder = GraphQLContext.newContext(); + if (environment.getLocalContext() instanceof GraphQLContext localContext) { + localContextBuilder.of(localContext); + } + localContextBuilder.of(ObservationThreadLocalAccessor.KEY, dataFetcherObservation); + return DataFetchingEnvironmentImpl + .newDataFetchingEnvironment(environment) + .localContext(localContextBuilder.build()) + .build(); } - localContextBuilder.of(ObservationThreadLocalAccessor.KEY, dataFetcherObservation); - return DataFetchingEnvironmentImpl - .newDataFetchingEnvironment(environment) - .localContext(localContextBuilder.build()) - .build(); + // do not wrap environment, there is an existing custom context + return environment; } diff --git a/spring-graphql/src/test/java/org/springframework/graphql/observation/GraphQlObservationInstrumentationTests.java b/spring-graphql/src/test/java/org/springframework/graphql/observation/GraphQlObservationInstrumentationTests.java index a2dd367a..8d6ad88a 100644 --- a/spring-graphql/src/test/java/org/springframework/graphql/observation/GraphQlObservationInstrumentationTests.java +++ b/spring-graphql/src/test/java/org/springframework/graphql/observation/GraphQlObservationInstrumentationTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2020-2023 the original author or authors. + * Copyright 2020-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -31,7 +31,15 @@ import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.MethodSource; -import org.springframework.graphql.*; + +import org.springframework.graphql.Author; +import org.springframework.graphql.Book; +import org.springframework.graphql.BookSource; +import org.springframework.graphql.ExecutionGraphQlRequest; +import org.springframework.graphql.ExecutionGraphQlResponse; +import org.springframework.graphql.GraphQlSetup; +import org.springframework.graphql.ResponseHelper; +import org.springframework.graphql.TestExecutionRequest; import org.springframework.graphql.execution.DataFetcherExceptionResolver; import org.springframework.graphql.execution.ErrorType; import reactor.core.publisher.Mono; @@ -319,4 +327,42 @@ void shouldNotOverrideExistingLocalContext() { ResponseHelper.forResponse(responseMono); } + @Test + void shouldNotOverrideCustomLocalContext() { + + String document = """ + { + bookById(id: 1) { + author { + firstName, + lastName + } + } + } + """; + DataFetcher> bookDataFetcher = environment -> DataFetcherResult.newResult() + .data(BookSource.getBook(1L)) + .localContext(new CustomLocalContext()) + .build(); + DataFetcher authorDataFetcher = environment -> BookSource.getAuthor(101L); + DataFetcher authorFirstNameDataFetcher = environment -> { + Object context = environment.getLocalContext(); + assertThat(context).isInstanceOf(CustomLocalContext.class); + return BookSource.getAuthor(101L).getFirstName(); + }; + + ExecutionGraphQlRequest request = TestExecutionRequest.forDocument(document); + Mono responseMono = graphQlSetup + .queryFetcher("bookById", bookDataFetcher) + .dataFetcher("Book", "author", authorDataFetcher) + .dataFetcher("Author", "firstName", authorFirstNameDataFetcher) + .toGraphQlService() + .execute(request); + ResponseHelper.forResponse(responseMono); + } + + static class CustomLocalContext { + + } + }