From 95365c53a72ae46f73bd195b38188d9346f4b782 Mon Sep 17 00:00:00 2001 From: Kristen Kozak Date: Wed, 1 Nov 2017 17:24:13 -0700 Subject: [PATCH] Refactor the code for overriding the Census implementation in tests. This commit adds a method to set the CensusStatsModule on AbstractServerImplBuilder and AbstractManagedChannelImplBuilder. It is simpler to set the whole CensusStatsModule than pass in all necessary OpenCensus implementation objects. --- .../AbstractManagedChannelImplBuilder.java | 41 ++-------- .../internal/AbstractServerImplBuilder.java | 43 ++--------- .../io/grpc/internal/CensusStatsModule.java | 75 ++++++++++++++----- ...AbstractManagedChannelImplBuilderTest.java | 16 +++- .../AbstractServerImplBuilderTest.java | 8 +- .../io/grpc/internal/CensusModulesTest.java | 19 +++-- .../io/grpc/internal/TestingAccessor.java | 8 +- 7 files changed, 110 insertions(+), 100 deletions(-) diff --git a/core/src/main/java/io/grpc/internal/AbstractManagedChannelImplBuilder.java b/core/src/main/java/io/grpc/internal/AbstractManagedChannelImplBuilder.java index 0652cf402682..91851fd2ac9c 100644 --- a/core/src/main/java/io/grpc/internal/AbstractManagedChannelImplBuilder.java +++ b/core/src/main/java/io/grpc/internal/AbstractManagedChannelImplBuilder.java @@ -32,11 +32,6 @@ import io.grpc.NameResolver; import io.grpc.NameResolverProvider; import io.grpc.PickFirstBalancerFactory; -import io.opencensus.stats.Stats; -import io.opencensus.stats.StatsRecorder; -import io.opencensus.tags.Tagger; -import io.opencensus.tags.Tags; -import io.opencensus.tags.propagation.TagContextBinarySerializer; import io.opencensus.trace.Tracing; import java.net.SocketAddress; import java.net.URI; @@ -155,13 +150,7 @@ protected final int maxInboundMessageSize() { private boolean tracingEnabled = true; @Nullable - private Tagger tagger; - - @Nullable - private TagContextBinarySerializer tagCtxSerializer; - - @Nullable - private StatsRecorder statsRecorder; + private CensusStatsModule censusStatsOverride; protected AbstractManagedChannelImplBuilder(String target) { this.target = Preconditions.checkNotNull(target, "target"); @@ -296,11 +285,8 @@ public final T idleTimeout(long value, TimeUnit unit) { * Override the default stats implementation. */ @VisibleForTesting - protected final T statsImplementation( - Tagger tagger, TagContextBinarySerializer tagCtxSerializer, StatsRecorder statsRecorder) { - this.tagger = tagger; - this.tagCtxSerializer = tagCtxSerializer; - this.statsRecorder = statsRecorder; + protected final T overrideCensusStatsModule(CensusStatsModule censusStats) { + this.censusStatsOverride = censusStats; return thisT(); } @@ -358,24 +344,13 @@ final List getEffectiveInterceptors() { List effectiveInterceptors = new ArrayList(this.interceptors); if (statsEnabled) { - Tagger tagger = this.tagger != null ? this.tagger : Tags.getTagger(); - TagContextBinarySerializer tagCtxSerializer = - this.tagCtxSerializer != null - ? this.tagCtxSerializer - : Tags.getTagPropagationComponent().getBinarySerializer(); - StatsRecorder statsRecorder = - this.statsRecorder != null ? this.statsRecorder : Stats.getStatsRecorder(); - CensusStatsModule censusStats = - new CensusStatsModule( - tagger, - tagCtxSerializer, - statsRecorder, - GrpcUtil.STOPWATCH_SUPPLIER, - true, - recordStats); + CensusStatsModule censusStats = this.censusStatsOverride; + if (censusStats == null) { + censusStats = new CensusStatsModule(GrpcUtil.STOPWATCH_SUPPLIER, true); + } // First interceptor runs last (see ClientInterceptors.intercept()), so that no // other interceptor can override the tracer factory we set in CallOptions. - effectiveInterceptors.add(0, censusStats.getClientInterceptor()); + effectiveInterceptors.add(0, censusStats.getClientInterceptor(recordStats)); } if (tracingEnabled) { CensusTracingModule censusTracing = diff --git a/core/src/main/java/io/grpc/internal/AbstractServerImplBuilder.java b/core/src/main/java/io/grpc/internal/AbstractServerImplBuilder.java index 2d6b00c1cc45..29768d150379 100644 --- a/core/src/main/java/io/grpc/internal/AbstractServerImplBuilder.java +++ b/core/src/main/java/io/grpc/internal/AbstractServerImplBuilder.java @@ -34,11 +34,6 @@ import io.grpc.ServerServiceDefinition; import io.grpc.ServerStreamTracer; import io.grpc.ServerTransportFilter; -import io.opencensus.stats.Stats; -import io.opencensus.stats.StatsRecorder; -import io.opencensus.tags.Tagger; -import io.opencensus.tags.Tags; -import io.opencensus.tags.propagation.TagContextBinarySerializer; import io.opencensus.trace.Tracing; import java.util.ArrayList; import java.util.Collections; @@ -100,13 +95,7 @@ public List getServices() { CompressorRegistry compressorRegistry = DEFAULT_COMPRESSOR_REGISTRY; @Nullable - private Tagger tagger; - - @Nullable - private TagContextBinarySerializer tagCtxSerializer; - - @Nullable - private StatsRecorder statsRecorder; + private CensusStatsModule censusStatsOverride; private boolean statsEnabled = true; private boolean recordStats = true; @@ -193,13 +182,8 @@ public final T compressorRegistry(CompressorRegistry registry) { * Override the default stats implementation. */ @VisibleForTesting - protected T statsImplementation( - final Tagger tagger, - TagContextBinarySerializer tagCtxSerializer, - StatsRecorder statsRecorder) { - this.tagger = tagger; - this.tagCtxSerializer = tagCtxSerializer; - this.statsRecorder = statsRecorder; + protected T overrideCensusStatsModule(CensusStatsModule censusStats) { + this.censusStatsOverride = censusStats; return thisT(); } @@ -242,22 +226,11 @@ final List getTracerFactories() { ArrayList tracerFactories = new ArrayList(); if (statsEnabled) { - Tagger tagger = this.tagger != null ? this.tagger : Tags.getTagger(); - TagContextBinarySerializer tagCtxSerializer = - this.tagCtxSerializer != null - ? this.tagCtxSerializer - : Tags.getTagPropagationComponent().getBinarySerializer(); - StatsRecorder statsRecorder = - this.statsRecorder != null ? this.statsRecorder : Stats.getStatsRecorder(); - CensusStatsModule censusStats = - new CensusStatsModule( - tagger, - tagCtxSerializer, - statsRecorder, - GrpcUtil.STOPWATCH_SUPPLIER, - true, - recordStats); - tracerFactories.add(censusStats.getServerTracerFactory()); + CensusStatsModule censusStats = this.censusStatsOverride; + if (censusStats == null) { + censusStats = new CensusStatsModule(GrpcUtil.STOPWATCH_SUPPLIER, true); + } + tracerFactories.add(censusStats.getServerTracerFactory(recordStats)); } if (tracingEnabled) { CensusTracingModule censusTracing = diff --git a/core/src/main/java/io/grpc/internal/CensusStatsModule.java b/core/src/main/java/io/grpc/internal/CensusStatsModule.java index acdaf458a982..781722004cd9 100644 --- a/core/src/main/java/io/grpc/internal/CensusStatsModule.java +++ b/core/src/main/java/io/grpc/internal/CensusStatsModule.java @@ -39,10 +39,12 @@ import io.grpc.StreamTracer; import io.opencensus.contrib.grpc.metrics.RpcMeasureConstants; import io.opencensus.stats.MeasureMap; +import io.opencensus.stats.Stats; import io.opencensus.stats.StatsRecorder; import io.opencensus.tags.TagContext; import io.opencensus.tags.TagValue; import io.opencensus.tags.Tagger; +import io.opencensus.tags.Tags; import io.opencensus.tags.propagation.TagContextBinarySerializer; import io.opencensus.tags.propagation.TagContextSerializationException; import java.util.concurrent.TimeUnit; @@ -74,21 +76,33 @@ final class CensusStatsModule { private final Supplier stopwatchSupplier; @VisibleForTesting final Metadata.Key statsHeader; - private final StatsClientInterceptor clientInterceptor = new StatsClientInterceptor(); - private final ServerTracerFactory serverTracerFactory = new ServerTracerFactory(); private final boolean propagateTags; - private final boolean recordStats; + /** + * Creates a {@link CensusStatsModule} with the default OpenCensus implementation. + */ + CensusStatsModule(Supplier stopwatchSupplier, boolean propagateTags) { + this( + Tags.getTagger(), + Tags.getTagPropagationComponent().getBinarySerializer(), + Stats.getStatsRecorder(), + stopwatchSupplier, + propagateTags); + } + + /** + * Creates a {@link CensusStatsModule} with the given OpenCensus implementation. + */ CensusStatsModule( final Tagger tagger, final TagContextBinarySerializer tagCtxSerializer, StatsRecorder statsRecorder, Supplier stopwatchSupplier, - boolean propagateTags, boolean recordStats) { + boolean propagateTags) { this.tagger = checkNotNull(tagger, "tagger"); this.statsRecorder = checkNotNull(statsRecorder, "statsRecorder"); + checkNotNull(tagCtxSerializer, "tagCtxSerializer"); this.stopwatchSupplier = checkNotNull(stopwatchSupplier, "stopwatchSupplier"); this.propagateTags = propagateTags; - this.recordStats = recordStats; this.statsHeader = Metadata.Key.of("grpc-tags-bin", new Metadata.BinaryMarshaller() { @Override @@ -118,22 +132,23 @@ public TagContext parseBytes(byte[] serialized) { * Creates a {@link ClientCallTracer} for a new call. */ @VisibleForTesting - ClientCallTracer newClientCallTracer(TagContext parentCtx, String fullMethodName) { - return new ClientCallTracer(this, parentCtx, fullMethodName); + ClientCallTracer newClientCallTracer( + TagContext parentCtx, String fullMethodName, boolean recordStats) { + return new ClientCallTracer(this, parentCtx, fullMethodName, recordStats); } /** * Returns the server tracer factory. */ - ServerStreamTracer.Factory getServerTracerFactory() { - return serverTracerFactory; + ServerStreamTracer.Factory getServerTracerFactory(boolean recordStats) { + return new ServerTracerFactory(recordStats); } /** * Returns the client interceptor that facilitates Census-based stats reporting. */ - ClientInterceptor getClientInterceptor() { - return clientInterceptor; + ClientInterceptor getClientInterceptor(boolean recordStats) { + return new StatsClientInterceptor(recordStats); } private static final class ClientTracer extends ClientStreamTracer { @@ -206,12 +221,18 @@ static final class ClientCallTracer extends ClientStreamTracer.Factory { private volatile ClientTracer streamTracer; private volatile int callEnded; private final TagContext parentCtx; + private final boolean recordStats; - ClientCallTracer(CensusStatsModule module, TagContext parentCtx, String fullMethodName) { + ClientCallTracer( + CensusStatsModule module, + TagContext parentCtx, + String fullMethodName, + boolean recordStats) { this.module = module; this.parentCtx = checkNotNull(parentCtx, "parentCtx"); this.fullMethodName = checkNotNull(fullMethodName, "fullMethodName"); this.stopwatch = module.stopwatchSupplier.get().start(); + this.recordStats = recordStats; } @Override @@ -241,7 +262,7 @@ void callEnded(Status status) { if (callEndedUpdater.getAndSet(this, 1) != 0) { return; } - if (!module.recordStats) { + if (!recordStats) { return; } stopwatch.stop(); @@ -299,6 +320,7 @@ private static final class ServerTracer extends ServerStreamTracer { private volatile int streamClosed; private final Stopwatch stopwatch; private final Tagger tagger; + private final boolean recordStats; private volatile long outboundMessageCount; private volatile long inboundMessageCount; private volatile long outboundWireSize; @@ -311,12 +333,14 @@ private static final class ServerTracer extends ServerStreamTracer { String fullMethodName, TagContext parentCtx, Supplier stopwatchSupplier, - Tagger tagger) { + Tagger tagger, + boolean recordStats) { this.module = module; this.fullMethodName = checkNotNull(fullMethodName, "fullMethodName"); this.parentCtx = checkNotNull(parentCtx, "parentCtx"); this.stopwatch = stopwatchSupplier.get().start(); this.tagger = tagger; + this.recordStats = recordStats; } @Override @@ -360,7 +384,7 @@ public void streamClosed(Status status) { if (streamClosedUpdater.getAndSet(this, 1) != 0) { return; } - if (!module.recordStats) { + if (!recordStats) { return; } stopwatch.stop(); @@ -397,6 +421,12 @@ public Context filterContext(Context context) { @VisibleForTesting final class ServerTracerFactory extends ServerStreamTracer.Factory { + private final boolean recordStats; + + ServerTracerFactory(boolean recordStats) { + this.recordStats = recordStats; + } + @Override public ServerStreamTracer newServerStreamTracer(String fullMethodName, Metadata headers) { TagContext parentCtx = headers.get(statsHeader); @@ -409,19 +439,30 @@ public ServerStreamTracer newServerStreamTracer(String fullMethodName, Metadata .put(RpcMeasureConstants.RPC_METHOD, TagValue.create(fullMethodName)) .build(); return new ServerTracer( - CensusStatsModule.this, fullMethodName, parentCtx, stopwatchSupplier, tagger); + CensusStatsModule.this, + fullMethodName, + parentCtx, + stopwatchSupplier, + tagger, + recordStats); } } @VisibleForTesting final class StatsClientInterceptor implements ClientInterceptor { + private final boolean recordStats; + + StatsClientInterceptor(boolean recordStats) { + this.recordStats = recordStats; + } + @Override public ClientCall interceptCall( MethodDescriptor method, CallOptions callOptions, Channel next) { // New RPCs on client-side inherit the tag context from the current Context. TagContext parentCtx = tagger.getCurrentTagContext(); final ClientCallTracer tracerFactory = - newClientCallTracer(parentCtx, method.getFullMethodName()); + newClientCallTracer(parentCtx, method.getFullMethodName(), recordStats); ClientCall call = next.newCall(method, callOptions.withStreamTracerFactory(tracerFactory)); return new SimpleForwardingClientCall(call) { diff --git a/core/src/test/java/io/grpc/internal/AbstractManagedChannelImplBuilderTest.java b/core/src/test/java/io/grpc/internal/AbstractManagedChannelImplBuilderTest.java index c9ec9110f183..58b367581c24 100644 --- a/core/src/test/java/io/grpc/internal/AbstractManagedChannelImplBuilderTest.java +++ b/core/src/test/java/io/grpc/internal/AbstractManagedChannelImplBuilderTest.java @@ -393,12 +393,24 @@ public void idleTimeout() { static class Builder extends AbstractManagedChannelImplBuilder { Builder(String target) { super(target); - statsImplementation(DUMMY_TAGGER, DUMMY_TAG_CONTEXT_BINARY_SERIALIZER, DUMMY_STATS_RECORDER); + overrideCensusStatsModule( + new CensusStatsModule( + DUMMY_TAGGER, + DUMMY_TAG_CONTEXT_BINARY_SERIALIZER, + DUMMY_STATS_RECORDER, + GrpcUtil.STOPWATCH_SUPPLIER, + true)); } Builder(SocketAddress directServerAddress, String authority) { super(directServerAddress, authority); - statsImplementation(DUMMY_TAGGER, DUMMY_TAG_CONTEXT_BINARY_SERIALIZER, DUMMY_STATS_RECORDER); + overrideCensusStatsModule( + new CensusStatsModule( + DUMMY_TAGGER, + DUMMY_TAG_CONTEXT_BINARY_SERIALIZER, + DUMMY_STATS_RECORDER, + GrpcUtil.STOPWATCH_SUPPLIER, + true)); } @Override diff --git a/core/src/test/java/io/grpc/internal/AbstractServerImplBuilderTest.java b/core/src/test/java/io/grpc/internal/AbstractServerImplBuilderTest.java index d03d0302c498..a331ccf19865 100644 --- a/core/src/test/java/io/grpc/internal/AbstractServerImplBuilderTest.java +++ b/core/src/test/java/io/grpc/internal/AbstractServerImplBuilderTest.java @@ -145,7 +145,13 @@ public void getTracerFactories_disableBoth() { static class Builder extends AbstractServerImplBuilder { Builder() { - statsImplementation(DUMMY_TAGGER, DUMMY_TAG_CONTEXT_BINARY_SERIALIZER, DUMMY_STATS_RECORDER); + overrideCensusStatsModule( + new CensusStatsModule( + DUMMY_TAGGER, + DUMMY_TAG_CONTEXT_BINARY_SERIALIZER, + DUMMY_STATS_RECORDER, + GrpcUtil.STOPWATCH_SUPPLIER, + true)); } @Override diff --git a/core/src/test/java/io/grpc/internal/CensusModulesTest.java b/core/src/test/java/io/grpc/internal/CensusModulesTest.java index 89c7896b78ff..d77b268ce35b 100644 --- a/core/src/test/java/io/grpc/internal/CensusModulesTest.java +++ b/core/src/test/java/io/grpc/internal/CensusModulesTest.java @@ -191,7 +191,7 @@ public void setUp() throws Exception { .thenReturn(fakeClientSpanContext); censusStats = new CensusStatsModule( - tagger, tagCtxSerializer, statsRecorder, fakeClock.getStopwatchSupplier(), true, true); + tagger, tagCtxSerializer, statsRecorder, fakeClock.getStopwatchSupplier(), true); censusTracing = new CensusTracingModule(tracer, mockTracingPropagationHandler); } @@ -240,7 +240,7 @@ public ClientCall interceptCall( Channel interceptedChannel = ClientInterceptors.intercept( grpcServerRule.getChannel(), callOptionsCaptureInterceptor, - censusStats.getClientInterceptor(), censusTracing.getClientInterceptor()); + censusStats.getClientInterceptor(true), censusTracing.getClientInterceptor()); ClientCall call; if (nonDefaultContext) { Context ctx = @@ -322,7 +322,7 @@ public ClientCall interceptCall( @Test public void clientBasicStatsDefaultContext() { CensusStatsModule.ClientCallTracer callTracer = - censusStats.newClientCallTracer(tagger.empty(), method.getFullMethodName()); + censusStats.newClientCallTracer(tagger.empty(), method.getFullMethodName(), true); Metadata headers = new Metadata(); ClientStreamTracer tracer = callTracer.newClientStreamTracer(CallOptions.DEFAULT, headers); @@ -433,7 +433,7 @@ public void clientTracingSampledToLocalSpanStore() { public void clientStreamNeverCreatedStillRecordStats() { CensusStatsModule.ClientCallTracer callTracer = censusStats.newClientCallTracer( - tagger.empty(), method.getFullMethodName()); + tagger.empty(), method.getFullMethodName(), true); fakeClock.forwardTime(3000, MILLISECONDS); callTracer.callEnded(Status.DEADLINE_EXCEEDED.withDescription("3 seconds")); @@ -510,11 +510,10 @@ private void subtestStatsHeadersPropagateTags(boolean propagate, boolean recordS tagCtxSerializer, statsRecorder, fakeClock.getStopwatchSupplier(), - propagate, - recordStats); + propagate); Metadata headers = new Metadata(); CensusStatsModule.ClientCallTracer callTracer = - census.newClientCallTracer(clientCtx, method.getFullMethodName()); + census.newClientCallTracer(clientCtx, method.getFullMethodName(), recordStats); // This propagates clientCtx to headers if propagates==true callTracer.newClientStreamTracer(CallOptions.DEFAULT, headers); if (propagate) { @@ -525,7 +524,7 @@ private void subtestStatsHeadersPropagateTags(boolean propagate, boolean recordS } ServerStreamTracer serverTracer = - census.getServerTracerFactory().newServerStreamTracer( + census.getServerTracerFactory(recordStats).newServerStreamTracer( method.getFullMethodName(), headers); // Server tracer deserializes clientCtx from the headers, so that it records stats with the // propagated tags. @@ -578,7 +577,7 @@ private void subtestStatsHeadersPropagateTags(boolean propagate, boolean recordS @Test public void statsHeadersNotPropagateDefaultContext() { CensusStatsModule.ClientCallTracer callTracer = - censusStats.newClientCallTracer(tagger.empty(), method.getFullMethodName()); + censusStats.newClientCallTracer(tagger.empty(), method.getFullMethodName(), true); Metadata headers = new Metadata(); callTracer.newClientStreamTracer(CallOptions.DEFAULT, headers); assertFalse(headers.containsKey(censusStats.statsHeader)); @@ -659,7 +658,7 @@ public void traceHeaderMalformed() throws Exception { @Test public void serverBasicStatsNoHeaders() { - ServerStreamTracer.Factory tracerFactory = censusStats.getServerTracerFactory(); + ServerStreamTracer.Factory tracerFactory = censusStats.getServerTracerFactory(true); ServerStreamTracer tracer = tracerFactory.newServerStreamTracer(method.getFullMethodName(), new Metadata()); diff --git a/testing/src/main/java/io/grpc/internal/TestingAccessor.java b/testing/src/main/java/io/grpc/internal/TestingAccessor.java index b03967e7686e..0b97081a3470 100644 --- a/testing/src/main/java/io/grpc/internal/TestingAccessor.java +++ b/testing/src/main/java/io/grpc/internal/TestingAccessor.java @@ -32,7 +32,9 @@ public static void setStatsImplementation( Tagger tagger, TagContextBinarySerializer tagCtxSerializer, StatsRecorder statsRecorder) { - builder.statsImplementation(tagger, tagCtxSerializer, statsRecorder); + builder.overrideCensusStatsModule( + new CensusStatsModule( + tagger, tagCtxSerializer, statsRecorder, GrpcUtil.STOPWATCH_SUPPLIER, true)); } /** @@ -43,7 +45,9 @@ public static void setStatsImplementation( Tagger tagger, TagContextBinarySerializer tagCtxSerializer, StatsRecorder statsRecorder) { - builder.statsImplementation(tagger, tagCtxSerializer, statsRecorder); + builder.overrideCensusStatsModule( + new CensusStatsModule( + tagger, tagCtxSerializer, statsRecorder, GrpcUtil.STOPWATCH_SUPPLIER, true)); } private TestingAccessor() {