From 4763d23c1bbaa39c327630b7123a6b0c04e0d45f Mon Sep 17 00:00:00 2001 From: slinkydeveloper Date: Wed, 10 Jan 2024 15:46:46 +0100 Subject: [PATCH] Check protocol version is supported --- .../sdk/core/InvocationStateMachine.java | 1 + .../dev/restate/sdk/core/MessageHeader.java | 23 +++++++++++++++---- .../restate/sdk/core/MessageHeaderTest.java | 13 +++++++++++ 3 files changed, 33 insertions(+), 4 deletions(-) diff --git a/sdk-core/src/main/java/dev/restate/sdk/core/InvocationStateMachine.java b/sdk-core/src/main/java/dev/restate/sdk/core/InvocationStateMachine.java index 5cd80b06..08f4159f 100644 --- a/sdk-core/src/main/java/dev/restate/sdk/core/InvocationStateMachine.java +++ b/sdk-core/src/main/java/dev/restate/sdk/core/InvocationStateMachine.java @@ -134,6 +134,7 @@ public void onNext(InvocationFlow.InvocationInput invocationInput) { MessageLite msg = invocationInput.message(); LOG.trace("Received input message {} {}", msg.getClass(), msg); if (this.invocationState == InvocationState.WAITING_START) { + MessageHeader.checkProtocolVersion(invocationInput.header()); this.onStart(msg); } else if (msg instanceof Protocol.CompletionMessage) { // We check the instance rather than the state, because the user code might still be diff --git a/sdk-core/src/main/java/dev/restate/sdk/core/MessageHeader.java b/sdk-core/src/main/java/dev/restate/sdk/core/MessageHeader.java index 93e88b81..c3cce743 100644 --- a/sdk-core/src/main/java/dev/restate/sdk/core/MessageHeader.java +++ b/sdk-core/src/main/java/dev/restate/sdk/core/MessageHeader.java @@ -14,6 +14,9 @@ public class MessageHeader { + static final short SUPPORTED_PROTOCOL_VERSION = 0; + + static final short VERSION_MASK = 0x03FF; static final short DONE_FLAG = 0x0001; static final int REQUIRES_ACK_FLAG = 0x8000; @@ -43,10 +46,6 @@ public long encode() { return res; } - public MessageHeader copyWithFlags(short flag) { - return new MessageHeader(type, flag, length); - } - public static MessageHeader parse(long encoded) throws ProtocolException { var ty_code = (short) (encoded >> 48); var flags = (short) (encoded >> 32); @@ -127,4 +126,20 @@ public static MessageHeader fromMessage(MessageLite msg) { } throw new IllegalStateException(); } + + public static void checkProtocolVersion(MessageHeader header) { + if (header.type != MessageType.StartMessage) { + throw new IllegalStateException("Expected StartMessage, got " + header.type); + } + + short version = (short) (header.flags & VERSION_MASK); + if (version != SUPPORTED_PROTOCOL_VERSION) { + throw new IllegalStateException( + "Unsupported protocol version " + + version + + ", only version " + + SUPPORTED_PROTOCOL_VERSION + + " is supported"); + } + } } diff --git a/sdk-core/src/test/java/dev/restate/sdk/core/MessageHeaderTest.java b/sdk-core/src/test/java/dev/restate/sdk/core/MessageHeaderTest.java index db0b7aaa..3de3267d 100644 --- a/sdk-core/src/test/java/dev/restate/sdk/core/MessageHeaderTest.java +++ b/sdk-core/src/test/java/dev/restate/sdk/core/MessageHeaderTest.java @@ -9,6 +9,7 @@ package dev.restate.sdk.core; import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; import org.junit.jupiter.api.Test; @@ -24,4 +25,16 @@ void requiresAckFlag() { .encode()) .isEqualTo(0x0C01_8001_0000_0002L); } + + @Test + void checkProtocolVersion() { + int unknownVersion = Integer.MAX_VALUE & MessageHeader.VERSION_MASK; + assertThatThrownBy( + () -> + MessageHeader.checkProtocolVersion( + new MessageHeader(MessageType.StartMessage, unknownVersion, 0))) + .hasMessage( + "Unsupported protocol version %d, only version %d is supported", + unknownVersion, MessageHeader.SUPPORTED_PROTOCOL_VERSION); + } }