Skip to content

Commit

Permalink
Fix HTTPS proxying (Azure#476)
Browse files Browse the repository at this point in the history
* Progress on HTTPS proxying

* Fix Azure compile error

* Simplify netty pipeline and reduce possibility of corruption

* Fix build error
  • Loading branch information
RikkiGibson authored and jianghaolu committed Aug 23, 2018
1 parent 45df0d7 commit 426c5f5
Show file tree
Hide file tree
Showing 10 changed files with 73 additions and 98 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ final class RefreshTokenClient {

private static HttpClient createHttpClient(Proxy proxy) {
return new NettyClient.Factory()
.create(new HttpClientConfiguration(proxy, false));
.create(new HttpClientConfiguration(proxy));
}

AuthenticationResult refreshToken(String tenant, String clientId, String resource, String refreshToken, boolean isMultipleResoureRefreshToken) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,6 @@ public class AzureProxyToRestProxyWithNettyTests extends AzureProxyToRestProxyTe

@Override
protected HttpClient createHttpClient() {
return nettyClientFactory.create(new HttpClientConfiguration(null, false));
return nettyClientFactory.create(new HttpClientConfiguration(null));
}
}
4 changes: 4 additions & 0 deletions client-runtime/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,10 @@
<groupId>io.netty</groupId>
<artifactId>netty-handler</artifactId>
</dependency>
<dependency>
<groupId>io.netty</groupId>
<artifactId>netty-handler-proxy</artifactId>
</dependency>
<dependency>
<groupId>io.netty</groupId>
<artifactId>netty-buffer</artifactId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
*/
public class HttpClientConfiguration {
private final Proxy proxy;
private final boolean isProxyHTTPS;

/**
* @return The optional proxy to use.
Expand All @@ -22,22 +21,11 @@ public Proxy proxy() {
return proxy;
}

/**
* Indicates whether the connection to the proxy is via HTTP or HTTPS.
* This is unrelated to whether the final resource being accessed is over HTTP or HTTPS.
* @return true if the proxy should be connected via HTTPS
*/
public boolean isProxyHTTPS() {
return isProxyHTTPS;
}

/**
* Creates an HttpClientConfiguration.
* @param proxy The optional proxy to use.
* @param isProxyHTTPS true if the proxy should be connected via HTTPS
*/
public HttpClientConfiguration(Proxy proxy, boolean isProxyHTTPS) {
public HttpClientConfiguration(Proxy proxy) {
this.proxy = proxy;
this.isProxyHTTPS = isProxyHTTPS;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,6 @@
package com.microsoft.rest.v2.http;

import java.io.IOException;
import java.net.InetSocketAddress;
import java.net.Proxy;
import java.net.URI;
import java.net.URISyntaxException;
import java.nio.ByteBuffer;
import java.util.concurrent.ThreadFactory;
import java.util.concurrent.TimeUnit;
Expand Down Expand Up @@ -43,8 +39,6 @@
import io.netty.handler.codec.http.HttpClientCodec;
import io.netty.handler.codec.http.HttpContent;
import io.netty.handler.codec.http.HttpMethod;
import io.netty.handler.codec.http.HttpRequestEncoder;
import io.netty.handler.codec.http.HttpResponseDecoder;
import io.netty.handler.codec.http.HttpVersion;
import io.netty.handler.codec.http.LastHttpContent;
import io.netty.util.concurrent.DefaultThreadFactory;
Expand Down Expand Up @@ -80,7 +74,7 @@ public final class NettyClient extends HttpClient {
*/
private NettyClient(HttpClientConfiguration configuration, NettyAdapter adapter) {
this.adapter = adapter;
this.configuration = configuration != null ? configuration : new HttpClientConfiguration(null, false);
this.configuration = configuration != null ? configuration : new HttpClientConfiguration(null);
}

@Override
Expand Down Expand Up @@ -168,8 +162,7 @@ private static SharedChannelPool createChannelPool(Bootstrap bootstrap, Transpor
public synchronized void channelCreated(Channel ch) throws Exception {
// Why is it necessary to have "synchronized" to prevent NRE in pipeline().get(Class<T>)?
// Is channelCreated not run on the eventLoop assigned to the channel?
ch.pipeline().addLast("HttpResponseDecoder", new HttpResponseDecoder());
ch.pipeline().addLast("HttpRequestEncoder", new HttpRequestEncoder());
ch.pipeline().addLast("HttpClientCodec", new HttpClientCodec());
ch.pipeline().addLast("HttpClientInboundHandler", new HttpClientInboundHandler());
}
}, poolSize);
Expand All @@ -188,19 +181,13 @@ private NettyAdapter(Bootstrap baseBootstrap, int eventLoopGroupSize, int channe
}

private Single<HttpResponse> sendRequestInternalAsync(final HttpRequest request, final HttpClientConfiguration configuration) {
final URI channelAddress;
try {
channelAddress = getChannelAddress(request, configuration);
} catch (URISyntaxException e) {
return Single.error(e);
}
addHeaders(request);

// Creates cold observable from an emitter
return Single.create((SingleEmitter<HttpResponse> responseEmitter) -> {
AcquisitionListener listener = new AcquisitionListener(channelPool, request, responseEmitter);
responseEmitter.setDisposable(listener);
channelPool.acquire(channelAddress).addListener(listener);
channelPool.acquire(request.url().toURI(), configuration.proxy()).addListener(listener);
});
}
}
Expand All @@ -211,21 +198,6 @@ private static void addHeaders(final HttpRequest request) {
io.netty.handler.codec.http.HttpHeaderValues.KEEP_ALIVE.toString());
}

private static URI getChannelAddress(final HttpRequest request, final HttpClientConfiguration configuration) throws URISyntaxException {
final Proxy proxy = configuration.proxy();
if (proxy == null) {
return request.url().toURI();
} else if (proxy.address() instanceof InetSocketAddress) {
InetSocketAddress address = (InetSocketAddress) proxy.address();
String scheme = configuration.isProxyHTTPS() ? "https" : "http";
String channelAddressString = scheme + "://" + address.getHostString() + ":" + address.getPort();
return new URI(channelAddressString);
} else {
throw new IllegalArgumentException(
"SocketAddress on java.net.Proxy must be an InetSocketAddress. Found proxy: " + proxy);
}
}

private static final class AcquisitionListener
implements GenericFutureListener<Future<? super Channel>>, Disposable {

Expand Down Expand Up @@ -256,7 +228,9 @@ private static final class AcquisitionListener
private volatile boolean finishedWritingRequestBody;
private volatile RequestSubscriber requestSubscriber;

AcquisitionListener(SharedChannelPool channelPool, final HttpRequest request,
AcquisitionListener(
SharedChannelPool channelPool,
HttpRequest request,
SingleEmitter<HttpResponse> responseEmitter) {
this.channelPool = channelPool;
this.request = request;
Expand Down Expand Up @@ -292,7 +266,6 @@ public void operationComplete(Future<? super Channel> cf) {
//TODO do we need a memory barrier here to ensure vis of responseEmitter in other threads?

try {
configurePipeline(channel, request);

final DefaultHttpRequest raw = createDefaultHttpRequest(request);

Expand Down Expand Up @@ -649,22 +622,6 @@ public void channelWritable(boolean writable) {

}

private static void configurePipeline(Channel channel, HttpRequest request) {
if (request.httpMethod() == com.microsoft.rest.v2.http.HttpMethod.HEAD) {
// Use HttpClientCodec for HEAD operations
if (channel.pipeline().get("HttpClientCodec") == null) {
channel.pipeline().remove(HttpRequestEncoder.class);
channel.pipeline().replace(HttpResponseDecoder.class, "HttpClientCodec", new HttpClientCodec());
}
} else {
// Use HttpResponseDecoder for other operations
if (channel.pipeline().get("HttpResponseDecoder") == null) {
channel.pipeline().replace(HttpClientCodec.class, "HttpResponseDecoder", new HttpResponseDecoder());
channel.pipeline().addAfter("HttpResponseDecoder", "HttpRequestEncoder", new HttpRequestEncoder());
}
}
}

private static DefaultHttpRequest createDefaultHttpRequest(HttpRequest request) {
final DefaultHttpRequest raw = new DefaultHttpRequest(HttpVersion.HTTP_1_1,
HttpMethod.valueOf(request.httpMethod().toString()), request.url().toString());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,18 @@
import io.netty.channel.ChannelInitializer;
import io.netty.channel.pool.ChannelPool;
import io.netty.channel.pool.ChannelPoolHandler;
import io.netty.handler.proxy.HttpProxyHandler;
import io.netty.handler.ssl.SslContext;
import io.netty.handler.ssl.SslContextBuilder;
import io.netty.util.AttributeKey;
import io.netty.util.concurrent.Future;
import io.netty.util.concurrent.Promise;
import io.reactivex.annotations.Nullable;
import io.reactivex.exceptions.Exceptions;

import javax.net.ssl.SSLException;
import java.net.InetSocketAddress;
import java.net.Proxy;
import java.net.URI;
import java.net.URISyntaxException;
import java.util.Queue;
Expand All @@ -36,7 +40,7 @@
* pool. All the internal pools for all the requests have a fixed size limit.
* This channel pool should be shared between multiple Netty adapters.
*/
public class SharedChannelPool implements ChannelPool {
class SharedChannelPool implements ChannelPool {
private static final AttributeKey<URI> CHANNEL_URI = AttributeKey.newInstance("channel-uri");
private final Bootstrap bootstrap;
private final ChannelPoolHandler handler;
Expand All @@ -62,7 +66,7 @@ private static boolean isChannelHealthy(Channel channel) {
* @param handler the handler to apply to the channels on creation, acquisition and release
* @param size the upper limit of total number of channels
*/
public SharedChannelPool(final Bootstrap bootstrap, final ChannelPoolHandler handler, int size) {
SharedChannelPool(final Bootstrap bootstrap, final ChannelPoolHandler handler, int size) {
this.bootstrap = bootstrap.clone().handler(new ChannelInitializer<Channel>() {
@Override
protected void initChannel(Channel ch) throws Exception {
Expand Down Expand Up @@ -97,7 +101,8 @@ protected void initChannel(Channel ch) throws Exception {
requests.wait();
}
}
request = requests.poll();
// requests must be non-empty based on the above condition
request = requests.remove();

synchronized (sync) {
while (leased.size() >= poolSize && !closed) {
Expand All @@ -108,12 +113,12 @@ protected void initChannel(Channel ch) throws Exception {
break;
}

if (available.containsKey(request.uri)) {
Channel channel = available.poll(request.uri);
if (available.containsKey(request.channelURI)) {
Channel channel = available.poll(request.channelURI);
if (isChannelHealthy(channel)) {
handler.channelAcquired(channel);
request.promise.setSuccess(channel);
leased.put(request.uri, channel);
leased.put(request.channelURI, channel);
continue;
}
}
Expand All @@ -122,30 +127,35 @@ protected void initChannel(Channel ch) throws Exception {
available.poll().close();
}
int port;
if (request.uri.getPort() < 0) {
port = "https".equals(request.uri.getScheme()) ? 443 : 80;
if (request.destinationURI.getPort() < 0) {
port = "https".equals(request.destinationURI.getScheme()) ? 443 : 80;
} else {
port = request.uri.getPort();
port = request.destinationURI.getPort();
}
channelFuture = SharedChannelPool.this.bootstrap.clone().connect(request.uri.getHost(), port);
channelFuture = SharedChannelPool.this.bootstrap.clone().connect(request.destinationURI.getHost(), port);
channelFuture.channel().eventLoop().execute(() -> {
channelFuture.channel().attr(CHANNEL_URI).set(request.channelURI);

channelFuture.channel().attr(CHANNEL_URI).set(request.uri);
// Apply SSL handler for https connections
if ("https".equalsIgnoreCase(request.destinationURI.getScheme())) {
channelFuture.channel().pipeline().addFirst(sslContext.newHandler(channelFuture.channel().alloc(), request.destinationURI.getHost(), port));
}

// Apply SSL handler for https connections
if ("https".equalsIgnoreCase(request.uri.getScheme())) {
channelFuture.channel().pipeline().addFirst(sslContext.newHandler(channelFuture.channel().alloc(), request.uri.getHost(), port));
}
if (request.proxy != null) {
channelFuture.channel().pipeline().addFirst("HttpProxyHandler", new HttpProxyHandler(request.proxy.address()));
}

leased.put(request.uri, channelFuture.channel());
channelFuture.addListener((ChannelFuture future) -> {
if (future.isSuccess()) {
handler.channelAcquired(future.channel());
request.promise.setSuccess(future.channel());
} else {
leased.remove(request.uri, future.channel());
leased.put(request.channelURI, channelFuture.channel());
channelFuture.addListener((ChannelFuture future) -> {
if (future.isSuccess()) {
handler.channelAcquired(future.channel());
request.promise.setSuccess(future.channel());
} else {
leased.remove(request.channelURI, future.channel());

request.promise.setFailure(future.cause());
}
request.promise.setFailure(future.cause());
}
});
});
}
} catch (Exception e) {
Expand All @@ -160,8 +170,8 @@ protected void initChannel(Channel ch) throws Exception {
* @param uri the URI the channel acquired should be connected to
* @return the future to a connected channel
*/
public Future<Channel> acquire(URI uri) {
return this.acquire(uri, this.bootstrap.config().group().next().<Channel>newPromise());
public Future<Channel> acquire(URI uri, @Nullable Proxy proxy) {
return this.acquire(uri, proxy, this.bootstrap.config().group().next().<Channel>newPromise());
}

/**
Expand All @@ -170,21 +180,30 @@ public Future<Channel> acquire(URI uri) {
* @param promise the writable future to a connected channel
* @return the future to a connected channel
*/
public Future<Channel> acquire(URI uri, final Promise<Channel> promise) {
public Future<Channel> acquire(URI uri, @Nullable Proxy proxy, final Promise<Channel> promise) {
if (closed) {
throw new RejectedExecutionException("SharedChannelPool is closed");
}

ChannelRequest channelRequest = new ChannelRequest();
channelRequest.promise = promise;
channelRequest.proxy = proxy;
int port;
if (uri.getPort() < 0) {
port = "https".equals(uri.getScheme()) ? 443 : 80;
} else {
port = uri.getPort();
}
try {
channelRequest.uri = new URI(String.format("%s://%s:%d", uri.getScheme(), uri.getHost(), port));
channelRequest.destinationURI = new URI(String.format("%s://%s:%d", uri.getScheme(), uri.getHost(), port));

if (proxy == null) {
channelRequest.channelURI = channelRequest.destinationURI;
} else {
InetSocketAddress address = (InetSocketAddress) proxy.address();
channelRequest.channelURI = new URI(String.format("%s://%s:%d", uri.getScheme(), address.getHostString(), address.getPort()));
}

requests.add(channelRequest);
synchronized (requests) {
requests.notify();
Expand Down Expand Up @@ -248,7 +267,9 @@ public void close() {
}

private static class ChannelRequest {
private URI uri;
private URI destinationURI;
private URI channelURI;
private Proxy proxy;
private Promise<Channel> promise;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -1180,7 +1180,7 @@ public void service20GetBytes100OnlyHeaders() {
final HttpBinHeaders headers = response.headers();
assertNotNull(headers);
assertEquals(true, headers.accessControlAllowCredentials);
assertEquals("keep-alive", headers.connection);
assertEquals("keep-alive", headers.connection.toLowerCase());
assertNotNull(headers.date);
assertEquals("1.1 vegur", headers.via);
assertNotEquals(0, headers.xProcessedTime);
Expand Down Expand Up @@ -1236,7 +1236,7 @@ public void service20PutOnlyHeaders() {
final HttpBinHeaders headers = response.headers();
assertNotNull(headers);
assertEquals(true, headers.accessControlAllowCredentials);
assertEquals("keep-alive", headers.connection);
assertEquals("keep-alive", headers.connection.toLowerCase());
assertNotNull(headers.date);
assertEquals("1.1 vegur", headers.via);
assertNotEquals(0, headers.xProcessedTime);
Expand All @@ -1258,7 +1258,7 @@ public void service20PutBodyAndHeaders() {
final HttpBinHeaders headers = response.headers();
assertNotNull(headers);
assertEquals(true, headers.accessControlAllowCredentials);
assertEquals("keep-alive", headers.connection);
assertEquals("keep-alive", headers.connection.toLowerCase());
assertNotNull(headers.date);
assertEquals("1.1 vegur", headers.via);
assertNotEquals(0, headers.xProcessedTime);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,6 @@ public class RestProxyWithHttpProxyNettyTests extends RestProxyTests {
protected HttpClient createHttpClient() {
InetSocketAddress address = new InetSocketAddress("127.0.0.1", 8888);
Proxy proxy = new Proxy(Proxy.Type.HTTP, address);
return nettyClientFactory.create(new HttpClientConfiguration(proxy, false));
return nettyClientFactory.create(new HttpClientConfiguration(proxy));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,6 @@ public class RestProxyWithNettyTests extends RestProxyTests {

@Override
protected HttpClient createHttpClient() {
return nettyClientFactory.create(new HttpClientConfiguration(null, false));
return nettyClientFactory.create(new HttpClientConfiguration(null));
}
}
Loading

0 comments on commit 426c5f5

Please sign in to comment.