Skip to content

Commit

Permalink
fix: stabilize nats connection (#889)
Browse files Browse the repository at this point in the history
Signed-off-by: Yashash H L <[email protected]>
Signed-off-by: Vigith Maurice <[email protected]>
Co-authored-by: Vigith Maurice <[email protected]>
  • Loading branch information
yhl25 and vigith authored Jul 26, 2023
1 parent d4f8f59 commit 85360f6
Show file tree
Hide file tree
Showing 36 changed files with 593 additions and 764 deletions.
17 changes: 14 additions & 3 deletions pkg/daemon/server/daemon_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,13 +61,24 @@ func NewDaemonServer(pl *v1alpha1.Pipeline, isbSvcType v1alpha1.ISBSvcType) *dae

func (ds *daemonServer) Run(ctx context.Context) error {
log := logging.FromContext(ctx)
var isbSvcClient isbsvc.ISBService
var err error
var (
isbSvcClient isbsvc.ISBService
err error
natsClientPool *jsclient.ClientPool
)

natsClientPool, err = jsclient.NewClientPool(ctx, jsclient.WithClientPoolSize(1))
defer natsClientPool.CloseAll()

if err != nil {
log.Errorw("Failed to get a NATS client pool.", zap.Error(err))
return err
}
switch ds.isbSvcType {
case v1alpha1.ISBSvcTypeRedis:
isbSvcClient = isbsvc.NewISBRedisSvc(redisclient.NewInClusterRedisClient())
case v1alpha1.ISBSvcTypeJetStream:
isbSvcClient, err = isbsvc.NewISBJetStreamSvc(ds.pipeline.Name, isbsvc.WithJetStreamClient(jsclient.NewInClusterJetStreamClient()))
isbSvcClient, err = isbsvc.NewISBJetStreamSvc(ds.pipeline.Name, isbsvc.WithJetStreamClient(natsClientPool.NextAvailableClient()))
if err != nil {
log.Errorw("Failed to get an ISB Service client.", zap.Error(err))
return err
Expand Down
2 changes: 2 additions & 0 deletions pkg/forward/forward.go
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,7 @@ func (isdf *InterStepDataForward) forwardAChunk(ctx context.Context) {
// at-least-once semantics for reading, during restart we will have to reprocess all unacknowledged messages. It is the
// responsibility of the Read function to do that.
readMessages, err := isdf.fromBufferPartition.Read(ctx, isdf.opts.readBatchSize)
isdf.opts.logger.Debugw("Read from buffer", zap.String("bufferFrom", isdf.fromBufferPartition.GetName()), zap.Int64("length", int64(len(readMessages))))
if err != nil {
isdf.opts.logger.Warnw("failed to read fromBufferPartition", zap.Error(err))
readMessagesError.With(map[string]string{metrics.LabelVertex: isdf.vertexName, metrics.LabelPipeline: isdf.pipelineName, metrics.LabelPartitionName: isdf.fromBufferPartition.GetName()}).Inc()
Expand Down Expand Up @@ -336,6 +337,7 @@ func (isdf *InterStepDataForward) forwardAChunk(ctx context.Context) {
isdf.fromBufferPartition.NoAck(ctx, readOffsets)
return
}
isdf.opts.logger.Debugw("writeToBuffers completed")
} else {
writeOffsets, err = isdf.streamMessage(ctx, dataMessages, processorWM)
if err != nil {
Expand Down
88 changes: 18 additions & 70 deletions pkg/isb/stores/jetstream/reader.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,20 +25,17 @@ import (
"time"

"github.com/nats-io/nats.go"
"go.uber.org/zap"
"k8s.io/apimachinery/pkg/util/wait"

"github.com/numaproj/numaflow/pkg/isb"
jsclient "github.com/numaproj/numaflow/pkg/shared/clients/nats"
"github.com/numaproj/numaflow/pkg/shared/logging"
"go.uber.org/zap"
)

type jetStreamReader struct {
name string
stream string
subject string
conn *jsclient.NatsConn
js *jsclient.JetStreamContext
client *jsclient.NATSClient
sub *nats.Subscription
opts *readOptions
inProgressTickDuration time.Duration
Expand All @@ -47,7 +44,7 @@ type jetStreamReader struct {
}

// NewJetStreamBufferReader is used to provide a new JetStream buffer reader connection
func NewJetStreamBufferReader(ctx context.Context, client jsclient.JetStreamClient, name, stream, subject string, partitionIdx int32, opts ...ReadOption) (isb.BufferReader, error) {
func NewJetStreamBufferReader(ctx context.Context, client *jsclient.NATSClient, name, stream, subject string, partitionIdx int32, opts ...ReadOption) (isb.BufferReader, error) {
log := logging.FromContext(ctx).With("bufferReader", name).With("stream", stream).With("subject", subject)
o := defaultReadOptions()
for _, opt := range opts {
Expand All @@ -57,82 +54,40 @@ func NewJetStreamBufferReader(ctx context.Context, client jsclient.JetStreamClie
}
}
}
result := &jetStreamReader{
reader := &jetStreamReader{
name: name,
stream: stream,
subject: subject,
client: client,
partitionIdx: partitionIdx,
opts: o,
log: log,
}

connectAndSubscribe := func() (*jsclient.NatsConn, *jsclient.JetStreamContext, *nats.Subscription, error) {
conn, err := client.Connect(ctx, jsclient.ReconnectHandler(func(c *jsclient.NatsConn) {
if result.js == nil {
log.Error("JetStreamContext is nil")
return
}
var e error
_ = wait.ExponentialBackoffWithContext(ctx, wait.Backoff{
Steps: 5,
Duration: 1 * time.Second,
Factor: 1.0,
Jitter: 0.1,
}, func() (bool, error) {
var s *nats.Subscription
if s, e = result.js.PullSubscribe(subject, stream, nats.Bind(stream, stream)); e != nil {
log.Errorw("Failed to re-subscribe to the stream after reconnection, will retry if the limit is not reached", zap.Error(e))
return false, nil
} else {
result.sub = s
log.Info("Re-subscribed to the stream successfully")
return true, nil
}
})
if e != nil {
// Let it panic to start over
log.Fatalw("Failed to re-subscribe after retries", zap.Error(e))
}
}), jsclient.DisconnectErrHandler(func(nc *jsclient.NatsConn, err error) {
log.Errorw("Nats JetStream connection lost", zap.Error(err))
}))
if err != nil {
return nil, nil, nil, fmt.Errorf("failed to get nats connection, %w", err)
}
js, err := conn.JetStream()
if err != nil {
conn.Close()
return nil, nil, nil, fmt.Errorf("failed to get jetstream context, %w", err)
}
sub, err := js.PullSubscribe(subject, stream, nats.Bind(stream, stream))
if err != nil {
conn.Close()
return nil, nil, nil, fmt.Errorf("failed to subscribe jet stream subject %q, %w", subject, err)
}
return conn, js, sub, nil
}

conn, js, sub, err := connectAndSubscribe()
jsContext, err := reader.client.JetStreamContext()
if err != nil {
return nil, err
return nil, fmt.Errorf("failed to get JetStream context, %w", err)
}

consumer, err := js.ConsumerInfo(stream, stream)
consumer, err := jsContext.ConsumerInfo(stream, stream)
if err != nil {
conn.Close()
return nil, fmt.Errorf("failed to get consumer info, %w", err)
}

// If ackWait is 3s, ticks every 2s.
inProgessTickSeconds := int64(consumer.Config.AckWait.Seconds() * 2 / 3)
if inProgessTickSeconds < 1 {
inProgessTickSeconds = 1
}

result.conn = conn
result.js = js
result.sub = sub
result.inProgressTickDuration = time.Duration(inProgessTickSeconds * int64(time.Second))
return result, nil
sub, err := reader.client.Subscribe(subject, stream, nats.Bind(stream, stream))
if err != nil {
return nil, fmt.Errorf("failed to subscribe to subject %q, %w", subject, err)
}

reader.sub = sub
reader.inProgressTickDuration = time.Duration(inProgessTickSeconds * int64(time.Second))
return reader, nil
}

func (jr *jetStreamReader) GetName() string {
Expand All @@ -149,18 +104,11 @@ func (jr *jetStreamReader) Close() error {
jr.log.Errorw("Failed to unsubscribe", zap.Error(err))
}
}
if jr.conn != nil && !jr.conn.IsClosed() {
jr.conn.Close()
}
return nil
}

func (jr *jetStreamReader) Pending(_ context.Context) (int64, error) {
c, err := jr.js.ConsumerInfo(jr.stream, jr.stream)
if err != nil {
return isb.PendingNotAvailable, fmt.Errorf("failed to get consumer info, %w", err)
}
return int64(c.NumPending) + int64(c.NumAckPending), nil
return jr.client.PendingForStream(jr.stream, jr.stream)
}

func (jr *jetStreamReader) Read(_ context.Context, count int64) ([]*isb.ReadMessage, error) {
Expand Down
25 changes: 9 additions & 16 deletions pkg/isb/stores/jetstream/reader_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ import (

"github.com/numaproj/numaflow/pkg/isb"
"github.com/numaproj/numaflow/pkg/isb/testutils"
jsclient "github.com/numaproj/numaflow/pkg/shared/clients/nats"
natstest "github.com/numaproj/numaflow/pkg/shared/clients/nats/test"
)

Expand All @@ -46,10 +45,8 @@ func TestJetStreamBufferRead(t *testing.T) {
defer cancel()

defaultJetStreamClient := natstest.JetStreamClient(t, s)
conn, err := defaultJetStreamClient.Connect(ctx)
assert.NoError(t, err)
defer conn.Close()
js, err := conn.JetStream()
defer defaultJetStreamClient.Close()
js, err := defaultJetStreamClient.JetStreamContext()
assert.NoError(t, err)

streamName := "testJetStreamBufferReader"
Expand Down Expand Up @@ -102,7 +99,7 @@ func TestJetStreamBufferRead(t *testing.T) {
}
assert.Equal(t, 20, len(offsetsInsideReadMessages))

fromStepJs, err := fromStep.conn.JetStream()
fromStepJs, err := fromStep.client.JetStreamContext()
assert.NoError(t, err)
streamInfo, err := fromStepJs.StreamInfo(streamName)
assert.NoError(t, err)
Expand Down Expand Up @@ -141,11 +138,9 @@ func TestGetName(t *testing.T) {
defer cancel()

defaultJetStreamClient := natstest.JetStreamClient(t, s)
conn, err := defaultJetStreamClient.Connect(ctx)
assert.NoError(t, err)
js, err := conn.JetStream()
js, err := defaultJetStreamClient.JetStreamContext()
assert.NoError(t, err)
defer conn.Close()
defer defaultJetStreamClient.Close()

streamName := "getName"
addStream(t, js, streamName)
Expand All @@ -168,10 +163,8 @@ func TestClose(t *testing.T) {
defer cancel()

defaultJetStreamClient := natstest.JetStreamClient(t, s)
conn, err := defaultJetStreamClient.Connect(ctx)
assert.NoError(t, err)
defer conn.Close()
js, err := conn.JetStream()
defer defaultJetStreamClient.Close()
js, err := defaultJetStreamClient.JetStreamContext()
assert.NoError(t, err)

streamName := "close"
Expand All @@ -186,7 +179,7 @@ func TestClose(t *testing.T) {

}

func addStream(t *testing.T, js *jsclient.JetStreamContext, streamName string) {
func addStream(t *testing.T, js nats.JetStreamContext, streamName string) {

_, err := js.AddStream(&nats.StreamConfig{
Name: streamName,
Expand All @@ -210,7 +203,7 @@ func addStream(t *testing.T, js *jsclient.JetStreamContext, streamName string) {

}

func deleteStream(js *jsclient.JetStreamContext, streamName string) {
func deleteStream(js nats.JetStreamContext, streamName string) {
_ = js.DeleteConsumer(streamName, streamName)
_ = js.DeleteStream(streamName)
}
22 changes: 8 additions & 14 deletions pkg/isb/stores/jetstream/writer.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,15 +38,15 @@ type jetStreamWriter struct {
partitionIdx int32
stream string
subject string
conn *jsclient.NatsConn
js *jsclient.JetStreamContext
client *jsclient.NATSClient
js nats.JetStreamContext
opts *writeOptions
isFull *atomic.Bool
log *zap.SugaredLogger
}

// NewJetStreamBufferWriter is used to provide a new instance of JetStreamBufferWriter
func NewJetStreamBufferWriter(ctx context.Context, client jsclient.JetStreamClient, name, stream, subject string, partitionIdx int32, opts ...WriteOption) (isb.BufferWriter, error) {
func NewJetStreamBufferWriter(ctx context.Context, client *jsclient.NATSClient, name, stream, subject string, partitionIdx int32, opts ...WriteOption) (isb.BufferWriter, error) {
o := defaultWriteOptions()
for _, opt := range opts {
if opt != nil {
Expand All @@ -55,14 +55,10 @@ func NewJetStreamBufferWriter(ctx context.Context, client jsclient.JetStreamClie
}
}
}
conn, err := client.Connect(ctx)
if err != nil {
return nil, fmt.Errorf("failed to get nats connection, %w", err)
}

js, err := conn.JetStream(nats.PublishAsyncMaxPending(1024))
js, err := client.JetStreamContext(nats.PublishAsyncMaxPending(1024))

if err != nil {
conn.Close()
return nil, fmt.Errorf("failed to get JetStream context for writer")
}

Expand All @@ -71,7 +67,7 @@ func NewJetStreamBufferWriter(ctx context.Context, client jsclient.JetStreamClie
partitionIdx: partitionIdx,
stream: stream,
subject: subject,
conn: conn,
client: client,
js: js,
opts: o,
isFull: atomic.NewBool(true),
Expand All @@ -85,7 +81,7 @@ func NewJetStreamBufferWriter(ctx context.Context, client jsclient.JetStreamClie
func (jw *jetStreamWriter) runStatusChecker(ctx context.Context) {
labels := map[string]string{"buffer": jw.GetName()}
// Use a separated JetStream context for status checker
js, err := jw.conn.JetStream()
js, err := jw.client.JetStreamContext()
if err != nil {
// Let it exit if it fails to start the status checker
jw.log.Fatal("Failed to get Jet Stream context, %w", err)
Expand Down Expand Up @@ -150,10 +146,8 @@ func (jw *jetStreamWriter) GetPartitionIdx() int32 {
return jw.partitionIdx
}

// Close doesn't have to do anything for JetStreamBufferWriter, client will be closed by the caller.
func (jw *jetStreamWriter) Close() error {
if jw.conn != nil && !jw.conn.IsClosed() {
jw.conn.Close()
}
return nil
}

Expand Down
Loading

0 comments on commit 85360f6

Please sign in to comment.