diff --git a/crates/invoker-impl/src/invocation_task.rs b/crates/invoker-impl/src/invocation_task.rs index bed21d83d..87b51cb71 100644 --- a/crates/invoker-impl/src/invocation_task.rs +++ b/crates/invoker-impl/src/invocation_task.rs @@ -599,8 +599,10 @@ where ProtocolMessage::new_start_message( Bytes::copy_from_slice(&self.full_invocation_id.to_invocation_id_bytes()), self.full_invocation_id.to_string(), + Some(self.full_invocation_id.service_id.key.clone()), journal_size, is_partial, + iter::empty(), state_entries, ), ) diff --git a/crates/service-protocol/src/message/encoding.rs b/crates/service-protocol/src/message/encoding.rs index 3a2050110..25f0f80e2 100644 --- a/crates/service-protocol/src/message/encoding.rs +++ b/crates/service-protocol/src/message/encoding.rs @@ -352,8 +352,15 @@ mod tests { let encoder = Encoder::new(protocol_version); let mut decoder = Decoder::default(); - let expected_msg_0 = - ProtocolMessage::new_start_message("key".into(), "key".into(), 1, true, vec![]); + let expected_msg_0 = ProtocolMessage::new_start_message( + "key".into(), + "key".into(), + Some("key".into()), + 1, + true, + vec![], + vec![], + ); let expected_msg_1: ProtocolMessage = ProtobufRawEntryCodec::serialize_as_unary_input_entry(Bytes::from_static( diff --git a/crates/service-protocol/src/message/mod.rs b/crates/service-protocol/src/message/mod.rs index 01145bdce..fc0e9661c 100644 --- a/crates/service-protocol/src/message/mod.rs +++ b/crates/service-protocol/src/message/mod.rs @@ -44,8 +44,10 @@ impl ProtocolMessage { pub fn new_start_message( id: Bytes, debug_id: String, + key: Option, known_entries: u32, partial_state: bool, + headers: impl IntoIterator, state_map_entries: impl IntoIterator, ) -> Self { Self::Start(pb::protocol::StartMessage { @@ -57,7 +59,13 @@ impl ProtocolMessage { .into_iter() .map(|(key, value)| pb::protocol::start_message::StateEntry { key, value }) .collect(), - ..pb::protocol::StartMessage::default() + key: key + .and_then(|b| String::from_utf8(b.to_vec()).ok()) + .unwrap_or_default(), + headers: headers + .into_iter() + .map(|(key, value)| pb::protocol::Header { key, value }) + .collect(), }) } diff --git a/crates/worker/src/partition/services/non_deterministic/remote_context.rs b/crates/worker/src/partition/services/non_deterministic/remote_context.rs index 55eb300e2..66ef41d37 100644 --- a/crates/worker/src/partition/services/non_deterministic/remote_context.rs +++ b/crates/worker/src/partition/services/non_deterministic/remote_context.rs @@ -788,9 +788,11 @@ impl<'a, State: StateReader + Send + Sync> RemoteContextBuiltInService ProtocolMessage::new_start_message( Bytes::copy_from_slice(&virtual_invocation_id.to_bytes()), virtual_invocation_id.to_string(), + None, length, true, // TODO add eager state iter::empty(), + iter::empty(), ), ) .expect("Encoding messages to a BytesMut should be infallible, unless OOM is reached.");