diff --git a/go.mod b/go.mod index 975a6a1..2ea12cc 100644 --- a/go.mod +++ b/go.mod @@ -3,7 +3,9 @@ module github.com/restatedev/sdk-go go 1.22.0 require ( + github.com/golang-jwt/jwt/v5 v5.2.1 github.com/google/uuid v1.6.0 + github.com/mr-tron/base58 v1.2.0 github.com/posener/h2conn v0.0.0-20231204025407-3997deeca0f0 github.com/stretchr/testify v1.9.0 github.com/vmihailenco/msgpack/v5 v5.4.1 diff --git a/go.sum b/go.sum index d3a47ec..5e2ee5a 100644 --- a/go.sum +++ b/go.sum @@ -1,9 +1,13 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/golang-jwt/jwt/v5 v5.2.1 h1:OuVbFODueb089Lh128TAcimifWaLhJwVflnrgM17wHk= +github.com/golang-jwt/jwt/v5 v5.2.1/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk= github.com/google/go-cmp v0.5.5 h1:Khx7svrCpmxxtHBq5j2mp/xVjsi8hQMfNLvJFAlrGgU= github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/mr-tron/base58 v1.2.0 h1:T/HDJBh4ZCPbU39/+c3rRvE0uKBQlU27+QI8LJ4t64o= +github.com/mr-tron/base58 v1.2.0/go.mod h1:BinMc/sQntlIE1frQmRFPUoPA1Zkr8VRgBdjWI2mNwc= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/posener/h2conn v0.0.0-20231204025407-3997deeca0f0 h1:zZg03nifrj6ayWNaDO8tNj57tqrOIKDwiBaLkhtK7Kk= diff --git a/internal/identity/identity.go b/internal/identity/identity.go new file mode 100644 index 0000000..afc0ded --- /dev/null +++ b/internal/identity/identity.go @@ -0,0 +1,30 @@ +package identity + +import "fmt" + +const SIGNATURE_SCHEME_HEADER = "X-Restate-Signature-Scheme" + +type SignatureScheme string + +var ( + SchemeUnsigned SignatureScheme = "unsigned" + errMissingIdentity = fmt.Errorf("request has no identity") +) + +func ValidateRequestIdentity(keySet KeySetV1, path string, headers map[string][]string) error { + switch len(headers[SIGNATURE_SCHEME_HEADER]) { + case 0: + return errMissingIdentity + case 1: + switch SignatureScheme(headers[SIGNATURE_SCHEME_HEADER][0]) { + case SchemeV1: + return validateV1(keySet, path, headers) + case SchemeUnsigned: + return errMissingIdentity + default: + return fmt.Errorf("unexpected signature scheme %v, allowed values are [%s %s]", headers[SIGNATURE_SCHEME_HEADER][0], SchemeUnsigned, SchemeV1) + } + default: + return fmt.Errorf("unexpected multi-value signature scheme header: %v", headers[SIGNATURE_SCHEME_HEADER]) + } +} diff --git a/internal/identity/v1.go b/internal/identity/v1.go new file mode 100644 index 0000000..e97afdb --- /dev/null +++ b/internal/identity/v1.go @@ -0,0 +1,83 @@ +package identity + +import ( + "crypto/ed25519" + "fmt" + "strings" + + jwt "github.com/golang-jwt/jwt/v5" + "github.com/mr-tron/base58" +) + +const ( + JWT_HEADER = "X-Restate-Jwt-V1" + SchemeV1 SignatureScheme = "v1" +) + +type KeySetV1 = map[string]ed25519.PublicKey + +func validateV1(keySet KeySetV1, path string, headers map[string][]string) error { + switch len(headers[JWT_HEADER]) { + case 0: + return fmt.Errorf("v1 signature scheme expects the following headers: [%s]", JWT_HEADER) + case 1: + default: + return fmt.Errorf("unexpected multi-value JWT header: %v", headers[JWT_HEADER]) + } + + token, err := jwt.Parse(headers[JWT_HEADER][0], func(token *jwt.Token) (interface{}, error) { + if _, ok := token.Method.(*jwt.SigningMethodEd25519); !ok { + return nil, fmt.Errorf("Unexpected signing method: %v", token.Header["alg"]) + } + + kid, ok := token.Header["kid"] + if !ok { + return nil, fmt.Errorf("Token missing 'kid' header field") + } + + kidS, ok := kid.(string) + if !ok { + return nil, fmt.Errorf("Token 'kid' header field was not a string: %v", kid) + } + + key, ok := keySet[kidS] + if !ok { + return nil, fmt.Errorf("Key ID %s is not present in key set", kid) + } + + return key, nil + }, jwt.WithValidMethods([]string{"EdDSA"}), jwt.WithAudience(path), jwt.WithExpirationRequired()) + if err != nil { + return fmt.Errorf("failed to validate v1 request identity jwt: %w", err) + } + + nbf, _ := token.Claims.GetNotBefore() + if nbf == nil { + // jwt library only validates nbf if its present, so we should check it was present + return fmt.Errorf("'nbf' claim is missing in v1 request identity jwt") + } + + return nil +} + +func ParseKeySetV1(keys []string) (KeySetV1, error) { + out := make(KeySetV1, len(keys)) + for _, key := range keys { + if !strings.HasPrefix(key, "publickeyv1_") { + return nil, fmt.Errorf("v1 public key must start with 'publickeyv1_'") + } + + pubBytes, err := base58.Decode(key[len("publickeyv1_"):]) + if err != nil { + return nil, fmt.Errorf("v1 public key must be valid base58: %w", err) + } + + if len(pubBytes) != ed25519.PublicKeySize { + return nil, fmt.Errorf("v1 public key must have exactly %d bytes, found %d", ed25519.PublicKeySize, len(pubBytes)) + } + + out[key] = ed25519.PublicKey(pubBytes) + } + + return out, nil +} diff --git a/internal/log/log.go b/internal/log/log.go index 538ff5d..b72bdf8 100644 --- a/internal/log/log.go +++ b/internal/log/log.go @@ -31,7 +31,7 @@ func (t stringerValue[T]) LogValue() slog.Value { } func Stringer[T fmt.Stringer](key string, value T) slog.Attr { - return slog.Any(key, slog.AnyValue(stringerValue[T]{value})) + return slog.Any(key, stringerValue[T]{value}) } func Error(err error) slog.Attr { diff --git a/server/restate.go b/server/restate.go index 88354bc..5294d7a 100644 --- a/server/restate.go +++ b/server/restate.go @@ -15,6 +15,7 @@ import ( "github.com/restatedev/sdk-go/generated/proto/discovery" "github.com/restatedev/sdk-go/generated/proto/protocol" "github.com/restatedev/sdk-go/internal" + "github.com/restatedev/sdk-go/internal/identity" "github.com/restatedev/sdk-go/internal/log" "github.com/restatedev/sdk-go/internal/state" "golang.org/x/net/http2" @@ -45,6 +46,8 @@ type Restate struct { dropReplayLogs bool systemLog *slog.Logger routers map[string]restate.Router + keyIDs []string + keySet identity.KeySetV1 } // NewRestate creates a new instance of Restate server @@ -69,6 +72,11 @@ func (r *Restate) WithLogger(h slog.Handler, dropReplayLogs bool) *Restate { return r } +func (r *Restate) WithIdentityV1(keys ...string) *Restate { + r.keyIDs = append(r.keyIDs, keys...) + return r +} + func (r *Restate) Bind(router restate.Router) *Restate { if _, ok := r.routers[router.Name()]; ok { // panic because this is a programming error @@ -120,8 +128,8 @@ func (r *Restate) discoverHandler(writer http.ResponseWriter, req *http.Request) acceptVersionsString := req.Header.Get("accept") if acceptVersionsString == "" { - writer.Write([]byte("missing accept header")) writer.WriteHeader(http.StatusUnsupportedMediaType) + writer.Write([]byte("missing accept header")) return } @@ -129,29 +137,28 @@ func (r *Restate) discoverHandler(writer http.ResponseWriter, req *http.Request) serviceDiscoveryProtocolVersion := selectSupportedServiceDiscoveryProtocolVersion(acceptVersionsString) if serviceDiscoveryProtocolVersion == discovery.ServiceDiscoveryProtocolVersion_SERVICE_DISCOVERY_PROTOCOL_VERSION_UNSPECIFIED { - writer.Write([]byte(fmt.Sprintf("Unsupported service discovery protocol version '%s'", acceptVersionsString))) writer.WriteHeader(http.StatusUnsupportedMediaType) + writer.Write([]byte(fmt.Sprintf("Unsupported service discovery protocol version '%s'", acceptVersionsString))) return } response, err := r.discover() if err != nil { - writer.Write([]byte(err.Error())) writer.WriteHeader(http.StatusInternalServerError) + writer.Write([]byte(err.Error())) return } bytes, err := json.Marshal(response) if err != nil { - writer.Write([]byte(err.Error())) writer.WriteHeader(http.StatusInternalServerError) + writer.Write([]byte(err.Error())) return } writer.Header().Add("Content-Type", serviceDiscoveryProtocolVersionToHeaderValue(serviceDiscoveryProtocolVersion)) - writer.WriteHeader(200) if _, err := writer.Write(bytes); err != nil { r.systemLog.LogAttrs(req.Context(), slog.LevelError, "Failed to write discovery information", log.Error(err)) } @@ -252,6 +259,17 @@ func (r *Restate) callHandler(serviceProtocolVersion protocol.ServiceProtocolVer } func (r *Restate) handler(writer http.ResponseWriter, request *http.Request) { + if r.keySet != nil { + if err := identity.ValidateRequestIdentity(r.keySet, request.RequestURI, request.Header); err != nil { + r.systemLog.LogAttrs(request.Context(), slog.LevelError, "Rejecting request as its JWT did not validate", log.Error(err)) + + writer.WriteHeader(http.StatusUnauthorized) + writer.Write([]byte("Unauthorized")) + + return + } + } + if request.RequestURI == "/discover" { r.discoverHandler(writer, request) return @@ -261,8 +279,8 @@ func (r *Restate) handler(writer http.ResponseWriter, request *http.Request) { if serviceProtocolVersionString == "" { r.systemLog.ErrorContext(request.Context(), "Missing content-type header") - writer.Write([]byte("missing content-type header")) writer.WriteHeader(http.StatusUnsupportedMediaType) + writer.Write([]byte("missing content-type header")) return } @@ -272,8 +290,8 @@ func (r *Restate) handler(writer http.ResponseWriter, request *http.Request) { if !isServiceProtocolVersionSupported(serviceProtocolVersion) { r.systemLog.LogAttrs(request.Context(), slog.LevelError, "Unsupported service protocol version", slog.String("version", serviceProtocolVersionString)) - writer.Write([]byte(fmt.Sprintf("Unsupported service protocol version '%s'", serviceProtocolVersionString))) writer.WriteHeader(http.StatusUnsupportedMediaType) + writer.Write([]byte(fmt.Sprintf("Unsupported service protocol version '%s'", serviceProtocolVersionString))) return } @@ -297,6 +315,16 @@ func (r *Restate) handler(writer http.ResponseWriter, request *http.Request) { } func (r *Restate) Start(ctx context.Context, address string) error { + if r.keyIDs == nil { + r.systemLog.WarnContext(ctx, "Accepting requests without validating request signatures; handler access must be restricted") + } else { + ks, err := identity.ParseKeySetV1(r.keyIDs) + if err != nil { + return fmt.Errorf("invalid request identity keys: %w", err) + } + r.keySet = ks + r.systemLog.LogAttrs(ctx, slog.LevelInfo, "Validating requests using signing keys", slog.Any("keys", r.keyIDs)) + } listener, err := net.Listen("tcp", address) if err != nil {