Skip to content

Commit

Permalink
Refactor the code for overriding the Census implementation in tests.
Browse files Browse the repository at this point in the history
This commit adds a method to set the CensusStatsModule on
AbstractServerImplBuilder and AbstractManagedChannelImplBuilder.  It is simpler
to set the whole CensusStatsModule than pass in all necessary OpenCensus
implementation objects.
  • Loading branch information
sebright committed Nov 2, 2017
1 parent c724318 commit 1b464f6
Show file tree
Hide file tree
Showing 7 changed files with 110 additions and 100 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,6 @@
import io.grpc.NameResolver;
import io.grpc.NameResolverProvider;
import io.grpc.PickFirstBalancerFactory;
import io.opencensus.stats.Stats;
import io.opencensus.stats.StatsRecorder;
import io.opencensus.tags.Tagger;
import io.opencensus.tags.Tags;
import io.opencensus.tags.propagation.TagContextBinarySerializer;
import io.opencensus.trace.Tracing;
import java.net.SocketAddress;
import java.net.URI;
Expand Down Expand Up @@ -155,13 +150,7 @@ protected final int maxInboundMessageSize() {
private boolean tracingEnabled = true;

@Nullable
private Tagger tagger;

@Nullable
private TagContextBinarySerializer tagCtxSerializer;

@Nullable
private StatsRecorder statsRecorder;
private CensusStatsModule censusStatsOverride;

protected AbstractManagedChannelImplBuilder(String target) {
this.target = Preconditions.checkNotNull(target, "target");
Expand Down Expand Up @@ -296,11 +285,8 @@ public final T idleTimeout(long value, TimeUnit unit) {
* Override the default stats implementation.
*/
@VisibleForTesting
protected final T statsImplementation(
Tagger tagger, TagContextBinarySerializer tagCtxSerializer, StatsRecorder statsRecorder) {
this.tagger = tagger;
this.tagCtxSerializer = tagCtxSerializer;
this.statsRecorder = statsRecorder;
protected final T overrideCensusStatsModule(CensusStatsModule censusStats) {
this.censusStatsOverride = censusStats;
return thisT();
}

Expand Down Expand Up @@ -358,24 +344,13 @@ final List<ClientInterceptor> getEffectiveInterceptors() {
List<ClientInterceptor> effectiveInterceptors =
new ArrayList<ClientInterceptor>(this.interceptors);
if (statsEnabled) {
Tagger tagger = this.tagger != null ? this.tagger : Tags.getTagger();
TagContextBinarySerializer tagCtxSerializer =
this.tagCtxSerializer != null
? this.tagCtxSerializer
: Tags.getTagPropagationComponent().getBinarySerializer();
StatsRecorder statsRecorder =
this.statsRecorder != null ? this.statsRecorder : Stats.getStatsRecorder();
CensusStatsModule censusStats =
new CensusStatsModule(
tagger,
tagCtxSerializer,
statsRecorder,
GrpcUtil.STOPWATCH_SUPPLIER,
true,
recordStats);
CensusStatsModule censusStats = this.censusStatsOverride;
if (censusStats == null) {
censusStats = new CensusStatsModule(GrpcUtil.STOPWATCH_SUPPLIER, true);
}
// First interceptor runs last (see ClientInterceptors.intercept()), so that no
// other interceptor can override the tracer factory we set in CallOptions.
effectiveInterceptors.add(0, censusStats.getClientInterceptor());
effectiveInterceptors.add(0, censusStats.getClientInterceptor(recordStats));
}
if (tracingEnabled) {
CensusTracingModule censusTracing =
Expand Down
43 changes: 8 additions & 35 deletions core/src/main/java/io/grpc/internal/AbstractServerImplBuilder.java
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,6 @@
import io.grpc.ServerServiceDefinition;
import io.grpc.ServerStreamTracer;
import io.grpc.ServerTransportFilter;
import io.opencensus.stats.Stats;
import io.opencensus.stats.StatsRecorder;
import io.opencensus.tags.Tagger;
import io.opencensus.tags.Tags;
import io.opencensus.tags.propagation.TagContextBinarySerializer;
import io.opencensus.trace.Tracing;
import java.util.ArrayList;
import java.util.Collections;
Expand Down Expand Up @@ -100,13 +95,7 @@ public List<ServerServiceDefinition> getServices() {
CompressorRegistry compressorRegistry = DEFAULT_COMPRESSOR_REGISTRY;

@Nullable
private Tagger tagger;

@Nullable
private TagContextBinarySerializer tagCtxSerializer;

@Nullable
private StatsRecorder statsRecorder;
private CensusStatsModule censusStatsOverride;

private boolean statsEnabled = true;
private boolean recordStats = true;
Expand Down Expand Up @@ -193,13 +182,8 @@ public final T compressorRegistry(CompressorRegistry registry) {
* Override the default stats implementation.
*/
@VisibleForTesting
protected T statsImplementation(
final Tagger tagger,
TagContextBinarySerializer tagCtxSerializer,
StatsRecorder statsRecorder) {
this.tagger = tagger;
this.tagCtxSerializer = tagCtxSerializer;
this.statsRecorder = statsRecorder;
protected T overrideCensusStatsModule(CensusStatsModule censusStats) {
this.censusStatsOverride = censusStats;
return thisT();
}

Expand Down Expand Up @@ -242,22 +226,11 @@ final List<ServerStreamTracer.Factory> getTracerFactories() {
ArrayList<ServerStreamTracer.Factory> tracerFactories =
new ArrayList<ServerStreamTracer.Factory>();
if (statsEnabled) {
Tagger tagger = this.tagger != null ? this.tagger : Tags.getTagger();
TagContextBinarySerializer tagCtxSerializer =
this.tagCtxSerializer != null
? this.tagCtxSerializer
: Tags.getTagPropagationComponent().getBinarySerializer();
StatsRecorder statsRecorder =
this.statsRecorder != null ? this.statsRecorder : Stats.getStatsRecorder();
CensusStatsModule censusStats =
new CensusStatsModule(
tagger,
tagCtxSerializer,
statsRecorder,
GrpcUtil.STOPWATCH_SUPPLIER,
true,
recordStats);
tracerFactories.add(censusStats.getServerTracerFactory());
CensusStatsModule censusStats = this.censusStatsOverride;
if (censusStats == null) {
censusStats = new CensusStatsModule(GrpcUtil.STOPWATCH_SUPPLIER, true);
}
tracerFactories.add(censusStats.getServerTracerFactory(recordStats));
}
if (tracingEnabled) {
CensusTracingModule censusTracing =
Expand Down
75 changes: 58 additions & 17 deletions core/src/main/java/io/grpc/internal/CensusStatsModule.java
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,12 @@
import io.grpc.StreamTracer;
import io.opencensus.contrib.grpc.metrics.RpcMeasureConstants;
import io.opencensus.stats.MeasureMap;
import io.opencensus.stats.Stats;
import io.opencensus.stats.StatsRecorder;
import io.opencensus.tags.TagContext;
import io.opencensus.tags.TagValue;
import io.opencensus.tags.Tagger;
import io.opencensus.tags.Tags;
import io.opencensus.tags.propagation.TagContextBinarySerializer;
import io.opencensus.tags.propagation.TagContextSerializationException;
import java.util.concurrent.TimeUnit;
Expand Down Expand Up @@ -74,21 +76,33 @@ final class CensusStatsModule {
private final Supplier<Stopwatch> stopwatchSupplier;
@VisibleForTesting
final Metadata.Key<TagContext> statsHeader;
private final StatsClientInterceptor clientInterceptor = new StatsClientInterceptor();
private final ServerTracerFactory serverTracerFactory = new ServerTracerFactory();
private final boolean propagateTags;
private final boolean recordStats;

/**
* Creates a {@link CensusStatsModule} with the default OpenCensus implementation.
*/
CensusStatsModule(Supplier<Stopwatch> stopwatchSupplier, boolean propagateTags) {
this(
Tags.getTagger(),
Tags.getTagPropagationComponent().getBinarySerializer(),
Stats.getStatsRecorder(),
stopwatchSupplier,
propagateTags);
}

/**
* Creates a {@link CensusStatsModule} with the given OpenCensus implementation.
*/
CensusStatsModule(
final Tagger tagger,
final TagContextBinarySerializer tagCtxSerializer,
StatsRecorder statsRecorder, Supplier<Stopwatch> stopwatchSupplier,
boolean propagateTags, boolean recordStats) {
boolean propagateTags) {
this.tagger = checkNotNull(tagger, "tagger");
this.statsRecorder = checkNotNull(statsRecorder, "statsRecorder");
checkNotNull(tagCtxSerializer, "tagCtxSerializer");
this.stopwatchSupplier = checkNotNull(stopwatchSupplier, "stopwatchSupplier");
this.propagateTags = propagateTags;
this.recordStats = recordStats;
this.statsHeader =
Metadata.Key.of("grpc-tags-bin", new Metadata.BinaryMarshaller<TagContext>() {
@Override
Expand Down Expand Up @@ -118,22 +132,23 @@ public TagContext parseBytes(byte[] serialized) {
* Creates a {@link ClientCallTracer} for a new call.
*/
@VisibleForTesting
ClientCallTracer newClientCallTracer(TagContext parentCtx, String fullMethodName) {
return new ClientCallTracer(this, parentCtx, fullMethodName);
ClientCallTracer newClientCallTracer(
TagContext parentCtx, String fullMethodName, boolean recordStats) {
return new ClientCallTracer(this, parentCtx, fullMethodName, recordStats);
}

/**
* Returns the server tracer factory.
*/
ServerStreamTracer.Factory getServerTracerFactory() {
return serverTracerFactory;
ServerStreamTracer.Factory getServerTracerFactory(boolean recordStats) {
return new ServerTracerFactory(recordStats);
}

/**
* Returns the client interceptor that facilitates Census-based stats reporting.
*/
ClientInterceptor getClientInterceptor() {
return clientInterceptor;
ClientInterceptor getClientInterceptor(boolean recordStats) {
return new StatsClientInterceptor(recordStats);
}

private static final class ClientTracer extends ClientStreamTracer {
Expand Down Expand Up @@ -206,12 +221,18 @@ static final class ClientCallTracer extends ClientStreamTracer.Factory {
private volatile ClientTracer streamTracer;
private volatile int callEnded;
private final TagContext parentCtx;
private final boolean recordStats;

ClientCallTracer(CensusStatsModule module, TagContext parentCtx, String fullMethodName) {
ClientCallTracer(
CensusStatsModule module,
TagContext parentCtx,
String fullMethodName,
boolean recordStats) {
this.module = module;
this.parentCtx = checkNotNull(parentCtx, "parentCtx");
this.fullMethodName = checkNotNull(fullMethodName, "fullMethodName");
this.stopwatch = module.stopwatchSupplier.get().start();
this.recordStats = recordStats;
}

@Override
Expand Down Expand Up @@ -241,7 +262,7 @@ void callEnded(Status status) {
if (callEndedUpdater.getAndSet(this, 1) != 0) {
return;
}
if (!module.recordStats) {
if (!recordStats) {
return;
}
stopwatch.stop();
Expand Down Expand Up @@ -299,6 +320,7 @@ private static final class ServerTracer extends ServerStreamTracer {
private volatile int streamClosed;
private final Stopwatch stopwatch;
private final Tagger tagger;
private final boolean recordStats;
private volatile long outboundMessageCount;
private volatile long inboundMessageCount;
private volatile long outboundWireSize;
Expand All @@ -311,12 +333,14 @@ private static final class ServerTracer extends ServerStreamTracer {
String fullMethodName,
TagContext parentCtx,
Supplier<Stopwatch> stopwatchSupplier,
Tagger tagger) {
Tagger tagger,
boolean recordStats) {
this.module = module;
this.fullMethodName = checkNotNull(fullMethodName, "fullMethodName");
this.parentCtx = checkNotNull(parentCtx, "parentCtx");
this.stopwatch = stopwatchSupplier.get().start();
this.tagger = tagger;
this.recordStats = recordStats;
}

@Override
Expand Down Expand Up @@ -360,7 +384,7 @@ public void streamClosed(Status status) {
if (streamClosedUpdater.getAndSet(this, 1) != 0) {
return;
}
if (!module.recordStats) {
if (!recordStats) {
return;
}
stopwatch.stop();
Expand Down Expand Up @@ -397,6 +421,12 @@ public Context filterContext(Context context) {

@VisibleForTesting
final class ServerTracerFactory extends ServerStreamTracer.Factory {
private final boolean recordStats;

ServerTracerFactory(boolean recordStats) {
this.recordStats = recordStats;
}

@Override
public ServerStreamTracer newServerStreamTracer(String fullMethodName, Metadata headers) {
TagContext parentCtx = headers.get(statsHeader);
Expand All @@ -409,19 +439,30 @@ public ServerStreamTracer newServerStreamTracer(String fullMethodName, Metadata
.put(RpcMeasureConstants.RPC_METHOD, TagValue.create(fullMethodName))
.build();
return new ServerTracer(
CensusStatsModule.this, fullMethodName, parentCtx, stopwatchSupplier, tagger);
CensusStatsModule.this,
fullMethodName,
parentCtx,
stopwatchSupplier,
tagger,
recordStats);
}
}

@VisibleForTesting
final class StatsClientInterceptor implements ClientInterceptor {
private final boolean recordStats;

StatsClientInterceptor(boolean recordStats) {
this.recordStats = recordStats;
}

@Override
public <ReqT, RespT> ClientCall<ReqT, RespT> interceptCall(
MethodDescriptor<ReqT, RespT> method, CallOptions callOptions, Channel next) {
// New RPCs on client-side inherit the tag context from the current Context.
TagContext parentCtx = tagger.getCurrentTagContext();
final ClientCallTracer tracerFactory =
newClientCallTracer(parentCtx, method.getFullMethodName());
newClientCallTracer(parentCtx, method.getFullMethodName(), recordStats);
ClientCall<ReqT, RespT> call =
next.newCall(method, callOptions.withStreamTracerFactory(tracerFactory));
return new SimpleForwardingClientCall<ReqT, RespT>(call) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -393,12 +393,24 @@ public void idleTimeout() {
static class Builder extends AbstractManagedChannelImplBuilder<Builder> {
Builder(String target) {
super(target);
statsImplementation(DUMMY_TAGGER, DUMMY_TAG_CONTEXT_BINARY_SERIALIZER, DUMMY_STATS_RECORDER);
overrideCensusStatsModule(
new CensusStatsModule(
DUMMY_TAGGER,
DUMMY_TAG_CONTEXT_BINARY_SERIALIZER,
DUMMY_STATS_RECORDER,
GrpcUtil.STOPWATCH_SUPPLIER,
true));
}

Builder(SocketAddress directServerAddress, String authority) {
super(directServerAddress, authority);
statsImplementation(DUMMY_TAGGER, DUMMY_TAG_CONTEXT_BINARY_SERIALIZER, DUMMY_STATS_RECORDER);
overrideCensusStatsModule(
new CensusStatsModule(
DUMMY_TAGGER,
DUMMY_TAG_CONTEXT_BINARY_SERIALIZER,
DUMMY_STATS_RECORDER,
GrpcUtil.STOPWATCH_SUPPLIER,
true));
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,13 @@ public void getTracerFactories_disableBoth() {

static class Builder extends AbstractServerImplBuilder<Builder> {
Builder() {
statsImplementation(DUMMY_TAGGER, DUMMY_TAG_CONTEXT_BINARY_SERIALIZER, DUMMY_STATS_RECORDER);
overrideCensusStatsModule(
new CensusStatsModule(
DUMMY_TAGGER,
DUMMY_TAG_CONTEXT_BINARY_SERIALIZER,
DUMMY_STATS_RECORDER,
GrpcUtil.STOPWATCH_SUPPLIER,
true));
}

@Override
Expand Down
Loading

0 comments on commit 1b464f6

Please sign in to comment.