diff --git a/sdk-core/src/main/java/dev/restate/sdk/core/MessageHeader.java b/sdk-core/src/main/java/dev/restate/sdk/core/MessageHeader.java index 93e88b81..c3cce743 100644 --- a/sdk-core/src/main/java/dev/restate/sdk/core/MessageHeader.java +++ b/sdk-core/src/main/java/dev/restate/sdk/core/MessageHeader.java @@ -14,6 +14,9 @@ public class MessageHeader { + static final short SUPPORTED_PROTOCOL_VERSION = 0; + + static final short VERSION_MASK = 0x03FF; static final short DONE_FLAG = 0x0001; static final int REQUIRES_ACK_FLAG = 0x8000; @@ -43,10 +46,6 @@ public long encode() { return res; } - public MessageHeader copyWithFlags(short flag) { - return new MessageHeader(type, flag, length); - } - public static MessageHeader parse(long encoded) throws ProtocolException { var ty_code = (short) (encoded >> 48); var flags = (short) (encoded >> 32); @@ -127,4 +126,20 @@ public static MessageHeader fromMessage(MessageLite msg) { } throw new IllegalStateException(); } + + public static void checkProtocolVersion(MessageHeader header) { + if (header.type != MessageType.StartMessage) { + throw new IllegalStateException("Expected StartMessage, got " + header.type); + } + + short version = (short) (header.flags & VERSION_MASK); + if (version != SUPPORTED_PROTOCOL_VERSION) { + throw new IllegalStateException( + "Unsupported protocol version " + + version + + ", only version " + + SUPPORTED_PROTOCOL_VERSION + + " is supported"); + } + } } diff --git a/sdk-core/src/test/java/dev/restate/sdk/core/MessageHeaderTest.java b/sdk-core/src/test/java/dev/restate/sdk/core/MessageHeaderTest.java index db0b7aaa..3de3267d 100644 --- a/sdk-core/src/test/java/dev/restate/sdk/core/MessageHeaderTest.java +++ b/sdk-core/src/test/java/dev/restate/sdk/core/MessageHeaderTest.java @@ -9,6 +9,7 @@ package dev.restate.sdk.core; import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; import org.junit.jupiter.api.Test; @@ -24,4 +25,16 @@ void requiresAckFlag() { .encode()) .isEqualTo(0x0C01_8001_0000_0002L); } + + @Test + void checkProtocolVersion() { + int unknownVersion = Integer.MAX_VALUE & MessageHeader.VERSION_MASK; + assertThatThrownBy( + () -> + MessageHeader.checkProtocolVersion( + new MessageHeader(MessageType.StartMessage, unknownVersion, 0))) + .hasMessage( + "Unsupported protocol version %d, only version %d is supported", + unknownVersion, MessageHeader.SUPPORTED_PROTOCOL_VERSION); + } } diff --git a/sdk-http-vertx/src/main/java/dev/restate/sdk/http/vertx/MessageDecoder.java b/sdk-http-vertx/src/main/java/dev/restate/sdk/http/vertx/MessageDecoder.java index 19a7b631..3f47eb63 100644 --- a/sdk-http-vertx/src/main/java/dev/restate/sdk/http/vertx/MessageDecoder.java +++ b/sdk-http-vertx/src/main/java/dev/restate/sdk/http/vertx/MessageDecoder.java @@ -20,6 +20,7 @@ class MessageDecoder { private enum State { + WAITING_START_HEADER, WAITING_HEADER, WAITING_PAYLOAD, FAILED @@ -36,7 +37,7 @@ private enum State { this.parsedMessages = new ArrayDeque<>(); this.internalBuffer = Unpooled.compositeBuffer(); - this.state = State.WAITING_HEADER; + this.state = State.WAITING_START_HEADER; this.lastParsedMessageHeader = null; this.lastParsingFailure = null; } @@ -59,34 +60,38 @@ void offer(Buffer buffer) { private void tryConsumeInternalBuffer() { while (this.state != State.FAILED && this.internalBuffer.readableBytes() >= wantBytes()) { - if (state == State.WAITING_HEADER) { - try { - this.lastParsedMessageHeader = MessageHeader.parse(this.internalBuffer.readLong()); - this.state = State.WAITING_PAYLOAD; - } catch (RuntimeException e) { - this.lastParsingFailure = e; - this.state = State.FAILED; - } - } else { - try { - this.parsedMessages.offer( - InvocationFlow.InvocationInput.of( - this.lastParsedMessageHeader, - this.lastParsedMessageHeader - .getType() - .messageParser() - .parseFrom( - this.internalBuffer - .readBytes(this.lastParsedMessageHeader.getLength()) - .nioBuffer()))); - this.state = State.WAITING_HEADER; - } catch (InvalidProtocolBufferException e) { - this.lastParsingFailure = new RuntimeException("Cannot parse the protobuf message", e); - this.state = State.FAILED; - } catch (RuntimeException e) { - this.lastParsingFailure = e; - this.state = State.FAILED; + try { + switch (state) { + case WAITING_START_HEADER: + this.lastParsedMessageHeader = MessageHeader.parse(this.internalBuffer.readLong()); + MessageHeader.checkProtocolVersion(this.lastParsedMessageHeader); + this.state = State.WAITING_PAYLOAD; + break; + case WAITING_HEADER: + this.lastParsedMessageHeader = MessageHeader.parse(this.internalBuffer.readLong()); + this.state = State.WAITING_PAYLOAD; + break; + case WAITING_PAYLOAD: + try { + this.parsedMessages.offer( + InvocationFlow.InvocationInput.of( + this.lastParsedMessageHeader, + this.lastParsedMessageHeader + .getType() + .messageParser() + .parseFrom( + this.internalBuffer + .readBytes(this.lastParsedMessageHeader.getLength()) + .nioBuffer()))); + this.state = State.WAITING_HEADER; + break; + } catch (InvalidProtocolBufferException e) { + throw new RuntimeException("Cannot parse the protobuf message", e); + } } + } catch (RuntimeException e) { + this.lastParsingFailure = e; + this.state = State.FAILED; } } }