From ff36dc7233baef23dbc4d41969787074421c0045 Mon Sep 17 00:00:00 2001 From: Matheus Nogueira Date: Fri, 16 Feb 2024 11:10:06 -0300 Subject: [PATCH 1/5] add otelgrpc to agent and update grpc server to accept context --- agent/client/connector.go | 2 ++ agent/client/mocks/grpc_server.go | 11 ++++++----- .../workflow_listen_for_ds_connection_tests_test.go | 7 ++++--- .../workflow_listen_for_otlp_connection_tests_test.go | 7 ++++--- .../client/workflow_listen_for_poll_requests_test.go | 9 +++++---- .../workflow_listen_for_trigger_requests_test.go | 9 +++++---- agent/workers/poller_test.go | 8 ++++---- agent/workers/trigger_test.go | 9 ++++++--- go.mod | 2 +- go.sum | 2 ++ 10 files changed, 39 insertions(+), 27 deletions(-) diff --git a/agent/client/connector.go b/agent/client/connector.go index bf52526e8b..3a0a66faaa 100644 --- a/agent/client/connector.go +++ b/agent/client/connector.go @@ -7,6 +7,7 @@ import ( "net" "time" + "go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc" "go.uber.org/zap" "google.golang.org/grpc" "google.golang.org/grpc/credentials" @@ -73,6 +74,7 @@ func (c *Client) connect(ctx context.Context) error { grpc.WithTransportCredentials(transportCredentials), grpc.WithDefaultServiceConfig(retryPolicy), grpc.WithIdleTimeout(0), // disable grpc idle timeout + grpc.WithStatsHandler(otelgrpc.NewClientHandler()), ) if err != nil { return fmt.Errorf("could not connect to server: %w", err) diff --git a/agent/client/mocks/grpc_server.go b/agent/client/mocks/grpc_server.go index ca6dca7b38..59fa9dd4c7 100644 --- a/agent/client/mocks/grpc_server.go +++ b/agent/client/mocks/grpc_server.go @@ -10,6 +10,7 @@ import ( "github.com/avast/retry-go" "github.com/kubeshop/tracetest/agent/client" "github.com/kubeshop/tracetest/agent/proto" + "go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc" "google.golang.org/grpc" ) @@ -65,7 +66,7 @@ func (s *GrpcServerMock) start(wg *sync.WaitGroup, port int) error { s.port = lis.Addr().(*net.TCPAddr).Port - server := grpc.NewServer() + server := grpc.NewServer(grpc.StatsHandler(otelgrpc.NewServerHandler())) proto.RegisterOrchestratorServer(server, s) s.server = server @@ -205,19 +206,19 @@ func (s *GrpcServerMock) RegisterShutdownListener(_ *proto.AgentIdentification, // Test methods -func (s *GrpcServerMock) SendTriggerRequest(request *proto.TriggerRequest) { +func (s *GrpcServerMock) SendTriggerRequest(ctx context.Context, request *proto.TriggerRequest) { s.triggerChannel <- request } -func (s *GrpcServerMock) SendPollingRequest(request *proto.PollingRequest) { +func (s *GrpcServerMock) SendPollingRequest(ctx context.Context, request *proto.PollingRequest) { s.pollingChannel <- request } -func (s *GrpcServerMock) SendDataStoreConnectionTestRequest(request *proto.DataStoreConnectionTestRequest) { +func (s *GrpcServerMock) SendDataStoreConnectionTestRequest(ctx context.Context, request *proto.DataStoreConnectionTestRequest) { s.dataStoreTestChannel <- request } -func (s *GrpcServerMock) SendOTLPConnectionTestRequest(request *proto.OTLPConnectionTestRequest) { +func (s *GrpcServerMock) SendOTLPConnectionTestRequest(ctx context.Context, request *proto.OTLPConnectionTestRequest) { s.otlpConnectionTestChannel <- request } diff --git a/agent/client/workflow_listen_for_ds_connection_tests_test.go b/agent/client/workflow_listen_for_ds_connection_tests_test.go index 7aa5090b83..d0efb0d1bc 100644 --- a/agent/client/workflow_listen_for_ds_connection_tests_test.go +++ b/agent/client/workflow_listen_for_ds_connection_tests_test.go @@ -13,10 +13,11 @@ import ( ) func TestDataStoreConnectionTestWorkflow(t *testing.T) { + ctx := context.Background() server := mocks.NewGrpcServer() defer server.Stop() - client, err := client.Connect(context.Background(), server.Addr()) + client, err := client.Connect(ctx, server.Addr()) require.NoError(t, err) var receivedConnectionTestRequest *proto.DataStoreConnectionTestRequest @@ -25,14 +26,14 @@ func TestDataStoreConnectionTestWorkflow(t *testing.T) { return nil }) - err = client.Start(context.Background()) + err = client.Start(ctx) require.NoError(t, err) connectionTestRequest := &proto.DataStoreConnectionTestRequest{ RequestID: "request-id", } - server.SendDataStoreConnectionTestRequest(connectionTestRequest) + server.SendDataStoreConnectionTestRequest(ctx, connectionTestRequest) // ensures there's enough time for networking between server and client time.Sleep(1 * time.Second) diff --git a/agent/client/workflow_listen_for_otlp_connection_tests_test.go b/agent/client/workflow_listen_for_otlp_connection_tests_test.go index 9041a13457..64ad665782 100644 --- a/agent/client/workflow_listen_for_otlp_connection_tests_test.go +++ b/agent/client/workflow_listen_for_otlp_connection_tests_test.go @@ -13,10 +13,11 @@ import ( ) func TestOtlpConnectionTestWorkflow(t *testing.T) { + ctx := context.Background() server := mocks.NewGrpcServer() defer server.Stop() - client, err := client.Connect(context.Background(), server.Addr()) + client, err := client.Connect(ctx, server.Addr()) require.NoError(t, err) var receivedConnectionTestRequest *proto.OTLPConnectionTestRequest @@ -25,14 +26,14 @@ func TestOtlpConnectionTestWorkflow(t *testing.T) { return nil }) - err = client.Start(context.Background()) + err = client.Start(ctx) require.NoError(t, err) connectionTestRequest := &proto.OTLPConnectionTestRequest{ RequestID: "request-id", } - server.SendOTLPConnectionTestRequest(connectionTestRequest) + server.SendOTLPConnectionTestRequest(ctx, connectionTestRequest) // ensures there's enough time for networking between server and client time.Sleep(1 * time.Second) diff --git a/agent/client/workflow_listen_for_poll_requests_test.go b/agent/client/workflow_listen_for_poll_requests_test.go index 2019dab411..a629798f70 100644 --- a/agent/client/workflow_listen_for_poll_requests_test.go +++ b/agent/client/workflow_listen_for_poll_requests_test.go @@ -13,10 +13,11 @@ import ( ) func TestPollWorkflow(t *testing.T) { + ctx := context.Background() server := mocks.NewGrpcServer() defer server.Stop() - client, err := client.Connect(context.Background(), server.Addr()) + client, err := client.Connect(ctx, server.Addr()) require.NoError(t, err) var receivedPollingRequest *proto.PollingRequest @@ -25,7 +26,7 @@ func TestPollWorkflow(t *testing.T) { return nil }) - err = client.Start(context.Background()) + err = client.Start(ctx) require.NoError(t, err) pollingRequest := &proto.PollingRequest{ @@ -44,7 +45,7 @@ func TestPollWorkflow(t *testing.T) { }, } - server.SendPollingRequest(pollingRequest) + server.SendPollingRequest(ctx, pollingRequest) // ensures there's enough time for networking between server and client time.Sleep(1 * time.Second) @@ -78,7 +79,7 @@ func TestPollWorkflow(t *testing.T) { }, } - server.SendPollingRequest(anotherPollingRequest) + server.SendPollingRequest(ctx, anotherPollingRequest) // ensures there's enough time for networking between server and client time.Sleep(1 * time.Second) diff --git a/agent/client/workflow_listen_for_trigger_requests_test.go b/agent/client/workflow_listen_for_trigger_requests_test.go index 082f29bea6..117ae78e66 100644 --- a/agent/client/workflow_listen_for_trigger_requests_test.go +++ b/agent/client/workflow_listen_for_trigger_requests_test.go @@ -13,10 +13,11 @@ import ( ) func TestTriggerWorkflow(t *testing.T) { + ctx := context.Background() server := mocks.NewGrpcServer() defer server.Stop() - client, err := client.Connect(context.Background(), server.Addr()) + client, err := client.Connect(ctx, server.Addr()) require.NoError(t, err) var receivedTrigger *proto.TriggerRequest @@ -25,7 +26,7 @@ func TestTriggerWorkflow(t *testing.T) { return nil }) - err = client.Start(context.Background()) + err = client.Start(ctx) require.NoError(t, err) triggerRequest := &proto.TriggerRequest{ @@ -43,7 +44,7 @@ func TestTriggerWorkflow(t *testing.T) { }, } - server.SendTriggerRequest(triggerRequest) + server.SendTriggerRequest(ctx, triggerRequest) // ensures there's enough time for networking between server and client time.Sleep(1 * time.Second) @@ -75,7 +76,7 @@ func TestTriggerWorkflow(t *testing.T) { }, } - server.SendTriggerRequest(anotherTriggerRequest) + server.SendTriggerRequest(ctx, anotherTriggerRequest) // ensures there's enough time for networking between server and client time.Sleep(1 * time.Second) diff --git a/agent/workers/poller_test.go b/agent/workers/poller_test.go index f55d855d28..64ffacf54e 100644 --- a/agent/workers/poller_test.go +++ b/agent/workers/poller_test.go @@ -52,7 +52,7 @@ func TestPollerWorker(t *testing.T) { }, } - controlPlane.SendPollingRequest(&pollingRequest) + controlPlane.SendPollingRequest(ctx, &pollingRequest) time.Sleep(1 * time.Second) @@ -146,7 +146,7 @@ func TestPollerWorkerWithInmemoryDatastore(t *testing.T) { }, } - controlPlane.SendPollingRequest(&pollingRequest) + controlPlane.SendPollingRequest(ctx, &pollingRequest) time.Sleep(1 * time.Second) @@ -165,7 +165,7 @@ func TestPollerWorkerWithInmemoryDatastore(t *testing.T) { {Name: "span 2", ParentSpanId: span1ID[:], SpanId: span2ID[:], TraceId: traceID[:]}, }) - controlPlane.SendPollingRequest(&pollingRequest) + controlPlane.SendPollingRequest(ctx, &pollingRequest) time.Sleep(1 * time.Second) @@ -208,7 +208,7 @@ func TestPollerWithInvalidDataStore(t *testing.T) { }, } - controlPlane.SendPollingRequest(&pollingRequest) + controlPlane.SendPollingRequest(ctx, &pollingRequest) time.Sleep(1 * time.Second) diff --git a/agent/workers/trigger_test.go b/agent/workers/trigger_test.go index 143c1944fd..dc46aceed4 100644 --- a/agent/workers/trigger_test.go +++ b/agent/workers/trigger_test.go @@ -41,6 +41,7 @@ func setupTriggerWorker(t *testing.T) (*mocks.GrpcServerMock, collector.TraceCac } func TestTrigger(t *testing.T) { + ctx := context.Background() controlPlane, cache := setupTriggerWorker(t) targetServer := createHelloWorldApi() @@ -63,7 +64,7 @@ func TestTrigger(t *testing.T) { } // make the control plane send a trigger request to the agent - controlPlane.SendTriggerRequest(triggerRequest) + controlPlane.SendTriggerRequest(ctx, triggerRequest) time.Sleep(1 * time.Second) response := controlPlane.GetLastTriggerResponse() @@ -78,6 +79,7 @@ func TestTrigger(t *testing.T) { } func TestTriggerAgainstGoogle(t *testing.T) { + ctx := context.Background() controlPlane, _ := setupTriggerWorker(t) traceID := "42a2c381da1a5b3a32bc4988bf2431b0" @@ -99,7 +101,7 @@ func TestTriggerAgainstGoogle(t *testing.T) { } // make the control plane send a trigger request to the agent - controlPlane.SendTriggerRequest(triggerRequest) + controlPlane.SendTriggerRequest(ctx, triggerRequest) time.Sleep(1 * time.Second) response := controlPlane.GetLastTriggerResponse() @@ -110,6 +112,7 @@ func TestTriggerAgainstGoogle(t *testing.T) { } func TestTriggerInexistentAPI(t *testing.T) { + ctx := context.Background() controlPlane, _ := setupTriggerWorker(t) traceID := "42a2c381da1a5b3a32bc4988bf2431b0" @@ -131,7 +134,7 @@ func TestTriggerInexistentAPI(t *testing.T) { } // make the control plane send a trigger request to the agent - controlPlane.SendTriggerRequest(triggerRequest) + controlPlane.SendTriggerRequest(ctx, triggerRequest) time.Sleep(1 * time.Second) response := controlPlane.GetLastTriggerResponse() diff --git a/go.mod b/go.mod index e481c05224..eb6e2f901e 100644 --- a/go.mod +++ b/go.mod @@ -64,7 +64,7 @@ require ( go.opentelemetry.io/collector/config/configtls v0.80.0 go.opentelemetry.io/collector/pdata v1.0.0-rcv0015 go.opentelemetry.io/collector/semconv v0.80.0 - go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.44.0 + go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.45.0 go.opentelemetry.io/contrib/propagators/aws v1.5.0 go.opentelemetry.io/contrib/propagators/b3 v1.17.0 go.opentelemetry.io/contrib/propagators/jaeger v1.5.0 diff --git a/go.sum b/go.sum index 890bb6d9ef..05d7f9e558 100644 --- a/go.sum +++ b/go.sum @@ -1950,6 +1950,8 @@ go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.2 go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.28.0/go.mod h1:vEhqr0m4eTc+DWxfsXoXue2GBgV2uUwVznkGIHW/e5w= go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.44.0 h1:b8xjZxHbLrXAum4SxJd1Rlm7Y/fKaB+6ACI7/e5EfSA= go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.44.0/go.mod h1:1ei0a32xOGkFoySu7y1DAHfcuIhC0pNZpvY2huXuMy4= +go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.45.0 h1:RsQi0qJ2imFfCvZabqzM9cNXBG8k6gXMv1A0cXRmH6A= +go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.45.0/go.mod h1:vsh3ySueQCiKPxFLvjWC4Z135gIa34TQ/NSqkDTZYUM= go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.20.0/go.mod h1:2AboqHi0CiIZU0qwhtUfCYD1GeUzvvIXWNkhDt7ZMG4= go.opentelemetry.io/contrib/propagators/aws v1.5.0 h1:PC3EZMPaYYQH7aDmpcc59ZgEVkgqYmqNOdpF1sfLDrc= go.opentelemetry.io/contrib/propagators/aws v1.5.0/go.mod h1:8+Ak9s8R1Je6HOczkmNJc3FV66T9r15YVW0n1jNV+tg= From 068543b8a41d9ca246db95750f3c084fbe7097b2 Mon Sep 17 00:00:00 2001 From: Matheus Nogueira Date: Fri, 16 Feb 2024 14:12:18 -0300 Subject: [PATCH 2/5] make context propagation work for trigger request --- agent/client/connector.go | 6 ++ agent/client/mocks/grpc_server.go | 65 ++++++++++++------- .../workflow_listen_for_trigger_requests.go | 9 ++- agent/client/workflow_shutdown_test.go | 7 +- agent/workers/trigger_test.go | 46 +++++++++++++ server/telemetry/grpc.go | 62 ++++++++++++++++++ 6 files changed, 167 insertions(+), 28 deletions(-) create mode 100644 server/telemetry/grpc.go diff --git a/agent/client/connector.go b/agent/client/connector.go index 3a0a66faaa..b0447e70f3 100644 --- a/agent/client/connector.go +++ b/agent/client/connector.go @@ -8,6 +8,7 @@ import ( "time" "go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc" + "go.opentelemetry.io/otel/propagation" "go.uber.org/zap" "google.golang.org/grpc" "google.golang.org/grpc/credentials" @@ -75,6 +76,11 @@ func (c *Client) connect(ctx context.Context) error { grpc.WithDefaultServiceConfig(retryPolicy), grpc.WithIdleTimeout(0), // disable grpc idle timeout grpc.WithStatsHandler(otelgrpc.NewClientHandler()), + grpc.WithUnaryInterceptor(otelgrpc.UnaryClientInterceptor( + otelgrpc.WithPropagators( + propagation.NewCompositeTextMapPropagator(propagation.TraceContext{}, propagation.Baggage{}), + ), + )), ) if err != nil { return fmt.Errorf("could not connect to server: %w", err) diff --git a/agent/client/mocks/grpc_server.go b/agent/client/mocks/grpc_server.go index 59fa9dd4c7..467ace1dad 100644 --- a/agent/client/mocks/grpc_server.go +++ b/agent/client/mocks/grpc_server.go @@ -10,18 +10,20 @@ import ( "github.com/avast/retry-go" "github.com/kubeshop/tracetest/agent/client" "github.com/kubeshop/tracetest/agent/proto" + "github.com/kubeshop/tracetest/server/telemetry" "go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc" + "go.opentelemetry.io/otel/propagation" "google.golang.org/grpc" ) type GrpcServerMock struct { proto.UnimplementedOrchestratorServer port int - triggerChannel chan *proto.TriggerRequest - pollingChannel chan *proto.PollingRequest - otlpConnectionTestChannel chan *proto.OTLPConnectionTestRequest - terminationChannel chan *proto.ShutdownRequest - dataStoreTestChannel chan *proto.DataStoreConnectionTestRequest + triggerChannel chan Message[*proto.TriggerRequest] + pollingChannel chan Message[*proto.PollingRequest] + otlpConnectionTestChannel chan Message[*proto.OTLPConnectionTestRequest] + terminationChannel chan Message[*proto.ShutdownRequest] + dataStoreTestChannel chan Message[*proto.DataStoreConnectionTestRequest] lastTriggerResponse *proto.TriggerResponse lastPollingResponse *proto.PollingResponse @@ -31,13 +33,18 @@ type GrpcServerMock struct { server *grpc.Server } +type Message[T any] struct { + Context context.Context + Data T +} + func NewGrpcServer() *GrpcServerMock { server := &GrpcServerMock{ - triggerChannel: make(chan *proto.TriggerRequest), - pollingChannel: make(chan *proto.PollingRequest), - terminationChannel: make(chan *proto.ShutdownRequest), - dataStoreTestChannel: make(chan *proto.DataStoreConnectionTestRequest), - otlpConnectionTestChannel: make(chan *proto.OTLPConnectionTestRequest), + triggerChannel: make(chan Message[*proto.TriggerRequest]), + pollingChannel: make(chan Message[*proto.PollingRequest]), + terminationChannel: make(chan Message[*proto.ShutdownRequest]), + dataStoreTestChannel: make(chan Message[*proto.DataStoreConnectionTestRequest]), + otlpConnectionTestChannel: make(chan Message[*proto.OTLPConnectionTestRequest]), } var wg sync.WaitGroup wg.Add(1) @@ -66,7 +73,13 @@ func (s *GrpcServerMock) start(wg *sync.WaitGroup, port int) error { s.port = lis.Addr().(*net.TCPAddr).Port - server := grpc.NewServer(grpc.StatsHandler(otelgrpc.NewServerHandler())) + server := grpc.NewServer( + grpc.UnaryInterceptor(otelgrpc.UnaryServerInterceptor( + otelgrpc.WithPropagators( + propagation.NewCompositeTextMapPropagator(propagation.TraceContext{}, propagation.Baggage{}), + ), + )), + ) proto.RegisterOrchestratorServer(server, s) s.server = server @@ -108,7 +121,12 @@ func (s *GrpcServerMock) RegisterTriggerAgent(id *proto.AgentIdentification, str for { triggerRequest := <-s.triggerChannel - err := stream.Send(triggerRequest) + err := telemetry.InjectContextIntoStream(triggerRequest.Context, stream) + if err != nil { + log.Println(err.Error()) + } + + err = stream.Send(triggerRequest.Data) if err != nil { log.Println("could not send trigger request to agent: %w", err) } @@ -132,7 +150,7 @@ func (s *GrpcServerMock) RegisterPollerAgent(id *proto.AgentIdentification, stre for { pollerRequest := <-s.pollingChannel - err := stream.Send(pollerRequest) + err := stream.Send(pollerRequest.Data) if err != nil { log.Println("could not send polling request to agent: %w", err) } @@ -146,7 +164,7 @@ func (s *GrpcServerMock) RegisterDataStoreConnectionTestAgent(id *proto.AgentIde for { dsTestRequest := <-s.dataStoreTestChannel - err := stream.Send(dsTestRequest) + err := stream.Send(dsTestRequest.Data) if err != nil { log.Println("could not send polling request to agent: %w", err) } @@ -160,7 +178,7 @@ func (s *GrpcServerMock) RegisterOTLPConnectionTestListener(id *proto.AgentIdent for { testRequest := <-s.otlpConnectionTestChannel - err := stream.Send(testRequest) + err := stream.Send(testRequest.Data) if err != nil { log.Println("could not send polling request to agent: %w", err) } @@ -197,7 +215,7 @@ func (s *GrpcServerMock) SendPolledSpans(ctx context.Context, result *proto.Poll func (s *GrpcServerMock) RegisterShutdownListener(_ *proto.AgentIdentification, stream proto.Orchestrator_RegisterShutdownListenerServer) error { for { shutdownRequest := <-s.terminationChannel - err := stream.Send(shutdownRequest) + err := stream.Send(shutdownRequest.Data) if err != nil { log.Println("could not send polling request to agent: %w", err) } @@ -207,19 +225,19 @@ func (s *GrpcServerMock) RegisterShutdownListener(_ *proto.AgentIdentification, // Test methods func (s *GrpcServerMock) SendTriggerRequest(ctx context.Context, request *proto.TriggerRequest) { - s.triggerChannel <- request + s.triggerChannel <- Message[*proto.TriggerRequest]{Context: ctx, Data: request} } func (s *GrpcServerMock) SendPollingRequest(ctx context.Context, request *proto.PollingRequest) { - s.pollingChannel <- request + s.pollingChannel <- Message[*proto.PollingRequest]{Context: ctx, Data: request} } func (s *GrpcServerMock) SendDataStoreConnectionTestRequest(ctx context.Context, request *proto.DataStoreConnectionTestRequest) { - s.dataStoreTestChannel <- request + s.dataStoreTestChannel <- Message[*proto.DataStoreConnectionTestRequest]{Context: ctx, Data: request} } func (s *GrpcServerMock) SendOTLPConnectionTestRequest(ctx context.Context, request *proto.OTLPConnectionTestRequest) { - s.otlpConnectionTestChannel <- request + s.otlpConnectionTestChannel <- Message[*proto.OTLPConnectionTestRequest]{Context: ctx, Data: request} } func (s *GrpcServerMock) GetLastTriggerResponse() *proto.TriggerResponse { @@ -238,9 +256,10 @@ func (s *GrpcServerMock) GetLastDataStoreConnectionResponse() *proto.DataStoreCo return s.lastDataStoreConnectionResponse } -func (s *GrpcServerMock) TerminateConnection(reason string) { - s.terminationChannel <- &proto.ShutdownRequest{ - Reason: reason, +func (s *GrpcServerMock) TerminateConnection(ctx context.Context, reason string) { + s.terminationChannel <- Message[*proto.ShutdownRequest]{ + Context: ctx, + Data: &proto.ShutdownRequest{Reason: reason}, } } diff --git a/agent/client/workflow_listen_for_trigger_requests.go b/agent/client/workflow_listen_for_trigger_requests.go index ade503c227..2d609a11aa 100644 --- a/agent/client/workflow_listen_for_trigger_requests.go +++ b/agent/client/workflow_listen_for_trigger_requests.go @@ -7,6 +7,7 @@ import ( "time" "github.com/kubeshop/tracetest/agent/proto" + "github.com/kubeshop/tracetest/server/telemetry" ) func (c *Client) startTriggerListener(ctx context.Context) error { @@ -36,8 +37,12 @@ func (c *Client) startTriggerListener(ctx context.Context) error { continue } - // TODO: get context from request - err = c.triggerListener(context.Background(), &resp) + ctx, err := telemetry.ExtractContextFromStream(stream) + if err != nil { + log.Println("could not extract context from stream %w", err) + } + + err = c.triggerListener(ctx, &resp) if err != nil { fmt.Println(err.Error()) } diff --git a/agent/client/workflow_shutdown_test.go b/agent/client/workflow_shutdown_test.go index 60666efb95..26df95b3e1 100644 --- a/agent/client/workflow_shutdown_test.go +++ b/agent/client/workflow_shutdown_test.go @@ -13,10 +13,11 @@ import ( ) func TestShutdownFlow(t *testing.T) { + ctx := context.Background() server := mocks.NewGrpcServer() defer server.Stop() - client, err := client.Connect(context.Background(), server.Addr()) + client, err := client.Connect(ctx, server.Addr()) require.NoError(t, err) var called bool = false @@ -27,12 +28,12 @@ func TestShutdownFlow(t *testing.T) { return nil }) - err = client.Start(context.Background()) + err = client.Start(ctx) require.NoError(t, err) time.Sleep(1 * time.Second) - server.TerminateConnection("shutdown requested by user") + server.TerminateConnection(ctx, "shutdown requested by user") time.Sleep(1 * time.Second) assert.True(t, called, "client.OnConnectionClosed should have been called") diff --git a/agent/workers/trigger_test.go b/agent/workers/trigger_test.go index dc46aceed4..57253f0ac5 100644 --- a/agent/workers/trigger_test.go +++ b/agent/workers/trigger_test.go @@ -14,6 +14,8 @@ import ( "github.com/kubeshop/tracetest/agent/workers" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + tracesdk "go.opentelemetry.io/otel/sdk/trace" + "go.opentelemetry.io/otel/trace" ) func setupTriggerWorker(t *testing.T) (*mocks.GrpcServerMock, collector.TraceCache) { @@ -144,9 +146,53 @@ func TestTriggerInexistentAPI(t *testing.T) { assert.Contains(t, response.TriggerResult.Error.Message, "connection refused") } +func TestTriggerWorkerTracePropagation(t *testing.T) { + ctx, span := getTracer().Start(context.Background(), "root span") + defer span.End() + + controlPlane, cache := setupTriggerWorker(t) + + targetServer := createHelloWorldApi() + traceID := "42a2c381da1a5b3a32bc4988bf2431b0" + + triggerRequest := &proto.TriggerRequest{ + TestID: "my test", + RunID: 1, + TraceID: traceID, + Trigger: &proto.Trigger{ + Type: "http", + Http: &proto.HttpRequest{ + Method: "GET", + Url: targetServer.URL, + Headers: []*proto.HttpHeader{ + {Key: "Content-Type", Value: "application/json"}, + }, + }, + }, + } + + // make the control plane send a trigger request to the agent + controlPlane.SendTriggerRequest(ctx, triggerRequest) + time.Sleep(1 * time.Second) + + response := controlPlane.GetLastTriggerResponse() + + require.NotNil(t, response) + assert.Equal(t, "http", response.TriggerResult.Type) + assert.Equal(t, int32(http.StatusOK), response.TriggerResult.Http.StatusCode) + assert.JSONEq(t, `{"hello": "world"}`, string(response.TriggerResult.Http.Body)) + + _, traceIdIsWatched := cache.Get(traceID) + assert.True(t, traceIdIsWatched) +} + func createHelloWorldApi() *httptest.Server { return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Write([]byte(`{"hello": "world"}`)) w.WriteHeader(http.StatusOK) })) } + +func getTracer() trace.Tracer { + return tracesdk.NewTracerProvider().Tracer("asd") +} diff --git a/server/telemetry/grpc.go b/server/telemetry/grpc.go new file mode 100644 index 0000000000..d8fea665a8 --- /dev/null +++ b/server/telemetry/grpc.go @@ -0,0 +1,62 @@ +package telemetry + +import ( + "context" + "fmt" + + "go.opentelemetry.io/otel/propagation" + "google.golang.org/grpc" + "google.golang.org/grpc/metadata" +) + +var propagator = propagation.NewCompositeTextMapPropagator(propagation.TraceContext{}, propagation.Baggage{}) + +func InjectContextIntoStream(ctx context.Context, stream grpc.ServerStream) error { + header := make(metadata.MD) + propagator.Inject(ctx, &metadataSupplier{metadata: &header}) + + err := stream.SetHeader(header) + if err != nil { + return fmt.Errorf("could not set header: %w", err) + } + + return nil +} + +func ExtractContextFromStream(stream grpc.ClientStream) (context.Context, error) { + ctx := stream.Context() + header, err := stream.Header() + if err != nil { + return ctx, fmt.Errorf("coult not get header from stream: %w", err) + } + + ctx = propagator.Extract(ctx, &metadataSupplier{metadata: &header}) + return ctx, nil +} + +type metadataSupplier struct { + metadata *metadata.MD +} + +// assert that metadataSupplier implements the TextMapCarrier interface. +var _ propagation.TextMapCarrier = &metadataSupplier{} + +func (s *metadataSupplier) Get(key string) string { + values := s.metadata.Get(key) + if len(values) == 0 { + return "" + } + return values[0] +} + +func (s *metadataSupplier) Set(key string, value string) { + s.metadata.Set(key, value) +} + +func (s *metadataSupplier) Keys() []string { + out := make([]string, 0, len(*s.metadata)) + for key := range *s.metadata { + out = append(out, key) + } + return out +} From f6632bb72828941476d91a812ba7ea370f163878 Mon Sep 17 00:00:00 2001 From: Matheus Nogueira Date: Fri, 16 Feb 2024 14:16:08 -0300 Subject: [PATCH 3/5] return context in test method --- agent/client/mocks/grpc_server.go | 24 +++++++++---------- ...workflow_send_ds_connection_result_test.go | 2 +- ...rkflow_send_otlp_connection_result_test.go | 6 ++--- agent/client/workflow_send_trace_test.go | 24 +++++++++---------- .../workflow_send_trigger_response_test.go | 14 +++++------ agent/workers/poller_test.go | 14 +++++------ agent/workers/trigger_test.go | 20 ++++++++-------- 7 files changed, 52 insertions(+), 52 deletions(-) diff --git a/agent/client/mocks/grpc_server.go b/agent/client/mocks/grpc_server.go index 467ace1dad..3edd950f03 100644 --- a/agent/client/mocks/grpc_server.go +++ b/agent/client/mocks/grpc_server.go @@ -25,10 +25,10 @@ type GrpcServerMock struct { terminationChannel chan Message[*proto.ShutdownRequest] dataStoreTestChannel chan Message[*proto.DataStoreConnectionTestRequest] - lastTriggerResponse *proto.TriggerResponse - lastPollingResponse *proto.PollingResponse - lastOtlpConnectionResponse *proto.OTLPConnectionTestResponse - lastDataStoreConnectionResponse *proto.DataStoreConnectionTestResponse + lastTriggerResponse Message[*proto.TriggerResponse] + lastPollingResponse Message[*proto.PollingResponse] + lastOtlpConnectionResponse Message[*proto.OTLPConnectionTestResponse] + lastDataStoreConnectionResponse Message[*proto.DataStoreConnectionTestResponse] server *grpc.Server } @@ -139,7 +139,7 @@ func (s *GrpcServerMock) SendTriggerResult(ctx context.Context, result *proto.Tr return nil, fmt.Errorf("could not validate token") } - s.lastTriggerResponse = result + s.lastTriggerResponse = Message[*proto.TriggerResponse]{Data: result, Context: ctx} return &proto.Empty{}, nil } @@ -190,7 +190,7 @@ func (s *GrpcServerMock) SendOTLPConnectionTestResult(ctx context.Context, resul return nil, fmt.Errorf("could not validate token") } - s.lastOtlpConnectionResponse = result + s.lastOtlpConnectionResponse = Message[*proto.OTLPConnectionTestResponse]{Data: result, Context: ctx} return &proto.Empty{}, nil } @@ -199,7 +199,7 @@ func (s *GrpcServerMock) SendDataStoreConnectionTestResult(ctx context.Context, return nil, fmt.Errorf("could not validate token") } - s.lastDataStoreConnectionResponse = result + s.lastDataStoreConnectionResponse = Message[*proto.DataStoreConnectionTestResponse]{Data: result, Context: ctx} return &proto.Empty{}, nil } @@ -208,7 +208,7 @@ func (s *GrpcServerMock) SendPolledSpans(ctx context.Context, result *proto.Poll return nil, fmt.Errorf("could not validate token") } - s.lastPollingResponse = result + s.lastPollingResponse = Message[*proto.PollingResponse]{Data: result, Context: ctx} return &proto.Empty{}, nil } @@ -240,19 +240,19 @@ func (s *GrpcServerMock) SendOTLPConnectionTestRequest(ctx context.Context, requ s.otlpConnectionTestChannel <- Message[*proto.OTLPConnectionTestRequest]{Context: ctx, Data: request} } -func (s *GrpcServerMock) GetLastTriggerResponse() *proto.TriggerResponse { +func (s *GrpcServerMock) GetLastTriggerResponse() Message[*proto.TriggerResponse] { return s.lastTriggerResponse } -func (s *GrpcServerMock) GetLastPollingResponse() *proto.PollingResponse { +func (s *GrpcServerMock) GetLastPollingResponse() Message[*proto.PollingResponse] { return s.lastPollingResponse } -func (s *GrpcServerMock) GetLastOTLPConnectionResponse() *proto.OTLPConnectionTestResponse { +func (s *GrpcServerMock) GetLastOTLPConnectionResponse() Message[*proto.OTLPConnectionTestResponse] { return s.lastOtlpConnectionResponse } -func (s *GrpcServerMock) GetLastDataStoreConnectionResponse() *proto.DataStoreConnectionTestResponse { +func (s *GrpcServerMock) GetLastDataStoreConnectionResponse() Message[*proto.DataStoreConnectionTestResponse] { return s.lastDataStoreConnectionResponse } diff --git a/agent/client/workflow_send_ds_connection_result_test.go b/agent/client/workflow_send_ds_connection_result_test.go index 648a714c08..76573b1664 100644 --- a/agent/client/workflow_send_ds_connection_result_test.go +++ b/agent/client/workflow_send_ds_connection_result_test.go @@ -37,7 +37,7 @@ func TestDataStoreConnectionResult(t *testing.T) { receivedResponse := server.GetLastDataStoreConnectionResponse() - assert.Equal(t, result.RequestID, receivedResponse.RequestID) + assert.Equal(t, result.RequestID, receivedResponse.Data.RequestID) assert.True(t, result.Successful) assert.True(t, result.Steps.PortCheck.Passed) } diff --git a/agent/client/workflow_send_otlp_connection_result_test.go b/agent/client/workflow_send_otlp_connection_result_test.go index 9fdd739ddf..c533649f8f 100644 --- a/agent/client/workflow_send_otlp_connection_result_test.go +++ b/agent/client/workflow_send_otlp_connection_result_test.go @@ -35,7 +35,7 @@ func TestOTLPConnectionResultTrace(t *testing.T) { receivedResponse := server.GetLastOTLPConnectionResponse() - assert.Equal(t, result.RequestID, receivedResponse.RequestID) - assert.Equal(t, result.SpanCount, receivedResponse.SpanCount) - assert.Equal(t, result.LastSpanTimestamp, receivedResponse.LastSpanTimestamp) + assert.Equal(t, result.RequestID, receivedResponse.Data.RequestID) + assert.Equal(t, result.SpanCount, receivedResponse.Data.SpanCount) + assert.Equal(t, result.LastSpanTimestamp, receivedResponse.Data.LastSpanTimestamp) } diff --git a/agent/client/workflow_send_trace_test.go b/agent/client/workflow_send_trace_test.go index 9c396fffa8..4446564117 100644 --- a/agent/client/workflow_send_trace_test.go +++ b/agent/client/workflow_send_trace_test.go @@ -45,21 +45,21 @@ func TestSendTrace(t *testing.T) { receivedPollingResponse := server.GetLastPollingResponse() - assert.Equal(t, pollingRequest.TestID, receivedPollingResponse.TestID) - assert.Equal(t, pollingRequest.RunID, receivedPollingResponse.RunID) - assert.Equal(t, pollingRequest.TraceID, receivedPollingResponse.TraceID) + assert.Equal(t, pollingRequest.TestID, receivedPollingResponse.Data.TestID) + assert.Equal(t, pollingRequest.RunID, receivedPollingResponse.Data.RunID) + assert.Equal(t, pollingRequest.TraceID, receivedPollingResponse.Data.TraceID) - require.Len(t, receivedPollingResponse.Spans, len(pollingRequest.Spans)) + require.Len(t, receivedPollingResponse.Data.Spans, len(pollingRequest.Spans)) for i, span := range pollingRequest.Spans { - assert.Equal(t, span.Id, receivedPollingResponse.Spans[i].Id) - assert.Equal(t, span.Name, receivedPollingResponse.Spans[i].Name) - assert.Equal(t, span.Kind, receivedPollingResponse.Spans[i].Kind) - assert.Equal(t, span.ParentId, receivedPollingResponse.Spans[i].ParentId) - assert.Equal(t, span.StartTime, receivedPollingResponse.Spans[i].StartTime) - assert.Equal(t, span.EndTime, receivedPollingResponse.Spans[i].EndTime) + assert.Equal(t, span.Id, receivedPollingResponse.Data.Spans[i].Id) + assert.Equal(t, span.Name, receivedPollingResponse.Data.Spans[i].Name) + assert.Equal(t, span.Kind, receivedPollingResponse.Data.Spans[i].Kind) + assert.Equal(t, span.ParentId, receivedPollingResponse.Data.Spans[i].ParentId) + assert.Equal(t, span.StartTime, receivedPollingResponse.Data.Spans[i].StartTime) + assert.Equal(t, span.EndTime, receivedPollingResponse.Data.Spans[i].EndTime) for j := range span.Attributes { - assert.Equal(t, span.Attributes[i].Key, receivedPollingResponse.Spans[i].Attributes[j].Key) - assert.Equal(t, span.Attributes[i].Value, receivedPollingResponse.Spans[i].Attributes[j].Value) + assert.Equal(t, span.Attributes[i].Key, receivedPollingResponse.Data.Spans[i].Attributes[j].Key) + assert.Equal(t, span.Attributes[i].Value, receivedPollingResponse.Data.Spans[i].Attributes[j].Value) } } } diff --git a/agent/client/workflow_send_trigger_response_test.go b/agent/client/workflow_send_trigger_response_test.go index 4581ce529c..7dac798da3 100644 --- a/agent/client/workflow_send_trigger_response_test.go +++ b/agent/client/workflow_send_trigger_response_test.go @@ -40,11 +40,11 @@ func TestSendTriggerResult(t *testing.T) { receivedTriggerResponse := server.GetLastTriggerResponse() - assert.Equal(t, triggerResponse.TestID, receivedTriggerResponse.TestID) - assert.Equal(t, triggerResponse.RunID, receivedTriggerResponse.RunID) - assert.Equal(t, triggerResponse.TriggerResult.Type, receivedTriggerResponse.TriggerResult.Type) - assert.Equal(t, triggerResponse.TriggerResult.Http.StatusCode, receivedTriggerResponse.TriggerResult.Http.StatusCode) - assert.Equal(t, triggerResponse.TriggerResult.Http.Status, receivedTriggerResponse.TriggerResult.Http.Status) - assert.Equal(t, len(triggerResponse.TriggerResult.Http.Headers), len(receivedTriggerResponse.TriggerResult.Http.Headers)) - assert.Equal(t, triggerResponse.TriggerResult.Http.Body, receivedTriggerResponse.TriggerResult.Http.Body) + assert.Equal(t, triggerResponse.TestID, receivedTriggerResponse.Data.TestID) + assert.Equal(t, triggerResponse.RunID, receivedTriggerResponse.Data.RunID) + assert.Equal(t, triggerResponse.TriggerResult.Type, receivedTriggerResponse.Data.TriggerResult.Type) + assert.Equal(t, triggerResponse.TriggerResult.Http.StatusCode, receivedTriggerResponse.Data.TriggerResult.Http.StatusCode) + assert.Equal(t, triggerResponse.TriggerResult.Http.Status, receivedTriggerResponse.Data.TriggerResult.Http.Status) + assert.Equal(t, len(triggerResponse.TriggerResult.Http.Headers), len(receivedTriggerResponse.Data.TriggerResult.Http.Headers)) + assert.Equal(t, triggerResponse.TriggerResult.Http.Body, receivedTriggerResponse.Data.TriggerResult.Http.Body) } diff --git a/agent/workers/poller_test.go b/agent/workers/poller_test.go index 64ffacf54e..913934c347 100644 --- a/agent/workers/poller_test.go +++ b/agent/workers/poller_test.go @@ -63,7 +63,7 @@ func TestPollerWorker(t *testing.T) { // Very rudimentar sorting algorithm for only two items in the array // first item is always the root span, second is it's child var spans = make([]*proto.Span, 2) - for _, span := range pollingResponse.Spans { + for _, span := range pollingResponse.Data.Spans { if span.ParentId == "" { spans[0] = span } else { @@ -154,8 +154,8 @@ func TestPollerWorkerWithInmemoryDatastore(t *testing.T) { pollingResponse := controlPlane.GetLastPollingResponse() require.NotNil(t, pollingResponse, "agent did not send polling response back to server") - assert.False(t, pollingResponse.TraceFound) - assert.Len(t, pollingResponse.Spans, 0) + assert.False(t, pollingResponse.Data.TraceFound) + assert.Len(t, pollingResponse.Data.Spans, 0) span1ID := id.NewRandGenerator().SpanID() span2ID := id.NewRandGenerator().SpanID() @@ -173,8 +173,8 @@ func TestPollerWorkerWithInmemoryDatastore(t *testing.T) { pollingResponse = controlPlane.GetLastPollingResponse() require.NotNil(t, pollingResponse, "agent did not send polling response back to server") - assert.True(t, pollingResponse.TraceFound) - assert.Len(t, pollingResponse.Spans, 2) + assert.True(t, pollingResponse.Data.TraceFound) + assert.Len(t, pollingResponse.Data.Spans, 2) } func TestPollerWithInvalidDataStore(t *testing.T) { @@ -214,6 +214,6 @@ func TestPollerWithInvalidDataStore(t *testing.T) { pollingResponse := controlPlane.GetLastPollingResponse() require.NotNil(t, pollingResponse, "agent did not send polling response back to server") - require.NotNil(t, pollingResponse.Error) - assert.Contains(t, pollingResponse.Error.Message, "connection refused") + require.NotNil(t, pollingResponse.Data.Error) + assert.Contains(t, pollingResponse.Data.Error.Message, "connection refused") } diff --git a/agent/workers/trigger_test.go b/agent/workers/trigger_test.go index 57253f0ac5..cdec850205 100644 --- a/agent/workers/trigger_test.go +++ b/agent/workers/trigger_test.go @@ -72,9 +72,9 @@ func TestTrigger(t *testing.T) { response := controlPlane.GetLastTriggerResponse() require.NotNil(t, response) - assert.Equal(t, "http", response.TriggerResult.Type) - assert.Equal(t, int32(http.StatusOK), response.TriggerResult.Http.StatusCode) - assert.JSONEq(t, `{"hello": "world"}`, string(response.TriggerResult.Http.Body)) + assert.Equal(t, "http", response.Data.TriggerResult.Type) + assert.Equal(t, int32(http.StatusOK), response.Data.TriggerResult.Http.StatusCode) + assert.JSONEq(t, `{"hello": "world"}`, string(response.Data.TriggerResult.Http.Body)) _, traceIdIsWatched := cache.Get(traceID) assert.True(t, traceIdIsWatched) @@ -109,8 +109,8 @@ func TestTriggerAgainstGoogle(t *testing.T) { response := controlPlane.GetLastTriggerResponse() require.NotNil(t, response) - assert.Equal(t, "http", response.TriggerResult.Type) - assert.Equal(t, int32(http.StatusOK), response.TriggerResult.Http.StatusCode) + assert.Equal(t, "http", response.Data.TriggerResult.Type) + assert.Equal(t, int32(http.StatusOK), response.Data.TriggerResult.Http.StatusCode) } func TestTriggerInexistentAPI(t *testing.T) { @@ -142,8 +142,8 @@ func TestTriggerInexistentAPI(t *testing.T) { response := controlPlane.GetLastTriggerResponse() require.NotNil(t, response) - assert.NotNil(t, response.TriggerResult.Error) - assert.Contains(t, response.TriggerResult.Error.Message, "connection refused") + assert.NotNil(t, response.Data.TriggerResult.Error) + assert.Contains(t, response.Data.TriggerResult.Error.Message, "connection refused") } func TestTriggerWorkerTracePropagation(t *testing.T) { @@ -178,9 +178,9 @@ func TestTriggerWorkerTracePropagation(t *testing.T) { response := controlPlane.GetLastTriggerResponse() require.NotNil(t, response) - assert.Equal(t, "http", response.TriggerResult.Type) - assert.Equal(t, int32(http.StatusOK), response.TriggerResult.Http.StatusCode) - assert.JSONEq(t, `{"hello": "world"}`, string(response.TriggerResult.Http.Body)) + assert.Equal(t, "http", response.Data.TriggerResult.Type) + assert.Equal(t, int32(http.StatusOK), response.Data.TriggerResult.Http.StatusCode) + assert.JSONEq(t, `{"hello": "world"}`, string(response.Data.TriggerResult.Http.Body)) _, traceIdIsWatched := cache.Get(traceID) assert.True(t, traceIdIsWatched) From dc64eb5e2aecc55e42f107feac3127dc324e647b Mon Sep 17 00:00:00 2001 From: Matheus Nogueira Date: Fri, 16 Feb 2024 14:22:39 -0300 Subject: [PATCH 4/5] test that trigger request/response share the same traceID --- agent/workers/telemetry_test.go | 27 +++++++++++++++++++++++++++ agent/workers/trigger_test.go | 7 ++++--- 2 files changed, 31 insertions(+), 3 deletions(-) create mode 100644 agent/workers/telemetry_test.go diff --git a/agent/workers/telemetry_test.go b/agent/workers/telemetry_test.go new file mode 100644 index 0000000000..74d9821992 --- /dev/null +++ b/agent/workers/telemetry_test.go @@ -0,0 +1,27 @@ +package workers_test + +import ( + "context" + + "go.opentelemetry.io/otel/propagation" + "go.opentelemetry.io/otel/sdk/trace" +) + +var propagator = propagation.NewCompositeTextMapPropagator(propagation.TraceContext{}, propagation.Baggage{}) + +func ContextWithTracingEnabled() context.Context { + ctx, span := trace.NewTracerProvider().Tracer("tracer").Start(context.Background(), "root span") + defer span.End() + + return ctx +} + +func SameTraceID(ctx1, ctx2 context.Context) bool { + header1 := make(propagation.HeaderCarrier) + header2 := make(propagation.HeaderCarrier) + + propagator.Inject(ctx1, header1) + propagator.Inject(ctx2, header2) + + return header1.Get("traceparent") == header2.Get("traceparent") +} diff --git a/agent/workers/trigger_test.go b/agent/workers/trigger_test.go index cdec850205..a9c74c1e3b 100644 --- a/agent/workers/trigger_test.go +++ b/agent/workers/trigger_test.go @@ -147,9 +147,6 @@ func TestTriggerInexistentAPI(t *testing.T) { } func TestTriggerWorkerTracePropagation(t *testing.T) { - ctx, span := getTracer().Start(context.Background(), "root span") - defer span.End() - controlPlane, cache := setupTriggerWorker(t) targetServer := createHelloWorldApi() @@ -171,6 +168,8 @@ func TestTriggerWorkerTracePropagation(t *testing.T) { }, } + ctx := ContextWithTracingEnabled() + // make the control plane send a trigger request to the agent controlPlane.SendTriggerRequest(ctx, triggerRequest) time.Sleep(1 * time.Second) @@ -184,6 +183,8 @@ func TestTriggerWorkerTracePropagation(t *testing.T) { _, traceIdIsWatched := cache.Get(traceID) assert.True(t, traceIdIsWatched) + + assert.True(t, SameTraceID(ctx, response.Context)) } func createHelloWorldApi() *httptest.Server { From 84389a40c79300dd49869fe0327934032dca2520 Mon Sep 17 00:00:00 2001 From: Matheus Nogueira Date: Fri, 16 Feb 2024 14:38:10 -0300 Subject: [PATCH 5/5] feat: add trace propagation to the rest of the workers --- agent/client/mocks/grpc_server.go | 28 ++++++++++-- ...workflow_listen_for_ds_connection_tests.go | 9 +++- ...rkflow_listen_for_otlp_connection_tests.go | 9 +++- .../workflow_listen_for_poll_requests.go | 9 +++- .../workflow_listen_for_stop_requests.go | 1 + agent/workers/poller_test.go | 4 +- agent/workers/trigger_test.go | 45 ++----------------- 7 files changed, 52 insertions(+), 53 deletions(-) diff --git a/agent/client/mocks/grpc_server.go b/agent/client/mocks/grpc_server.go index 3edd950f03..c927f39d85 100644 --- a/agent/client/mocks/grpc_server.go +++ b/agent/client/mocks/grpc_server.go @@ -150,7 +150,12 @@ func (s *GrpcServerMock) RegisterPollerAgent(id *proto.AgentIdentification, stre for { pollerRequest := <-s.pollingChannel - err := stream.Send(pollerRequest.Data) + err := telemetry.InjectContextIntoStream(pollerRequest.Context, stream) + if err != nil { + log.Println(err.Error()) + } + + err = stream.Send(pollerRequest.Data) if err != nil { log.Println("could not send polling request to agent: %w", err) } @@ -164,7 +169,12 @@ func (s *GrpcServerMock) RegisterDataStoreConnectionTestAgent(id *proto.AgentIde for { dsTestRequest := <-s.dataStoreTestChannel - err := stream.Send(dsTestRequest.Data) + err := telemetry.InjectContextIntoStream(dsTestRequest.Context, stream) + if err != nil { + log.Println(err.Error()) + } + + err = stream.Send(dsTestRequest.Data) if err != nil { log.Println("could not send polling request to agent: %w", err) } @@ -178,7 +188,12 @@ func (s *GrpcServerMock) RegisterOTLPConnectionTestListener(id *proto.AgentIdent for { testRequest := <-s.otlpConnectionTestChannel - err := stream.Send(testRequest.Data) + err := telemetry.InjectContextIntoStream(testRequest.Context, stream) + if err != nil { + log.Println(err.Error()) + } + + err = stream.Send(testRequest.Data) if err != nil { log.Println("could not send polling request to agent: %w", err) } @@ -215,7 +230,12 @@ func (s *GrpcServerMock) SendPolledSpans(ctx context.Context, result *proto.Poll func (s *GrpcServerMock) RegisterShutdownListener(_ *proto.AgentIdentification, stream proto.Orchestrator_RegisterShutdownListenerServer) error { for { shutdownRequest := <-s.terminationChannel - err := stream.Send(shutdownRequest.Data) + err := telemetry.InjectContextIntoStream(shutdownRequest.Context, stream) + if err != nil { + log.Println(err.Error()) + } + + err = stream.Send(shutdownRequest.Data) if err != nil { log.Println("could not send polling request to agent: %w", err) } diff --git a/agent/client/workflow_listen_for_ds_connection_tests.go b/agent/client/workflow_listen_for_ds_connection_tests.go index d7d91545a0..992d58c702 100644 --- a/agent/client/workflow_listen_for_ds_connection_tests.go +++ b/agent/client/workflow_listen_for_ds_connection_tests.go @@ -7,6 +7,7 @@ import ( "time" "github.com/kubeshop/tracetest/agent/proto" + "github.com/kubeshop/tracetest/server/telemetry" ) func (c *Client) startDataStoreConnectionTestListener(ctx context.Context) error { @@ -36,8 +37,12 @@ func (c *Client) startDataStoreConnectionTestListener(ctx context.Context) error continue } - // TODO: Get ctx from request - err = c.dataStoreConnectionListener(context.Background(), &req) + ctx, err := telemetry.ExtractContextFromStream(stream) + if err != nil { + log.Println("could not extract context from stream %w", err) + } + + err = c.dataStoreConnectionListener(ctx, &req) if err != nil { fmt.Println(err.Error()) } diff --git a/agent/client/workflow_listen_for_otlp_connection_tests.go b/agent/client/workflow_listen_for_otlp_connection_tests.go index 04b534a5db..18c9cce077 100644 --- a/agent/client/workflow_listen_for_otlp_connection_tests.go +++ b/agent/client/workflow_listen_for_otlp_connection_tests.go @@ -7,6 +7,7 @@ import ( "time" "github.com/kubeshop/tracetest/agent/proto" + "github.com/kubeshop/tracetest/server/telemetry" ) func (c *Client) startOTLPConnectionTestListener(ctx context.Context) error { @@ -36,8 +37,12 @@ func (c *Client) startOTLPConnectionTestListener(ctx context.Context) error { continue } - // TODO: Get ctx from request - err = c.otlpConnectionTestListener(context.Background(), &req) + ctx, err := telemetry.ExtractContextFromStream(stream) + if err != nil { + log.Println("could not extract context from stream %w", err) + } + + err = c.otlpConnectionTestListener(ctx, &req) if err != nil { fmt.Println(err.Error()) } diff --git a/agent/client/workflow_listen_for_poll_requests.go b/agent/client/workflow_listen_for_poll_requests.go index dee9eae065..1e308a40f4 100644 --- a/agent/client/workflow_listen_for_poll_requests.go +++ b/agent/client/workflow_listen_for_poll_requests.go @@ -7,6 +7,7 @@ import ( "time" "github.com/kubeshop/tracetest/agent/proto" + "github.com/kubeshop/tracetest/server/telemetry" ) func (c *Client) startPollerListener(ctx context.Context) error { @@ -36,8 +37,12 @@ func (c *Client) startPollerListener(ctx context.Context) error { continue } - // TODO: Get ctx from request - err = c.pollListener(context.Background(), &resp) + ctx, err := telemetry.ExtractContextFromStream(stream) + if err != nil { + log.Println("could not extract context from stream %w", err) + } + + err = c.pollListener(ctx, &resp) if err != nil { fmt.Println(err.Error()) } diff --git a/agent/client/workflow_listen_for_stop_requests.go b/agent/client/workflow_listen_for_stop_requests.go index 86b74f4782..77928821b1 100644 --- a/agent/client/workflow_listen_for_stop_requests.go +++ b/agent/client/workflow_listen_for_stop_requests.go @@ -9,6 +9,7 @@ import ( "github.com/kubeshop/tracetest/agent/proto" ) +// TODO: fix this and add test func (c *Client) startStopListener(ctx context.Context) error { client := proto.NewOrchestratorClient(c.conn) diff --git a/agent/workers/poller_test.go b/agent/workers/poller_test.go index 913934c347..1eb47d0f48 100644 --- a/agent/workers/poller_test.go +++ b/agent/workers/poller_test.go @@ -20,7 +20,7 @@ import ( ) func TestPollerWorker(t *testing.T) { - ctx := context.Background() + ctx := ContextWithTracingEnabled() controlPlane := mocks.NewGrpcServer() client, err := client.Connect(ctx, controlPlane.Addr()) @@ -74,6 +74,8 @@ func TestPollerWorker(t *testing.T) { assert.Len(t, spans, 2) assert.Equal(t, "", spans[0].ParentId) assert.Equal(t, spans[0].Id, spans[1].ParentId) + + assert.True(t, SameTraceID(ctx, pollingResponse.Context)) } func createTempoFakeApi() *httptest.Server { diff --git a/agent/workers/trigger_test.go b/agent/workers/trigger_test.go index a9c74c1e3b..bddd1431d6 100644 --- a/agent/workers/trigger_test.go +++ b/agent/workers/trigger_test.go @@ -43,7 +43,6 @@ func setupTriggerWorker(t *testing.T) (*mocks.GrpcServerMock, collector.TraceCac } func TestTrigger(t *testing.T) { - ctx := context.Background() controlPlane, cache := setupTriggerWorker(t) targetServer := createHelloWorldApi() @@ -65,6 +64,8 @@ func TestTrigger(t *testing.T) { }, } + ctx := ContextWithTracingEnabled() + // make the control plane send a trigger request to the agent controlPlane.SendTriggerRequest(ctx, triggerRequest) time.Sleep(1 * time.Second) @@ -78,6 +79,7 @@ func TestTrigger(t *testing.T) { _, traceIdIsWatched := cache.Get(traceID) assert.True(t, traceIdIsWatched) + assert.True(t, SameTraceID(ctx, response.Context)) } func TestTriggerAgainstGoogle(t *testing.T) { @@ -146,47 +148,6 @@ func TestTriggerInexistentAPI(t *testing.T) { assert.Contains(t, response.Data.TriggerResult.Error.Message, "connection refused") } -func TestTriggerWorkerTracePropagation(t *testing.T) { - controlPlane, cache := setupTriggerWorker(t) - - targetServer := createHelloWorldApi() - traceID := "42a2c381da1a5b3a32bc4988bf2431b0" - - triggerRequest := &proto.TriggerRequest{ - TestID: "my test", - RunID: 1, - TraceID: traceID, - Trigger: &proto.Trigger{ - Type: "http", - Http: &proto.HttpRequest{ - Method: "GET", - Url: targetServer.URL, - Headers: []*proto.HttpHeader{ - {Key: "Content-Type", Value: "application/json"}, - }, - }, - }, - } - - ctx := ContextWithTracingEnabled() - - // make the control plane send a trigger request to the agent - controlPlane.SendTriggerRequest(ctx, triggerRequest) - time.Sleep(1 * time.Second) - - response := controlPlane.GetLastTriggerResponse() - - require.NotNil(t, response) - assert.Equal(t, "http", response.Data.TriggerResult.Type) - assert.Equal(t, int32(http.StatusOK), response.Data.TriggerResult.Http.StatusCode) - assert.JSONEq(t, `{"hello": "world"}`, string(response.Data.TriggerResult.Http.Body)) - - _, traceIdIsWatched := cache.Get(traceID) - assert.True(t, traceIdIsWatched) - - assert.True(t, SameTraceID(ctx, response.Context)) -} - func createHelloWorldApi() *httptest.Server { return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Write([]byte(`{"hello": "world"}`))