Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: agent propagation #3654

Merged
merged 7 commits into from
Feb 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions agent/client/connector.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ import (
"net"
"time"

"go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc"
"go.opentelemetry.io/otel/propagation"
"go.uber.org/zap"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
Expand Down Expand Up @@ -73,6 +75,12 @@ func (c *Client) connect(ctx context.Context) error {
grpc.WithTransportCredentials(transportCredentials),
grpc.WithDefaultServiceConfig(retryPolicy),
grpc.WithIdleTimeout(0), // disable grpc idle timeout
grpc.WithStatsHandler(otelgrpc.NewClientHandler()),
grpc.WithUnaryInterceptor(otelgrpc.UnaryClientInterceptor(
otelgrpc.WithPropagators(
propagation.NewCompositeTextMapPropagator(propagation.TraceContext{}, propagation.Baggage{}),
),
)),
)
if err != nil {
return fmt.Errorf("could not connect to server: %w", err)
Expand Down
118 changes: 79 additions & 39 deletions agent/client/mocks/grpc_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,33 +10,41 @@ import (
"github.com/avast/retry-go"
"github.com/kubeshop/tracetest/agent/client"
"github.com/kubeshop/tracetest/agent/proto"
"github.com/kubeshop/tracetest/server/telemetry"
"go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc"
"go.opentelemetry.io/otel/propagation"
"google.golang.org/grpc"
)

type GrpcServerMock struct {
proto.UnimplementedOrchestratorServer
port int
triggerChannel chan *proto.TriggerRequest
pollingChannel chan *proto.PollingRequest
otlpConnectionTestChannel chan *proto.OTLPConnectionTestRequest
terminationChannel chan *proto.ShutdownRequest
dataStoreTestChannel chan *proto.DataStoreConnectionTestRequest
triggerChannel chan Message[*proto.TriggerRequest]
pollingChannel chan Message[*proto.PollingRequest]
otlpConnectionTestChannel chan Message[*proto.OTLPConnectionTestRequest]
terminationChannel chan Message[*proto.ShutdownRequest]
dataStoreTestChannel chan Message[*proto.DataStoreConnectionTestRequest]

lastTriggerResponse *proto.TriggerResponse
lastPollingResponse *proto.PollingResponse
lastOtlpConnectionResponse *proto.OTLPConnectionTestResponse
lastDataStoreConnectionResponse *proto.DataStoreConnectionTestResponse
lastTriggerResponse Message[*proto.TriggerResponse]
lastPollingResponse Message[*proto.PollingResponse]
lastOtlpConnectionResponse Message[*proto.OTLPConnectionTestResponse]
lastDataStoreConnectionResponse Message[*proto.DataStoreConnectionTestResponse]

server *grpc.Server
}

type Message[T any] struct {
Context context.Context
Data T
}

func NewGrpcServer() *GrpcServerMock {
server := &GrpcServerMock{
triggerChannel: make(chan *proto.TriggerRequest),
pollingChannel: make(chan *proto.PollingRequest),
terminationChannel: make(chan *proto.ShutdownRequest),
dataStoreTestChannel: make(chan *proto.DataStoreConnectionTestRequest),
otlpConnectionTestChannel: make(chan *proto.OTLPConnectionTestRequest),
triggerChannel: make(chan Message[*proto.TriggerRequest]),
pollingChannel: make(chan Message[*proto.PollingRequest]),
terminationChannel: make(chan Message[*proto.ShutdownRequest]),
dataStoreTestChannel: make(chan Message[*proto.DataStoreConnectionTestRequest]),
otlpConnectionTestChannel: make(chan Message[*proto.OTLPConnectionTestRequest]),
}
var wg sync.WaitGroup
wg.Add(1)
Expand Down Expand Up @@ -65,7 +73,13 @@ func (s *GrpcServerMock) start(wg *sync.WaitGroup, port int) error {

s.port = lis.Addr().(*net.TCPAddr).Port

server := grpc.NewServer()
server := grpc.NewServer(
grpc.UnaryInterceptor(otelgrpc.UnaryServerInterceptor(
otelgrpc.WithPropagators(
propagation.NewCompositeTextMapPropagator(propagation.TraceContext{}, propagation.Baggage{}),
),
)),
)
proto.RegisterOrchestratorServer(server, s)

s.server = server
Expand Down Expand Up @@ -107,7 +121,12 @@ func (s *GrpcServerMock) RegisterTriggerAgent(id *proto.AgentIdentification, str

for {
triggerRequest := <-s.triggerChannel
err := stream.Send(triggerRequest)
err := telemetry.InjectContextIntoStream(triggerRequest.Context, stream)
if err != nil {
log.Println(err.Error())
}

err = stream.Send(triggerRequest.Data)
if err != nil {
log.Println("could not send trigger request to agent: %w", err)
}
Expand All @@ -120,7 +139,7 @@ func (s *GrpcServerMock) SendTriggerResult(ctx context.Context, result *proto.Tr
return nil, fmt.Errorf("could not validate token")
}

s.lastTriggerResponse = result
s.lastTriggerResponse = Message[*proto.TriggerResponse]{Data: result, Context: ctx}
return &proto.Empty{}, nil
}

Expand All @@ -131,7 +150,12 @@ func (s *GrpcServerMock) RegisterPollerAgent(id *proto.AgentIdentification, stre

for {
pollerRequest := <-s.pollingChannel
err := stream.Send(pollerRequest)
err := telemetry.InjectContextIntoStream(pollerRequest.Context, stream)
if err != nil {
log.Println(err.Error())
}

err = stream.Send(pollerRequest.Data)
if err != nil {
log.Println("could not send polling request to agent: %w", err)
}
Expand All @@ -145,7 +169,12 @@ func (s *GrpcServerMock) RegisterDataStoreConnectionTestAgent(id *proto.AgentIde

for {
dsTestRequest := <-s.dataStoreTestChannel
err := stream.Send(dsTestRequest)
err := telemetry.InjectContextIntoStream(dsTestRequest.Context, stream)
if err != nil {
log.Println(err.Error())
}

err = stream.Send(dsTestRequest.Data)
if err != nil {
log.Println("could not send polling request to agent: %w", err)
}
Expand All @@ -159,7 +188,12 @@ func (s *GrpcServerMock) RegisterOTLPConnectionTestListener(id *proto.AgentIdent

for {
testRequest := <-s.otlpConnectionTestChannel
err := stream.Send(testRequest)
err := telemetry.InjectContextIntoStream(testRequest.Context, stream)
if err != nil {
log.Println(err.Error())
}

err = stream.Send(testRequest.Data)
if err != nil {
log.Println("could not send polling request to agent: %w", err)
}
Expand All @@ -171,7 +205,7 @@ func (s *GrpcServerMock) SendOTLPConnectionTestResult(ctx context.Context, resul
return nil, fmt.Errorf("could not validate token")
}

s.lastOtlpConnectionResponse = result
s.lastOtlpConnectionResponse = Message[*proto.OTLPConnectionTestResponse]{Data: result, Context: ctx}
return &proto.Empty{}, nil
}

Expand All @@ -180,7 +214,7 @@ func (s *GrpcServerMock) SendDataStoreConnectionTestResult(ctx context.Context,
return nil, fmt.Errorf("could not validate token")
}

s.lastDataStoreConnectionResponse = result
s.lastDataStoreConnectionResponse = Message[*proto.DataStoreConnectionTestResponse]{Data: result, Context: ctx}
return &proto.Empty{}, nil
}

Expand All @@ -189,14 +223,19 @@ func (s *GrpcServerMock) SendPolledSpans(ctx context.Context, result *proto.Poll
return nil, fmt.Errorf("could not validate token")
}

s.lastPollingResponse = result
s.lastPollingResponse = Message[*proto.PollingResponse]{Data: result, Context: ctx}
return &proto.Empty{}, nil
}

func (s *GrpcServerMock) RegisterShutdownListener(_ *proto.AgentIdentification, stream proto.Orchestrator_RegisterShutdownListenerServer) error {
for {
shutdownRequest := <-s.terminationChannel
err := stream.Send(shutdownRequest)
err := telemetry.InjectContextIntoStream(shutdownRequest.Context, stream)
if err != nil {
log.Println(err.Error())
}

err = stream.Send(shutdownRequest.Data)
if err != nil {
log.Println("could not send polling request to agent: %w", err)
}
Expand All @@ -205,41 +244,42 @@ func (s *GrpcServerMock) RegisterShutdownListener(_ *proto.AgentIdentification,

// Test methods

func (s *GrpcServerMock) SendTriggerRequest(request *proto.TriggerRequest) {
s.triggerChannel <- request
func (s *GrpcServerMock) SendTriggerRequest(ctx context.Context, request *proto.TriggerRequest) {
s.triggerChannel <- Message[*proto.TriggerRequest]{Context: ctx, Data: request}
}

func (s *GrpcServerMock) SendPollingRequest(request *proto.PollingRequest) {
s.pollingChannel <- request
func (s *GrpcServerMock) SendPollingRequest(ctx context.Context, request *proto.PollingRequest) {
s.pollingChannel <- Message[*proto.PollingRequest]{Context: ctx, Data: request}
}

func (s *GrpcServerMock) SendDataStoreConnectionTestRequest(request *proto.DataStoreConnectionTestRequest) {
s.dataStoreTestChannel <- request
func (s *GrpcServerMock) SendDataStoreConnectionTestRequest(ctx context.Context, request *proto.DataStoreConnectionTestRequest) {
s.dataStoreTestChannel <- Message[*proto.DataStoreConnectionTestRequest]{Context: ctx, Data: request}
}

func (s *GrpcServerMock) SendOTLPConnectionTestRequest(request *proto.OTLPConnectionTestRequest) {
s.otlpConnectionTestChannel <- request
func (s *GrpcServerMock) SendOTLPConnectionTestRequest(ctx context.Context, request *proto.OTLPConnectionTestRequest) {
s.otlpConnectionTestChannel <- Message[*proto.OTLPConnectionTestRequest]{Context: ctx, Data: request}
}

func (s *GrpcServerMock) GetLastTriggerResponse() *proto.TriggerResponse {
func (s *GrpcServerMock) GetLastTriggerResponse() Message[*proto.TriggerResponse] {
return s.lastTriggerResponse
}

func (s *GrpcServerMock) GetLastPollingResponse() *proto.PollingResponse {
func (s *GrpcServerMock) GetLastPollingResponse() Message[*proto.PollingResponse] {
return s.lastPollingResponse
}

func (s *GrpcServerMock) GetLastOTLPConnectionResponse() *proto.OTLPConnectionTestResponse {
func (s *GrpcServerMock) GetLastOTLPConnectionResponse() Message[*proto.OTLPConnectionTestResponse] {
return s.lastOtlpConnectionResponse
}

func (s *GrpcServerMock) GetLastDataStoreConnectionResponse() *proto.DataStoreConnectionTestResponse {
func (s *GrpcServerMock) GetLastDataStoreConnectionResponse() Message[*proto.DataStoreConnectionTestResponse] {
return s.lastDataStoreConnectionResponse
}

func (s *GrpcServerMock) TerminateConnection(reason string) {
s.terminationChannel <- &proto.ShutdownRequest{
Reason: reason,
func (s *GrpcServerMock) TerminateConnection(ctx context.Context, reason string) {
s.terminationChannel <- Message[*proto.ShutdownRequest]{
Context: ctx,
Data: &proto.ShutdownRequest{Reason: reason},
}
}

Expand Down
9 changes: 7 additions & 2 deletions agent/client/workflow_listen_for_ds_connection_tests.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"time"

"github.com/kubeshop/tracetest/agent/proto"
"github.com/kubeshop/tracetest/server/telemetry"
)

func (c *Client) startDataStoreConnectionTestListener(ctx context.Context) error {
Expand Down Expand Up @@ -36,8 +37,12 @@ func (c *Client) startDataStoreConnectionTestListener(ctx context.Context) error
continue
}

// TODO: Get ctx from request
err = c.dataStoreConnectionListener(context.Background(), &req)
ctx, err := telemetry.ExtractContextFromStream(stream)
if err != nil {
log.Println("could not extract context from stream %w", err)
}

err = c.dataStoreConnectionListener(ctx, &req)
if err != nil {
fmt.Println(err.Error())
}
Expand Down
7 changes: 4 additions & 3 deletions agent/client/workflow_listen_for_ds_connection_tests_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,11 @@ import (
)

func TestDataStoreConnectionTestWorkflow(t *testing.T) {
ctx := context.Background()
server := mocks.NewGrpcServer()
defer server.Stop()

client, err := client.Connect(context.Background(), server.Addr())
client, err := client.Connect(ctx, server.Addr())
require.NoError(t, err)

var receivedConnectionTestRequest *proto.DataStoreConnectionTestRequest
Expand All @@ -25,14 +26,14 @@ func TestDataStoreConnectionTestWorkflow(t *testing.T) {
return nil
})

err = client.Start(context.Background())
err = client.Start(ctx)
require.NoError(t, err)

connectionTestRequest := &proto.DataStoreConnectionTestRequest{
RequestID: "request-id",
}

server.SendDataStoreConnectionTestRequest(connectionTestRequest)
server.SendDataStoreConnectionTestRequest(ctx, connectionTestRequest)

// ensures there's enough time for networking between server and client
time.Sleep(1 * time.Second)
Expand Down
9 changes: 7 additions & 2 deletions agent/client/workflow_listen_for_otlp_connection_tests.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"time"

"github.com/kubeshop/tracetest/agent/proto"
"github.com/kubeshop/tracetest/server/telemetry"
)

func (c *Client) startOTLPConnectionTestListener(ctx context.Context) error {
Expand Down Expand Up @@ -36,8 +37,12 @@ func (c *Client) startOTLPConnectionTestListener(ctx context.Context) error {
continue
}

// TODO: Get ctx from request
err = c.otlpConnectionTestListener(context.Background(), &req)
ctx, err := telemetry.ExtractContextFromStream(stream)
if err != nil {
log.Println("could not extract context from stream %w", err)
}

err = c.otlpConnectionTestListener(ctx, &req)
if err != nil {
fmt.Println(err.Error())
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,11 @@ import (
)

func TestOtlpConnectionTestWorkflow(t *testing.T) {
ctx := context.Background()
server := mocks.NewGrpcServer()
defer server.Stop()

client, err := client.Connect(context.Background(), server.Addr())
client, err := client.Connect(ctx, server.Addr())
require.NoError(t, err)

var receivedConnectionTestRequest *proto.OTLPConnectionTestRequest
Expand All @@ -25,14 +26,14 @@ func TestOtlpConnectionTestWorkflow(t *testing.T) {
return nil
})

err = client.Start(context.Background())
err = client.Start(ctx)
require.NoError(t, err)

connectionTestRequest := &proto.OTLPConnectionTestRequest{
RequestID: "request-id",
}

server.SendOTLPConnectionTestRequest(connectionTestRequest)
server.SendOTLPConnectionTestRequest(ctx, connectionTestRequest)

// ensures there's enough time for networking between server and client
time.Sleep(1 * time.Second)
Expand Down
9 changes: 7 additions & 2 deletions agent/client/workflow_listen_for_poll_requests.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"time"

"github.com/kubeshop/tracetest/agent/proto"
"github.com/kubeshop/tracetest/server/telemetry"
)

func (c *Client) startPollerListener(ctx context.Context) error {
Expand Down Expand Up @@ -36,8 +37,12 @@ func (c *Client) startPollerListener(ctx context.Context) error {
continue
}

// TODO: Get ctx from request
err = c.pollListener(context.Background(), &resp)
ctx, err := telemetry.ExtractContextFromStream(stream)
if err != nil {
log.Println("could not extract context from stream %w", err)
}

err = c.pollListener(ctx, &resp)
if err != nil {
fmt.Println(err.Error())
}
Expand Down
Loading
Loading