diff --git a/datahub-frontend/app/auth/sso/oidc/OidcAuthorizationGenerator.java b/datahub-frontend/app/auth/sso/oidc/OidcAuthorizationGenerator.java index 18b482de916e3..3f864ed5abddf 100644 --- a/datahub-frontend/app/auth/sso/oidc/OidcAuthorizationGenerator.java +++ b/datahub-frontend/app/auth/sso/oidc/OidcAuthorizationGenerator.java @@ -1,8 +1,19 @@ package auth.sso.oidc; +import java.text.ParseException; import java.util.Map.Entry; import java.util.Optional; +import com.nimbusds.jose.Algorithm; +import com.nimbusds.jose.Header; +import com.nimbusds.jose.JWEAlgorithm; +import com.nimbusds.jose.JWSAlgorithm; +import com.nimbusds.jose.util.Base64URL; +import com.nimbusds.jose.util.JSONObjectUtils; +import com.nimbusds.jwt.EncryptedJWT; +import com.nimbusds.jwt.JWTParser; +import com.nimbusds.jwt.SignedJWT; +import net.minidev.json.JSONObject; import org.pac4j.core.authorization.generator.AuthorizationGenerator; import org.pac4j.core.context.WebContext; import org.pac4j.core.profile.AttributeLocation; @@ -14,7 +25,6 @@ import org.slf4j.LoggerFactory; import com.nimbusds.jwt.JWT; -import com.nimbusds.jwt.JWTParser; public class OidcAuthorizationGenerator implements AuthorizationGenerator { @@ -53,5 +63,32 @@ public Optional generate(WebContext context, UserProfile profile) { return Optional.ofNullable(profile); } + + private static JWT parse(final String s) throws ParseException { + final int firstDotPos = s.indexOf("."); + + if (firstDotPos == -1) { + throw new ParseException("Invalid JWT serialization: Missing dot delimiter(s)", 0); + } + + Base64URL header = new Base64URL(s.substring(0, firstDotPos)); + JSONObject jsonObject; + + try { + jsonObject = JSONObjectUtils.parse(header.decodeToString()); + } catch (ParseException e) { + throw new ParseException("Invalid unsecured/JWS/JWE header: " + e.getMessage(), 0); + } + + Algorithm alg = Header.parseAlgorithm(jsonObject); + + if (alg instanceof JWSAlgorithm) { + return SignedJWT.parse(s); + } else if (alg instanceof JWEAlgorithm) { + return EncryptedJWT.parse(s); + } else { + throw new AssertionError("Unexpected algorithm type: " + alg); + } + } }