Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

adding stream interceptor for logging middleware #3359

Merged
merged 11 commits into from
Sep 18, 2024
Merged
77 changes: 74 additions & 3 deletions transport/grpc/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
grpcinsecure "google.golang.org/grpc/credentials/insecure"
grpcmd "google.golang.org/grpc/metadata"

"github.com/go-kratos/kratos/v2/internal/matcher"
"github.com/go-kratos/kratos/v2/log"
"github.com/go-kratos/kratos/v2/middleware"
"github.com/go-kratos/kratos/v2/registry"
Expand Down Expand Up @@ -132,6 +133,7 @@ type clientOptions struct {
timeout time.Duration
discovery registry.Discovery
middleware []middleware.Middleware
streamMiddleware []middleware.Middleware
ints []grpc.UnaryClientInterceptor
streamInts []grpc.StreamClientInterceptor
grpcOpts []grpc.DialOption
Expand Down Expand Up @@ -166,7 +168,7 @@ func dial(ctx context.Context, insecure bool, opts ...ClientOption) (*grpc.Clien
unaryClientInterceptor(options.middleware, options.timeout, options.filters),
}
sints := []grpc.StreamClientInterceptor{
streamClientInterceptor(options.filters),
streamClientInterceptor(options.streamMiddleware, options.filters),
}

if len(options.ints) > 0 {
Expand Down Expand Up @@ -239,7 +241,54 @@ func unaryClientInterceptor(ms []middleware.Middleware, timeout time.Duration, f
}
}

func streamClientInterceptor(filters []selector.NodeFilter) grpc.StreamClientInterceptor {
// wrappedClientStream wraps the grpc.ClientStream and applies middleware
type wrappedClientStream struct {
grpc.ClientStream
ctx context.Context
middleware matcher.Matcher
}

func (w *wrappedClientStream) Context() context.Context {
return w.ctx
}

func (w *wrappedClientStream) SendMsg(m interface{}) error {
h := func(ctx context.Context, req interface{}) (interface{}, error) {
return req, w.ClientStream.SendMsg(m)
}

info, ok := transport.FromClientContext(w.ctx)
if !ok {
return fmt.Errorf("transport value stored in ctx returns: %v", ok)
}

if next := w.middleware.Match(info.Operation()); len(next) > 0 {
h = middleware.Chain(next...)(h)
}

_, err := h(w.ctx, m)
return err
}

func (w *wrappedClientStream) RecvMsg(m interface{}) error {
h := func(ctx context.Context, req interface{}) (interface{}, error) {
return req, w.ClientStream.RecvMsg(m)
}

info, ok := transport.FromClientContext(w.ctx)
if !ok {
return fmt.Errorf("transport value stored in ctx returns: %v", ok)
}

if next := w.middleware.Match(info.Operation()); len(next) > 0 {
h = middleware.Chain(next...)(h)
}

_, err := h(w.ctx, m)
return err
}

func streamClientInterceptor(ms []middleware.Middleware, filters []selector.NodeFilter) grpc.StreamClientInterceptor {
return func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) { // nolint
ctx = transport.NewClientContext(ctx, &Transport{
endpoint: cc.Target(),
Expand All @@ -249,6 +298,28 @@ func streamClientInterceptor(filters []selector.NodeFilter) grpc.StreamClientInt
})
var p selector.Peer
ctx = selector.NewPeerContext(ctx, &p)
return streamer(ctx, desc, cc, method, opts...)

clientStream, err := streamer(ctx, desc, cc, method, opts...)
if err != nil {
return nil, err
}

h := func(ctx context.Context, req interface{}) (interface{}, error) {
return streamer, nil
}

m := matcher.New()
if len(ms) > 0 {
m.Use(ms...)
middleware.Chain(ms...)(h)
}

wrappedStream := &wrappedClientStream{
ClientStream: clientStream,
ctx: ctx,
middleware: m,
}

return wrappedStream, nil
}
}
67 changes: 64 additions & 3 deletions transport/grpc/interceptor.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@ package grpc

import (
"context"
"fmt"

"google.golang.org/grpc"
grpcmd "google.golang.org/grpc/metadata"

ic "github.com/go-kratos/kratos/v2/internal/context"
"github.com/go-kratos/kratos/v2/internal/matcher"
"github.com/go-kratos/kratos/v2/middleware"
"github.com/go-kratos/kratos/v2/transport"
)
Expand Down Expand Up @@ -48,13 +50,15 @@ func (s *Server) unaryServerInterceptor() grpc.UnaryServerInterceptor {
// wrappedStream is rewrite grpc stream's context
type wrappedStream struct {
grpc.ServerStream
ctx context.Context
ctx context.Context
middleware matcher.Matcher
}

func NewWrappedStream(ctx context.Context, stream grpc.ServerStream) grpc.ServerStream {
func NewWrappedStream(ctx context.Context, stream grpc.ServerStream, m matcher.Matcher) grpc.ServerStream {
return &wrappedStream{
ServerStream: stream,
ctx: ctx,
middleware: m,
}
}

Expand All @@ -76,7 +80,19 @@ func (s *Server) streamServerInterceptor() grpc.StreamServerInterceptor {
replyHeader: headerCarrier(replyHeader),
})

ws := NewWrappedStream(ctx, ss)
h := func(ctx context.Context, req interface{}) (interface{}, error) {
return handler(srv, ss), nil
}

if next := s.streamMiddleware.Match(info.FullMethod); len(next) > 0 {
middleware.Chain(next...)(h)
}

ctx = context.WithValue(ctx, stream{
ServerStream: ss,
streamMiddleware: s.streamMiddleware,
}, ss)
ws := NewWrappedStream(ctx, ss, s.streamMiddleware)

err := handler(srv, ws)
if len(replyHeader) > 0 {
Expand All @@ -85,3 +101,48 @@ func (s *Server) streamServerInterceptor() grpc.StreamServerInterceptor {
return err
}
}

type stream struct {
grpc.ServerStream
streamMiddleware matcher.Matcher
}

func GetStream(ctx context.Context) grpc.ServerStream {
return ctx.Value(stream{}).(grpc.ServerStream)
}

func (w *wrappedStream) SendMsg(m interface{}) error {
h := func(_ context.Context, req interface{}) (interface{}, error) {
return req, w.ServerStream.SendMsg(m)
}

info, ok := transport.FromServerContext(w.ctx)
if !ok {
return fmt.Errorf("transport value stored in ctx returns: %v", ok)
}

if next := w.middleware.Match(info.Operation()); len(next) > 0 {
h = middleware.Chain(next...)(h)
}

_, err := h(w.ctx, m)
return err
}

func (w *wrappedStream) RecvMsg(m interface{}) error {
h := func(_ context.Context, req interface{}) (interface{}, error) {
return req, w.ServerStream.RecvMsg(m)
}

info, ok := transport.FromServerContext(w.ctx)
if !ok {
return fmt.Errorf("transport value stored in ctx returns: %v", ok)
}

if next := w.middleware.Match(info.Operation()); len(next) > 0 {
h = middleware.Chain(next...)(h)
}

_, err := h(w.ctx, m)
return err
}
52 changes: 30 additions & 22 deletions transport/grpc/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,12 @@ func Middleware(m ...middleware.Middleware) ServerOption {
}
}

func StreamMiddleware(m ...middleware.Middleware) ServerOption {
return func(s *Server) {
s.streamMiddleware.Use(m...)
}
}

// CustomHealth Checks server.
func CustomHealth() ServerOption {
return func(s *Server) {
Expand Down Expand Up @@ -117,33 +123,35 @@ func Options(opts ...grpc.ServerOption) ServerOption {
// Server is a gRPC server wrapper.
type Server struct {
*grpc.Server
baseCtx context.Context
tlsConf *tls.Config
lis net.Listener
err error
network string
address string
endpoint *url.URL
timeout time.Duration
middleware matcher.Matcher
unaryInts []grpc.UnaryServerInterceptor
streamInts []grpc.StreamServerInterceptor
grpcOpts []grpc.ServerOption
health *health.Server
customHealth bool
metadata *apimd.Server
adminClean func()
baseCtx context.Context
tlsConf *tls.Config
lis net.Listener
err error
network string
address string
endpoint *url.URL
timeout time.Duration
middleware matcher.Matcher
streamMiddleware matcher.Matcher
unaryInts []grpc.UnaryServerInterceptor
streamInts []grpc.StreamServerInterceptor
grpcOpts []grpc.ServerOption
health *health.Server
customHealth bool
metadata *apimd.Server
adminClean func()
}

// NewServer creates a gRPC server by options.
func NewServer(opts ...ServerOption) *Server {
srv := &Server{
baseCtx: context.Background(),
network: "tcp",
address: ":0",
timeout: 1 * time.Second,
health: health.NewServer(),
middleware: matcher.New(),
baseCtx: context.Background(),
network: "tcp",
address: ":0",
timeout: 1 * time.Second,
health: health.NewServer(),
middleware: matcher.New(),
streamMiddleware: matcher.New(),
}
for _, o := range opts {
o(srv)
Expand Down
77 changes: 77 additions & 0 deletions transport/grpc/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"time"

"google.golang.org/grpc"
"google.golang.org/grpc/metadata"

"github.com/go-kratos/kratos/v2/errors"
"github.com/go-kratos/kratos/v2/internal/matcher"
Expand Down Expand Up @@ -280,6 +281,82 @@ func TestServer_unaryServerInterceptor(t *testing.T) {
}
}

type mockServerStream struct {
ctx context.Context
sentMsg interface{}
recvMsg interface{}
metadata metadata.MD
grpc.ServerStream
}

func (m *mockServerStream) SetHeader(md metadata.MD) error {
m.metadata = md
return nil
}

func (m *mockServerStream) SendHeader(md metadata.MD) error {
m.metadata = md
return nil
}

func (m *mockServerStream) SetTrailer(md metadata.MD) {
m.metadata = md
}

func (m *mockServerStream) Context() context.Context {
return m.ctx
}

func (m *mockServerStream) SendMsg(msg interface{}) error {
m.sentMsg = msg
return nil
}

func (m *mockServerStream) RecvMsg(msg interface{}) error {
m.recvMsg = msg
return nil
}

func TestServer_streamServerInterceptor(t *testing.T) {
u, err := url.Parse("grpc://hello/world")
if err != nil {
t.Errorf("expect %v, got %v", nil, err)
}
srv := &Server{
baseCtx: context.Background(),
endpoint: u,
timeout: time.Duration(10),
middleware: matcher.New(),
streamMiddleware: matcher.New(),
}

srv.streamMiddleware.Use(EmptyMiddleware())

mockStream := &mockServerStream{
ctx: srv.baseCtx,
}

handler := func(_ interface{}, stream grpc.ServerStream) error {
resp := &testResp{Data: "stream hi"}
return stream.SendMsg(resp)
}

info := &grpc.StreamServerInfo{
FullMethod: "/grpc.reflection.v1.ServerReflection/ServerReflectionInfo",
}

err = srv.streamServerInterceptor()(nil, mockStream, info, handler)
if err != nil {
t.Errorf("expect %v, got %v", nil, err)
}

// Check response
resp := mockStream.sentMsg.(*testResp)
if !reflect.DeepEqual("stream hi", resp.Data) {
t.Errorf("expect %s, got %s", "stream hi", resp.Data)
}
}

func TestListener(t *testing.T) {
lis, err := net.Listen("tcp", ":0")
if err != nil {
Expand Down
Loading