Skip to content

Commit

Permalink
Check protocol version is supported
Browse files Browse the repository at this point in the history
  • Loading branch information
slinkydeveloper committed Jan 10, 2024
1 parent b715d5b commit 9f9f613
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 32 deletions.
23 changes: 19 additions & 4 deletions sdk-core/src/main/java/dev/restate/sdk/core/MessageHeader.java
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@

public class MessageHeader {

static final short SUPPORTED_PROTOCOL_VERSION = 0;

static final short VERSION_MASK = 0x03FF;
static final short DONE_FLAG = 0x0001;
static final int REQUIRES_ACK_FLAG = 0x8000;

Expand Down Expand Up @@ -43,10 +46,6 @@ public long encode() {
return res;
}

public MessageHeader copyWithFlags(short flag) {
return new MessageHeader(type, flag, length);
}

public static MessageHeader parse(long encoded) throws ProtocolException {
var ty_code = (short) (encoded >> 48);
var flags = (short) (encoded >> 32);
Expand Down Expand Up @@ -127,4 +126,20 @@ public static MessageHeader fromMessage(MessageLite msg) {
}
throw new IllegalStateException();
}

public static void checkProtocolVersion(MessageHeader header) {
if (header.type != MessageType.StartMessage) {
throw new IllegalStateException("Expected StartMessage, got " + header.type);
}

short version = (short) (header.flags & VERSION_MASK);
if (version != SUPPORTED_PROTOCOL_VERSION) {
throw new IllegalStateException(
"Unsupported protocol version "
+ version
+ ", only version "
+ SUPPORTED_PROTOCOL_VERSION
+ " is supported");
}
}
}
13 changes: 13 additions & 0 deletions sdk-core/src/test/java/dev/restate/sdk/core/MessageHeaderTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
package dev.restate.sdk.core;

import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;

import org.junit.jupiter.api.Test;

Expand All @@ -24,4 +25,16 @@ void requiresAckFlag() {
.encode())
.isEqualTo(0x0C01_8001_0000_0002L);
}

@Test
void checkProtocolVersion() {
int unknownVersion = Integer.MAX_VALUE & MessageHeader.VERSION_MASK;
assertThatThrownBy(
() ->
MessageHeader.checkProtocolVersion(
new MessageHeader(MessageType.StartMessage, unknownVersion, 0)))
.hasMessage(
"Unsupported protocol version %d, only version %d is supported",
unknownVersion, MessageHeader.SUPPORTED_PROTOCOL_VERSION);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
class MessageDecoder {

private enum State {
WAITING_START_HEADER,
WAITING_HEADER,
WAITING_PAYLOAD,
FAILED
Expand All @@ -36,7 +37,7 @@ private enum State {
this.parsedMessages = new ArrayDeque<>();
this.internalBuffer = Unpooled.compositeBuffer();

this.state = State.WAITING_HEADER;
this.state = State.WAITING_START_HEADER;
this.lastParsedMessageHeader = null;
this.lastParsingFailure = null;
}
Expand All @@ -59,34 +60,38 @@ void offer(Buffer buffer) {

private void tryConsumeInternalBuffer() {
while (this.state != State.FAILED && this.internalBuffer.readableBytes() >= wantBytes()) {
if (state == State.WAITING_HEADER) {
try {
this.lastParsedMessageHeader = MessageHeader.parse(this.internalBuffer.readLong());
this.state = State.WAITING_PAYLOAD;
} catch (RuntimeException e) {
this.lastParsingFailure = e;
this.state = State.FAILED;
}
} else {
try {
this.parsedMessages.offer(
InvocationFlow.InvocationInput.of(
this.lastParsedMessageHeader,
this.lastParsedMessageHeader
.getType()
.messageParser()
.parseFrom(
this.internalBuffer
.readBytes(this.lastParsedMessageHeader.getLength())
.nioBuffer())));
this.state = State.WAITING_HEADER;
} catch (InvalidProtocolBufferException e) {
this.lastParsingFailure = new RuntimeException("Cannot parse the protobuf message", e);
this.state = State.FAILED;
} catch (RuntimeException e) {
this.lastParsingFailure = e;
this.state = State.FAILED;
try {
switch (state) {
case WAITING_START_HEADER:
this.lastParsedMessageHeader = MessageHeader.parse(this.internalBuffer.readLong());
MessageHeader.checkProtocolVersion(this.lastParsedMessageHeader);
this.state = State.WAITING_PAYLOAD;
break;
case WAITING_HEADER:
this.lastParsedMessageHeader = MessageHeader.parse(this.internalBuffer.readLong());
this.state = State.WAITING_PAYLOAD;
break;
case WAITING_PAYLOAD:
try {
this.parsedMessages.offer(
InvocationFlow.InvocationInput.of(
this.lastParsedMessageHeader,
this.lastParsedMessageHeader
.getType()
.messageParser()
.parseFrom(
this.internalBuffer
.readBytes(this.lastParsedMessageHeader.getLength())
.nioBuffer())));
this.state = State.WAITING_HEADER;
break;
} catch (InvalidProtocolBufferException e) {
throw new RuntimeException("Cannot parse the protobuf message", e);
}
}
} catch (RuntimeException e) {
this.lastParsingFailure = e;
this.state = State.FAILED;
}
}
}
Expand Down

0 comments on commit 9f9f613

Please sign in to comment.