From 8521ac03d3e6b23622032bc459b78ffa1e1ecdb4 Mon Sep 17 00:00:00 2001 From: Jack Kleeman Date: Mon, 12 Aug 2024 11:23:55 +0100 Subject: [PATCH] Add awakeable example (#26) --- examples/codegen/main.go | 72 ++++- examples/codegen/proto/helloworld.pb.go | 264 +++++++++++++++--- examples/codegen/proto/helloworld.proto | 18 ++ .../codegen/proto/helloworld_restate.pb.go | 36 +++ 4 files changed, 353 insertions(+), 37 deletions(-) diff --git a/examples/codegen/main.go b/examples/codegen/main.go index b3d7132..f1cc262 100644 --- a/examples/codegen/main.go +++ b/examples/codegen/main.go @@ -6,6 +6,7 @@ import ( "fmt" "log/slog" "os" + "time" restate "github.com/restatedev/sdk-go" helloworld "github.com/restatedev/sdk-go/examples/codegen/proto" @@ -38,11 +39,23 @@ func (c counter) Add(ctx restate.ObjectContext, req *helloworld.AddRequest) (*he return nil, err } - count += 1 + watchers, err := restate.GetAs[[]string](ctx, "watchers") + if err != nil && !errors.Is(err, restate.ErrKeyNotFound) { + return nil, err + } + + count += req.Delta if err := ctx.Set("counter", count); err != nil { return nil, err } + for _, awakeableID := range watchers { + if err := ctx.ResolveAwakeable(awakeableID, count); err != nil { + return nil, err + } + } + ctx.Clear("watchers") + return &helloworld.GetResponse{Value: count}, nil } @@ -55,6 +68,63 @@ func (c counter) Get(ctx restate.ObjectSharedContext, _ *helloworld.GetRequest) return &helloworld.GetResponse{Value: count}, nil } +func (c counter) AddWatcher(ctx restate.ObjectContext, req *helloworld.AddWatcherRequest) (*helloworld.AddWatcherResponse, error) { + watchers, err := restate.GetAs[[]string](ctx, "watchers") + if err != nil && !errors.Is(err, restate.ErrKeyNotFound) { + return nil, err + } + watchers = append(watchers, req.AwakeableId) + if err := ctx.Set("watchers", watchers); err != nil { + return nil, err + } + return &helloworld.AddWatcherResponse{}, nil +} + +func (c counter) Watch(ctx restate.ObjectSharedContext, req *helloworld.WatchRequest) (*helloworld.GetResponse, error) { + awakeable := restate.AwakeableAs[int64](ctx) + + // since this is a shared handler, we need to use a separate exclusive handler to store the awakeable ID + // if there is an in-flight Add call, this will take effect after it completes + // we could add a version counter check here to detect changes that happen mid-request and return immediately + if _, err := helloworld.NewCounterClient(ctx, ctx.Key()). + AddWatcher(). + Request(&helloworld.AddWatcherRequest{AwakeableId: awakeable.Id()}); err != nil { + return nil, err + } + + timeout := time.Duration(req.TimeoutMillis) * time.Millisecond + if timeout == 0 { + // infinite timeout case; just await the next value + next, err := awakeable.Result() + if err != nil { + return nil, err + } + + return &helloworld.GetResponse{Value: next}, nil + } + + after := ctx.After(timeout) + + // this is the safe way to race two results + selector := ctx.Select(after, awakeable) + + if selector.Select() == after { + // the timeout won + if err := after.Done(); err != nil { + // an error here implies this invocation was cancelled + return nil, err + } + return nil, restate.TerminalError(context.DeadlineExceeded, 408) + } + + // otherwise, the awakeable won + next, err := awakeable.Result() + if err != nil { + return nil, err + } + return &helloworld.GetResponse{Value: next}, nil +} + func main() { server := server.NewRestate(). Bind(helloworld.NewGreeterServer(greeter{})). diff --git a/examples/codegen/proto/helloworld.pb.go b/examples/codegen/proto/helloworld.pb.go index 152676d..c1e7b34 100644 --- a/examples/codegen/proto/helloworld.pb.go +++ b/examples/codegen/proto/helloworld.pb.go @@ -247,6 +247,138 @@ func (x *GetResponse) GetValue() int64 { return 0 } +type AddWatcherRequest struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + AwakeableId string `protobuf:"bytes,1,opt,name=awakeable_id,json=awakeableId,proto3" json:"awakeable_id,omitempty"` +} + +func (x *AddWatcherRequest) Reset() { + *x = AddWatcherRequest{} + if protoimpl.UnsafeEnabled { + mi := &file_proto_helloworld_proto_msgTypes[5] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *AddWatcherRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*AddWatcherRequest) ProtoMessage() {} + +func (x *AddWatcherRequest) ProtoReflect() protoreflect.Message { + mi := &file_proto_helloworld_proto_msgTypes[5] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use AddWatcherRequest.ProtoReflect.Descriptor instead. +func (*AddWatcherRequest) Descriptor() ([]byte, []int) { + return file_proto_helloworld_proto_rawDescGZIP(), []int{5} +} + +func (x *AddWatcherRequest) GetAwakeableId() string { + if x != nil { + return x.AwakeableId + } + return "" +} + +type AddWatcherResponse struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields +} + +func (x *AddWatcherResponse) Reset() { + *x = AddWatcherResponse{} + if protoimpl.UnsafeEnabled { + mi := &file_proto_helloworld_proto_msgTypes[6] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *AddWatcherResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*AddWatcherResponse) ProtoMessage() {} + +func (x *AddWatcherResponse) ProtoReflect() protoreflect.Message { + mi := &file_proto_helloworld_proto_msgTypes[6] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use AddWatcherResponse.ProtoReflect.Descriptor instead. +func (*AddWatcherResponse) Descriptor() ([]byte, []int) { + return file_proto_helloworld_proto_rawDescGZIP(), []int{6} +} + +type WatchRequest struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + TimeoutMillis int64 `protobuf:"varint,1,opt,name=timeout_millis,json=timeoutMillis,proto3" json:"timeout_millis,omitempty"` +} + +func (x *WatchRequest) Reset() { + *x = WatchRequest{} + if protoimpl.UnsafeEnabled { + mi := &file_proto_helloworld_proto_msgTypes[7] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *WatchRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*WatchRequest) ProtoMessage() {} + +func (x *WatchRequest) ProtoReflect() protoreflect.Message { + mi := &file_proto_helloworld_proto_msgTypes[7] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use WatchRequest.ProtoReflect.Descriptor instead. +func (*WatchRequest) Descriptor() ([]byte, []int) { + return file_proto_helloworld_proto_rawDescGZIP(), []int{7} +} + +func (x *WatchRequest) GetTimeoutMillis() int64 { + if x != nil { + return x.TimeoutMillis + } + return 0 +} + var File_proto_helloworld_proto protoreflect.FileDescriptor var file_proto_helloworld_proto_rawDesc = []byte{ @@ -264,31 +396,48 @@ var file_proto_helloworld_proto_rawDesc = []byte{ 0x61, 0x22, 0x0c, 0x0a, 0x0a, 0x47, 0x65, 0x74, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x22, 0x23, 0x0a, 0x0b, 0x47, 0x65, 0x74, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x14, 0x0a, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x03, 0x52, 0x05, 0x76, - 0x61, 0x6c, 0x75, 0x65, 0x32, 0x4c, 0x0a, 0x07, 0x47, 0x72, 0x65, 0x65, 0x74, 0x65, 0x72, 0x12, - 0x41, 0x0a, 0x08, 0x53, 0x61, 0x79, 0x48, 0x65, 0x6c, 0x6c, 0x6f, 0x12, 0x18, 0x2e, 0x68, 0x65, - 0x6c, 0x6c, 0x6f, 0x77, 0x6f, 0x72, 0x6c, 0x64, 0x2e, 0x48, 0x65, 0x6c, 0x6c, 0x6f, 0x52, 0x65, - 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x19, 0x2e, 0x68, 0x65, 0x6c, 0x6c, 0x6f, 0x77, 0x6f, 0x72, - 0x6c, 0x64, 0x2e, 0x48, 0x65, 0x6c, 0x6c, 0x6f, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, - 0x22, 0x00, 0x32, 0x87, 0x01, 0x0a, 0x07, 0x43, 0x6f, 0x75, 0x6e, 0x74, 0x65, 0x72, 0x12, 0x38, - 0x0a, 0x03, 0x41, 0x64, 0x64, 0x12, 0x16, 0x2e, 0x68, 0x65, 0x6c, 0x6c, 0x6f, 0x77, 0x6f, 0x72, - 0x6c, 0x64, 0x2e, 0x41, 0x64, 0x64, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x17, 0x2e, + 0x61, 0x6c, 0x75, 0x65, 0x22, 0x36, 0x0a, 0x11, 0x41, 0x64, 0x64, 0x57, 0x61, 0x74, 0x63, 0x68, + 0x65, 0x72, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x21, 0x0a, 0x0c, 0x61, 0x77, 0x61, + 0x6b, 0x65, 0x61, 0x62, 0x6c, 0x65, 0x5f, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, + 0x0b, 0x61, 0x77, 0x61, 0x6b, 0x65, 0x61, 0x62, 0x6c, 0x65, 0x49, 0x64, 0x22, 0x14, 0x0a, 0x12, + 0x41, 0x64, 0x64, 0x57, 0x61, 0x74, 0x63, 0x68, 0x65, 0x72, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, + 0x73, 0x65, 0x22, 0x35, 0x0a, 0x0c, 0x57, 0x61, 0x74, 0x63, 0x68, 0x52, 0x65, 0x71, 0x75, 0x65, + 0x73, 0x74, 0x12, 0x25, 0x0a, 0x0e, 0x74, 0x69, 0x6d, 0x65, 0x6f, 0x75, 0x74, 0x5f, 0x6d, 0x69, + 0x6c, 0x6c, 0x69, 0x73, 0x18, 0x01, 0x20, 0x01, 0x28, 0x03, 0x52, 0x0d, 0x74, 0x69, 0x6d, 0x65, + 0x6f, 0x75, 0x74, 0x4d, 0x69, 0x6c, 0x6c, 0x69, 0x73, 0x32, 0x4c, 0x0a, 0x07, 0x47, 0x72, 0x65, + 0x65, 0x74, 0x65, 0x72, 0x12, 0x41, 0x0a, 0x08, 0x53, 0x61, 0x79, 0x48, 0x65, 0x6c, 0x6c, 0x6f, + 0x12, 0x18, 0x2e, 0x68, 0x65, 0x6c, 0x6c, 0x6f, 0x77, 0x6f, 0x72, 0x6c, 0x64, 0x2e, 0x48, 0x65, + 0x6c, 0x6c, 0x6f, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x19, 0x2e, 0x68, 0x65, 0x6c, + 0x6c, 0x6f, 0x77, 0x6f, 0x72, 0x6c, 0x64, 0x2e, 0x48, 0x65, 0x6c, 0x6c, 0x6f, 0x52, 0x65, 0x73, + 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x32, 0x98, 0x02, 0x0a, 0x07, 0x43, 0x6f, 0x75, 0x6e, + 0x74, 0x65, 0x72, 0x12, 0x38, 0x0a, 0x03, 0x41, 0x64, 0x64, 0x12, 0x16, 0x2e, 0x68, 0x65, 0x6c, + 0x6c, 0x6f, 0x77, 0x6f, 0x72, 0x6c, 0x64, 0x2e, 0x41, 0x64, 0x64, 0x52, 0x65, 0x71, 0x75, 0x65, + 0x73, 0x74, 0x1a, 0x17, 0x2e, 0x68, 0x65, 0x6c, 0x6c, 0x6f, 0x77, 0x6f, 0x72, 0x6c, 0x64, 0x2e, + 0x47, 0x65, 0x74, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x3c, 0x0a, + 0x03, 0x47, 0x65, 0x74, 0x12, 0x16, 0x2e, 0x68, 0x65, 0x6c, 0x6c, 0x6f, 0x77, 0x6f, 0x72, 0x6c, + 0x64, 0x2e, 0x47, 0x65, 0x74, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x17, 0x2e, 0x68, + 0x65, 0x6c, 0x6c, 0x6f, 0x77, 0x6f, 0x72, 0x6c, 0x64, 0x2e, 0x47, 0x65, 0x74, 0x52, 0x65, 0x73, + 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x04, 0x98, 0x80, 0x01, 0x02, 0x12, 0x4d, 0x0a, 0x0a, 0x41, + 0x64, 0x64, 0x57, 0x61, 0x74, 0x63, 0x68, 0x65, 0x72, 0x12, 0x1d, 0x2e, 0x68, 0x65, 0x6c, 0x6c, + 0x6f, 0x77, 0x6f, 0x72, 0x6c, 0x64, 0x2e, 0x41, 0x64, 0x64, 0x57, 0x61, 0x74, 0x63, 0x68, 0x65, + 0x72, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1e, 0x2e, 0x68, 0x65, 0x6c, 0x6c, 0x6f, + 0x77, 0x6f, 0x72, 0x6c, 0x64, 0x2e, 0x41, 0x64, 0x64, 0x57, 0x61, 0x74, 0x63, 0x68, 0x65, 0x72, + 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x40, 0x0a, 0x05, 0x57, 0x61, + 0x74, 0x63, 0x68, 0x12, 0x18, 0x2e, 0x68, 0x65, 0x6c, 0x6c, 0x6f, 0x77, 0x6f, 0x72, 0x6c, 0x64, + 0x2e, 0x57, 0x61, 0x74, 0x63, 0x68, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x17, 0x2e, 0x68, 0x65, 0x6c, 0x6c, 0x6f, 0x77, 0x6f, 0x72, 0x6c, 0x64, 0x2e, 0x47, 0x65, 0x74, 0x52, 0x65, - 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x3c, 0x0a, 0x03, 0x47, 0x65, 0x74, 0x12, - 0x16, 0x2e, 0x68, 0x65, 0x6c, 0x6c, 0x6f, 0x77, 0x6f, 0x72, 0x6c, 0x64, 0x2e, 0x47, 0x65, 0x74, - 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x17, 0x2e, 0x68, 0x65, 0x6c, 0x6c, 0x6f, 0x77, - 0x6f, 0x72, 0x6c, 0x64, 0x2e, 0x47, 0x65, 0x74, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, - 0x22, 0x04, 0x98, 0x80, 0x01, 0x02, 0x1a, 0x04, 0x98, 0x80, 0x01, 0x01, 0x42, 0x9e, 0x01, 0x0a, - 0x0e, 0x63, 0x6f, 0x6d, 0x2e, 0x68, 0x65, 0x6c, 0x6c, 0x6f, 0x77, 0x6f, 0x72, 0x6c, 0x64, 0x42, - 0x0f, 0x48, 0x65, 0x6c, 0x6c, 0x6f, 0x77, 0x6f, 0x72, 0x6c, 0x64, 0x50, 0x72, 0x6f, 0x74, 0x6f, - 0x50, 0x01, 0x5a, 0x33, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x72, - 0x65, 0x73, 0x74, 0x61, 0x74, 0x65, 0x64, 0x65, 0x76, 0x2f, 0x73, 0x64, 0x6b, 0x2d, 0x67, 0x6f, - 0x2f, 0x65, 0x78, 0x61, 0x6d, 0x70, 0x6c, 0x65, 0x73, 0x2f, 0x63, 0x6f, 0x64, 0x65, 0x67, 0x65, - 0x6e, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0xa2, 0x02, 0x03, 0x48, 0x58, 0x58, 0xaa, 0x02, 0x0a, - 0x48, 0x65, 0x6c, 0x6c, 0x6f, 0x77, 0x6f, 0x72, 0x6c, 0x64, 0xca, 0x02, 0x0a, 0x48, 0x65, 0x6c, - 0x6c, 0x6f, 0x77, 0x6f, 0x72, 0x6c, 0x64, 0xe2, 0x02, 0x16, 0x48, 0x65, 0x6c, 0x6c, 0x6f, 0x77, - 0x6f, 0x72, 0x6c, 0x64, 0x5c, 0x47, 0x50, 0x42, 0x4d, 0x65, 0x74, 0x61, 0x64, 0x61, 0x74, 0x61, - 0xea, 0x02, 0x0a, 0x48, 0x65, 0x6c, 0x6c, 0x6f, 0x77, 0x6f, 0x72, 0x6c, 0x64, 0x62, 0x06, 0x70, - 0x72, 0x6f, 0x74, 0x6f, 0x33, + 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x04, 0x98, 0x80, 0x01, 0x02, 0x1a, 0x04, 0x98, 0x80, + 0x01, 0x01, 0x42, 0x9e, 0x01, 0x0a, 0x0e, 0x63, 0x6f, 0x6d, 0x2e, 0x68, 0x65, 0x6c, 0x6c, 0x6f, + 0x77, 0x6f, 0x72, 0x6c, 0x64, 0x42, 0x0f, 0x48, 0x65, 0x6c, 0x6c, 0x6f, 0x77, 0x6f, 0x72, 0x6c, + 0x64, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x50, 0x01, 0x5a, 0x33, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, + 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x72, 0x65, 0x73, 0x74, 0x61, 0x74, 0x65, 0x64, 0x65, 0x76, 0x2f, + 0x73, 0x64, 0x6b, 0x2d, 0x67, 0x6f, 0x2f, 0x65, 0x78, 0x61, 0x6d, 0x70, 0x6c, 0x65, 0x73, 0x2f, + 0x63, 0x6f, 0x64, 0x65, 0x67, 0x65, 0x6e, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0xa2, 0x02, 0x03, + 0x48, 0x58, 0x58, 0xaa, 0x02, 0x0a, 0x48, 0x65, 0x6c, 0x6c, 0x6f, 0x77, 0x6f, 0x72, 0x6c, 0x64, + 0xca, 0x02, 0x0a, 0x48, 0x65, 0x6c, 0x6c, 0x6f, 0x77, 0x6f, 0x72, 0x6c, 0x64, 0xe2, 0x02, 0x16, + 0x48, 0x65, 0x6c, 0x6c, 0x6f, 0x77, 0x6f, 0x72, 0x6c, 0x64, 0x5c, 0x47, 0x50, 0x42, 0x4d, 0x65, + 0x74, 0x61, 0x64, 0x61, 0x74, 0x61, 0xea, 0x02, 0x0a, 0x48, 0x65, 0x6c, 0x6c, 0x6f, 0x77, 0x6f, + 0x72, 0x6c, 0x64, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, } var ( @@ -303,23 +452,30 @@ func file_proto_helloworld_proto_rawDescGZIP() []byte { return file_proto_helloworld_proto_rawDescData } -var file_proto_helloworld_proto_msgTypes = make([]protoimpl.MessageInfo, 5) +var file_proto_helloworld_proto_msgTypes = make([]protoimpl.MessageInfo, 8) var file_proto_helloworld_proto_goTypes = []any{ - (*HelloRequest)(nil), // 0: helloworld.HelloRequest - (*HelloResponse)(nil), // 1: helloworld.HelloResponse - (*AddRequest)(nil), // 2: helloworld.AddRequest - (*GetRequest)(nil), // 3: helloworld.GetRequest - (*GetResponse)(nil), // 4: helloworld.GetResponse + (*HelloRequest)(nil), // 0: helloworld.HelloRequest + (*HelloResponse)(nil), // 1: helloworld.HelloResponse + (*AddRequest)(nil), // 2: helloworld.AddRequest + (*GetRequest)(nil), // 3: helloworld.GetRequest + (*GetResponse)(nil), // 4: helloworld.GetResponse + (*AddWatcherRequest)(nil), // 5: helloworld.AddWatcherRequest + (*AddWatcherResponse)(nil), // 6: helloworld.AddWatcherResponse + (*WatchRequest)(nil), // 7: helloworld.WatchRequest } var file_proto_helloworld_proto_depIdxs = []int32{ 0, // 0: helloworld.Greeter.SayHello:input_type -> helloworld.HelloRequest 2, // 1: helloworld.Counter.Add:input_type -> helloworld.AddRequest 3, // 2: helloworld.Counter.Get:input_type -> helloworld.GetRequest - 1, // 3: helloworld.Greeter.SayHello:output_type -> helloworld.HelloResponse - 4, // 4: helloworld.Counter.Add:output_type -> helloworld.GetResponse - 4, // 5: helloworld.Counter.Get:output_type -> helloworld.GetResponse - 3, // [3:6] is the sub-list for method output_type - 0, // [0:3] is the sub-list for method input_type + 5, // 3: helloworld.Counter.AddWatcher:input_type -> helloworld.AddWatcherRequest + 7, // 4: helloworld.Counter.Watch:input_type -> helloworld.WatchRequest + 1, // 5: helloworld.Greeter.SayHello:output_type -> helloworld.HelloResponse + 4, // 6: helloworld.Counter.Add:output_type -> helloworld.GetResponse + 4, // 7: helloworld.Counter.Get:output_type -> helloworld.GetResponse + 6, // 8: helloworld.Counter.AddWatcher:output_type -> helloworld.AddWatcherResponse + 4, // 9: helloworld.Counter.Watch:output_type -> helloworld.GetResponse + 5, // [5:10] is the sub-list for method output_type + 0, // [0:5] is the sub-list for method input_type 0, // [0:0] is the sub-list for extension type_name 0, // [0:0] is the sub-list for extension extendee 0, // [0:0] is the sub-list for field type_name @@ -391,6 +547,42 @@ func file_proto_helloworld_proto_init() { return nil } } + file_proto_helloworld_proto_msgTypes[5].Exporter = func(v any, i int) any { + switch v := v.(*AddWatcherRequest); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_proto_helloworld_proto_msgTypes[6].Exporter = func(v any, i int) any { + switch v := v.(*AddWatcherResponse); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_proto_helloworld_proto_msgTypes[7].Exporter = func(v any, i int) any { + switch v := v.(*WatchRequest); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } } type x struct{} out := protoimpl.TypeBuilder{ @@ -398,7 +590,7 @@ func file_proto_helloworld_proto_init() { GoPackagePath: reflect.TypeOf(x{}).PkgPath(), RawDescriptor: file_proto_helloworld_proto_rawDesc, NumEnums: 0, - NumMessages: 5, + NumMessages: 8, NumExtensions: 0, NumServices: 2, }, diff --git a/examples/codegen/proto/helloworld.proto b/examples/codegen/proto/helloworld.proto index a89058e..127d1e0 100644 --- a/examples/codegen/proto/helloworld.proto +++ b/examples/codegen/proto/helloworld.proto @@ -12,10 +12,18 @@ service Greeter { service Counter { option (dev.restate.sdk.go.service_type) = VIRTUAL_OBJECT; + // Mutate the value rpc Add (AddRequest) returns (GetResponse) {} + // Get the current value rpc Get (GetRequest) returns (GetResponse) { option (dev.restate.sdk.go.handler_type) = SHARED; } + // Internal method to store an awakeable ID for the Watch method + rpc AddWatcher (AddWatcherRequest) returns (AddWatcherResponse) {} + // Wait for the counter to change and then return the new value + rpc Watch (WatchRequest) returns (GetResponse) { + option (dev.restate.sdk.go.handler_type) = SHARED; + } } message HelloRequest { @@ -35,3 +43,13 @@ message GetRequest {} message GetResponse { int64 value = 1; } + +message AddWatcherRequest { + string awakeable_id = 1; +} + +message AddWatcherResponse {} + +message WatchRequest { + int64 timeout_millis = 1; +} diff --git a/examples/codegen/proto/helloworld_restate.pb.go b/examples/codegen/proto/helloworld_restate.pb.go index 5d9d506..dd83bb5 100644 --- a/examples/codegen/proto/helloworld_restate.pb.go +++ b/examples/codegen/proto/helloworld_restate.pb.go @@ -78,8 +78,14 @@ func NewGreeterServer(srv GreeterServer, opts ...sdk_go.ServiceOption) sdk_go.Se // CounterClient is the client API for Counter service. type CounterClient interface { + // Mutate the value Add(opts ...sdk_go.CallOption) sdk_go.TypedCallClient[*AddRequest, *GetResponse] + // Get the current value Get(opts ...sdk_go.CallOption) sdk_go.TypedCallClient[*GetRequest, *GetResponse] + // Internal method to store an awakeable ID for the Watch method + AddWatcher(opts ...sdk_go.CallOption) sdk_go.TypedCallClient[*AddWatcherRequest, *AddWatcherResponse] + // Wait for the counter to change and then return the new value + Watch(opts ...sdk_go.CallOption) sdk_go.TypedCallClient[*WatchRequest, *GetResponse] } type counterClient struct { @@ -112,12 +118,34 @@ func (c *counterClient) Get(opts ...sdk_go.CallOption) sdk_go.TypedCallClient[*G return sdk_go.NewTypedCallClient[*GetRequest, *GetResponse](c.ctx.Object("Counter", c.key, "Get", cOpts...)) } +func (c *counterClient) AddWatcher(opts ...sdk_go.CallOption) sdk_go.TypedCallClient[*AddWatcherRequest, *AddWatcherResponse] { + cOpts := c.options + if len(opts) > 0 { + cOpts = append(append([]sdk_go.CallOption{}, cOpts...), opts...) + } + return sdk_go.NewTypedCallClient[*AddWatcherRequest, *AddWatcherResponse](c.ctx.Object("Counter", c.key, "AddWatcher", cOpts...)) +} + +func (c *counterClient) Watch(opts ...sdk_go.CallOption) sdk_go.TypedCallClient[*WatchRequest, *GetResponse] { + cOpts := c.options + if len(opts) > 0 { + cOpts = append(append([]sdk_go.CallOption{}, cOpts...), opts...) + } + return sdk_go.NewTypedCallClient[*WatchRequest, *GetResponse](c.ctx.Object("Counter", c.key, "Watch", cOpts...)) +} + // CounterServer is the server API for Counter service. // All implementations should embed UnimplementedCounterServer // for forward compatibility. type CounterServer interface { + // Mutate the value Add(ctx sdk_go.ObjectContext, req *AddRequest) (*GetResponse, error) + // Get the current value Get(ctx sdk_go.ObjectSharedContext, req *GetRequest) (*GetResponse, error) + // Internal method to store an awakeable ID for the Watch method + AddWatcher(ctx sdk_go.ObjectContext, req *AddWatcherRequest) (*AddWatcherResponse, error) + // Wait for the counter to change and then return the new value + Watch(ctx sdk_go.ObjectSharedContext, req *WatchRequest) (*GetResponse, error) } // UnimplementedCounterServer should be embedded to have @@ -133,6 +161,12 @@ func (UnimplementedCounterServer) Add(ctx sdk_go.ObjectContext, req *AddRequest) func (UnimplementedCounterServer) Get(ctx sdk_go.ObjectSharedContext, req *GetRequest) (*GetResponse, error) { return nil, sdk_go.TerminalError(fmt.Errorf("method Get not implemented"), 501) } +func (UnimplementedCounterServer) AddWatcher(ctx sdk_go.ObjectContext, req *AddWatcherRequest) (*AddWatcherResponse, error) { + return nil, sdk_go.TerminalError(fmt.Errorf("method AddWatcher not implemented"), 501) +} +func (UnimplementedCounterServer) Watch(ctx sdk_go.ObjectSharedContext, req *WatchRequest) (*GetResponse, error) { + return nil, sdk_go.TerminalError(fmt.Errorf("method Watch not implemented"), 501) +} func (UnimplementedCounterServer) testEmbeddedByValue() {} // UnsafeCounterServer may be embedded to opt out of forward compatibility for this service. @@ -154,5 +188,7 @@ func NewCounterServer(srv CounterServer, opts ...sdk_go.ObjectOption) sdk_go.Ser router := sdk_go.NewObject("Counter", sOpts...) router = router.Handler("Add", sdk_go.NewObjectHandler(srv.Add)) router = router.Handler("Get", sdk_go.NewObjectSharedHandler(srv.Get)) + router = router.Handler("AddWatcher", sdk_go.NewObjectHandler(srv.AddWatcher)) + router = router.Handler("Watch", sdk_go.NewObjectSharedHandler(srv.Watch)) return router }