Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Trusted entitlements: Support signing static endpoints #1119

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -142,11 +142,12 @@ class HTTPClient(
val urlPathWithVersion = "/v1$path"
val connection: HttpURLConnection
val shouldSignResponse = signingManager.shouldVerifyEndpoint(endpoint)
val shouldAddNonce = shouldSignResponse && endpoint.needsNonceToPerformSigning
val nonce: String?
try {
val fullURL = URL(baseURL, urlPathWithVersion)

nonce = if (shouldSignResponse) signingManager.createRandomNonce() else null
nonce = if (shouldAddNonce) signingManager.createRandomNonce() else null
val headers = getHeaders(requestHeaders, urlPathWithVersion, refreshETag, nonce, shouldSignResponse)

val httpRequest = HTTPRequest(fullURL, headers, jsonBody)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,23 @@ sealed class Endpoint(val pathTemplate: String, val name: String) {
override fun getPath() = pathTemplate
}

val supportsSignatureValidation: Boolean
val supportsSignatureVerification: Boolean
get() = when (this) {
is GetCustomerInfo,
LogIn,
PostReceipt,
is GetOfferings,
GetProductEntitlementMapping,
->
true
is GetAmazonReceipt,
is PostAttributes,
PostDiagnostics,
->
false
}

val needsNonceToPerformSigning: Boolean
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is great 👍🏻

get() = when (this) {
is GetCustomerInfo,
LogIn,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import com.revenuecat.purchases.VerificationResult
import com.revenuecat.purchases.common.AppConfig
import com.revenuecat.purchases.common.errorLog
import com.revenuecat.purchases.common.networking.Endpoint
import com.revenuecat.purchases.common.verboseLog
import com.revenuecat.purchases.common.warnLog
import com.revenuecat.purchases.strings.NetworkStrings
import com.revenuecat.purchases.utils.Result
Expand Down Expand Up @@ -69,7 +70,7 @@ class SigningManager(
}

fun shouldVerifyEndpoint(endpoint: Endpoint): Boolean {
return endpoint.supportsSignatureValidation && signatureVerificationMode.shouldVerify
return endpoint.supportsSignatureVerification && signatureVerificationMode.shouldVerify
}

fun createRandomNonce(): String {
Expand Down Expand Up @@ -142,6 +143,7 @@ class SigningManager(
)

return if (verificationResult) {
verboseLog(NetworkStrings.VERIFICATION_SUCCESS.format(urlPath))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍🏻

VerificationResult.VERIFIED
} else {
errorLog(NetworkStrings.VERIFICATION_ERROR.format(urlPath))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,5 @@ object NetworkStrings {
" but none provided."
const val VERIFICATION_INVALID_SIZE = "Verification: Request to '%s' has signature with wrong size. '%s'"
const val VERIFICATION_ERROR = "Verification: Request to '%s' failed verification."
const val VERIFICATION_SUCCESS = "Verification: Request to '%s' verified successfully."
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import io.mockk.Runs
import io.mockk.every
import io.mockk.just
import io.mockk.mockk
import io.mockk.spyk
import io.mockk.verify
import org.assertj.core.api.Assertions.assertThat
import org.junit.Assert.fail
Expand Down Expand Up @@ -65,6 +66,14 @@ abstract class BaseBackendIntegrationTest {

@Before
fun setUp() {
setupTest()
}

abstract fun apiKey(): String

protected fun setupTest(
signatureVerificationMode: SignatureVerificationMode = SignatureVerificationMode.Disabled
) {
appConfig = mockk<AppConfig>().apply {
every { baseURL } returns URL("https://api.revenuecat.com")
every { store } returns Store.PLAY_STORE
Expand All @@ -87,14 +96,12 @@ abstract class BaseBackendIntegrationTest {
every { edit() } returns sharedPreferencesEditor
}
eTagManager = ETagManager(sharedPreferences)
signingManager = SigningManager(SignatureVerificationMode.Disabled, appConfig, apiKey())
signingManager = spyk(SigningManager(signatureVerificationMode, appConfig, apiKey()))
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here I changed to a spy so we can see and verify the calls more easily, but it's using the actual implementation. Since there is no way to know that the offerings/product-entitlement mapping calls were verified or not, this made easier to make sure we were verifying those endpoints

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since there is no way to know that the offerings/product-entitlement mapping calls were verified or not

I checked that by using the "fake invalid signature" thing. Did you add a way to do that in Android too?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, I meant in the "success" case, we should make sure we are verifying the signature. But right now there is no way to detect whether a SUCCESS or a NOT_REQUESTED happened. We could maybe if we could verify the logs, but I haven't added that in Android yet. Until then, this makes sure we are at least calling the signingManager.verify method, so it shouldn't be NOT_REQUESTED.

httpClient = HTTPClient(appConfig, eTagManager, diagnosticsTrackerIfEnabled = null, signingManager)
backendHelper = BackendHelper(apiKey(), dispatcher, appConfig, httpClient)
backend = Backend(appConfig, dispatcher, diagnosticsDispatcher, httpClient, backendHelper)
}

abstract fun apiKey(): String

protected fun ensureBlockFinishes(block: (CountDownLatch) -> Unit) {
val latch = CountDownLatch(1)
block(latch)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@ package com.revenuecat.purchases.backend_integration_tests
import com.revenuecat.purchases.PurchasesError
import com.revenuecat.purchases.common.networking.Endpoint
import com.revenuecat.purchases.common.offlineentitlements.ProductEntitlementMapping
import com.revenuecat.purchases.common.verification.SignatureVerificationMode
import io.mockk.verify
import org.assertj.core.api.Assertions.assertThat
import org.assertj.core.api.Assertions.fail
import org.junit.Test

class ProductionBackendIntegrationTest: BaseBackendIntegrationTest() {
Expand Down Expand Up @@ -38,5 +40,72 @@ class ProductionBackendIntegrationTest: BaseBackendIntegrationTest() {
sharedPreferencesEditor.putString("/v1${Endpoint.GetProductEntitlementMapping.getPath()}", any())
}
verify(exactly = 1) { sharedPreferencesEditor.apply() }
verify(exactly = 0) { signingManager.verifyResponse(any(), any(), any(), any(), any(), any()) }
}

@Test
fun `can perform verified product entitlement mapping backend request`() {
setupTest(SignatureVerificationMode.Enforced())
ensureBlockFinishes { latch ->
backend.getProductEntitlementMapping(
onSuccessHandler = { productEntitlementMapping ->
assertThat(productEntitlementMapping.mappings.size).isEqualTo(36)
assertThat(productEntitlementMapping.mappings["annual_freetrial"]).isEqualTo(
ProductEntitlementMapping.Mapping(
productIdentifier = "annual_freetrial",
basePlanId = "p1y",
entitlements = listOf("pro_cat")
)
)
latch.countDown()
},
onErrorHandler = {
fail("Expected success but got error: $it")
}
)
}
verify(exactly = 1) { signingManager.verifyResponse(any(), any(), any(), any(), any(), any()) }
}

@Test
fun `can perform offerings backend request`() {
ensureBlockFinishes { latch ->
backend.getOfferings(
appUserID = "test-user-id",
appInBackground = false,
onSuccess = { offeringsResponse ->
assertThat(offeringsResponse.length()).isPositive
latch.countDown()
},
onError = { _, _ ->
fail("Expected success")
}
)
}
verify(exactly = 1) {
// Verify we save the backend response in the shared preferences
sharedPreferencesEditor.putString("/v1${Endpoint.GetOfferings("test-user-id").getPath()}", any())
}
verify(exactly = 1) { sharedPreferencesEditor.apply() }
verify(exactly = 0) { signingManager.verifyResponse(any(), any(), any(), any(), any(), any()) }
}

@Test
fun `can perform verified offerings backend request`() {
setupTest(SignatureVerificationMode.Enforced())
ensureBlockFinishes { latch ->
backend.getOfferings(
appUserID = "test-user-id",
appInBackground = false,
onSuccess = { offeringsResponse ->
assertThat(offeringsResponse.length()).isPositive
latch.countDown()
},
onError = { error, _ ->
fail("Expected success. Got error: $error")
}
)
}
verify(exactly = 1) { signingManager.verifyResponse(any(), any(), any(), any(), any(), any()) }
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import io.mockk.mockk
import io.mockk.verify
import okhttp3.mockwebserver.MockResponse
import org.assertj.core.api.Assertions.assertThat
import org.assertj.core.api.Assertions.assertThatExceptionOfType
import org.junit.Before
import org.junit.Test
import org.junit.runner.RunWith
Expand Down Expand Up @@ -249,6 +250,30 @@ class HTTPClientVerificationTest: BaseHTTPClientTest() {
assertThat(result.verificationResult).isEqualTo(VerificationResult.FAILED)
}

@Test
fun `performRequest on informational client without nonce does not throw verification error`() {
val endpoint = Endpoint.GetOfferings("test-user-id")
enqueue(
endpoint = endpoint,
expectedResult = HTTPResult.createResult(verificationResult = VerificationResult.FAILED),
verificationResult = VerificationResult.FAILED
)

every {
mockSigningManager.verifyResponse(any(), any(), any(), any(), any(), any())
} returns VerificationResult.FAILED

val result = client.performRequest(
baseURL,
endpoint,
body = null,
requestHeaders = emptyMap()
)

server.takeRequest()
assertThat(result.verificationResult).isEqualTo(VerificationResult.FAILED)
}

@Test
fun `performRequest on enforced client throws verification error`() {
every { mockSigningManager.signatureVerificationMode } returns mockk<SignatureVerificationMode.Enforced>()
Expand Down Expand Up @@ -281,6 +306,34 @@ class HTTPClientVerificationTest: BaseHTTPClientTest() {
}
}

@Test
fun `performRequest on enforced client in request without nonce throws verification error`() {
every { mockSigningManager.signatureVerificationMode } returns mockk<SignatureVerificationMode.Enforced>()
val endpoint = Endpoint.GetOfferings("test-user-id")
enqueue(
endpoint = endpoint,
expectedResult = HTTPResult.createResult(verificationResult = VerificationResult.FAILED),
verificationResult = VerificationResult.FAILED
)

every {
mockSigningManager.verifyResponse(any(), any(), any(), any(), any(), any())
} returns VerificationResult.FAILED

assertThatExceptionOfType(SignatureVerificationException::class.java).isThrownBy {
client.performRequest(
baseURL,
endpoint,
body = null,
requestHeaders = emptyMap()
)
}

verify(exactly = 0) {
mockETagManager.getHTTPResultFromCacheOrBackend(any(), any(), any(), any(), any(), any(), any())
}
}

@Test
fun `performRequest on enforced client does not throw if verification success`() {
val endpoint = Endpoint.GetCustomerInfo("test-user-id")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,17 @@ import org.robolectric.annotation.Config
@Config(manifest = Config.NONE)
class EndpointTest {

private val allEndpoints = listOf(
Endpoint.GetCustomerInfo("test-user-id"),
Endpoint.LogIn,
Endpoint.PostReceipt,
Endpoint.GetOfferings("test-user-id"),
Endpoint.GetProductEntitlementMapping,
Endpoint.GetAmazonReceipt("test-user-id", "test-receipt-id"),
Endpoint.PostAttributes("test-user-id"),
Endpoint.PostDiagnostics,
)

@Test
fun `GetCustomerInfo has correct path`() {
val endpoint = Endpoint.GetCustomerInfo("test user-id")
Expand Down Expand Up @@ -67,28 +78,73 @@ class EndpointTest {
}

@Test
fun `supportsSignatureValidation returns true for expected values`() {
fun `supportsSignatureVerification returns true for expected values`() {
val expectedSupportsValidationEndpoints = listOf(
Endpoint.GetCustomerInfo("test-user-id"),
Endpoint.LogIn,
Endpoint.PostReceipt
Endpoint.PostReceipt,
Endpoint.GetOfferings("test-user-id"),
Endpoint.GetProductEntitlementMapping,
)
for (endpoint in expectedSupportsValidationEndpoints) {
assertThat(endpoint.supportsSignatureValidation).isTrue
assertThat(endpoint.supportsSignatureVerification)
.withFailMessage { "Endpoint $endpoint expected to support signature validation" }
.isTrue
}
}

@Test
fun `supportsSignatureValidation returns false for expected values`() {
fun `supportsSignatureVerification returns false for expected values`() {
val expectedNotSupportsValidationEndpoints = listOf(
Endpoint.GetAmazonReceipt("test-user-id", "test-receipt-id"),
Endpoint.GetOfferings("test-user-id"),
Endpoint.PostAttributes("test-user-id"),
Endpoint.PostDiagnostics,
Endpoint.GetProductEntitlementMapping
)
for (endpoint in expectedNotSupportsValidationEndpoints) {
assertThat(endpoint.supportsSignatureValidation).isFalse
assertThat(endpoint.supportsSignatureVerification)
.withFailMessage { "Endpoint $endpoint expected to not support signature validation" }
.isFalse
}
}

@Test
fun `verify needsNonceToPerformSigning is true only if supportsSignatureVerification is true`() {
for (endpoint in allEndpoints) {
if (!endpoint.supportsSignatureVerification) {
assertThat(endpoint.needsNonceToPerformSigning)
.withFailMessage { "Endpoint $endpoint requires nonce but does not support signature validation" }
.isFalse
}
}
}

@Test
fun `needsNonceToPerformSigning is true for expected values`() {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These tests are great.

val expectedEndpoints = listOf(
Endpoint.GetCustomerInfo("test-user-id"),
Endpoint.LogIn,
Endpoint.PostReceipt,
)
for (endpoint in expectedEndpoints) {
assertThat(endpoint.needsNonceToPerformSigning)
.withFailMessage { "Endpoint $endpoint expected to require nonce for signing" }
.isTrue
}
}

@Test
fun `needsNonceToPerformSigning is false for expected values`() {
val expectedEndpoints = listOf(
Endpoint.GetOfferings("test-user-id"),
Endpoint.GetProductEntitlementMapping,
Endpoint.GetAmazonReceipt("test-user-id", "test-receipt-id"),
Endpoint.PostAttributes("test-user-id"),
Endpoint.PostDiagnostics,
)
for (endpoint in expectedEndpoints) {
assertThat(endpoint.needsNonceToPerformSigning)
.withFailMessage { "Endpoint $endpoint expected to not require nonce for signing" }
.isFalse
}
}
}