Skip to content

Commit

Permalink
refactor: reuse rate limiting panic handler between rpc and grpc flows
Browse files Browse the repository at this point in the history
  • Loading branch information
JadhavPoonam committed Jan 24, 2023
1 parent 312f635 commit 774d962
Show file tree
Hide file tree
Showing 12 changed files with 41 additions and 66 deletions.
20 changes: 20 additions & 0 deletions agent/consul/rate/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -326,3 +326,23 @@ func (nullRequestLimitsHandler) Run(ctx context.Context) {}
func (nullRequestLimitsHandler) UpdateConfig(cfg HandlerConfig) {}

func (nullRequestLimitsHandler) Register(leaderStatusProvider LeaderStatusProvider) {}

func NewPanicHandler(logger Logger) RecoveryHandlerFunc {
return func(p interface{}) (err error) {
// Log the panic and the stack trace of the Goroutine that caused the panic.
stacktrace := hclog.Stacktrace()
logger.Error("panic serving request",
"panic", p,
"stack", stacktrace,
)

return fmt.Errorf("panic serving request in rate limiter")
}
}

type RecoveryHandlerFunc func(p interface{}) (err error)

type Logger interface {
Error(string, ...interface{})
Warn(string, ...interface{})
}
2 changes: 1 addition & 1 deletion agent/consul/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -489,7 +489,7 @@ func NewServer(config *Config, flat Deps, externalGRPCServer *grpc.Server, incom
}

rpcServerOpts := []func(*rpc.Server){
rpc.WithPreBodyInterceptor(middleware.GetNetRPCRateLimitingInterceptor(s.incomingRPCLimiter, middleware.NewPanicHandler(s.logger))),
rpc.WithPreBodyInterceptor(middleware.GetNetRPCRateLimitingInterceptor(s.incomingRPCLimiter, rpcRate.NewPanicHandler(s.logger))),
}

