diff --git a/internal/matcher/middleware.go b/internal/matcher/middleware.go index 8d56818200e..a6cb1e47463 100644 --- a/internal/matcher/middleware.go +++ b/internal/matcher/middleware.go @@ -10,27 +10,38 @@ import ( // Matcher is a middleware matcher. type Matcher interface { Use(ms ...middleware.Middleware) + UseStream(ms ...middleware.StreamMiddleware) Add(selector string, ms ...middleware.Middleware) Match(operation string) []middleware.Middleware + AddStream(selector string, ms ...middleware.StreamMiddleware) + MatchStream(operation string) []middleware.StreamMiddleware } // New new a middleware matcher. func New() Matcher { return &matcher{ - matchs: make(map[string][]middleware.Middleware), + matchs: make(map[string][]middleware.Middleware), + streamMatchs: make(map[string][]middleware.StreamMiddleware), } } type matcher struct { - prefix []string - defaults []middleware.Middleware - matchs map[string][]middleware.Middleware + prefix []string + streamPrefix []string + defaults []middleware.Middleware + streamDefaults []middleware.StreamMiddleware + matchs map[string][]middleware.Middleware + streamMatchs map[string][]middleware.StreamMiddleware } func (m *matcher) Use(ms ...middleware.Middleware) { m.defaults = ms } +func (m *matcher) UseStream(ms ...middleware.StreamMiddleware) { + m.streamDefaults = ms +} + func (m *matcher) Add(selector string, ms ...middleware.Middleware) { if strings.HasSuffix(selector, "*") { selector = strings.TrimSuffix(selector, "*") @@ -45,6 +56,20 @@ func (m *matcher) Add(selector string, ms ...middleware.Middleware) { m.matchs[selector] = ms } +func (m *matcher) AddStream(selector string, ms ...middleware.StreamMiddleware) { + if strings.HasSuffix(selector, "*") { + selector = strings.TrimSuffix(selector, "*") + m.streamPrefix = append(m.streamPrefix, selector) + // sort the prefix: + // - /foo/bar + // - /foo + sort.Slice(m.streamPrefix, func(i, j int) bool { + return m.streamPrefix[i] > m.streamPrefix[j] + }) + } + m.streamMatchs[selector] = ms +} + func (m *matcher) Match(operation string) []middleware.Middleware { ms := make([]middleware.Middleware, 0, len(m.defaults)) if len(m.defaults) > 0 { @@ -60,3 +85,19 @@ func (m *matcher) Match(operation string) []middleware.Middleware { } return ms } + +func (m *matcher) MatchStream(operation string) []middleware.StreamMiddleware { + ms := make([]middleware.StreamMiddleware, 0, len(m.streamDefaults)) + if len(m.streamDefaults) > 0 { + ms = append(ms, m.streamDefaults...) + } + if next, ok := m.streamMatchs[operation]; ok { + return append(ms, next...) + } + for _, prefix := range m.streamPrefix { + if strings.HasPrefix(operation, prefix) { + return append(ms, m.streamMatchs[prefix]...) + } + } + return ms +} diff --git a/middleware/logging/intercepter.go b/middleware/logging/logging_stream.go similarity index 58% rename from middleware/logging/intercepter.go rename to middleware/logging/logging_stream.go index 356d7744bf9..1d8fed1ab7e 100644 --- a/middleware/logging/intercepter.go +++ b/middleware/logging/logging_stream.go @@ -3,56 +3,59 @@ package logging import ( "github.com/go-kratos/kratos/v2/errors" "github.com/go-kratos/kratos/v2/log" + "github.com/go-kratos/kratos/v2/middleware" "github.com/go-kratos/kratos/v2/transport" "google.golang.org/grpc" "time" ) -// StreamServerInterceptor is the logging middleware for gRPC streams. -func StreamServerInterceptor(logger log.Logger) grpc.StreamServerInterceptor { - return func(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error { - var ( - code int32 - reason string - kind string - operation string - ) - ctx := ss.Context() - startTime := time.Now() - if info, ok := transport.FromClientContext(ctx); ok { - kind = info.Kind().String() - operation = info.Operation() - } - wrappedStream := &loggingServerStream{ - ServerStream: ss, - logger: logger, - } - err := handler(srv, wrappedStream) - if se := errors.FromError(err); se != nil { - code = se.Code - reason = se.Reason - } - level, stack := extractError(err) - - log.NewHelper(logger).Log(level, - "kind", kind, - "component", kind, - "operation", operation, - "args", extractArgs(wrappedStream.req), - "code", code, - "reason", reason, - "stack", stack, - "latency", time.Since(startTime).Seconds()) - return err - } -} - type loggingServerStream struct { req any grpc.ServerStream logger log.Logger } +// StreamServer is a server logging middleware for gRPC streams. +func StreamServer(logger log.Logger) middleware.StreamMiddleware { + return func(handler middleware.StreamHandler) middleware.StreamHandler { + return func(srv interface{}, stream grpc.ServerStream) error { + var ( + code int32 + reason string + kind string + operation string + ) + ctx := stream.Context() + startTime := time.Now() + if info, ok := transport.FromClientContext(ctx); ok { + kind = info.Kind().String() + operation = info.Operation() + } + wrappedStream := &loggingServerStream{ + ServerStream: stream, + logger: logger, + } + err := handler(srv, wrappedStream) + if se := errors.FromError(err); se != nil { + code = se.Code + reason = se.Reason + } + level, stack := extractError(err) + + log.NewHelper(logger).Log(level, + "kind", kind, + "component", kind, + "operation", operation, + "args", extractArgs(wrappedStream.req), + "code", code, + "reason", reason, + "stack", stack, + "latency", time.Since(startTime).Seconds()) + return err + } + } +} + func (ss *loggingServerStream) RecvMsg(m interface{}) error { var ( code int32 diff --git a/middleware/middleware.go b/middleware/middleware.go index 8a514ad97fa..aca0bb86b99 100644 --- a/middleware/middleware.go +++ b/middleware/middleware.go @@ -2,14 +2,21 @@ package middleware import ( "context" + "google.golang.org/grpc" ) // Handler defines the handler invoked by Middleware. type Handler func(ctx context.Context, req interface{}) (interface{}, error) +// StreamHandler defines the handler invoked by Middleware for stream calls. +type StreamHandler func(srv interface{}, stream grpc.ServerStream) error + // Middleware is HTTP/gRPC transport middleware. type Middleware func(Handler) Handler +// StreamMiddleware is gRPC stream transport middleware. +type StreamMiddleware func(StreamHandler) StreamHandler + // Chain returns a Middleware that specifies the chained handler for endpoint. func Chain(m ...Middleware) Middleware { return func(next Handler) Handler { @@ -19,3 +26,13 @@ func Chain(m ...Middleware) Middleware { return next } } + +// ChainStream returns a StreamMiddleware that specifies the chained handler for endpoint. +func ChainStream(m ...StreamMiddleware) StreamMiddleware { + return func(next StreamHandler) StreamHandler { + for i := len(m) - 1; i >= 0; i-- { + next = m[i](next) + } + return next + } +} diff --git a/transport/grpc/interceptor.go b/transport/grpc/interceptor.go index 6cc331547c6..29db152d251 100644 --- a/transport/grpc/interceptor.go +++ b/transport/grpc/interceptor.go @@ -77,6 +77,12 @@ func (s *Server) streamServerInterceptor() grpc.StreamServerInterceptor { }) ws := NewWrappedStream(ctx, ss) + h := func(srv interface{}, stream grpc.ServerStream) error { + return handler(srv, stream) + } + if next := s.streamMiddleware.MatchStream(info.FullMethod); len(next) > 0 { + h = middleware.ChainStream(next...)(h) + } err := handler(srv, ws) if len(replyHeader) > 0 { diff --git a/transport/grpc/server.go b/transport/grpc/server.go index c8c3f4f3bcc..479d424ec0b 100644 --- a/transport/grpc/server.go +++ b/transport/grpc/server.go @@ -72,6 +72,12 @@ func Middleware(m ...middleware.Middleware) ServerOption { } } +func StreamMiddleware(m ...middleware.StreamMiddleware) ServerOption { + return func(s *Server) { + s.streamMiddleware.UseStream(m...) + } +} + // CustomHealth Checks server. func CustomHealth() ServerOption { return func(s *Server) { @@ -117,33 +123,35 @@ func Options(opts ...grpc.ServerOption) ServerOption { // Server is a gRPC server wrapper. type Server struct { *grpc.Server - baseCtx context.Context - tlsConf *tls.Config - lis net.Listener - err error - network string - address string - endpoint *url.URL - timeout time.Duration - middleware matcher.Matcher - unaryInts []grpc.UnaryServerInterceptor - streamInts []grpc.StreamServerInterceptor - grpcOpts []grpc.ServerOption - health *health.Server - customHealth bool - metadata *apimd.Server - adminClean func() + baseCtx context.Context + tlsConf *tls.Config + lis net.Listener + err error + network string + address string + endpoint *url.URL + timeout time.Duration + middleware matcher.Matcher + streamMiddleware matcher.Matcher + unaryInts []grpc.UnaryServerInterceptor + streamInts []grpc.StreamServerInterceptor + grpcOpts []grpc.ServerOption + health *health.Server + customHealth bool + metadata *apimd.Server + adminClean func() } // NewServer creates a gRPC server by options. func NewServer(opts ...ServerOption) *Server { srv := &Server{ - baseCtx: context.Background(), - network: "tcp", - address: ":0", - timeout: 1 * time.Second, - health: health.NewServer(), - middleware: matcher.New(), + baseCtx: context.Background(), + network: "tcp", + address: ":0", + timeout: 1 * time.Second, + health: health.NewServer(), + middleware: matcher.New(), + streamMiddleware: matcher.New(), } for _, o := range opts { o(srv) @@ -192,6 +200,10 @@ func (s *Server) Use(selector string, m ...middleware.Middleware) { s.middleware.Add(selector, m...) } +func (s *Server) UseStream(selector string, m ...middleware.StreamMiddleware) { + s.middleware.AddStream(selector, m...) +} + // Endpoint return a real address to registry endpoint. // examples: //