diff --git a/sdk/core/azure-core/src/main/java/com/azure/core/http/policy/HttpLoggingPolicy.java b/sdk/core/azure-core/src/main/java/com/azure/core/http/policy/HttpLoggingPolicy.java index fd1f8acf85422..50f6a5f196ad5 100644 --- a/sdk/core/azure-core/src/main/java/com/azure/core/http/policy/HttpLoggingPolicy.java +++ b/sdk/core/azure-core/src/main/java/com/azure/core/http/policy/HttpLoggingPolicy.java @@ -13,13 +13,13 @@ import com.azure.core.implementation.LogLevel; import com.azure.core.implementation.LoggingUtil; import com.azure.core.util.CoreUtils; -import com.azure.core.util.FluxUtil; import com.azure.core.util.UrlBuilder; import com.azure.core.util.logging.ClientLogger; import com.fasterxml.jackson.databind.ObjectMapper; import com.fasterxml.jackson.databind.SerializationFeature; import reactor.core.publisher.Mono; +import java.io.ByteArrayOutputStream; import java.net.URL; import java.nio.charset.StandardCharsets; import java.util.Collections; @@ -100,7 +100,7 @@ public Mono process(HttpPipelineCallContext context, HttpPipelineN * @param request HTTP request being sent to Azure. * @return A Mono which will emit the string to log. */ - private Mono logRequest(final ClientLogger logger, final HttpRequest request) { + private Mono logRequest(final ClientLogger logger, final HttpRequest request) { int numericLogLevel = LoggingUtil.getEnvironmentLoggingLevel().toNumeric(); if (shouldLoggingBeSkipped(numericLogLevel)) { return Mono.empty(); @@ -117,44 +117,59 @@ private Mono logRequest(final ClientLogger logger, final HttpRequest req addHeadersToLogMessage(request.getHeaders(), requestLogMessage, numericLogLevel); - Mono requestLoggingMono = Mono.defer(() -> Mono.just(requestLogMessage.toString())); + if (!httpLogDetailLevel.shouldLogBody()) { + return logAndReturn(logger, requestLogMessage, null); + } - if (httpLogDetailLevel.shouldLogBody()) { - if (request.getBody() == null) { - requestLogMessage.append("(empty body)") - .append(System.lineSeparator()) - .append("--> END ") - .append(request.getHttpMethod()) - .append(System.lineSeparator()); - } else { - String contentType = request.getHeaders().getValue("Content-Type"); - long contentLength = getContentLength(logger, request.getHeaders()); + if (request.getBody() == null) { + requestLogMessage.append("(empty body)") + .append(System.lineSeparator()) + .append("--> END ") + .append(request.getHttpMethod()) + .append(System.lineSeparator()); - if (shouldBodyBeLogged(contentType, contentLength)) { - requestLoggingMono = FluxUtil.collectBytesInByteBufferStream(request.getBody()).flatMap(bytes -> { + return logAndReturn(logger, requestLogMessage, null); + } + + String contentType = request.getHeaders().getValue("Content-Type"); + long contentLength = getContentLength(logger, request.getHeaders()); + + if (shouldBodyBeLogged(contentType, contentLength)) { + ByteArrayOutputStream outputStream = new ByteArrayOutputStream((int) contentLength); + + // Add non-mutating operators to the data stream. + request.setBody( + request.getBody() + .doOnNext(byteBuffer -> { + for (int i = byteBuffer.position(); i < byteBuffer.limit(); i++) { + outputStream.write(byteBuffer.get(i)); + } + }) + .doFinally(ignored -> { requestLogMessage.append(contentLength) .append("-byte body:") .append(System.lineSeparator()) - .append(prettyPrintIfNeeded(logger, contentType, new String(bytes, StandardCharsets.UTF_8))) + .append(prettyPrintIfNeeded(logger, contentType, + new String(outputStream.toByteArray(), StandardCharsets.UTF_8))) .append(System.lineSeparator()) .append("--> END ") .append(request.getHttpMethod()) .append(System.lineSeparator()); - return Mono.just(requestLogMessage.toString()); - }); - } else { - requestLogMessage.append(contentLength) - .append("-byte body: (content not logged)") - .append(System.lineSeparator()) - .append("--> END ") - .append(request.getHttpMethod()) - .append(System.lineSeparator()); - } - } - } + logger.info(requestLogMessage.toString()); + })); - return requestLoggingMono.doOnNext(logger::info); + return Mono.empty(); + } else { + requestLogMessage.append(contentLength) + .append("-byte body: (content not logged)") + .append(System.lineSeparator()) + .append("--> END ") + .append(request.getHttpMethod()) + .append(System.lineSeparator()); + + return logAndReturn(logger, requestLogMessage, null); + } } /* @@ -194,32 +209,45 @@ private Mono logResponse(final ClientLogger logger, final HttpResp addHeadersToLogMessage(response.getHeaders(), responseLogMessage, numericLogLevel); - Mono responseLoggingMono = Mono.defer(() -> Mono.just(responseLogMessage.toString())); - - if (httpLogDetailLevel.shouldLogBody()) { - final String contentTypeHeader = response.getHeaderValue("Content-Type"); + if (!httpLogDetailLevel.shouldLogBody()) { + responseLogMessage.append("<-- END HTTP"); + return logAndReturn(logger, responseLogMessage, response); + } - if (shouldBodyBeLogged(contentTypeHeader, getContentLength(logger, response.getHeaders()))) { - final HttpResponse bufferedResponse = response.buffer(); - responseLoggingMono = bufferedResponse.getBodyAsString().flatMap(body -> { + String contentTypeHeader = response.getHeaderValue("Content-Type"); + long contentLength = getContentLength(logger, response.getHeaders()); + + if (shouldBodyBeLogged(contentTypeHeader, contentLength)) { + HttpResponse bufferedResponse = response.buffer(); + ByteArrayOutputStream outputStream = new ByteArrayOutputStream((int) contentLength); + return bufferedResponse.getBody() + .doOnNext(byteBuffer -> { + for (int i = byteBuffer.position(); i < byteBuffer.limit(); i++) { + outputStream.write(byteBuffer.get(i)); + } + }) + .doFinally(ignored -> { responseLogMessage.append("Response body:") .append(System.lineSeparator()) - .append(prettyPrintIfNeeded(logger, contentTypeHeader, body)) + .append(prettyPrintIfNeeded(logger, contentTypeHeader, + new String(outputStream.toByteArray(), StandardCharsets.UTF_8))) .append(System.lineSeparator()) .append("<-- END HTTP"); - return Mono.just(responseLogMessage.toString()); - }).switchIfEmpty(responseLoggingMono); - } else { - responseLogMessage.append("(body content not logged)") - .append(System.lineSeparator()) - .append("<-- END HTTP"); - } + logger.info(responseLogMessage.toString()); + }).then(Mono.just(bufferedResponse)); } else { - responseLogMessage.append("<-- END HTTP"); + responseLogMessage.append("(body content not logged)") + .append(System.lineSeparator()) + .append("<-- END HTTP"); + + return logAndReturn(logger, responseLogMessage, response); } + } - return responseLoggingMono.doOnNext(logger::info).thenReturn(response); + private Mono logAndReturn(ClientLogger logger, StringBuilder logMessageBuilder, T data) { + logger.info(logMessageBuilder.toString()); + return Mono.justOrEmpty(data); } /* diff --git a/sdk/core/azure-core/src/main/java/com/azure/core/implementation/http/BufferedHttpResponse.java b/sdk/core/azure-core/src/main/java/com/azure/core/implementation/http/BufferedHttpResponse.java index 02dbe1eda46cb..532f5a9e274ed 100644 --- a/sdk/core/azure-core/src/main/java/com/azure/core/implementation/http/BufferedHttpResponse.java +++ b/sdk/core/azure-core/src/main/java/com/azure/core/implementation/http/BufferedHttpResponse.java @@ -5,6 +5,7 @@ import com.azure.core.http.HttpHeaders; import com.azure.core.http.HttpResponse; +import com.azure.core.util.FluxUtil; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; @@ -17,7 +18,7 @@ */ public final class BufferedHttpResponse extends HttpResponse { private final HttpResponse innerHttpResponse; - private final Mono cachedBody; + private final Flux cachedBody; /** * Creates a buffered HTTP response. @@ -27,7 +28,7 @@ public final class BufferedHttpResponse extends HttpResponse { public BufferedHttpResponse(HttpResponse innerHttpResponse) { super(innerHttpResponse.getRequest()); this.innerHttpResponse = innerHttpResponse; - this.cachedBody = innerHttpResponse.getBodyAsByteArray().cache(); + this.cachedBody = innerHttpResponse.getBody().cache(); } @Override @@ -46,13 +47,13 @@ public HttpHeaders getHeaders() { } @Override - public Mono getBodyAsByteArray() { + public Flux getBody() { return cachedBody; } @Override - public Flux getBody() { - return getBodyAsByteArray().flatMapMany(bytes -> Flux.just(ByteBuffer.wrap(bytes))); + public Mono getBodyAsByteArray() { + return FluxUtil.collectBytesInByteBufferStream(cachedBody.map(ByteBuffer::duplicate)); } @Override diff --git a/sdk/core/azure-core/src/test/java/com/azure/core/http/policy/HttpLoggingPolicyTests.java b/sdk/core/azure-core/src/test/java/com/azure/core/http/policy/HttpLoggingPolicyTests.java index 90da7d462c2a3..df464a9dbf018 100644 --- a/sdk/core/azure-core/src/test/java/com/azure/core/http/policy/HttpLoggingPolicyTests.java +++ b/sdk/core/azure-core/src/test/java/com/azure/core/http/policy/HttpLoggingPolicyTests.java @@ -3,34 +3,48 @@ package com.azure.core.http.policy; +import com.azure.core.http.ContentType; +import com.azure.core.http.HttpHeaders; import com.azure.core.http.HttpMethod; import com.azure.core.http.HttpPipeline; import com.azure.core.http.HttpPipelineBuilder; import com.azure.core.http.HttpRequest; +import com.azure.core.http.HttpResponse; import com.azure.core.http.clients.NoOpHttpClient; import com.azure.core.util.Configuration; +import com.azure.core.util.Context; import com.azure.core.util.CoreUtils; +import com.azure.core.util.FluxUtil; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.MethodSource; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; import reactor.test.StepVerifier; import java.io.ByteArrayOutputStream; import java.io.PrintStream; +import java.net.MalformedURLException; +import java.net.URL; +import java.nio.ByteBuffer; +import java.nio.charset.Charset; import java.nio.charset.StandardCharsets; import java.util.Collections; import java.util.HashSet; import java.util.Set; import java.util.stream.Stream; +import static org.junit.jupiter.api.Assertions.assertArrayEquals; + /** * This class contains tests for {@link HttpLoggingPolicy}. */ public class HttpLoggingPolicyTests { private static final String REDACTED = "REDACTED"; + private static final Context CONTEXT = new Context("caller-method", "HttpLoggingPolicyTests"); private String originalLogLevel; private PrintStream originalErr; @@ -42,6 +56,13 @@ public void prepareForTest() { originalLogLevel = System.getProperty(Configuration.PROPERTY_AZURE_LOG_LEVEL); System.setProperty(Configuration.PROPERTY_AZURE_LOG_LEVEL, "2"); + /* + * Indicate to SLF4J to enable trace level logging for a logger named + * com.azure.core.util.logging.ClientLoggerTests. Trace is the maximum level of logging supported by the + * ClientLogger. + */ + System.setProperty("org.slf4j.simpleLogger.log.com.azure.core.util.logging.HttpLoggingPolicyTests", "trace"); + // Override System.err as that is where SLF4J will log by default. originalErr = System.err; logCaptureStream = new ByteArrayOutputStream(); @@ -57,6 +78,8 @@ public void cleanupAfterTest() { System.setProperty(Configuration.PROPERTY_AZURE_LOG_LEVEL, originalLogLevel); } + System.clearProperty("org.slf4j.simpleLogger.log.com.azure.core.util.logging.HttpLoggingPolicyTests"); + // Reset System.err to the original PrintStream. System.setErr(originalErr); } @@ -75,7 +98,7 @@ public void redactQueryParameters(String requestUrl, String expectedQueryString, .httpClient(new NoOpHttpClient()) .build(); - StepVerifier.create(pipeline.send(new HttpRequest(HttpMethod.POST, requestUrl))) + StepVerifier.create(pipeline.send(new HttpRequest(HttpMethod.POST, requestUrl), CONTEXT)) .verifyComplete(); String logString = new String(logCaptureStream.toByteArray(), StandardCharsets.UTF_8); @@ -105,4 +128,145 @@ private static Stream redactQueryParametersSupplier() { Arguments.of(requestUrl, fullyAllowedQueryString, allQueryParameters) ); } + + /** + * Tests that logging the request body doesn't consume the stream before it is sent over the network. + */ + @ParameterizedTest(name = "[{index}] {displayName}") + @MethodSource("validateLoggingDoesNotConsumeSupplier") + public void validateLoggingDoesNotConsumeRequest(Flux stream, byte[] data, int contentLength) + throws MalformedURLException { + URL requestUrl = new URL("https://test.com"); + HttpHeaders requestHeaders = new HttpHeaders() + .put("Content-Type", ContentType.APPLICATION_JSON) + .put("Content-Length", Integer.toString(contentLength)); + + HttpPipeline pipeline = new HttpPipelineBuilder() + .policies(new HttpLoggingPolicy(new HttpLogOptions().setLogLevel(HttpLogDetailLevel.BODY))) + .httpClient(request -> FluxUtil.collectBytesInByteBufferStream(request.getBody()) + .doOnSuccess(bytes -> assertArrayEquals(data, bytes)) + .then(Mono.empty())) + .build(); + + StepVerifier.create(pipeline.send(new HttpRequest(HttpMethod.POST, requestUrl, requestHeaders, stream), + CONTEXT)) + .verifyComplete(); + + String logString = new String(logCaptureStream.toByteArray(), StandardCharsets.UTF_8); + System.out.println(logString); + Assertions.assertTrue(logString.contains(new String(data, StandardCharsets.UTF_8))); + } + + /** + * Tests that logging the response body doesn't consume the stream before it is returned from the service call. + */ + @ParameterizedTest(name = "[{index}] {displayName}") + @MethodSource("validateLoggingDoesNotConsumeSupplier") + public void validateLoggingDoesNotConsumeResponse(Flux stream, byte[] data, int contentLength) { + HttpRequest request = new HttpRequest(HttpMethod.GET, "https://test.com"); + HttpHeaders responseHeaders = new HttpHeaders() + .put("Content-Type", ContentType.APPLICATION_JSON) + .put("Content-Length", Integer.toString(contentLength)); + + HttpPipeline pipeline = new HttpPipelineBuilder() + .policies(new HttpLoggingPolicy(new HttpLogOptions().setLogLevel(HttpLogDetailLevel.BODY))) + .httpClient(ignored -> Mono.just(new MockHttpResponse(ignored, responseHeaders, stream))) + .build(); + + StepVerifier.create(pipeline.send(request, CONTEXT)) + .assertNext(response -> StepVerifier.create(FluxUtil.collectBytesInByteBufferStream(response.getBody())) + .assertNext(bytes -> assertArrayEquals(data, bytes)) + .verifyComplete()) + .verifyComplete(); + + String logString = new String(logCaptureStream.toByteArray(), StandardCharsets.UTF_8); + System.out.println(logString); + Assertions.assertTrue(logString.contains(new String(data, StandardCharsets.UTF_8))); + } + + private static Stream validateLoggingDoesNotConsumeSupplier() { + byte[] data = "this is a test".getBytes(StandardCharsets.UTF_8); + byte[] repeatingData = new byte[data.length * 3]; + for (int i = 0; i < 3; i++) { + System.arraycopy(data, 0, repeatingData, i * data.length, data.length); + } + + return Stream.of( + // Single emission cold flux. + Arguments.of(Flux.just(ByteBuffer.wrap(data)), data, data.length), + + // Single emission Stream based Flux. + Arguments.of(Flux.fromStream(Stream.of(ByteBuffer.wrap(data))), data, data.length), + + // Single emission hot flux. + Arguments.of(Flux.just(ByteBuffer.wrap(data)).publish().autoConnect(), data, data.length), + + // Multiple emission cold flux. + Arguments.of(Flux.fromArray(new ByteBuffer[]{ + ByteBuffer.wrap(data), + ByteBuffer.wrap(data), + ByteBuffer.wrap(data) + }), repeatingData, repeatingData.length), + + // Multiple emission Stream based flux. + Arguments.of(Flux.fromStream(Stream.of( + ByteBuffer.wrap(data), + ByteBuffer.wrap(data), + ByteBuffer.wrap(data) + )), repeatingData, repeatingData.length), + + // Multiple emission hot flux. + Arguments.of(Flux.just( + ByteBuffer.wrap(data), + ByteBuffer.wrap(data), + ByteBuffer.wrap(data) + ).publish().autoConnect(), repeatingData, repeatingData.length) + ); + } + + private static class MockHttpResponse extends HttpResponse { + private final HttpHeaders headers; + private final Flux body; + + MockHttpResponse(HttpRequest request, HttpHeaders headers, Flux body) { + super(request); + this.headers = headers; + this.body = body; + } + + @Override + public int getStatusCode() { + return 200; + } + + @Override + public String getHeaderValue(String name) { + return headers.getValue(name); + } + + @Override + public HttpHeaders getHeaders() { + return headers; + } + + @Override + public Flux getBody() { + return body; + } + + @Override + public Mono getBodyAsByteArray() { + return FluxUtil.collectBytesInByteBufferStream(body); + } + + @Override + public Mono getBodyAsString() { + return getBodyAsByteArray().map(String::new); + } + + @Override + public Mono getBodyAsString(Charset charset) { + return getBodyAsByteArray().map(bytes -> new String(bytes, StandardCharsets.UTF_8)); + } + } }