Skip to content

Commit

Permalink
feat: support custom TimeProvider when validating tokens (introspect,…
Browse files Browse the repository at this point in the history
… userinfo) (#730)

* feat: support custom TimeProvider when validating tokens
* add verify function to OAuth2TokenProvider and use the TimeProvider if set - i.e. via overriding Nimbus DefaultJWTClaimsVerifier's currentTime function
* refactor tests for simplicity
* fix: use jwkSelector when returning keys in KeyProvider
* necessary to use jwkSelector to only get keys for supported algorithm
* use Instant.now for currentTime when TimeProvider not set
  • Loading branch information
tommytroen authored Aug 21, 2024
1 parent 014faf0 commit 5fe5d8e
Show file tree
Hide file tree
Showing 7 changed files with 140 additions and 85 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,11 @@ package no.nav.security.mock.oauth2.introspect
import com.fasterxml.jackson.annotation.JsonInclude
import com.fasterxml.jackson.annotation.JsonProperty
import com.nimbusds.jwt.JWTClaimsSet
import com.nimbusds.jwt.SignedJWT
import com.nimbusds.oauth2.sdk.OAuth2Error
import com.nimbusds.oauth2.sdk.id.Issuer
import mu.KotlinLogging
import no.nav.security.mock.oauth2.OAuth2Exception
import no.nav.security.mock.oauth2.extensions.OAuth2Endpoints.INTROSPECT
import no.nav.security.mock.oauth2.extensions.issuerId
import no.nav.security.mock.oauth2.extensions.toIssuerUrl
import no.nav.security.mock.oauth2.extensions.verifySignatureAndIssuer
import no.nav.security.mock.oauth2.http.OAuth2HttpRequest
import no.nav.security.mock.oauth2.http.Route
import no.nav.security.mock.oauth2.http.json
Expand Down Expand Up @@ -51,12 +47,10 @@ internal fun Route.Builder.introspect(tokenProvider: OAuth2TokenProvider) =
}

private fun OAuth2HttpRequest.verifyToken(tokenProvider: OAuth2TokenProvider): JWTClaimsSet? {
val tokenString = this.formParameters.get("token")
val issuer = url.toIssuerUrl()
val jwkSet = tokenProvider.publicJwkSet(issuer.issuerId())
val algorithm = tokenProvider.getAlgorithm()
return try {
SignedJWT.parse(tokenString).verifySignatureAndIssuer(Issuer(issuer.toString()), jwkSet, algorithm)
this.formParameters.get("token")?.let {
tokenProvider.verify(url.toIssuerUrl(), it)
}
} catch (e: Exception) {
log.debug("token_introspection: failed signature validation")
return null
Expand Down
13 changes: 12 additions & 1 deletion src/main/kotlin/no/nav/security/mock/oauth2/token/KeyProvider.kt
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,12 @@ package no.nav.security.mock.oauth2.token
import com.nimbusds.jose.JWSAlgorithm
import com.nimbusds.jose.jwk.ECKey
import com.nimbusds.jose.jwk.JWK
import com.nimbusds.jose.jwk.JWKSelector
import com.nimbusds.jose.jwk.JWKSet
import com.nimbusds.jose.jwk.KeyType
import com.nimbusds.jose.jwk.RSAKey
import com.nimbusds.jose.jwk.source.JWKSource
import com.nimbusds.jose.proc.SecurityContext
import no.nav.security.mock.oauth2.OAuth2Exception
import java.util.concurrent.ConcurrentHashMap
import java.util.concurrent.LinkedBlockingDeque
Expand All @@ -15,7 +18,7 @@ open class KeyProvider
constructor(
private val initialKeys: List<JWK> = keysFromFile(INITIAL_KEYS_FILE),
private val algorithm: String = JWSAlgorithm.RS256.name,
) {
) : JWKSource<SecurityContext> {
private val signingKeys: ConcurrentHashMap<String, JWK> = ConcurrentHashMap()

private var generator: KeyGenerator = KeyGenerator(JWSAlgorithm.parse(algorithm))
Expand All @@ -35,9 +38,11 @@ open class KeyProvider
KeyType.RSA.value -> {
RSAKey.Builder(polledJwk.toRSAKey()).keyID(keyId).build()
}

KeyType.EC.value -> {
ECKey.Builder(polledJwk.toECKey()).keyID(keyId).build()
}

else -> {
throw OAuth2Exception("Unsupported key type: ${polledJwk.keyType.value}")
}
Expand All @@ -63,4 +68,10 @@ open class KeyProvider
return emptyList()
}
}

override fun get(
jwkSelector: JWKSelector?,
context: SecurityContext?,
): MutableList<JWK> = jwkSelector?.select(JWKSet(signingKeys.values.toList()).toPublicJWKSet()) ?: mutableListOf()

}
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,13 @@ import com.nimbusds.jose.crypto.ECDSASigner
import com.nimbusds.jose.crypto.RSASSASigner
import com.nimbusds.jose.jwk.JWKSet
import com.nimbusds.jose.jwk.KeyType
import com.nimbusds.jose.proc.DefaultJOSEObjectTypeVerifier
import com.nimbusds.jose.proc.JWSVerificationKeySelector
import com.nimbusds.jose.proc.SecurityContext
import com.nimbusds.jwt.JWTClaimsSet
import com.nimbusds.jwt.SignedJWT
import com.nimbusds.jwt.proc.DefaultJWTClaimsVerifier
import com.nimbusds.jwt.proc.DefaultJWTProcessor
import com.nimbusds.oauth2.sdk.TokenRequest
import no.nav.security.mock.oauth2.OAuth2Exception
import no.nav.security.mock.oauth2.extensions.clientIdAsString
Expand Down Expand Up @@ -106,6 +111,11 @@ class OAuth2TokenProvider
builder.build()
}.sign(issuerId, JOSEObjectType.JWT.type)

