From c2443326eb933cb413a664a724f976cb08121193 Mon Sep 17 00:00:00 2001 From: Koen Punt Date: Thu, 14 Sep 2023 13:24:47 +0200 Subject: [PATCH] Raise exception if Principal is required but not present See gh-790 --- ...thenticationPrincipalArgumentResolver.java | 15 +++- .../support/BatchLoaderHandlerMethod.java | 2 +- .../PrincipalMethodArgumentResolver.java | 28 +++++- ...gPrincipalMethodArgumentResolverTests.java | 90 ++++++++++++++++++- 4 files changed, 123 insertions(+), 12 deletions(-) diff --git a/spring-graphql/src/main/java/org/springframework/graphql/data/method/annotation/support/AuthenticationPrincipalArgumentResolver.java b/spring-graphql/src/main/java/org/springframework/graphql/data/method/annotation/support/AuthenticationPrincipalArgumentResolver.java index 3f1a8bcb7..48462e9b1 100644 --- a/spring-graphql/src/main/java/org/springframework/graphql/data/method/annotation/support/AuthenticationPrincipalArgumentResolver.java +++ b/spring-graphql/src/main/java/org/springframework/graphql/data/method/annotation/support/AuthenticationPrincipalArgumentResolver.java @@ -96,7 +96,7 @@ private static AuthenticationPrincipal findMethodAnnotation(MethodParameter para @Override public Object resolveArgument(MethodParameter parameter, DataFetchingEnvironment environment) throws Exception { - return getCurrentAuthentication() + return getCurrentAuthentication(parameter.isOptional()) .flatMap(auth -> Mono.justOrEmpty(resolvePrincipal(parameter, auth.getPrincipal()))) .transform((argument) -> isParameterMonoAssignable(parameter) ? Mono.just(argument) : argument); } @@ -106,9 +106,16 @@ private static boolean isParameterMonoAssignable(MethodParameter parameter) { return (Publisher.class.equals(type) || Mono.class.equals(type)); } - private Mono getCurrentAuthentication() { - return Mono.justOrEmpty(SecurityContextHolder.getContext().getAuthentication()) - .switchIfEmpty(ReactiveSecurityContextHolder.getContext().map(SecurityContext::getAuthentication)); + @SuppressWarnings("unchecked") + private Mono getCurrentAuthentication(boolean optional) { + Object principal = PrincipalMethodArgumentResolver.doResolve(optional); + if (principal instanceof Authentication) { + return Mono.just((Authentication) principal); + } + else if (principal instanceof Mono) { + return (Mono) principal; + } + return Mono.error(new IllegalStateException("Unexpected return value: " + principal)); } @Nullable diff --git a/spring-graphql/src/main/java/org/springframework/graphql/data/method/annotation/support/BatchLoaderHandlerMethod.java b/spring-graphql/src/main/java/org/springframework/graphql/data/method/annotation/support/BatchLoaderHandlerMethod.java index ae72c289b..cf1023be2 100644 --- a/spring-graphql/src/main/java/org/springframework/graphql/data/method/annotation/support/BatchLoaderHandlerMethod.java +++ b/spring-graphql/src/main/java/org/springframework/graphql/data/method/annotation/support/BatchLoaderHandlerMethod.java @@ -145,7 +145,7 @@ else if ("kotlin.coroutines.Continuation".equals(parameterType.getName())) { return null; } else if (springSecurityPresent && Principal.class.isAssignableFrom(parameter.getParameterType())) { - return PrincipalMethodArgumentResolver.doResolve(); + return PrincipalMethodArgumentResolver.doResolve(parameter.isOptional()); } else { throw new IllegalStateException(formatArgumentError(parameter, "Unexpected argument type.")); diff --git a/spring-graphql/src/main/java/org/springframework/graphql/data/method/annotation/support/PrincipalMethodArgumentResolver.java b/spring-graphql/src/main/java/org/springframework/graphql/data/method/annotation/support/PrincipalMethodArgumentResolver.java index cc45ddbb6..21864295a 100644 --- a/spring-graphql/src/main/java/org/springframework/graphql/data/method/annotation/support/PrincipalMethodArgumentResolver.java +++ b/spring-graphql/src/main/java/org/springframework/graphql/data/method/annotation/support/PrincipalMethodArgumentResolver.java @@ -16,15 +16,19 @@ package org.springframework.graphql.data.method.annotation.support; import java.security.Principal; +import java.util.function.Function; import graphql.schema.DataFetchingEnvironment; import org.springframework.core.MethodParameter; import org.springframework.graphql.data.method.HandlerMethodArgumentResolver; +import org.springframework.security.authentication.AuthenticationCredentialsNotFoundException; import org.springframework.security.core.Authentication; +import org.springframework.security.core.AuthenticationException; import org.springframework.security.core.context.ReactiveSecurityContextHolder; import org.springframework.security.core.context.SecurityContext; import org.springframework.security.core.context.SecurityContextHolder; +import reactor.core.publisher.Mono; /** * Resolver to obtain {@link Principal} from Spring Security context via @@ -50,13 +54,29 @@ public boolean supportsParameter(MethodParameter parameter) { @Override public Object resolveArgument(MethodParameter parameter, DataFetchingEnvironment environment) { - return doResolve(); + return doResolve(parameter.isOptional()); } - static Object doResolve() { + static Object doResolve(boolean optional) { Authentication authentication = SecurityContextHolder.getContext().getAuthentication(); - return (authentication != null ? authentication : - ReactiveSecurityContextHolder.getContext().map(SecurityContext::getAuthentication)); + + if (authentication != null) { + return authentication; + } + + return ReactiveSecurityContextHolder.getContext() + .switchIfEmpty(optional ? Mono.empty() : Mono.error(new AuthenticationCredentialsNotFoundException("SecurityContext not available"))) + .handle((context, sink) -> { + Authentication auth = context.getAuthentication(); + + if (auth != null) { + sink.next(auth); + } else if (!optional) { + sink.error(new AuthenticationCredentialsNotFoundException("An Authentication object was not found in the SecurityContext")); + } else { + sink.complete(); + } + }); } } diff --git a/spring-graphql/src/test/java/org/springframework/graphql/data/method/annotation/support/SchemaMappingPrincipalMethodArgumentResolverTests.java b/spring-graphql/src/test/java/org/springframework/graphql/data/method/annotation/support/SchemaMappingPrincipalMethodArgumentResolverTests.java index b0a00d42b..2a9c8866f 100644 --- a/spring-graphql/src/test/java/org/springframework/graphql/data/method/annotation/support/SchemaMappingPrincipalMethodArgumentResolverTests.java +++ b/spring-graphql/src/test/java/org/springframework/graphql/data/method/annotation/support/SchemaMappingPrincipalMethodArgumentResolverTests.java @@ -20,11 +20,14 @@ import java.time.Duration; import java.util.function.Function; +import graphql.GraphqlErrorBuilder; import io.micrometer.context.ContextSnapshot; import org.junit.jupiter.api.Nested; import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.ValueSource; +import org.springframework.graphql.execution.DataFetcherExceptionResolver; +import org.springframework.graphql.execution.ErrorType; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; import reactor.test.StepVerifier; @@ -63,6 +66,9 @@ public class SchemaMappingPrincipalMethodArgumentResolverTests { private final Function reactiveContextWriter = context -> ReactiveSecurityContextHolder.withAuthentication(this.authentication); + private final Function reactiveContextWriterWithoutAuthentication = context -> + ReactiveSecurityContextHolder.withSecurityContext(Mono.just(SecurityContextHolder.createEmptyContext())); + private final Function threadLocalContextWriter = context -> ContextSnapshot.captureAll().updateContext(context); @@ -100,6 +106,68 @@ void resolveFromThreadLocalContext(String field) { } } + @Test + void nullablePrincipalDoesntRequireSecurityContext() { + Mono responseMono = executeAsync( + "type Query { greetingMonoNullable: String }", "{ greetingMonoNullable }", + context -> context); + + ResponseHelper responseHelper = ResponseHelper.forResponse(responseMono); + + assertThat(responseHelper.errorCount()).isEqualTo(0); + } + + @Test + void nonNullPrincipalRequiresSecurityContext() { + DataFetcherExceptionResolver exceptionResolver = + DataFetcherExceptionResolver.forSingleError((ex, env) -> GraphqlErrorBuilder.newError(env) + .message("Resolved error: " + ex.getMessage()) + .errorType(ErrorType.UNAUTHORIZED) + .build()); + + Mono responseMono = executeAsync( + "type Query { greetingMono: String }", "{ greetingMono }", + context -> context, + exceptionResolver); + + ResponseHelper responseHelper = ResponseHelper.forResponse(responseMono); + + assertThat(responseHelper.errorCount()).isEqualTo(1); + assertThat(responseHelper.error(0).errorType()).isEqualTo("UNAUTHORIZED"); + assertThat(responseHelper.error(0).message()).isEqualTo("Resolved error: SecurityContext not available"); + } + + @Test + void nonNullPrincipalRequiresAuthentication() { + DataFetcherExceptionResolver exceptionResolver = + DataFetcherExceptionResolver.forSingleError((ex, env) -> GraphqlErrorBuilder.newError(env) + .message("Resolved error: " + ex.getMessage()) + .errorType(ErrorType.UNAUTHORIZED) + .build()); + + Mono responseMono = executeAsync( + "type Query { greetingMono: String }", "{ greetingMono }", + reactiveContextWriterWithoutAuthentication, + exceptionResolver); + + ResponseHelper responseHelper = ResponseHelper.forResponse(responseMono); + + assertThat(responseHelper.errorCount()).isEqualTo(1); + assertThat(responseHelper.error(0).errorType()).isEqualTo("UNAUTHORIZED"); + assertThat(responseHelper.error(0).message()).isEqualTo("Resolved error: An Authentication object was not found in the SecurityContext"); + } + + @Test + void nullablePrincipalDoesntRequireAuthentication() { + Mono responseMono = executeAsync( + "type Query { greetingMonoNullable: String }", "{ greetingMonoNullable }", + reactiveContextWriterWithoutAuthentication); + + ResponseHelper responseHelper = ResponseHelper.forResponse(responseMono); + + assertThat(responseHelper.errorCount()).isEqualTo(0); + } + private void testQuery(String field, Function contextWriter) { Mono responseMono = executeAsync( "type Query { " + field + ": String }", "{ " + field + " }", contextWriter); @@ -150,14 +218,24 @@ private void testSubscription(Function contextModifier) { private Mono executeAsync( String schema, String document, Function contextWriter) { + return executeAsync(schema, document, contextWriter, null); + } + + private Mono executeAsync( + String schema, String document, Function contextWriter, @Nullable DataFetcherExceptionResolver exceptionResolver) { AnnotationConfigApplicationContext context = new AnnotationConfigApplicationContext(); context.registerBean(GreetingController.class, () -> greetingController); context.refresh(); - TestExecutionGraphQlService graphQlService = GraphQlSetup.schemaContent(schema) - .runtimeWiringForAnnotatedControllers(context) - .toGraphQlService(); + GraphQlSetup graphQlSetup = GraphQlSetup.schemaContent(schema) + .runtimeWiringForAnnotatedControllers(context); + + if (exceptionResolver != null) { + graphQlSetup.exceptionResolver(exceptionResolver); + } + + TestExecutionGraphQlService graphQlService = graphQlSetup.toGraphQlService(); return Mono.delay(Duration.ofMillis(10)) .flatMap(aLong -> graphQlService.execute(document)) @@ -197,6 +275,12 @@ Mono greetingMono(Principal principal) { return Mono.just("Hello"); } + @QueryMapping + Mono greetingMonoNullable(@Nullable Principal principal) { + this.principal = principal; + return Mono.just("Hello"); + } + @SubscriptionMapping Flux greetingSubscription(Principal principal) { this.principal = principal;