diff --git a/client-runtime/src/main/java/com/microsoft/rest/v2/annotations/HeaderCollection.java b/client-runtime/src/main/java/com/microsoft/rest/v2/annotations/HeaderCollection.java new file mode 100644 index 000000000000..9867788a4f39 --- /dev/null +++ b/client-runtime/src/main/java/com/microsoft/rest/v2/annotations/HeaderCollection.java @@ -0,0 +1,27 @@ +/** + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT License. See License.txt in the project root for + * license information. + */ + +package com.microsoft.rest.v2.annotations; + +import java.lang.annotation.Retention; +import java.lang.annotation.Target; + +import static java.lang.annotation.ElementType.FIELD; +import static java.lang.annotation.RetentionPolicy.RUNTIME; + +/** + * Marker on a deserialized header type that indicates that the property should + * be treated as a header collection with the provided prefix. + */ +@Retention(RUNTIME) +@Target(FIELD) +public @interface HeaderCollection { + /** + * The header collection prefix. + * @return The header collection prefix. + */ + String value(); +} \ No newline at end of file diff --git a/client-runtime/src/main/java/com/microsoft/rest/v2/protocol/HttpResponseDecoder.java b/client-runtime/src/main/java/com/microsoft/rest/v2/protocol/HttpResponseDecoder.java index b29f911bb6ca..cceb15cf32d9 100644 --- a/client-runtime/src/main/java/com/microsoft/rest/v2/protocol/HttpResponseDecoder.java +++ b/client-runtime/src/main/java/com/microsoft/rest/v2/protocol/HttpResponseDecoder.java @@ -14,6 +14,8 @@ import com.microsoft.rest.v2.RestResponse; import com.microsoft.rest.v2.SwaggerMethodParser; import com.microsoft.rest.v2.UnixTime; +import com.microsoft.rest.v2.annotations.HeaderCollection; +import com.microsoft.rest.v2.http.HttpHeader; import com.microsoft.rest.v2.http.HttpHeaders; import com.microsoft.rest.v2.http.HttpMethod; import com.microsoft.rest.v2.http.HttpResponse; @@ -26,8 +28,10 @@ import org.joda.time.DateTime; import java.io.IOException; +import java.lang.reflect.Field; import java.lang.reflect.ParameterizedType; import java.lang.reflect.Type; +import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Set; @@ -292,6 +296,45 @@ private Object deserializeHeaders(HttpHeaders headers) throws IOException { } else { final String headersJsonString = serializer.serialize(headers, SerializerEncoding.JSON); Object deserializedHeaders = serializer.deserialize(headersJsonString, deserializedHeadersType, SerializerEncoding.JSON); + + final Class deserializedHeadersClass = TypeToken.of(deserializedHeadersType).getRawType(); + final Field[] declaredFields = deserializedHeadersClass.getDeclaredFields(); + for (final Field declaredField : declaredFields) { + if (declaredField.isAnnotationPresent(HeaderCollection.class)) { + final Type declaredFieldType = declaredField.getGenericType(); + if (TypeToken.of(declaredField.getType()).isSubtypeOf(Map.class)) { + final Type[] mapTypeArguments = getTypeArguments(declaredFieldType); + if (mapTypeArguments.length == 2 && mapTypeArguments[0] == String.class && mapTypeArguments[1] == String.class) { + final HeaderCollection headerCollectionAnnotation = declaredField.getAnnotation(HeaderCollection.class); + final String headerCollectionPrefix = headerCollectionAnnotation.value().toLowerCase(); + final int headerCollectionPrefixLength = headerCollectionPrefix.length(); + if (headerCollectionPrefixLength > 0) { + final Map headerCollection = new HashMap<>(); + for (final HttpHeader header : headers) { + final String headerName = header.name(); + if (headerName.toLowerCase().startsWith(headerCollectionPrefix)) { + headerCollection.put(headerName.substring(headerCollectionPrefixLength), header.value()); + } + } + + final boolean declaredFieldAccessibleBackup = declaredField.isAccessible(); + try { + if (!declaredFieldAccessibleBackup) { + declaredField.setAccessible(true); + } + declaredField.set(deserializedHeaders, headerCollection); + } catch (IllegalAccessException ignored) { + } finally { + if (!declaredFieldAccessibleBackup) { + declaredField.setAccessible(declaredFieldAccessibleBackup); + } + } + } + } + } + } + } + return deserializedHeaders; } } diff --git a/client-runtime/src/test/java/com/microsoft/rest/v2/RestProxyTests.java b/client-runtime/src/test/java/com/microsoft/rest/v2/RestProxyTests.java index 95cf6e73f57d..ee72492b9db6 100644 --- a/client-runtime/src/test/java/com/microsoft/rest/v2/RestProxyTests.java +++ b/client-runtime/src/test/java/com/microsoft/rest/v2/RestProxyTests.java @@ -1403,6 +1403,10 @@ public void service24Put() { // Helpers protected T createService(Class serviceClass) { final HttpClient httpClient = createHttpClient(); + return createService(serviceClass, httpClient); + } + + protected T createService(Class serviceClass, HttpClient httpClient) { final HttpPipeline httpPipeline = HttpPipeline.build(httpClient, new DecodingPolicyFactory()); return RestProxy.create(serviceClass, httpPipeline, serializer); } diff --git a/client-runtime/src/test/java/com/microsoft/rest/v2/RestProxyWithMockTests.java b/client-runtime/src/test/java/com/microsoft/rest/v2/RestProxyWithMockTests.java index 7643acfcf803..30677ecac0bb 100644 --- a/client-runtime/src/test/java/com/microsoft/rest/v2/RestProxyWithMockTests.java +++ b/client-runtime/src/test/java/com/microsoft/rest/v2/RestProxyWithMockTests.java @@ -3,18 +3,29 @@ import com.google.common.base.Charsets; import com.microsoft.rest.v2.annotations.ExpectedResponses; import com.microsoft.rest.v2.annotations.GET; +import com.microsoft.rest.v2.annotations.HeaderCollection; import com.microsoft.rest.v2.annotations.Host; import com.microsoft.rest.v2.annotations.ReturnValueWireType; import com.microsoft.rest.v2.entities.HttpBinJSON; -import com.microsoft.rest.v2.http.*; +import com.microsoft.rest.v2.http.HttpClient; +import com.microsoft.rest.v2.http.HttpHeaders; +import com.microsoft.rest.v2.http.HttpPipeline; +import com.microsoft.rest.v2.http.HttpRequest; +import com.microsoft.rest.v2.http.HttpResponse; +import com.microsoft.rest.v2.http.MockHttpClient; +import com.microsoft.rest.v2.http.MockHttpResponse; import io.reactivex.Single; import org.joda.time.DateTime; import org.junit.Test; +import java.util.HashMap; import java.util.List; import java.util.Map; -import static org.junit.Assert.*; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; public class RestProxyWithMockTests extends RestProxyTests { @Override @@ -237,6 +248,139 @@ public Single sendRequestAsync(HttpRequest request) { } } + private static class HeaderCollectionTypePublicFields { + public String name; + + @HeaderCollection("header-collection-prefix-") + public Map headerCollection; + } + + private static class HeaderCollectionTypeProtectedFields { + protected String name; + + @HeaderCollection("header-collection-prefix-") + protected Map headerCollection; + } + + private static class HeaderCollectionTypePrivateFields { + private String name; + + @HeaderCollection("header-collection-prefix-") + private Map headerCollection; + } + + private static class HeaderCollectionTypePackagePrivateFields { + String name; + + @HeaderCollection("header-collection-prefix-") + Map headerCollection; + } + + @Host("https://www.example.com") + interface ServiceHeaderCollections { + @GET("url/path") + RestResponse publicFields(); + + @GET("url/path") + RestResponse protectedFields(); + + @GET("url/path") + RestResponse privateFields(); + + @GET("url/path") + RestResponse packagePrivateFields(); + } + + private static final HttpClient headerCollectionHttpClient = new MockHttpClient() { + @Override + public Single sendRequestAsync(HttpRequest request) { + final HttpHeaders headers = new HttpHeaders(); + headers.set("name", "Phillip"); + headers.set("header-collection-prefix-one", "1"); + headers.set("header-collection-prefix-two", "2"); + headers.set("header-collection-prefix-three", "3"); + final MockHttpResponse response = new MockHttpResponse(200, headers); + return Single.just(response); + } + }; + + private ServiceHeaderCollections createHeaderCollectionsService() { + return createService(ServiceHeaderCollections.class, headerCollectionHttpClient); + } + + private static void assertHeaderCollectionsRawHeaders(RestResponse response) { + final HttpHeaders responseRawHeaders = new HttpHeaders(response.rawHeaders()); + assertEquals("Phillip", responseRawHeaders.value("name")); + assertEquals("1", responseRawHeaders.value("header-collection-prefix-one")); + assertEquals("2", responseRawHeaders.value("header-collection-prefix-two")); + assertEquals("3", responseRawHeaders.value("header-collection-prefix-three")); + assertEquals(4, responseRawHeaders.size()); + } + + private static void assertHeaderCollections(Map headerCollections) { + final Map expectedHeaderCollections = new HashMap<>(); + expectedHeaderCollections.put("one", "1"); + expectedHeaderCollections.put("two", "2"); + expectedHeaderCollections.put("three", "3"); + + for (final String key : headerCollections.keySet()) { + assertEquals(expectedHeaderCollections.get(key), headerCollections.get(key)); + } + assertEquals(expectedHeaderCollections.size(), headerCollections.size()); + } + + @Test + public void serviceHeaderCollectionPublicFields() { + final RestResponse response = createHeaderCollectionsService() + .publicFields(); + assertNotNull(response); + assertHeaderCollectionsRawHeaders(response); + + final HeaderCollectionTypePublicFields responseHeaders = response.headers(); + assertNotNull(responseHeaders); + assertEquals("Phillip", responseHeaders.name); + assertHeaderCollections(responseHeaders.headerCollection); + } + + @Test + public void serviceHeaderCollectionProtectedFields() { + final RestResponse response = createHeaderCollectionsService() + .protectedFields(); + assertNotNull(response); + assertHeaderCollectionsRawHeaders(response); + + final HeaderCollectionTypeProtectedFields responseHeaders = response.headers(); + assertNotNull(responseHeaders); + assertEquals("Phillip", responseHeaders.name); + assertHeaderCollections(responseHeaders.headerCollection); + } + + @Test + public void serviceHeaderCollectionPrivateFields() { + final RestResponse response = createHeaderCollectionsService() + .privateFields(); + assertNotNull(response); + assertHeaderCollectionsRawHeaders(response); + + final HeaderCollectionTypePrivateFields responseHeaders = response.headers(); + assertNotNull(responseHeaders); + assertEquals("Phillip", responseHeaders.name); + assertHeaderCollections(responseHeaders.headerCollection); + } + + @Test + public void serviceHeaderCollectionPackagePrivateFields() { + final RestResponse response = createHeaderCollectionsService() + .packagePrivateFields(); + assertNotNull(response); + assertHeaderCollectionsRawHeaders(response); + + final HeaderCollectionTypePackagePrivateFields responseHeaders = response.headers(); + assertNotNull(responseHeaders); + assertEquals("Phillip", responseHeaders.name); + assertHeaderCollections(responseHeaders.headerCollection); + } + private static void assertContains(String value, String expectedSubstring) { assertTrue("Expected \"" + value + "\" to contain \"" + expectedSubstring + "\".", value.contains(expectedSubstring)); } diff --git a/client-runtime/src/test/java/com/microsoft/rest/v2/http/MockHttpResponse.java b/client-runtime/src/test/java/com/microsoft/rest/v2/http/MockHttpResponse.java index ed9717f183d6..1416b1efae05 100644 --- a/client-runtime/src/test/java/com/microsoft/rest/v2/http/MockHttpResponse.java +++ b/client-runtime/src/test/java/com/microsoft/rest/v2/http/MockHttpResponse.java @@ -42,6 +42,10 @@ public MockHttpResponse(int statusCode, String string) { this(statusCode, new HttpHeaders(), string == null ? new byte[0] : string.getBytes()); } + public MockHttpResponse(int statusCode, HttpHeaders headers) { + this(statusCode, headers, null); + } + public MockHttpResponse(int statusCode, HttpHeaders headers, Object serializable) { this(statusCode, headers, serialize(serializable)); }