fun verify(
issuerUrl: HttpUrl,
token: String,
): JWTClaimsSet = SignedJWT.parse(token).verify(issuerUrl)

private fun JWTClaimsSet.sign(
issuerId: String,
type: String,
Expand All @@ -124,6 +134,7 @@ class OAuth2TokenProvider
sign(RSASSASigner(key.toRSAKey().toPrivateKey()))
}
}

supported && keyType == KeyType.EC.value -> {
SignedJWT(
jwsHeader(key.keyID, type, algorithm),
Expand All @@ -132,6 +143,7 @@ class OAuth2TokenProvider
sign(ECDSASigner(key.toECKey().toECPrivateKey()))
}
}

else -> {
throw OAuth2Exception("Unsupported algorithm: ${algorithm.name}")
}
Expand Down Expand Up @@ -178,4 +190,20 @@ class OAuth2TokenProvider
}

private fun Instant?.orNow(): Instant = this ?: Instant.now()

private fun SignedJWT.verify(issuerUrl: HttpUrl): JWTClaimsSet {
val jwtProcessor =
DefaultJWTProcessor<SecurityContext?>().apply {
jwsTypeVerifier = DefaultJOSEObjectTypeVerifier(JOSEObjectType("JWT"))
jwsKeySelector = JWSVerificationKeySelector(keyProvider.algorithm(), keyProvider)
jwtClaimsSetVerifier =
object : DefaultJWTClaimsVerifier<SecurityContext?>(
JWTClaimsSet.Builder().issuer(issuerUrl.toString()).build(),
HashSet(listOf("iat", "exp")),
) {
override fun currentTime(): Date = Date.from(timeProvider().orNow())
}
}
return jwtProcessor.process(this, null)
}
}
15 changes: 3 additions & 12 deletions src/main/kotlin/no/nav/security/mock/oauth2/userinfo/UserInfo.kt
Original file line number Diff line number Diff line change
@@ -1,16 +1,12 @@
package no.nav.security.mock.oauth2.userinfo

import com.nimbusds.jwt.JWTClaimsSet
import com.nimbusds.jwt.SignedJWT
import com.nimbusds.oauth2.sdk.ErrorObject
import com.nimbusds.oauth2.sdk.http.HTTPResponse
import com.nimbusds.oauth2.sdk.id.Issuer
import mu.KotlinLogging
import no.nav.security.mock.oauth2.OAuth2Exception
import no.nav.security.mock.oauth2.extensions.OAuth2Endpoints.USER_INFO
import no.nav.security.mock.oauth2.extensions.issuerId
import no.nav.security.mock.oauth2.extensions.toIssuerUrl
import no.nav.security.mock.oauth2.extensions.verifySignatureAndIssuer
import no.nav.security.mock.oauth2.http.OAuth2HttpRequest
import no.nav.security.mock.oauth2.http.Route
import no.nav.security.mock.oauth2.http.json
Expand All @@ -26,17 +22,12 @@ internal fun Route.Builder.userInfo(tokenProvider: OAuth2TokenProvider) =
json(claims)
}

