diff --git a/core/src/main/java/apoc/util/s3/S3ParamsExtractor.java b/core/src/main/java/apoc/util/s3/S3ParamsExtractor.java index d41bd73216..294cfa68fe 100644 --- a/core/src/main/java/apoc/util/s3/S3ParamsExtractor.java +++ b/core/src/main/java/apoc/util/s3/S3ParamsExtractor.java @@ -3,6 +3,7 @@ import apoc.util.Util; import com.amazonaws.regions.Regions; +import java.net.URI; import java.net.URL; import java.util.Map; import java.util.Objects; @@ -15,9 +16,15 @@ public class S3ParamsExtractor { private static final String SESSION_TOKEN = "sessionToken"; public static S3Params extract(URL url) throws IllegalArgumentException { + return extract(url.toString()); + } + + public static S3Params extract(String url) throws IllegalArgumentException { + + URI uri = URI.create(url); - if (!PROTOCOL.equals(url.getProtocol())) { - throw new IllegalArgumentException("Unsupported protocol '" + url.getProtocol() + "'"); + if (!PROTOCOL.equals(uri.getScheme())) { + throw new IllegalArgumentException("Unsupported protocol '" + uri.getScheme() + "'"); } //aws credentials @@ -25,8 +32,8 @@ public static S3Params extract(URL url) throws IllegalArgumentException { String secretKey = null; String sessionToken = null; - if (url.getUserInfo() != null) { - String[] credentials = url.getUserInfo().split(":"); + if (uri.getUserInfo() != null) { + String[] credentials = uri.getUserInfo().split(":"); if (credentials.length > 1) { accessKey = credentials[0]; secretKey = credentials[1]; @@ -36,7 +43,7 @@ public static S3Params extract(URL url) throws IllegalArgumentException { } // User info part cannot contain session token. } else { - Map params = Util.getRequestParameter(url.getQuery()); + Map params = Util.getRequestParameter(uri.getQuery()); if(Objects.nonNull(params)) { if(params.containsKey(ACCESS_KEY)){accessKey = params.get(ACCESS_KEY);} if(params.containsKey(SECRET_KEY)){secretKey = params.get(SECRET_KEY);} @@ -45,17 +52,17 @@ public static S3Params extract(URL url) throws IllegalArgumentException { } // endpoint - String endpoint = url.getHost(); + String endpoint = uri.getHost(); - Integer slashIndex = url.getPath().lastIndexOf("/"); + Integer slashIndex = uri.getPath().indexOf("/", 1); String key; String bucket ; if(slashIndex > 0){ // key - key = url.getPath().substring(slashIndex + 1); + key = uri.getPath().substring(slashIndex + 1); // bucket - bucket = url.getPath().substring(1, slashIndex); + bucket = uri.getPath().substring(1, slashIndex); } else{ throw new IllegalArgumentException("Invalid url. Must be:\n's3://accessKey:secretKey@endpoint:port/bucket/key' or\n's3://endpoint:port/bucket/key?accessKey=accessKey&secretKey=secretKey'"); @@ -87,8 +94,8 @@ public static S3Params extract(URL url) throws IllegalArgumentException { } } - if (url.getPort() != 80 && url.getPort() != 443 && url.getPort() > 0) { - endpoint += ":" + url.getPort(); + if (uri.getPort() != 80 && uri.getPort() != 443 && uri.getPort() > 0) { + endpoint += ":" + uri.getPort(); } if (Objects.nonNull(endpoint) && endpoint.isEmpty()) { diff --git a/core/src/test/java/apoc/util/s3/S3ParamsExtractorTest.java b/core/src/test/java/apoc/util/s3/S3ParamsExtractorTest.java new file mode 100644 index 0000000000..b64495ec0c --- /dev/null +++ b/core/src/test/java/apoc/util/s3/S3ParamsExtractorTest.java @@ -0,0 +1,42 @@ +package apoc.util.s3; + +import org.junit.Test; +import static org.junit.Assert.*; + +public class S3ParamsExtractorTest { + + @Test + public void testEncodedS3Url() throws Exception { + S3Params params = S3ParamsExtractor.extract("s3://accessKeyId:some%2Fsecret%2Fkey:some%2Fsession%2Ftoken@s3.us-east-2.amazonaws.com:1234/bucket/path/to/key"); + assertEquals("some/secret/key", params.getSecretKey()); + assertEquals("some/session/token", params.getSessionToken()); + assertEquals("accessKeyId", params.getAccessKey()); + assertEquals("bucket", params.getBucket()); + assertEquals("path/to/key", params.getKey()); + assertEquals("s3.us-east-2.amazonaws.com:1234", params.getEndpoint()); + assertEquals("us-east-2", params.getRegion()); + } + + @Test + public void testEncodedS3UrlQueryParams() throws Exception { + S3Params params = S3ParamsExtractor.extract("s3://s3.us-east-2.amazonaws.com:1234/bucket/path/to/key?accessKey=accessKeyId&secretKey=some%2Fsecret%2Fkey&sessionToken=some%2Fsession%2Ftoken"); + assertEquals("some/secret/key", params.getSecretKey()); + assertEquals("some/session/token", params.getSessionToken()); + assertEquals("accessKeyId", params.getAccessKey()); + assertEquals("bucket", params.getBucket()); + assertEquals("path/to/key", params.getKey()); + assertEquals("s3.us-east-2.amazonaws.com:1234", params.getEndpoint()); + } + + @Test + public void testExtractEndpointPort() throws Exception { + assertEquals("s3.amazonaws.com", S3ParamsExtractor.extract("s3://s3.amazonaws.com:80/bucket/path/to/key").getEndpoint()); + assertEquals("s3.amazonaws.com:1234", S3ParamsExtractor.extract("s3://s3.amazonaws.com:1234/bucket/path/to/key").getEndpoint()); + } + + @Test + public void testExtractRegion() throws Exception { + assertEquals("us-east-2", S3ParamsExtractor.extract("s3://s3.us-east-2.amazonaws.com:80/bucket/path/to/key").getRegion()); + assertNull(S3ParamsExtractor.extract("s3://s3.amazonaws.com:80/bucket/path/to/key").getRegion()); + } +}