if flat.GetNetRPCInterceptorFunc != nil {
Expand Down
6 changes: 3 additions & 3 deletions agent/grpc-external/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import (
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/keepalive"

"github.com/hashicorp/consul/agent/consul/rate"
rate "github.com/hashicorp/consul/agent/consul/rate"
agentmiddleware "github.com/hashicorp/consul/agent/grpc-middleware"
"github.com/hashicorp/consul/tlsutil"
)
Expand All @@ -24,7 +24,7 @@ var (

// NewServer constructs a gRPC server for the external gRPC port, to which
// handlers can be registered.
func NewServer(logger agentmiddleware.Logger, metricsObj *metrics.Metrics, tls *tlsutil.Configurator, limiter rate.RequestLimitsHandler) *grpc.Server {
func NewServer(logger rate.Logger, metricsObj *metrics.Metrics, tls *tlsutil.Configurator, limiter rate.RequestLimitsHandler) *grpc.Server {
if metricsObj == nil {
metricsObj = metrics.Default()
}
Expand All @@ -49,7 +49,7 @@ func NewServer(logger agentmiddleware.Logger, metricsObj *metrics.Metrics, tls *
opts := []grpc.ServerOption{
grpc.MaxConcurrentStreams(2048),
grpc.MaxRecvMsgSize(50 * 1024 * 1024),
grpc.InTapHandle(agentmiddleware.ServerRateLimiterMiddleware(limiter, agentmiddleware.NewPanicHandler(logger), logger)),
grpc.InTapHandle(agentmiddleware.ServerRateLimiterMiddleware(limiter, rate.NewPanicHandler(logger), logger)),
grpc.StatsHandler(agentmiddleware.NewStatsHandler(metricsObj, metricsLabels)),
middleware.WithUnaryServerChain(unaryInterceptors...),
middleware.WithStreamServerChain(streamInterceptors...),
Expand Down
4 changes: 2 additions & 2 deletions agent/grpc-internal/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import (

middleware "github.com/grpc-ecosystem/go-grpc-middleware"
recovery "github.com/grpc-ecosystem/go-grpc-middleware/recovery"
"github.com/hashicorp/consul/agent/consul/rate"
rate "github.com/hashicorp/consul/agent/consul/rate"
"google.golang.org/grpc"
"google.golang.org/grpc/keepalive"
)
Expand All @@ -36,7 +36,7 @@ func NewHandler(logger Logger, addr net.Addr, register func(server *grpc.Server)
recoveryOpts := agentmiddleware.PanicHandlerMiddlewareOpts(logger)

opts := []grpc.ServerOption{
grpc.InTapHandle(agentmiddleware.ServerRateLimiterMiddleware(rateLimiter, agentmiddleware.NewPanicHandler(logger), logger)),
grpc.InTapHandle(agentmiddleware.ServerRateLimiterMiddleware(rateLimiter, rate.NewPanicHandler(logger), logger)),
grpc.StatsHandler(agentmiddleware.NewStatsHandler(metricsObj, metricsLabels)),
middleware.WithUnaryServerChain(
// Add middlware interceptors to recover in case of panics.
Expand Down
3 changes: 2 additions & 1 deletion agent/grpc-middleware/auth_interceptor.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"fmt"
"strings"

rate "github.com/hashicorp/consul/agent/consul/rate"
"github.com/hashicorp/consul/tlsutil"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
Expand All @@ -20,7 +21,7 @@ const AllowedPeerEndpointPrefix = "/hashicorp.consul.internal.peerstream.PeerStr
// connection will be allowed to proceed.
type AuthInterceptor struct {
TLS *tlsutil.Configurator
Logger Logger
Logger rate.Logger
}

// InterceptUnary prevents non-streaming gRPC calls from calling certain endpoints,
Expand Down
5 changes: 3 additions & 2 deletions agent/grpc-middleware/handshake.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"fmt"
"net"

rate "github.com/hashicorp/consul/agent/consul/rate"
"google.golang.org/grpc/credentials"
)

Expand Down Expand Up @@ -44,10 +45,10 @@ var _ credentials.TransportCredentials = (*optionalTransportCredentials)(nil)
// based on metadata extracted from the underlying connection object.
type optionalTransportCredentials struct {
credentials.TransportCredentials
logger Logger
logger rate.Logger
}

func NewOptionalTransportCredentials(creds credentials.TransportCredentials, logger Logger) credentials.TransportCredentials {
func NewOptionalTransportCredentials(creds credentials.TransportCredentials, logger rate.Logger) credentials.TransportCredentials {
return &optionalTransportCredentials{creds, logger}
}

Expand Down
7 changes: 3 additions & 4 deletions agent/grpc-middleware/rate.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,21 +9,20 @@ import (
"google.golang.org/grpc/status"
"google.golang.org/grpc/tap"

recovery "github.com/grpc-ecosystem/go-grpc-middleware/recovery"

"github.com/hashicorp/consul/agent/consul/rate"
)

// ServerRateLimiterMiddleware implements a ServerInHandle function to perform
// RPC rate limiting at the cheapest possible point (before the full request has
// been decoded).
func ServerRateLimiterMiddleware(limiter rate.RequestLimitsHandler, panicHandler recovery.RecoveryHandlerFunc, logger Logger) tap.ServerInHandle {
func ServerRateLimiterMiddleware(limiter rate.RequestLimitsHandler, panicHandler rate.RecoveryHandlerFunc, logger rate.Logger) tap.ServerInHandle {
return func(ctx context.Context, info *tap.Info) (_ context.Context, retErr error) {
// This function is called before unary and stream RPC interceptors, so we
// must handle our own panics here.
defer func() {
if r := recover(); r != nil {
retErr = panicHandler(r)
err := panicHandler(r)
retErr = status.Errorf(codes.Internal, err.Error())
}
}()

Expand Down
2 changes: 1 addition & 1 deletion agent/grpc-middleware/rate_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ func TestServerRateLimiterMiddleware_Integration(t *testing.T) {

logger := hclog.NewNullLogger()
server := grpc.NewServer(
grpc.InTapHandle(ServerRateLimiterMiddleware(limiter, NewPanicHandler(logger), logger)),
grpc.InTapHandle(ServerRateLimiterMiddleware(limiter, rate.NewPanicHandler(logger), logger)),
)
pbacl.RegisterACLServiceServer(server, mockACLServer{})

Expand Down
28 changes: 3 additions & 25 deletions agent/grpc-middleware/recovery.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,35 +2,13 @@ package middleware

import (
recovery "github.com/grpc-ecosystem/go-grpc-middleware/recovery"
"github.com/hashicorp/go-hclog"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
rate "github.com/hashicorp/consul/agent/consul/rate"
)

// PanicHandlerMiddlewareOpts returns the []recovery.Option containing
// recovery handler function.
func PanicHandlerMiddlewareOpts(logger Logger) []recovery.Option {
func PanicHandlerMiddlewareOpts(logger rate.Logger) []recovery.Option {
return []recovery.Option{
recovery.WithRecoveryHandler(NewPanicHandler(logger)),
recovery.WithRecoveryHandler(recovery.RecoveryHandlerFunc(rate.NewPanicHandler(logger))),
}
}

// NewPanicHandler returns a recovery.RecoveryHandlerFunc closure function
// to handle panic in GRPC server's handlers.
func NewPanicHandler(logger Logger) recovery.RecoveryHandlerFunc {
return func(p interface{}) (err error) {
// Log the panic and the stack trace of the Goroutine that caused the panic.
stacktrace := hclog.Stacktrace()
logger.Error("panic serving grpc request",
"panic", p,
"stack", stacktrace,
)

return status.Errorf(codes.Internal, "grpc: panic serving request")
}
}

type Logger interface {
Error(string, ...interface{})
Warn(string, ...interface{})
}
2 changes: 1 addition & 1 deletion agent/rpc/middleware/interceptors.go
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ func GetNetRPCInterceptor(recorder *RequestRecorder) rpc.ServerServiceCallInterc
}
}

func GetNetRPCRateLimitingInterceptor(requestLimitsHandler rpcRate.RequestLimitsHandler, panicHandler RecoveryHandlerFunc) rpc.PreBodyInterceptor {
func GetNetRPCRateLimitingInterceptor(requestLimitsHandler rpcRate.RequestLimitsHandler, panicHandler rpcRate.RecoveryHandlerFunc) rpc.PreBodyInterceptor {

return func(reqServiceMethod string, sourceAddr net.Addr) (retErr error) {

Expand Down
4 changes: 2 additions & 2 deletions agent/rpc/middleware/interceptors_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,7 @@ func TestGetNetRPCRateLimitingInterceptor(t *testing.T) {
limiter := rate.NewMockRequestLimitsHandler(t)

logger := hclog.NewNullLogger()
var rateLimitInterceptor = GetNetRPCRateLimitingInterceptor(limiter, NewPanicHandler(logger))
var rateLimitInterceptor = GetNetRPCRateLimitingInterceptor(limiter, rate.NewPanicHandler(logger))

listener, _ := net.Listen("tcp", "127.0.0.1:0")

Expand Down Expand Up @@ -306,6 +306,6 @@ func TestGetNetRPCRateLimitingInterceptor(t *testing.T) {
err := rateLimitInterceptor("Status.Leader", listener.Addr())

require.Error(t, err)
require.Equal(t, "rpc: panic serving request", err.Error())
require.Equal(t, "panic serving request in rate limiter", err.Error())
})
}
24 changes: 0 additions & 24 deletions agent/rpc/middleware/recovery.go

This file was deleted.

0 comments on commit 774d962

Please sign in to comment.