private fun OAuth2HttpRequest.verifyBearerToken(tokenProvider: OAuth2TokenProvider): JWTClaimsSet {
val tokenString = this.headers.bearerToken()
val issuer = url.toIssuerUrl()
val jwkSet = tokenProvider.publicJwkSet(issuer.issuerId())
val algorithm = tokenProvider.getAlgorithm()
return try {
SignedJWT.parse(tokenString).verifySignatureAndIssuer(Issuer(issuer.toString()), jwkSet, algorithm)
private fun OAuth2HttpRequest.verifyBearerToken(tokenProvider: OAuth2TokenProvider): JWTClaimsSet =
try {
tokenProvider.verify(url.toIssuerUrl(), this.headers.bearerToken())
} catch (e: Exception) {
throw invalidToken(e.message ?: "could not verify bearer token")
}
}

private fun Headers.bearerToken(): String =
this["Authorization"]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ import no.nav.security.mock.oauth2.token.OAuth2TokenProvider
import okhttp3.Headers
import okhttp3.HttpUrl.Companion.toHttpUrl
import org.junit.jupiter.api.Test
import java.time.Instant
import java.time.temporal.ChronoUnit

internal class IntrospectTest {
private val rs384TokenProvider = OAuth2TokenProvider(keyProvider = KeyProvider(initialKeys = emptyList(), algorithm = JWSAlgorithm.RS384.name))
Expand Down Expand Up @@ -66,6 +68,29 @@ internal class IntrospectTest {
}
}

@Test
fun `introspect should return active and claims from token when using a custom timeProvider in the OAuth2TokenProvider`() {
val issuerUrl = "http://localhost/default"
val yesterday = Instant.now().minus(1, ChronoUnit.DAYS)
val tokenProvider = OAuth2TokenProvider(timeProvider = { yesterday })
val claims =
mapOf(
"iss" to issuerUrl,
"client_id" to "yolo",
"token_type" to "token",
"sub" to "foo",
)
val token = tokenProvider.jwt(claims)
val request = request("$issuerUrl$INTROSPECT", token.serialize())

routes { introspect(tokenProvider) }.invoke(request).asClue {
it.status shouldBe 200
val response = it.parse<Map<String, Any>>()
response shouldContainAll claims
response shouldContain ("active" to true)
}
}

@Test
fun `introspect should return active false when token is missing`() {
val url = "http://localhost/default$INTROSPECT"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,7 @@ import okhttp3.HttpUrl.Companion.toHttpUrl
import org.junit.jupiter.api.Test
import org.junit.jupiter.params.ParameterizedTest
import org.junit.jupiter.params.provider.ValueSource
import java.time.Clock
import java.time.Instant
import java.time.ZoneId
import java.time.temporal.ChronoUnit
import java.util.Date

Expand Down Expand Up @@ -106,87 +104,71 @@ internal class OAuth2TokenProviderRSATest {
val yesterday = Instant.now().minus(1, ChronoUnit.DAYS)
val tokenProvider = OAuth2TokenProvider(systemTime = yesterday)

tokenProvider
.exchangeAccessToken(
tokenRequest =
nimbusTokenRequest(
"id",
"grant_type" to GrantType.CLIENT_CREDENTIALS.value,
"scope" to "scope1",
),
issuerUrl = "http://default_if_not_overridden".toHttpUrl(),
claimsSet = tokenProvider.jwt(mapOf()).jwtClaimsSet,
oAuth2TokenCallback = DefaultOAuth2TokenCallback(),
).asClue {
it.jwtClaimsSet.issueTime shouldBe Date.from(tokenProvider.systemTime)
println(it.serialize())
}
tokenProvider.clientCredentialsToken("http://localhost/default").asClue {
it.jwtClaimsSet.issueTime shouldBe Date.from(tokenProvider.systemTime)
}

val now = Instant.now().minus(1, ChronoUnit.SECONDS)
OAuth2TokenProvider().clientCredentialsToken("http://localhost/default").asClue {
it.jwtClaimsSet.issueTime shouldBeAfter now
}
}

@Test
fun `token should have issuedAt set dynamically according to timeProvider`() {
val clock =
object : Clock() {
private var clock = systemDefaultZone()
val timeProvider =
object : TimeProvider {
var time = Instant.now()

override fun instant() = clock.instant()

override fun withZone(zone: ZoneId) = clock.withZone(zone)

override fun getZone() = clock.zone

fun fixed(instant: Instant) {
clock = fixed(instant, zone)
}
override fun invoke(): Instant = time
}

val tokenProvider = OAuth2TokenProvider { clock.instant() }
val tokenProvider = OAuth2TokenProvider(timeProvider = timeProvider)

val instant1 = Instant.parse("2000-12-03T10:15:30.00Z")
val instant2 = Instant.parse("2020-01-21T00:00:00.00Z")
instant1 shouldNotBe instant2

run {
clock.fixed(instant1)
tokenProvider.systemTime shouldBe instant1
timeProvider.time = instant1
tokenProvider.systemTime shouldBe instant1

tokenProvider.exchangeAccessToken(
tokenRequest =
nimbusTokenRequest(
"id",
"grant_type" to GrantType.CLIENT_CREDENTIALS.value,
"scope" to "scope1",
),
issuerUrl = "http://default_if_not_overridden".toHttpUrl(),
claimsSet = tokenProvider.jwt(mapOf()).jwtClaimsSet,
oAuth2TokenCallback = DefaultOAuth2TokenCallback(),
)
}.asClue {
tokenProvider.clientCredentialsToken("http://localhost/default").asClue {
it.jwtClaimsSet.issueTime shouldBe Date.from(instant1)
println(it.serialize())
}

run {
clock.fixed(instant2)
tokenProvider.systemTime shouldBe instant2
timeProvider.time = instant2
tokenProvider.systemTime shouldBe instant2

tokenProvider.exchangeAccessToken(
tokenRequest =
nimbusTokenRequest(
"id",
"grant_type" to GrantType.CLIENT_CREDENTIALS.value,
"scope" to "scope1",
),
issuerUrl = "http://default_if_not_overridden".toHttpUrl(),
claimsSet = tokenProvider.jwt(mapOf()).jwtClaimsSet,
oAuth2TokenCallback = DefaultOAuth2TokenCallback(),
)
}.asClue {
tokenProvider.clientCredentialsToken("http://localhost/default").asClue {
it.jwtClaimsSet.issueTime shouldBe Date.from(instant2)
println(it.serialize())
}
}

@Test
fun `token with issueTime set to yesterday should be able to validate with the verify function using the same timeprovider`() {
val yesterday = Instant.now().minus(1, ChronoUnit.DAYS)
val tokenProvider = OAuth2TokenProvider(timeProvider = { yesterday })

val token = tokenProvider.clientCredentialsToken("http://localhost/default")

token.jwtClaimsSet.issueTime shouldBe Date.from(tokenProvider.systemTime)

tokenProvider.verify("http://localhost/default".toHttpUrl(), token.serialize()).toJSONObject().asClue {
it shouldBe token.jwtClaimsSet.toJSONObject()
}
}

private fun OAuth2TokenProvider.clientCredentialsToken(issuerUrl: String): SignedJWT =
accessToken(
tokenRequest =
nimbusTokenRequest(
"client1",
"grant_type" to "client_credentials",
"scope" to "scope1",
),
issuerUrl = issuerUrl.toHttpUrl(),
oAuth2TokenCallback = DefaultOAuth2TokenCallback(),
)

private fun idToken(issuerUrl: String): SignedJWT =
tokenProvider.idToken(
tokenRequest =
Expand All @@ -198,4 +180,6 @@ internal class OAuth2TokenProviderRSATest {
issuerUrl = issuerUrl.toHttpUrl(),
oAuth2TokenCallback = DefaultOAuth2TokenCallback(),
)

private infix fun Date.shouldBeAfter(instant: Instant?) = this.after(Date.from(instant)) shouldBe true
}
Loading

0 comments on commit 5fe5d8e

Please sign in to comment.