Skip to content

Commit

Permalink
adding stream middlware to handle gRPC stream middlewares
Browse files Browse the repository at this point in the history
  • Loading branch information
akoserwal committed Jul 29, 2024
1 parent 45955cf commit 576691b
Show file tree
Hide file tree
Showing 5 changed files with 144 additions and 65 deletions.
49 changes: 45 additions & 4 deletions internal/matcher/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,27 +10,38 @@ import (
// Matcher is a middleware matcher.
type Matcher interface {
Use(ms ...middleware.Middleware)
UseStream(ms ...middleware.StreamMiddleware)
Add(selector string, ms ...middleware.Middleware)
Match(operation string) []middleware.Middleware
AddStream(selector string, ms ...middleware.StreamMiddleware)
MatchStream(operation string) []middleware.StreamMiddleware
}

// New new a middleware matcher.
func New() Matcher {
return &matcher{
matchs: make(map[string][]middleware.Middleware),
matchs: make(map[string][]middleware.Middleware),
streamMatchs: make(map[string][]middleware.StreamMiddleware),
}
}

type matcher struct {
prefix []string
defaults []middleware.Middleware
matchs map[string][]middleware.Middleware
prefix []string
streamPrefix []string
defaults []middleware.Middleware
streamDefaults []middleware.StreamMiddleware
matchs map[string][]middleware.Middleware
streamMatchs map[string][]middleware.StreamMiddleware
}

func (m *matcher) Use(ms ...middleware.Middleware) {
m.defaults = ms
}

func (m *matcher) UseStream(ms ...middleware.StreamMiddleware) {
m.streamDefaults = ms
}

func (m *matcher) Add(selector string, ms ...middleware.Middleware) {
if strings.HasSuffix(selector, "*") {
selector = strings.TrimSuffix(selector, "*")
Expand All @@ -45,6 +56,20 @@ func (m *matcher) Add(selector string, ms ...middleware.Middleware) {
m.matchs[selector] = ms
}

func (m *matcher) AddStream(selector string, ms ...middleware.StreamMiddleware) {
if strings.HasSuffix(selector, "*") {
selector = strings.TrimSuffix(selector, "*")
m.streamPrefix = append(m.streamPrefix, selector)
// sort the prefix:
// - /foo/bar
// - /foo
sort.Slice(m.streamPrefix, func(i, j int) bool {
return m.streamPrefix[i] > m.streamPrefix[j]
})
}
m.streamMatchs[selector] = ms
}

func (m *matcher) Match(operation string) []middleware.Middleware {
ms := make([]middleware.Middleware, 0, len(m.defaults))
if len(m.defaults) > 0 {
Expand All @@ -60,3 +85,19 @@ func (m *matcher) Match(operation string) []middleware.Middleware {
}
return ms
}

func (m *matcher) MatchStream(operation string) []middleware.StreamMiddleware {
ms := make([]middleware.StreamMiddleware, 0, len(m.streamDefaults))
if len(m.streamDefaults) > 0 {
ms = append(ms, m.streamDefaults...)
}
if next, ok := m.streamMatchs[operation]; ok {
return append(ms, next...)
}
for _, prefix := range m.streamPrefix {
if strings.HasPrefix(operation, prefix) {
return append(ms, m.streamMatchs[prefix]...)
}
}
return ms
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,56 +3,59 @@ package logging
import (
"github.com/go-kratos/kratos/v2/errors"
"github.com/go-kratos/kratos/v2/log"
"github.com/go-kratos/kratos/v2/middleware"
"github.com/go-kratos/kratos/v2/transport"
"google.golang.org/grpc"
"time"
)

// StreamServerInterceptor is the logging middleware for gRPC streams.
func StreamServerInterceptor(logger log.Logger) grpc.StreamServerInterceptor {
return func(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
var (
code int32
reason string
kind string
operation string
)
ctx := ss.Context()
startTime := time.Now()
if info, ok := transport.FromClientContext(ctx); ok {
kind = info.Kind().String()
operation = info.Operation()
}
wrappedStream := &loggingServerStream{
ServerStream: ss,
logger: logger,
}
err := handler(srv, wrappedStream)
if se := errors.FromError(err); se != nil {
code = se.Code
reason = se.Reason
}
level, stack := extractError(err)

log.NewHelper(logger).Log(level,
"kind", kind,
"component", kind,
"operation", operation,
"args", extractArgs(wrappedStream.req),
"code", code,
"reason", reason,
"stack", stack,
"latency", time.Since(startTime).Seconds())
return err
}
}

type loggingServerStream struct {
req any
grpc.ServerStream
logger log.Logger
}

// StreamServer is a server logging middleware for gRPC streams.
func StreamServer(logger log.Logger) middleware.StreamMiddleware {
return func(handler middleware.StreamHandler) middleware.StreamHandler {
return func(srv interface{}, stream grpc.ServerStream) error {
var (
code int32
reason string
kind string
operation string
)
ctx := stream.Context()
startTime := time.Now()
if info, ok := transport.FromClientContext(ctx); ok {
kind = info.Kind().String()
operation = info.Operation()
}
wrappedStream := &loggingServerStream{
ServerStream: stream,
logger: logger,
}
err := handler(srv, wrappedStream)
if se := errors.FromError(err); se != nil {
code = se.Code
reason = se.Reason
}
level, stack := extractError(err)

log.NewHelper(logger).Log(level,
"kind", kind,
"component", kind,
"operation", operation,
"args", extractArgs(wrappedStream.req),
"code", code,
"reason", reason,
"stack", stack,
"latency", time.Since(startTime).Seconds())
return err
}
}
}

func (ss *loggingServerStream) RecvMsg(m interface{}) error {
var (
code int32
Expand Down
17 changes: 17 additions & 0 deletions middleware/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,21 @@ package middleware

import (
"context"
"google.golang.org/grpc"
)

// Handler defines the handler invoked by Middleware.
type Handler func(ctx context.Context, req interface{}) (interface{}, error)

// StreamHandler defines the handler invoked by Middleware for stream calls.
type StreamHandler func(srv interface{}, stream grpc.ServerStream) error

// Middleware is HTTP/gRPC transport middleware.
type Middleware func(Handler) Handler

// StreamMiddleware is gRPC stream transport middleware.
type StreamMiddleware func(StreamHandler) StreamHandler

// Chain returns a Middleware that specifies the chained handler for endpoint.
func Chain(m ...Middleware) Middleware {
return func(next Handler) Handler {
Expand All @@ -19,3 +26,13 @@ func Chain(m ...Middleware) Middleware {
return next
}
}

// ChainStream returns a StreamMiddleware that specifies the chained handler for endpoint.
func ChainStream(m ...StreamMiddleware) StreamMiddleware {
return func(next StreamHandler) StreamHandler {
for i := len(m) - 1; i >= 0; i-- {
next = m[i](next)
}
return next
}
}
6 changes: 6 additions & 0 deletions transport/grpc/interceptor.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,12 @@ func (s *Server) streamServerInterceptor() grpc.StreamServerInterceptor {
})

ws := NewWrappedStream(ctx, ss)
h := func(srv interface{}, stream grpc.ServerStream) error {
return handler(srv, stream)
}
if next := s.streamMiddleware.MatchStream(info.FullMethod); len(next) > 0 {
h = middleware.ChainStream(next...)(h)
}

err := handler(srv, ws)
if len(replyHeader) > 0 {
Expand Down
56 changes: 34 additions & 22 deletions transport/grpc/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,12 @@ func Middleware(m ...middleware.Middleware) ServerOption {
}
}

func StreamMiddleware(m ...middleware.StreamMiddleware) ServerOption {
return func(s *Server) {
s.streamMiddleware.UseStream(m...)
}
}

// CustomHealth Checks server.
func CustomHealth() ServerOption {
return func(s *Server) {
Expand Down Expand Up @@ -117,33 +123,35 @@ func Options(opts ...grpc.ServerOption) ServerOption {
// Server is a gRPC server wrapper.
type Server struct {
*grpc.Server
baseCtx context.Context
tlsConf *tls.Config
lis net.Listener
err error
network string
address string
endpoint *url.URL
timeout time.Duration
middleware matcher.Matcher
unaryInts []grpc.UnaryServerInterceptor
streamInts []grpc.StreamServerInterceptor
grpcOpts []grpc.ServerOption
health *health.Server
customHealth bool
metadata *apimd.Server
adminClean func()
baseCtx context.Context
tlsConf *tls.Config
lis net.Listener
err error
network string
address string
endpoint *url.URL
timeout time.Duration
middleware matcher.Matcher
streamMiddleware matcher.Matcher
unaryInts []grpc.UnaryServerInterceptor
streamInts []grpc.StreamServerInterceptor
grpcOpts []grpc.ServerOption
health *health.Server
customHealth bool
metadata *apimd.Server
adminClean func()
}

// NewServer creates a gRPC server by options.
func NewServer(opts ...ServerOption) *Server {
srv := &Server{
baseCtx: context.Background(),
network: "tcp",
address: ":0",
timeout: 1 * time.Second,
health: health.NewServer(),
middleware: matcher.New(),
baseCtx: context.Background(),
network: "tcp",
address: ":0",
timeout: 1 * time.Second,
health: health.NewServer(),
middleware: matcher.New(),
streamMiddleware: matcher.New(),
}
for _, o := range opts {
o(srv)
Expand Down Expand Up @@ -192,6 +200,10 @@ func (s *Server) Use(selector string, m ...middleware.Middleware) {
s.middleware.Add(selector, m...)
}

func (s *Server) UseStream(selector string, m ...middleware.StreamMiddleware) {
s.middleware.AddStream(selector, m...)
}

// Endpoint return a real address to registry endpoint.
// examples:
//
Expand Down

0 comments on commit 576691b

Please sign in to comment.