Skip to content

Commit

Permalink
Raise exception if Principal is required but not present
Browse files Browse the repository at this point in the history
See gh-790
  • Loading branch information
koenpunt authored and rstoyanchev committed Sep 14, 2023
1 parent ab39246 commit c244332
Show file tree
Hide file tree
Showing 4 changed files with 123 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ private static AuthenticationPrincipal findMethodAnnotation(MethodParameter para

@Override
public Object resolveArgument(MethodParameter parameter, DataFetchingEnvironment environment) throws Exception {
return getCurrentAuthentication()
return getCurrentAuthentication(parameter.isOptional())
.flatMap(auth -> Mono.justOrEmpty(resolvePrincipal(parameter, auth.getPrincipal())))
.transform((argument) -> isParameterMonoAssignable(parameter) ? Mono.just(argument) : argument);
}
Expand All @@ -106,9 +106,16 @@ private static boolean isParameterMonoAssignable(MethodParameter parameter) {
return (Publisher.class.equals(type) || Mono.class.equals(type));
}

private Mono<Authentication> getCurrentAuthentication() {
return Mono.justOrEmpty(SecurityContextHolder.getContext().getAuthentication())
.switchIfEmpty(ReactiveSecurityContextHolder.getContext().map(SecurityContext::getAuthentication));
@SuppressWarnings("unchecked")
private Mono<Authentication> getCurrentAuthentication(boolean optional) {
Object principal = PrincipalMethodArgumentResolver.doResolve(optional);
if (principal instanceof Authentication) {
return Mono.just((Authentication) principal);
}
else if (principal instanceof Mono) {
return (Mono<Authentication>) principal;
}
return Mono.error(new IllegalStateException("Unexpected return value: " + principal));
}

@Nullable
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ else if ("kotlin.coroutines.Continuation".equals(parameterType.getName())) {
return null;
}
else if (springSecurityPresent && Principal.class.isAssignableFrom(parameter.getParameterType())) {
return PrincipalMethodArgumentResolver.doResolve();
return PrincipalMethodArgumentResolver.doResolve(parameter.isOptional());
}
else {
throw new IllegalStateException(formatArgumentError(parameter, "Unexpected argument type."));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,19 @@
package org.springframework.graphql.data.method.annotation.support;

import java.security.Principal;
import java.util.function.Function;

import graphql.schema.DataFetchingEnvironment;

import org.springframework.core.MethodParameter;
import org.springframework.graphql.data.method.HandlerMethodArgumentResolver;
import org.springframework.security.authentication.AuthenticationCredentialsNotFoundException;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.AuthenticationException;
import org.springframework.security.core.context.ReactiveSecurityContextHolder;
import org.springframework.security.core.context.SecurityContext;
import org.springframework.security.core.context.SecurityContextHolder;
import reactor.core.publisher.Mono;

/**
* Resolver to obtain {@link Principal} from Spring Security context via
Expand All @@ -50,13 +54,29 @@ public boolean supportsParameter(MethodParameter parameter) {

@Override
public Object resolveArgument(MethodParameter parameter, DataFetchingEnvironment environment) {
return doResolve();
return doResolve(parameter.isOptional());
}

static Object doResolve() {
static Object doResolve(boolean optional) {
Authentication authentication = SecurityContextHolder.getContext().getAuthentication();
return (authentication != null ? authentication :
ReactiveSecurityContextHolder.getContext().map(SecurityContext::getAuthentication));

if (authentication != null) {
return authentication;
}

return ReactiveSecurityContextHolder.getContext()
.switchIfEmpty(optional ? Mono.empty() : Mono.error(new AuthenticationCredentialsNotFoundException("SecurityContext not available")))
.handle((context, sink) -> {
Authentication auth = context.getAuthentication();

if (auth != null) {
sink.next(auth);
} else if (!optional) {
sink.error(new AuthenticationCredentialsNotFoundException("An Authentication object was not found in the SecurityContext"));
} else {
sink.complete();
}
});
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,14 @@
import java.time.Duration;
import java.util.function.Function;

import graphql.GraphqlErrorBuilder;
import io.micrometer.context.ContextSnapshot;
import org.junit.jupiter.api.Nested;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource;
import org.springframework.graphql.execution.DataFetcherExceptionResolver;
import org.springframework.graphql.execution.ErrorType;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import reactor.test.StepVerifier;
Expand Down Expand Up @@ -63,6 +66,9 @@ public class SchemaMappingPrincipalMethodArgumentResolverTests {
private final Function<Context, Context> reactiveContextWriter = context ->
ReactiveSecurityContextHolder.withAuthentication(this.authentication);

private final Function<Context, Context> reactiveContextWriterWithoutAuthentication = context ->
ReactiveSecurityContextHolder.withSecurityContext(Mono.just(SecurityContextHolder.createEmptyContext()));

private final Function<Context, Context> threadLocalContextWriter = context ->
ContextSnapshot.captureAll().updateContext(context);

Expand Down Expand Up @@ -100,6 +106,68 @@ void resolveFromThreadLocalContext(String field) {
}
}

@Test
void nullablePrincipalDoesntRequireSecurityContext() {
Mono<ExecutionGraphQlResponse> responseMono = executeAsync(
"type Query { greetingMonoNullable: String }", "{ greetingMonoNullable }",
context -> context);

ResponseHelper responseHelper = ResponseHelper.forResponse(responseMono);

assertThat(responseHelper.errorCount()).isEqualTo(0);
}

@Test
void nonNullPrincipalRequiresSecurityContext() {
DataFetcherExceptionResolver exceptionResolver =
DataFetcherExceptionResolver.forSingleError((ex, env) -> GraphqlErrorBuilder.newError(env)
.message("Resolved error: " + ex.getMessage())
.errorType(ErrorType.UNAUTHORIZED)
.build());

Mono<ExecutionGraphQlResponse> responseMono = executeAsync(
"type Query { greetingMono: String }", "{ greetingMono }",
context -> context,
exceptionResolver);

ResponseHelper responseHelper = ResponseHelper.forResponse(responseMono);

assertThat(responseHelper.errorCount()).isEqualTo(1);
assertThat(responseHelper.error(0).errorType()).isEqualTo("UNAUTHORIZED");
assertThat(responseHelper.error(0).message()).isEqualTo("Resolved error: SecurityContext not available");
}

@Test
void nonNullPrincipalRequiresAuthentication() {
DataFetcherExceptionResolver exceptionResolver =
DataFetcherExceptionResolver.forSingleError((ex, env) -> GraphqlErrorBuilder.newError(env)
.message("Resolved error: " + ex.getMessage())
.errorType(ErrorType.UNAUTHORIZED)
.build());

Mono<ExecutionGraphQlResponse> responseMono = executeAsync(
"type Query { greetingMono: String }", "{ greetingMono }",
reactiveContextWriterWithoutAuthentication,
exceptionResolver);

ResponseHelper responseHelper = ResponseHelper.forResponse(responseMono);

assertThat(responseHelper.errorCount()).isEqualTo(1);
assertThat(responseHelper.error(0).errorType()).isEqualTo("UNAUTHORIZED");
assertThat(responseHelper.error(0).message()).isEqualTo("Resolved error: An Authentication object was not found in the SecurityContext");
}

@Test
void nullablePrincipalDoesntRequireAuthentication() {
Mono<ExecutionGraphQlResponse> responseMono = executeAsync(
"type Query { greetingMonoNullable: String }", "{ greetingMonoNullable }",
reactiveContextWriterWithoutAuthentication);

ResponseHelper responseHelper = ResponseHelper.forResponse(responseMono);

assertThat(responseHelper.errorCount()).isEqualTo(0);
}

private void testQuery(String field, Function<Context, Context> contextWriter) {
Mono<ExecutionGraphQlResponse> responseMono = executeAsync(
"type Query { " + field + ": String }", "{ " + field + " }", contextWriter);
Expand Down Expand Up @@ -150,14 +218,24 @@ private void testSubscription(Function<Context, Context> contextModifier) {

private Mono<ExecutionGraphQlResponse> executeAsync(
String schema, String document, Function<Context, Context> contextWriter) {
return executeAsync(schema, document, contextWriter, null);
}

private Mono<ExecutionGraphQlResponse> executeAsync(
String schema, String document, Function<Context, Context> contextWriter, @Nullable DataFetcherExceptionResolver exceptionResolver) {

AnnotationConfigApplicationContext context = new AnnotationConfigApplicationContext();
context.registerBean(GreetingController.class, () -> greetingController);
context.refresh();

TestExecutionGraphQlService graphQlService = GraphQlSetup.schemaContent(schema)
.runtimeWiringForAnnotatedControllers(context)
.toGraphQlService();
GraphQlSetup graphQlSetup = GraphQlSetup.schemaContent(schema)
.runtimeWiringForAnnotatedControllers(context);

if (exceptionResolver != null) {
graphQlSetup.exceptionResolver(exceptionResolver);
}

TestExecutionGraphQlService graphQlService = graphQlSetup.toGraphQlService();

return Mono.delay(Duration.ofMillis(10))
.flatMap(aLong -> graphQlService.execute(document))
Expand Down Expand Up @@ -197,6 +275,12 @@ Mono<String> greetingMono(Principal principal) {
return Mono.just("Hello");
}

@QueryMapping
Mono<String> greetingMonoNullable(@Nullable Principal principal) {
this.principal = principal;
return Mono.just("Hello");
}

@SubscriptionMapping
Flux<String> greetingSubscription(Principal principal) {
this.principal = principal;
Expand Down

0 comments on commit c244332

Please sign in to comment.