diff --git a/.travis.yml b/.travis.yml index 9b48214fd8ec..227521b43994 100644 --- a/.travis.yml +++ b/.travis.yml @@ -5,7 +5,8 @@ env: - MAVEN_OPTS="-Xmx512M -XX:+ExitOnOutOfMemoryError" - MAVEN_SKIP_CHECKS_AND_DOCS="-Dair.check.skip-all=true -Dmaven.javadoc.skip=true" - MAVEN_FAST_INSTALL="-DskipTests $MAVEN_SKIP_CHECKS_AND_DOCS -B -q -T C1" - - ARTIFACTS_UPLOAD_PATH=travis_build_artifacts/${TRAVIS_REPO_SLUG}/${TRAVIS_BRANCH}/${TRAVIS_BUILD_NUMBER} + - ARTIFACTS_UPLOAD_PATH_BRANCH=travis_build_artifacts/${TRAVIS_REPO_SLUG}/${TRAVIS_BRANCH}/${TRAVIS_BUILD_NUMBER} + - ARTIFACTS_UPLOAD_PATH_PR=travis_build_artifacts_pr/${TRAVIS_REPO_SLUG}/${TRAVIS_BRANCH}/${TRAVIS_BUILD_NUMBER} - TEST_FLAGS="" matrix: - MAVEN_CHECKS=true @@ -16,11 +17,17 @@ env: - TEST_SPECIFIC_MODULES=presto-cassandra - TEST_SPECIFIC_MODULES=presto-hive - TEST_OTHER_MODULES=!presto-tests,!presto-raptor,!presto-accumulo,!presto-cassandra,!presto-hive,!presto-docs,!presto-server,!presto-server-rpm - - PRODUCT_TESTS=true + - PRODUCT_TESTS_BASIC_ENVIRONMENT=true + - PRODUCT_TESTS_SPECIFIC_ENVIRONMENT=true - HIVE_TESTS=true sudo: required dist: trusty +group: deprecated-2017Q2 +addons: + apt: + packages: + - oracle-java8-installer cache: directories: @@ -40,7 +47,7 @@ install: ./mvnw install $MAVEN_FAST_INSTALL -pl '!presto-docs,!presto-server,!presto-server-rpm' fi - | - if [[ -v PRODUCT_TESTS ]]; then + if [[ -v PRODUCT_TESTS_BASIC_ENVIRONMENT || -v PRODUCT_TESTS_SPECIFIC_ENVIRONMENT ]]; then ./mvnw install $MAVEN_FAST_INSTALL -pl '!presto-docs,!presto-server-rpm' fi - | @@ -48,6 +55,13 @@ install: ./mvnw install $MAVEN_FAST_INSTALL -pl presto-hive-hadoop2 -am fi +before_script: + - | + export ARTIFACTS_UPLOAD_PATH=${ARTIFACTS_UPLOAD_PATH_BRANCH} + if [ "$TRAVIS_PULL_REQUEST" != "false" ]; then + export ARTIFACTS_UPLOAD_PATH=${ARTIFACTS_UPLOAD_PATH_PR} + fi + script: - | if [[ -v MAVEN_CHECKS ]]; then @@ -62,28 +76,33 @@ script: ./mvnw test $MAVEN_SKIP_CHECKS_AND_DOCS -B -pl $TEST_OTHER_MODULES fi - | - if [[ -v PRODUCT_TESTS ]]; then + if [[ -v PRODUCT_TESTS_BASIC_ENVIRONMENT ]]; then presto-product-tests/bin/run_on_docker.sh \ multinode -x quarantine,big_query,storage_formats,profile_specific_tests,tpcds fi - | - if [[ -v PRODUCT_TESTS ]]; then + if [[ -v PRODUCT_TESTS_SPECIFIC_ENVIRONMENT ]]; then presto-product-tests/bin/run_on_docker.sh \ singlenode-kerberos-hdfs-impersonation -g storage_formats,cli,hdfs_impersonation,authorization fi - | - if [[ -v PRODUCT_TESTS ]]; then + if [[ -v PRODUCT_TESTS_SPECIFIC_ENVIRONMENT ]]; then presto-product-tests/bin/run_on_docker.sh \ - singlenode-ldap -g ldap_cli + singlenode-ldap -g ldap -x simba_jdbc fi # SQL server image sporadically hangs during the startup # TODO: Uncomment it once issue is fixed # https://github.com/Microsoft/mssql-docker/issues/76 # - | -# if [[ -v PRODUCT_TESTS ]]; then +# if [[ -v PRODUCT_TESTS_SPECIFIC_ENVIRONMENT ]]; then # presto-product-tests/bin/run_on_docker.sh \ # singlenode-sqlserver -g sqlserver # fi + - | + if [[ -v PRODUCT_TESTS_SPECIFIC_ENVIRONMENT ]]; then + presto-product-tests/bin/run_on_docker.sh \ + multinode-tls -g smoke,cli,group-by,join,tls + fi - | if [[ -v HIVE_TESTS ]]; then presto-hive-hadoop2/bin/run_on_docker.sh diff --git a/README.md b/README.md index 7df5380c68c4..f7428dba837d 100644 --- a/README.md +++ b/README.md @@ -15,13 +15,13 @@ See the [User Manual](https://prestodb.io/docs/current/) for deployment instruct Presto is a standard Maven project. Simply run the following command from the project root directory: - mvn clean install + ./mvnw clean install On the first build, Maven will download all the dependencies from the internet and cache them in the local repository (`~/.m2/repository`), which can take a considerable amount of time. Subsequent builds will be faster. Presto has a comprehensive set of unit tests that can take several minutes to run. You can disable the tests when building: - mvn clean install -DskipTests + ./mvnw clean install -DskipTests ## Running Presto in your IDE @@ -83,3 +83,4 @@ We recommend you use IntelliJ as your IDE. The code style template for the proje * Consider using String formatting (printf style formatting using the Java `Formatter` class): `format("Session property %s is invalid: %s", name, value)` (note that `format()` should always be statically imported). Sometimes, if you only need to append something, consider using the `+` operator. * Avoid using the ternary operator except for trivial expressions. * Use an assertion from Airlift's `Assertions` class if there is one that covers your case rather than writing the assertion by hand. Over time we may move over to more fluent assertions like AssertJ. +* When writing a Git commit message, follow these [guidelines](https://chris.beams.io/posts/git-commit/). diff --git a/pom.xml b/pom.xml index 35ab6862c142..ddaf55607f08 100644 --- a/pom.xml +++ b/pom.xml @@ -10,7 +10,7 @@ com.facebook.presto presto-root - 0.179-tw-0.36 + 0.181-tw-0.37 pom presto-root @@ -30,7 +30,7 @@ scm:git:git://github.com/twitter-forks/presto.git https://github.com/twitter-forks/presto - 0.179-tw-0.36 + 0.181-tw-0.37 @@ -51,11 +51,11 @@ ${dep.airlift.version} 0.29 1.11.30 + 3.8.1 1.31 6.10 - - true - None + 0.15.1 + 0.15.2 Asia/Katmandu @@ -112,6 +112,9 @@ presto-plugin-toolkit presto-resource-group-managers presto-benchto-benchmarks + presto-thrift-connector-api + presto-thrift-testing-server + presto-thrift-connector @@ -344,6 +347,32 @@ 3 + + com.facebook.presto + presto-thrift-connector-api + ${project.version} + + + + com.facebook.presto + presto-thrift-connector-api + ${project.version} + test-jar + + + + com.facebook.presto + presto-thrift-testing-server + ${project.version} + + + + com.facebook.presto + presto-thrift-connector + ${project.version} + zip + + com.facebook.hive hive-dwrf @@ -353,7 +382,7 @@ io.airlift aircompressor - 0.7 + 0.8 @@ -604,6 +633,54 @@ 2.78 + + com.squareup.okhttp3 + okhttp + ${dep.okhttp.version} + + + + com.squareup.okhttp3 + mockwebserver + ${dep.okhttp.version} + + + + com.facebook.swift + swift-annotations + ${dep.swift.version} + + + + com.facebook.swift + swift-codec + ${dep.swift.version} + + + + com.facebook.swift + swift-service + ${dep.swift.version} + + + + com.facebook.swift + swift-javadoc + ${dep.swift.version} + + + + com.facebook.nifty + nifty-core + ${dep.nifty.version} + + + + com.facebook.nifty + nifty-client + ${dep.nifty.version} + + org.apache.thrift libthrift @@ -693,7 +770,7 @@ io.airlift testing-postgresql-server - 9.6.1-1 + 9.6.3-1 @@ -890,7 +967,7 @@ org.codehaus.mojo exec-maven-plugin - 1.2.1 + 1.6.0 @@ -912,6 +989,27 @@ + + com.ning.maven.plugins + maven-dependency-versions-check-plugin + + + + com.google.inject + guice + 4.0-beta5 + 4.0 + + + com.google.inject.extensions + guice-multibindings + 4.0-beta5 + 4.0 + + + + + @@ -1087,26 +1185,6 @@ - - - cli - - - - org.codehaus.mojo - exec-maven-plugin - - ${cli.skip-execute} - ${java.home}/bin/java - ${cli.main-class} - - --debug - - - - - - eclipse-compiler diff --git a/presto-accumulo/pom.xml b/presto-accumulo/pom.xml index a429b78ed102..c7b0cb709357 100644 --- a/presto-accumulo/pom.xml +++ b/presto-accumulo/pom.xml @@ -5,7 +5,7 @@ com.facebook.presto presto-root - 0.179-tw-0.36 + 0.181-tw-0.37 presto-accumulo diff --git a/presto-accumulo/src/main/java/com/facebook/presto/accumulo/index/IndexLookup.java b/presto-accumulo/src/main/java/com/facebook/presto/accumulo/index/IndexLookup.java index 8c0ee86ab263..f9edfe46d5cf 100644 --- a/presto-accumulo/src/main/java/com/facebook/presto/accumulo/index/IndexLookup.java +++ b/presto-accumulo/src/main/java/com/facebook/presto/accumulo/index/IndexLookup.java @@ -22,6 +22,7 @@ import com.google.common.collect.ImmutableListMultimap; import com.google.common.collect.ImmutableMultimap; import com.google.common.collect.Multimap; +import io.airlift.concurrent.BoundedExecutor; import io.airlift.log.Logger; import io.airlift.units.Duration; import org.apache.accumulo.core.client.AccumuloException; @@ -36,17 +37,25 @@ import org.apache.accumulo.core.security.Authorizations; import org.apache.hadoop.io.Text; +import javax.annotation.PreDestroy; import javax.inject.Inject; +import java.util.ArrayList; import java.util.Collection; import java.util.HashSet; import java.util.List; import java.util.Map.Entry; import java.util.Optional; import java.util.Set; +import java.util.concurrent.CompletionService; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.ExecutorCompletionService; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Future; import java.util.concurrent.TimeUnit; import static com.facebook.presto.accumulo.AccumuloClient.getRangesFromDomain; +import static com.facebook.presto.accumulo.AccumuloErrorCode.UNEXPECTED_ACCUMULO_ERROR; import static com.facebook.presto.accumulo.conf.AccumuloSessionProperties.getIndexCardinalityCachePollingDuration; import static com.facebook.presto.accumulo.conf.AccumuloSessionProperties.getIndexSmallCardThreshold; import static com.facebook.presto.accumulo.conf.AccumuloSessionProperties.getIndexThreshold; @@ -57,13 +66,13 @@ import static com.facebook.presto.accumulo.index.Indexer.CARDINALITY_CQ_AS_TEXT; import static com.facebook.presto.accumulo.index.Indexer.METRICS_TABLE_ROWID_AS_TEXT; import static com.facebook.presto.accumulo.index.Indexer.METRICS_TABLE_ROWS_CF_AS_TEXT; -import static com.facebook.presto.accumulo.index.Indexer.getIndexColumnFamily; import static com.facebook.presto.accumulo.index.Indexer.getIndexTableName; import static com.facebook.presto.accumulo.index.Indexer.getMetricsTableName; import static com.facebook.presto.spi.StandardErrorCode.FUNCTION_IMPLEMENTATION_ERROR; import static com.google.common.base.Preconditions.checkArgument; -import static java.nio.charset.StandardCharsets.UTF_8; +import static io.airlift.concurrent.Threads.daemonThreadsNamed; import static java.util.Objects.requireNonNull; +import static java.util.concurrent.Executors.newCachedThreadPool; /** * Class to assist the Presto connector, and maybe external applications, @@ -77,12 +86,24 @@ public class IndexLookup private static final Range METRICS_TABLE_ROWID_RANGE = new Range(METRICS_TABLE_ROWID_AS_TEXT); private final ColumnCardinalityCache cardinalityCache; private final Connector connector; + private final ExecutorService coreExecutor; + private final BoundedExecutor executorService; @Inject public IndexLookup(Connector connector, ColumnCardinalityCache cardinalityCache) { this.connector = requireNonNull(connector, "connector is null"); this.cardinalityCache = requireNonNull(cardinalityCache, "cardinalityCache is null"); + + // Create a bounded executor with a pool size at 4x number of processors + this.coreExecutor = newCachedThreadPool(daemonThreadsNamed("cardinality-lookup-%s")); + this.executorService = new BoundedExecutor(coreExecutor, 4 * Runtime.getRuntime().availableProcessors()); + } + + @PreDestroy + public void shutdown() + { + coreExecutor.shutdownNow(); } /** @@ -222,7 +243,7 @@ private boolean getRangesWithMetrics( if (cardinalities.size() == 1) { long numEntries = lowestCardinality.getKey(); double ratio = ((double) numEntries / (double) numRows); - LOG.debug("Use of index would scan %d of %d rows, ratio %s. Threshold %2f, Using for table? %b", numEntries, numRows, ratio, threshold, ratio < threshold); + LOG.debug("Use of index would scan %s of %s rows, ratio %s. Threshold %2f, Using for index table? %s", numEntries, numRows, ratio, threshold, ratio < threshold); if (ratio >= threshold) { return false; } @@ -297,55 +318,59 @@ private long getNumRowsInTable(String metricsTable, Authorizations auths) } private List getIndexRanges(String indexTable, Multimap constraintRanges, Collection rowIDRanges, Authorizations auths) - throws TableNotFoundException + throws TableNotFoundException, InterruptedException { - Set finalRanges = null; - // For each column/constraint pair + Set finalRanges = new HashSet<>(); + // For each column/constraint pair we submit a task to scan the index ranges + List>> tasks = new ArrayList<>(); + CompletionService> executor = new ExecutorCompletionService<>(executorService); for (Entry> constraintEntry : constraintRanges.asMap().entrySet()) { - // Create a batch scanner against the index table, setting the ranges - BatchScanner scanner = connector.createBatchScanner(indexTable, auths, 10); - scanner.setRanges(constraintEntry.getValue()); - - // Fetch the column family for this specific column - Text family = new Text(getIndexColumnFamily(constraintEntry.getKey().getFamily().getBytes(UTF_8), constraintEntry.getKey().getQualifier().getBytes(UTF_8)).array()); - scanner.fetchColumnFamily(family); - - // For each entry in the scanner - Text tmpQualifier = new Text(); - Set columnRanges = new HashSet<>(); - for (Entry entry : scanner) { - entry.getKey().getColumnQualifier(tmpQualifier); - - // Add to our column ranges if it is in one of the row ID ranges - if (inRange(tmpQualifier, rowIDRanges)) { - columnRanges.add(new Range(tmpQualifier)); + tasks.add(executor.submit(() -> { + // Create a batch scanner against the index table, setting the ranges + BatchScanner scan = connector.createBatchScanner(indexTable, auths, 10); + scan.setRanges(constraintEntry.getValue()); + + // Fetch the column family for this specific column + scan.fetchColumnFamily(new Text(Indexer.getIndexColumnFamily(constraintEntry.getKey().getFamily().getBytes(), constraintEntry.getKey().getQualifier().getBytes()).array())); + + // For each entry in the scanner + Text tmpQualifier = new Text(); + Set columnRanges = new HashSet<>(); + for (Entry entry : scan) { + entry.getKey().getColumnQualifier(tmpQualifier); + + // Add to our column ranges if it is in one of the row ID ranges + if (inRange(tmpQualifier, rowIDRanges)) { + columnRanges.add(new Range(tmpQualifier)); + } } - } - - LOG.debug("Retrieved %d ranges for column %s", columnRanges.size(), constraintEntry.getKey().getName()); - // If finalRanges is null, we have not yet added any column ranges - if (finalRanges == null) { - finalRanges = new HashSet<>(); - finalRanges.addAll(columnRanges); + LOG.debug("Retrieved %d ranges for index column %s", columnRanges.size(), constraintEntry.getKey().getName()); + scan.close(); + return columnRanges; + })); + } + tasks.forEach(future -> + { + try { + // If finalRanges is null, we have not yet added any column ranges + if (finalRanges.isEmpty()) { + finalRanges.addAll(future.get()); + } + else { + // Retain only the row IDs for this column that have already been added + // This is your set intersection operation! + finalRanges.retainAll(future.get()); + } } - else { - // Retain only the row IDs for this column that have already been added - // This is your set intersection operation! - finalRanges.retainAll(columnRanges); + catch (ExecutionException | InterruptedException e) { + if (e instanceof InterruptedException) { + Thread.currentThread().interrupt(); + } + throw new PrestoException(UNEXPECTED_ACCUMULO_ERROR, "Exception when getting index ranges", e.getCause()); } - - // Close the scanner - scanner.close(); - } - - // Return the final ranges for all constraint pairs - if (finalRanges != null) { - return ImmutableList.copyOf(finalRanges); - } - else { - return ImmutableList.of(); - } + }); + return ImmutableList.copyOf(finalRanges); } private static void binRanges(int numRangesPerBin, List splitRanges, List prestoSplits) diff --git a/presto-accumulo/src/main/java/com/facebook/presto/accumulo/serializers/AccumuloRowSerializer.java b/presto-accumulo/src/main/java/com/facebook/presto/accumulo/serializers/AccumuloRowSerializer.java index 21b060e586c8..affd8de9ba50 100644 --- a/presto-accumulo/src/main/java/com/facebook/presto/accumulo/serializers/AccumuloRowSerializer.java +++ b/presto-accumulo/src/main/java/com/facebook/presto/accumulo/serializers/AccumuloRowSerializer.java @@ -17,7 +17,6 @@ import com.facebook.presto.spi.block.Block; import com.facebook.presto.spi.block.BlockBuilder; import com.facebook.presto.spi.block.BlockBuilderStatus; -import com.facebook.presto.spi.block.InterleavedBlockBuilder; import com.facebook.presto.spi.type.Type; import com.facebook.presto.spi.type.TypeUtils; import com.facebook.presto.spi.type.VarcharType; @@ -556,13 +555,16 @@ static Block getBlockFromMap(Type mapType, Map map) Type keyType = mapType.getTypeParameters().get(0); Type valueType = mapType.getTypeParameters().get(1); - BlockBuilder builder = new InterleavedBlockBuilder(ImmutableList.of(keyType, valueType), new BlockBuilderStatus(), map.size() * 2); + BlockBuilder mapBlockBuilder = mapType.createBlockBuilder(new BlockBuilderStatus(), 1); + BlockBuilder builder = mapBlockBuilder.beginBlockEntry(); for (Entry entry : map.entrySet()) { writeObject(builder, keyType, entry.getKey()); writeObject(builder, valueType, entry.getValue()); } - return builder.build(); + + mapBlockBuilder.closeEntry(); + return (Block) mapType.getObject(mapBlockBuilder, 0); } /** diff --git a/presto-accumulo/src/test/java/com/facebook/presto/accumulo/TestAccumuloDistributedQueries.java b/presto-accumulo/src/test/java/com/facebook/presto/accumulo/TestAccumuloDistributedQueries.java index 27f754d2f8df..f24bdc681d73 100644 --- a/presto-accumulo/src/test/java/com/facebook/presto/accumulo/TestAccumuloDistributedQueries.java +++ b/presto-accumulo/src/test/java/com/facebook/presto/accumulo/TestAccumuloDistributedQueries.java @@ -49,6 +49,12 @@ public void testAddColumn() // Adding columns via SQL are not supported until adding columns with comments are supported } + @Override + public void testDropColumn() + { + // Dropping columns are not supported by the connector + } + @Override public void testCreateTableAsSelect() { diff --git a/presto-accumulo/src/test/java/com/facebook/presto/accumulo/index/TestIndexer.java b/presto-accumulo/src/test/java/com/facebook/presto/accumulo/index/TestIndexer.java index 6160553aac84..1282ca71ef43 100644 --- a/presto-accumulo/src/test/java/com/facebook/presto/accumulo/index/TestIndexer.java +++ b/presto-accumulo/src/test/java/com/facebook/presto/accumulo/index/TestIndexer.java @@ -17,8 +17,8 @@ import com.facebook.presto.accumulo.model.AccumuloColumnHandle; import com.facebook.presto.accumulo.serializers.AccumuloRowSerializer; import com.facebook.presto.accumulo.serializers.LexicoderRowSerializer; +import com.facebook.presto.spi.type.ArrayType; import com.facebook.presto.spi.type.Type; -import com.facebook.presto.type.ArrayType; import com.google.common.collect.ImmutableList; import org.apache.accumulo.core.client.BatchWriterConfig; import org.apache.accumulo.core.client.Connector; diff --git a/presto-accumulo/src/test/java/com/facebook/presto/accumulo/model/TestField.java b/presto-accumulo/src/test/java/com/facebook/presto/accumulo/model/TestField.java index 3600d6b0c8f7..55db1b1e98e2 100644 --- a/presto-accumulo/src/test/java/com/facebook/presto/accumulo/model/TestField.java +++ b/presto-accumulo/src/test/java/com/facebook/presto/accumulo/model/TestField.java @@ -17,12 +17,12 @@ import com.facebook.presto.block.BlockEncodingManager; import com.facebook.presto.metadata.FunctionRegistry; import com.facebook.presto.spi.block.Block; +import com.facebook.presto.spi.type.ArrayType; import com.facebook.presto.spi.type.StandardTypes; import com.facebook.presto.spi.type.Type; import com.facebook.presto.spi.type.TypeManager; import com.facebook.presto.spi.type.TypeSignatureParameter; import com.facebook.presto.sql.analyzer.FeaturesConfig; -import com.facebook.presto.type.ArrayType; import com.facebook.presto.type.TypeRegistry; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; diff --git a/presto-accumulo/src/test/java/com/facebook/presto/accumulo/model/TestRow.java b/presto-accumulo/src/test/java/com/facebook/presto/accumulo/model/TestRow.java index e45bf6a7cd9a..596c81b0c4ea 100644 --- a/presto-accumulo/src/test/java/com/facebook/presto/accumulo/model/TestRow.java +++ b/presto-accumulo/src/test/java/com/facebook/presto/accumulo/model/TestRow.java @@ -14,7 +14,7 @@ package com.facebook.presto.accumulo.model; import com.facebook.presto.accumulo.serializers.AccumuloRowSerializer; -import com.facebook.presto.type.ArrayType; +import com.facebook.presto.spi.type.ArrayType; import com.google.common.collect.ImmutableList; import org.testng.annotations.Test; diff --git a/presto-accumulo/src/test/java/com/facebook/presto/accumulo/serializers/AbstractTestAccumuloRowSerializer.java b/presto-accumulo/src/test/java/com/facebook/presto/accumulo/serializers/AbstractTestAccumuloRowSerializer.java index d4ed16987e08..dec82d6ac219 100644 --- a/presto-accumulo/src/test/java/com/facebook/presto/accumulo/serializers/AbstractTestAccumuloRowSerializer.java +++ b/presto-accumulo/src/test/java/com/facebook/presto/accumulo/serializers/AbstractTestAccumuloRowSerializer.java @@ -15,12 +15,12 @@ import com.facebook.presto.block.BlockEncodingManager; import com.facebook.presto.metadata.FunctionRegistry; +import com.facebook.presto.spi.type.ArrayType; import com.facebook.presto.spi.type.StandardTypes; import com.facebook.presto.spi.type.Type; import com.facebook.presto.spi.type.TypeManager; import com.facebook.presto.spi.type.TypeSignatureParameter; import com.facebook.presto.sql.analyzer.FeaturesConfig; -import com.facebook.presto.type.ArrayType; import com.facebook.presto.type.TypeRegistry; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; diff --git a/presto-array/pom.xml b/presto-array/pom.xml index 5a62826f3e01..a34e9edf331c 100644 --- a/presto-array/pom.xml +++ b/presto-array/pom.xml @@ -5,7 +5,7 @@ com.facebook.presto presto-root - 0.179-tw-0.36 + 0.181-tw-0.37 presto-array @@ -21,6 +21,11 @@ slice + + it.unimi.dsi + fastutil + + com.facebook.presto presto-spi @@ -30,5 +35,12 @@ org.openjdk.jol jol-core + + + + org.testng + testng + test + diff --git a/presto-array/src/main/java/com/facebook/presto/array/BlockBigArray.java b/presto-array/src/main/java/com/facebook/presto/array/BlockBigArray.java index da6e0d972c36..f6bb1106315b 100644 --- a/presto-array/src/main/java/com/facebook/presto/array/BlockBigArray.java +++ b/presto-array/src/main/java/com/facebook/presto/array/BlockBigArray.java @@ -20,6 +20,7 @@ public final class BlockBigArray { private static final int INSTANCE_SIZE = ClassLayout.parseClass(BlockBigArray.class).instanceSize(); private final ObjectBigArray array; + private final ReferenceCountMap trackedObjects = new ReferenceCountMap(); private long sizeOfBlocks; public BlockBigArray() @@ -37,7 +38,7 @@ public BlockBigArray(Block block) */ public long sizeOf() { - return INSTANCE_SIZE + array.sizeOf() + sizeOfBlocks; + return INSTANCE_SIZE + array.sizeOf() + sizeOfBlocks + trackedObjects.sizeOf(); } /** @@ -60,10 +61,30 @@ public void set(long index, Block value) { Block currentValue = array.get(index); if (currentValue != null) { - sizeOfBlocks -= currentValue.getRetainedSizeInBytes(); + currentValue.retainedBytesForEachPart((object, size) -> { + if (currentValue == object) { + // track instance size separately as the reference count for an instance is always 1 + sizeOfBlocks -= size; + return; + } + if (trackedObjects.decrementReference(object) == 0) { + // decrement the size only when it is the last reference + sizeOfBlocks -= size; + } + }); } if (value != null) { - sizeOfBlocks += value.getRetainedSizeInBytes(); + value.retainedBytesForEachPart((object, size) -> { + if (value == object) { + // track instance size separately as the reference count for an instance is always 1 + sizeOfBlocks += size; + return; + } + if (trackedObjects.incrementReference(object) == 1) { + // increment the size only when it is the first reference + sizeOfBlocks += size; + } + }); } array.set(index, value); } diff --git a/presto-array/src/main/java/com/facebook/presto/array/ReferenceCountMap.java b/presto-array/src/main/java/com/facebook/presto/array/ReferenceCountMap.java new file mode 100644 index 000000000000..b0c2130442ac --- /dev/null +++ b/presto-array/src/main/java/com/facebook/presto/array/ReferenceCountMap.java @@ -0,0 +1,77 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.array; + +import io.airlift.slice.SizeOf; +import it.unimi.dsi.fastutil.objects.Object2IntOpenCustomHashMap; +import org.openjdk.jol.info.ClassLayout; + +public final class ReferenceCountMap + extends Object2IntOpenCustomHashMap +{ + private static final int INSTANCE_SIZE = ClassLayout.parseClass(ReferenceCountMap.class).instanceSize(); + + /** + * Two different blocks can share the same underlying data + * Use the map to avoid memory over counting + */ + public ReferenceCountMap() + { + super(new ObjectStrategy()); + } + + /** + * Increments the reference count of an object by 1 and returns the updated reference count + */ + public int incrementReference(Object key) + { + return addTo(key, 1) + 1; + } + + /** + * Decrements the reference count of an object by 1 and returns the updated reference count + */ + public int decrementReference(Object key) + { + int previousCount = addTo(key, -1); + if (previousCount == 1) { + remove(key); + } + return previousCount - 1; + } + + /** + * Returns the size of this map in bytes. + */ + public long sizeOf() + { + return INSTANCE_SIZE + SizeOf.sizeOf(key) + SizeOf.sizeOf(value) + SizeOf.sizeOf(used); + } + + private static final class ObjectStrategy + implements Strategy + { + @Override + public int hashCode(Object object) + { + return System.identityHashCode(object); + } + + @Override + public boolean equals(Object left, Object right) + { + return left == right; + } + } +} diff --git a/presto-array/src/test/java/com/facebook/presto/array/TestBlockBigArray.java b/presto-array/src/test/java/com/facebook/presto/array/TestBlockBigArray.java new file mode 100644 index 000000000000..60d0a0f826c5 --- /dev/null +++ b/presto-array/src/test/java/com/facebook/presto/array/TestBlockBigArray.java @@ -0,0 +1,54 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.array; + +import com.facebook.presto.spi.block.Block; +import com.facebook.presto.spi.block.BlockBuilder; +import com.facebook.presto.spi.block.BlockBuilderStatus; +import com.facebook.presto.spi.block.IntArrayBlockBuilder; +import org.openjdk.jol.info.ClassLayout; +import org.testng.annotations.Test; + +import static org.testng.Assert.assertEquals; + +public class TestBlockBigArray +{ + @Test + public void testRetainedSizeWithOverlappingBlocks() + { + int entries = 123; + BlockBuilder blockBuilder = new IntArrayBlockBuilder(new BlockBuilderStatus(), entries); + for (int i = 0; i < entries; i++) { + blockBuilder.writeInt(i); + } + Block block = blockBuilder.build(); + + // Verify we do not over count + int arraySize = 456; + int blocks = 7890; + BlockBigArray blockBigArray = new BlockBigArray(); + blockBigArray.ensureCapacity(arraySize); + for (int i = 0; i < blocks; i++) { + blockBigArray.set(i % arraySize, block.getRegion(0, entries)); + } + + ReferenceCountMap referenceCountMap = new ReferenceCountMap(); + referenceCountMap.incrementReference(block); + long expectedSize = ClassLayout.parseClass(BlockBigArray.class).instanceSize() + + referenceCountMap.sizeOf() + + (new ObjectBigArray()).sizeOf() + + block.getRetainedSizeInBytes() + (arraySize - 1) * ClassLayout.parseClass(block.getClass()).instanceSize(); + assertEquals(blockBigArray.sizeOf(), expectedSize); + } +} diff --git a/presto-atop/pom.xml b/presto-atop/pom.xml index 4fc38060c34d..3e312dd41255 100644 --- a/presto-atop/pom.xml +++ b/presto-atop/pom.xml @@ -5,7 +5,7 @@ com.facebook.presto presto-root - 0.179-tw-0.36 + 0.181-tw-0.37 presto-atop diff --git a/presto-base-jdbc/pom.xml b/presto-base-jdbc/pom.xml index b0e99e42c24f..4c5e0f90e74b 100644 --- a/presto-base-jdbc/pom.xml +++ b/presto-base-jdbc/pom.xml @@ -5,7 +5,7 @@ com.facebook.presto presto-root - 0.179-tw-0.36 + 0.181-tw-0.37 presto-base-jdbc diff --git a/presto-benchmark-driver/pom.xml b/presto-benchmark-driver/pom.xml index dde5572e18d0..19b38c1c5ff9 100644 --- a/presto-benchmark-driver/pom.xml +++ b/presto-benchmark-driver/pom.xml @@ -5,7 +5,7 @@ com.facebook.presto presto-root - 0.179-tw-0.36 + 0.181-tw-0.37 presto-benchmark-driver @@ -72,6 +72,11 @@ commons-math3 + + com.squareup.okhttp3 + okhttp + + org.testng diff --git a/presto-benchmark-driver/src/main/java/com/facebook/presto/benchmark/driver/BenchmarkDriverOptions.java b/presto-benchmark-driver/src/main/java/com/facebook/presto/benchmark/driver/BenchmarkDriverOptions.java index 0685b71820cb..7e97b21ade27 100644 --- a/presto-benchmark-driver/src/main/java/com/facebook/presto/benchmark/driver/BenchmarkDriverOptions.java +++ b/presto-benchmark-driver/src/main/java/com/facebook/presto/benchmark/driver/BenchmarkDriverOptions.java @@ -110,7 +110,7 @@ private static URI parseServer(String server) HostAndPort host = HostAndPort.fromString(server); try { - return new URI("http", null, host.getHostText(), host.getPortOrDefault(80), null, null, null); + return new URI("http", null, host.getHost(), host.getPortOrDefault(80), null, null, null); } catch (URISyntaxException e) { throw new IllegalArgumentException(e); diff --git a/presto-benchmark-driver/src/main/java/com/facebook/presto/benchmark/driver/BenchmarkQueryRunner.java b/presto-benchmark-driver/src/main/java/com/facebook/presto/benchmark/driver/BenchmarkQueryRunner.java index 2ca469875c64..1f82be9d18f2 100644 --- a/presto-benchmark-driver/src/main/java/com/facebook/presto/benchmark/driver/BenchmarkQueryRunner.java +++ b/presto-benchmark-driver/src/main/java/com/facebook/presto/benchmark/driver/BenchmarkQueryRunner.java @@ -15,7 +15,6 @@ import com.facebook.presto.client.ClientSession; import com.facebook.presto.client.QueryError; -import com.facebook.presto.client.QueryResults; import com.facebook.presto.client.StatementClient; import com.facebook.presto.client.StatementStats; import com.google.common.base.Throwables; @@ -28,8 +27,8 @@ import io.airlift.http.client.JsonResponseHandler; import io.airlift.http.client.Request; import io.airlift.http.client.jetty.JettyHttpClient; -import io.airlift.json.JsonCodec; import io.airlift.units.Duration; +import okhttp3.OkHttpClient; import java.io.Closeable; import java.net.URI; @@ -39,6 +38,7 @@ import static com.facebook.presto.benchmark.driver.BenchmarkQueryResult.failResult; import static com.facebook.presto.benchmark.driver.BenchmarkQueryResult.passResult; +import static com.facebook.presto.client.OkHttpUtil.setupSocksProxy; import static com.google.common.base.Preconditions.checkArgument; import static io.airlift.http.client.HttpUriBuilder.uriBuilderFrom; import static io.airlift.http.client.JsonResponseHandler.createJsonResponseHandler; @@ -59,8 +59,8 @@ public class BenchmarkQueryRunner private final int maxFailures; private final HttpClient httpClient; + private final OkHttpClient okHttpClient; private final List nodes; - private final JsonCodec queryResultsCodec; private int failures; @@ -77,8 +77,6 @@ public BenchmarkQueryRunner(int warm, int runs, boolean debug, int maxFailures, this.debug = debug; - this.queryResultsCodec = jsonCodec(QueryResults.class); - requireNonNull(socksProxy, "socksProxy is null"); HttpClientConfig httpClientConfig = new HttpClientConfig(); if (socksProxy.isPresent()) { @@ -87,6 +85,10 @@ public BenchmarkQueryRunner(int warm, int runs, boolean debug, int maxFailures, this.httpClient = new JettyHttpClient(httpClientConfig.setConnectTimeout(new Duration(10, TimeUnit.SECONDS))); + OkHttpClient.Builder builder = new OkHttpClient.Builder(); + setupSocksProxy(builder, socksProxy); + this.okHttpClient = builder.build(); + nodes = getAllNodes(requireNonNull(serverUri, "serverUri is null")); } @@ -149,7 +151,7 @@ public List getSchemas(ClientSession session) failures = 0; while (true) { // start query - StatementClient client = new StatementClient(httpClient, queryResultsCodec, session, "show schemas"); + StatementClient client = new StatementClient(okHttpClient, session, "show schemas"); // read query output ImmutableList.Builder schemas = ImmutableList.builder(); @@ -190,7 +192,7 @@ public List getSchemas(ClientSession session) private StatementStats execute(ClientSession session, String name, String query) { // start query - StatementClient client = new StatementClient(httpClient, queryResultsCodec, session, query); + StatementClient client = new StatementClient(okHttpClient, session, query); // read query output while (client.isValid() && client.advance()) { diff --git a/presto-benchmark-driver/src/main/java/com/facebook/presto/benchmark/driver/Suite.java b/presto-benchmark-driver/src/main/java/com/facebook/presto/benchmark/driver/Suite.java index 8d15446a3e19..282236e91ab7 100644 --- a/presto-benchmark-driver/src/main/java/com/facebook/presto/benchmark/driver/Suite.java +++ b/presto-benchmark-driver/src/main/java/com/facebook/presto/benchmark/driver/Suite.java @@ -26,11 +26,11 @@ import java.util.Map.Entry; import java.util.Optional; import java.util.regex.Pattern; -import java.util.stream.Collectors; -import java.util.stream.StreamSupport; import static com.google.common.base.MoreObjects.toStringHelper; import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.collect.ImmutableList.toImmutableList; +import static com.google.common.collect.Streams.stream; import static io.airlift.json.JsonCodec.mapJsonCodec; import static java.util.Objects.requireNonNull; @@ -93,11 +93,11 @@ public List selectQueries(Iterable queries) return ImmutableList.copyOf(queries); } - List filteredQueries = StreamSupport.stream(queries.spliterator(), false) + List filteredQueries = stream(queries) .filter(query -> getQueryNamePatterns().stream().anyMatch(pattern -> pattern.matcher(query.getName()).matches())) - .collect(Collectors.toList()); + .collect(toImmutableList()); - return ImmutableList.copyOf(filteredQueries); + return filteredQueries; } @Override diff --git a/presto-benchmark/pom.xml b/presto-benchmark/pom.xml index 5a59fd82848e..99942fab22ca 100644 --- a/presto-benchmark/pom.xml +++ b/presto-benchmark/pom.xml @@ -5,7 +5,7 @@ presto-root com.facebook.presto - 0.179-tw-0.36 + 0.181-tw-0.37 presto-benchmark diff --git a/presto-benchto-benchmarks/pom.xml b/presto-benchto-benchmarks/pom.xml index 4302c3313dcc..45cac6b08720 100644 --- a/presto-benchto-benchmarks/pom.xml +++ b/presto-benchto-benchmarks/pom.xml @@ -4,7 +4,7 @@ com.facebook.presto presto-root - 0.179-tw-0.36 + 0.181-tw-0.37 presto-benchto-benchmarks diff --git a/presto-blackhole/pom.xml b/presto-blackhole/pom.xml index 26677923e7d0..d3fa947ee109 100644 --- a/presto-blackhole/pom.xml +++ b/presto-blackhole/pom.xml @@ -5,7 +5,7 @@ com.facebook.presto presto-root - 0.179-tw-0.36 + 0.181-tw-0.37 presto-blackhole diff --git a/presto-bytecode/pom.xml b/presto-bytecode/pom.xml index 6aaca8aded5a..75e1d227acf7 100644 --- a/presto-bytecode/pom.xml +++ b/presto-bytecode/pom.xml @@ -5,7 +5,7 @@ com.facebook.presto presto-root - 0.179-tw-0.36 + 0.181-tw-0.37 presto-bytecode diff --git a/presto-bytecode/src/main/java/com/facebook/presto/bytecode/MethodDefinition.java b/presto-bytecode/src/main/java/com/facebook/presto/bytecode/MethodDefinition.java index 0d0c50b015af..d56a84c43eab 100644 --- a/presto-bytecode/src/main/java/com/facebook/presto/bytecode/MethodDefinition.java +++ b/presto-bytecode/src/main/java/com/facebook/presto/bytecode/MethodDefinition.java @@ -15,6 +15,7 @@ import com.google.common.base.Joiner; import com.google.common.collect.ImmutableList; +import com.google.common.collect.Iterables; import com.google.common.collect.Lists; import org.objectweb.asm.ClassVisitor; import org.objectweb.asm.MethodVisitor; @@ -30,6 +31,7 @@ import static com.facebook.presto.bytecode.Access.STATIC; import static com.facebook.presto.bytecode.Access.toAccessModifier; import static com.facebook.presto.bytecode.ParameterizedType.type; +import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.collect.Iterables.transform; import static org.objectweb.asm.Opcodes.RETURN; @@ -70,6 +72,8 @@ public MethodDefinition( Iterable parameters ) { + checkArgument(Iterables.size(parameters) <= 254, "Too many parameters for method"); + this.declaringClass = declaringClass; body = new BytecodeBlock(); diff --git a/presto-cassandra/pom.xml b/presto-cassandra/pom.xml index 7d2de9976131..2a499ac16bbb 100644 --- a/presto-cassandra/pom.xml +++ b/presto-cassandra/pom.xml @@ -4,7 +4,7 @@ com.facebook.presto presto-root - 0.179-tw-0.36 + 0.181-tw-0.37 presto-cassandra diff --git a/presto-cassandra/src/main/java/com/facebook/presto/cassandra/CassandraClusteringPredicatesExtractor.java b/presto-cassandra/src/main/java/com/facebook/presto/cassandra/CassandraClusteringPredicatesExtractor.java index eecbbf6ae2f0..6693baff17b0 100644 --- a/presto-cassandra/src/main/java/com/facebook/presto/cassandra/CassandraClusteringPredicatesExtractor.java +++ b/presto-cassandra/src/main/java/com/facebook/presto/cassandra/CassandraClusteringPredicatesExtractor.java @@ -13,66 +13,41 @@ */ package com.facebook.presto.cassandra; +import com.datastax.driver.core.VersionNumber; import com.facebook.presto.cassandra.util.CassandraCqlUtils; import com.facebook.presto.spi.ColumnHandle; import com.facebook.presto.spi.predicate.Domain; import com.facebook.presto.spi.predicate.Range; import com.facebook.presto.spi.predicate.TupleDomain; +import com.google.common.base.Joiner; +import com.google.common.base.VerifyException; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; -import com.google.common.collect.ImmutableSet; +import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; -import java.util.Set; import static com.facebook.presto.cassandra.util.CassandraCqlUtils.toCQLCompatibleString; -import static com.google.common.collect.Sets.cartesianProduct; import static java.util.Objects.requireNonNull; public class CassandraClusteringPredicatesExtractor { private final List clusteringColumns; - private final TupleDomain predicates; private final ClusteringPushDownResult clusteringPushDownResult; + private final TupleDomain predicates; - public CassandraClusteringPredicatesExtractor(List clusteringColumns, TupleDomain predicates) + public CassandraClusteringPredicatesExtractor(List clusteringColumns, TupleDomain predicates, VersionNumber cassandraVersion) { - this.clusteringColumns = ImmutableList.copyOf(requireNonNull(clusteringColumns, "clusteringColumns is null")); + this.clusteringColumns = ImmutableList.copyOf(clusteringColumns); this.predicates = requireNonNull(predicates, "predicates is null"); - this.clusteringPushDownResult = getClusteringKeysSet(clusteringColumns, predicates); + this.clusteringPushDownResult = getClusteringKeysSet(clusteringColumns, predicates, requireNonNull(cassandraVersion, "cassandraVersion is null")); } - public List getClusteringKeyPredicates() + public String getClusteringKeyPredicates() { - Set> pushedDownDomainValues = clusteringPushDownResult.getDomainValues(); - - if (pushedDownDomainValues.isEmpty()) { - return ImmutableList.of(); - } - - ImmutableList.Builder clusteringPredicates = ImmutableList.builder(); - for (List clusteringKeys : pushedDownDomainValues) { - if (clusteringKeys.isEmpty()) { - continue; - } - - StringBuilder stringBuilder = new StringBuilder(); - - for (int i = 0; i < clusteringKeys.size(); i++) { - if (i > 0) { - stringBuilder.append(" AND "); - } - - stringBuilder.append(CassandraCqlUtils.validColumnName(clusteringColumns.get(i).getName())); - stringBuilder.append(" = "); - stringBuilder.append(CassandraCqlUtils.cqlValue(toCQLCompatibleString(clusteringKeys.get(i)), clusteringColumns.get(i).getCassandraType())); - } - - clusteringPredicates.add(stringBuilder.toString()); - } - return clusteringPredicates.build(); + return clusteringPushDownResult.getDomainQuery(); } public TupleDomain getUnenforcedConstraints() @@ -87,65 +62,133 @@ public TupleDomain getUnenforcedConstraints() return TupleDomain.withColumnDomains(notPushedDown); } - private static ClusteringPushDownResult getClusteringKeysSet(List clusteringColumns, TupleDomain predicates) + private static ClusteringPushDownResult getClusteringKeysSet(List clusteringColumns, TupleDomain predicates, VersionNumber cassandraVersion) { ImmutableMap.Builder domainsBuilder = ImmutableMap.builder(); - ImmutableList.Builder> clusteringColumnValues = ImmutableList.builder(); + ImmutableList.Builder clusteringColumnSql = ImmutableList.builder(); + int currentClusteringColumn = 0; for (CassandraColumnHandle columnHandle : clusteringColumns) { Domain domain = predicates.getDomains().get().get(columnHandle); - if (domain == null) { break; } - if (domain.isNullAllowed()) { - return new ClusteringPushDownResult(domainsBuilder.build(), ImmutableSet.of()); + break; } - - Set values = domain.getValues().getValuesProcessor().transform( + String predicateString = null; + predicateString = domain.getValues().getValuesProcessor().transform( ranges -> { - ImmutableSet.Builder columnValues = ImmutableSet.builder(); - for (Range range : ranges.getOrderedRanges()) { - if (!range.isSingleValue()) { - return ImmutableSet.of(); + List singleValues = new ArrayList<>(); + List rangeConjuncts = new ArrayList<>(); + String predicate = null; + + for (Range range : ranges.getOrderedRanges()) { + if (range.isAll()) { + return null; + } + if (range.isSingleValue()) { + singleValues.add(CassandraCqlUtils.cqlValue(toCQLCompatibleString(range.getSingleValue()), + columnHandle.getCassandraType())); + } + else { + if (!range.getLow().isLowerUnbounded()) { + switch (range.getLow().getBound()) { + case ABOVE: + rangeConjuncts.add(CassandraCqlUtils.validColumnName(columnHandle.getName()) + " > " + + CassandraCqlUtils.cqlValue(toCQLCompatibleString(range.getLow().getValue()), + columnHandle.getCassandraType())); + break; + case EXACTLY: + rangeConjuncts.add(CassandraCqlUtils.validColumnName(columnHandle.getName()) + " >= " + + CassandraCqlUtils.cqlValue(toCQLCompatibleString(range.getLow().getValue()), + columnHandle.getCassandraType())); + break; + case BELOW: + throw new VerifyException("Low Marker should never use BELOW bound"); + default: + throw new AssertionError("Unhandled bound: " + range.getLow().getBound()); } - /* TODO add code to handle a range of values for the last column - * Prior to Cassandra 2.2, only the last clustering column can have a range of values - * Take a look at how this is done in PreparedStatementBuilder.java - */ - - Object value = range.getSingleValue(); - - CassandraType valueType = columnHandle.getCassandraType(); - columnValues.add(valueType.validateClusteringKey(value)); } - return columnValues.build(); - }, - discreteValues -> { - if (discreteValues.isWhiteList()) { - return ImmutableSet.copyOf(discreteValues.getValues()); + if (!range.getHigh().isUpperUnbounded()) { + switch (range.getHigh().getBound()) { + case ABOVE: + throw new VerifyException("High Marker should never use ABOVE bound"); + case EXACTLY: + rangeConjuncts.add(CassandraCqlUtils.validColumnName(columnHandle.getName()) + " <= " + + CassandraCqlUtils.cqlValue(toCQLCompatibleString(range.getHigh().getValue()), + columnHandle.getCassandraType())); + break; + case BELOW: + rangeConjuncts.add(CassandraCqlUtils.validColumnName(columnHandle.getName()) + " < " + + CassandraCqlUtils.cqlValue(toCQLCompatibleString(range.getHigh().getValue()), + columnHandle.getCassandraType())); + break; + default: + throw new AssertionError("Unhandled bound: " + range.getHigh().getBound()); + } } - return ImmutableSet.of(); - }, - allOrNone -> ImmutableSet.of()); + } + } + + if (!singleValues.isEmpty() && !rangeConjuncts.isEmpty()) { + return null; + } + if (!singleValues.isEmpty()) { + if (singleValues.size() == 1) { + predicate = CassandraCqlUtils.validColumnName(columnHandle.getName()) + " = " + singleValues.get(0); + } + else { + predicate = CassandraCqlUtils.validColumnName(columnHandle.getName()) + " IN (" + + Joiner.on(",").join(singleValues) + ")"; + } + } + else if (!rangeConjuncts.isEmpty()) { + predicate = Joiner.on(" AND ").join(rangeConjuncts); + } + return predicate; + }, discreteValues -> { + if (discreteValues.isWhiteList()) { + ImmutableList.Builder discreteValuesList = ImmutableList.builder(); + for (Object discreteValue : discreteValues.getValues()) { + discreteValuesList.add(CassandraCqlUtils.cqlValue(toCQLCompatibleString(discreteValue), + columnHandle.getCassandraType())); + } + String predicate = CassandraCqlUtils.validColumnName(columnHandle.getName()) + " IN (" + + Joiner.on(",").join(discreteValuesList.build()) + ")"; + return predicate; + } + return null; + }, allOrNone -> null); - if (!values.isEmpty()) { - clusteringColumnValues.add(values); - domainsBuilder.put(columnHandle, domain); + if (predicateString == null) { + break; + } + // IN restriction only on last clustering column for Cassandra version = 2.1 + if (predicateString.contains(" IN (") && cassandraVersion.compareTo(VersionNumber.parse("2.2.0")) < 0 && currentClusteringColumn != (clusteringColumns.size() - 1)) { + break; + } + clusteringColumnSql.add(predicateString); + domainsBuilder.put(columnHandle, domain); + // Check for last clustering column should only be restricted by range condition + if (predicateString.contains(">") || predicateString.contains("<")) { + break; } + currentClusteringColumn++; } - return new ClusteringPushDownResult(domainsBuilder.build(), cartesianProduct(clusteringColumnValues.build())); + List clusteringColumnPredicates = clusteringColumnSql.build(); + + return new ClusteringPushDownResult(domainsBuilder.build(), Joiner.on(" AND ").join(clusteringColumnPredicates)); } private static class ClusteringPushDownResult { private final Map domains; - private final Set> domainValues; + private final String domainQuery; - public ClusteringPushDownResult(Map domains, Set> domainValues) + public ClusteringPushDownResult(Map domains, String domainQuery) { this.domains = requireNonNull(ImmutableMap.copyOf(domains)); - this.domainValues = requireNonNull(ImmutableSet.copyOf(domainValues)); + this.domainQuery = requireNonNull(domainQuery); } public Map getDomains() @@ -153,9 +196,9 @@ public Map getDomains() return domains; } - public Set> getDomainValues() + public String getDomainQuery() { - return domainValues; + return domainQuery; } } } diff --git a/presto-cassandra/src/main/java/com/facebook/presto/cassandra/CassandraConnectorRecordSinkProvider.java b/presto-cassandra/src/main/java/com/facebook/presto/cassandra/CassandraConnectorRecordSinkProvider.java index 671fb8e8ebad..fcd4dcd5fc86 100644 --- a/presto-cassandra/src/main/java/com/facebook/presto/cassandra/CassandraConnectorRecordSinkProvider.java +++ b/presto-cassandra/src/main/java/com/facebook/presto/cassandra/CassandraConnectorRecordSinkProvider.java @@ -55,6 +55,16 @@ public RecordSink getRecordSink(ConnectorTransactionHandle transaction, Connecto @Override public RecordSink getRecordSink(ConnectorTransactionHandle transaction, ConnectorSession session, ConnectorInsertTableHandle tableHandle) { - throw new UnsupportedOperationException(); + requireNonNull(tableHandle, "tableHandle is null"); + checkArgument(tableHandle instanceof CassandraInsertTableHandle, "tableHandle is not an instance of ConnectorInsertTableHandle"); + CassandraInsertTableHandle handle = (CassandraInsertTableHandle) tableHandle; + + return new CassandraRecordSink( + cassandraSession, + handle.getSchemaName(), + handle.getTableName(), + handle.getColumnNames(), + handle.getColumnTypes(), + false); } } diff --git a/presto-cassandra/src/main/java/com/facebook/presto/cassandra/CassandraErrorCode.java b/presto-cassandra/src/main/java/com/facebook/presto/cassandra/CassandraErrorCode.java index 362739abca76..70d415ebebba 100644 --- a/presto-cassandra/src/main/java/com/facebook/presto/cassandra/CassandraErrorCode.java +++ b/presto-cassandra/src/main/java/com/facebook/presto/cassandra/CassandraErrorCode.java @@ -22,7 +22,7 @@ public enum CassandraErrorCode implements ErrorCodeSupplier { - CASSANDRA_METADATA_ERROR(0, EXTERNAL); + CASSANDRA_METADATA_ERROR(0, EXTERNAL), CASSANDRA_VERSION_ERROR(1, EXTERNAL); private final ErrorCode errorCode; diff --git a/presto-cassandra/src/main/java/com/facebook/presto/cassandra/CassandraHandleResolver.java b/presto-cassandra/src/main/java/com/facebook/presto/cassandra/CassandraHandleResolver.java index f2fe188af719..c82269ea0f22 100644 --- a/presto-cassandra/src/main/java/com/facebook/presto/cassandra/CassandraHandleResolver.java +++ b/presto-cassandra/src/main/java/com/facebook/presto/cassandra/CassandraHandleResolver.java @@ -15,6 +15,7 @@ import com.facebook.presto.spi.ColumnHandle; import com.facebook.presto.spi.ConnectorHandleResolver; +import com.facebook.presto.spi.ConnectorInsertTableHandle; import com.facebook.presto.spi.ConnectorOutputTableHandle; import com.facebook.presto.spi.ConnectorSplit; import com.facebook.presto.spi.ConnectorTableHandle; @@ -59,4 +60,10 @@ public Class getTransactionHandleClass() { return CassandraTransactionHandle.class; } + + @Override + public Class getInsertTableHandleClass() + { + return CassandraInsertTableHandle.class; + } } diff --git a/presto-cassandra/src/main/java/com/facebook/presto/cassandra/CassandraInsertTableHandle.java b/presto-cassandra/src/main/java/com/facebook/presto/cassandra/CassandraInsertTableHandle.java new file mode 100644 index 000000000000..57d1e55abeb0 --- /dev/null +++ b/presto-cassandra/src/main/java/com/facebook/presto/cassandra/CassandraInsertTableHandle.java @@ -0,0 +1,90 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.cassandra; + +import com.facebook.presto.spi.ConnectorInsertTableHandle; +import com.facebook.presto.spi.type.Type; +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.google.common.collect.ImmutableList; + +import java.util.List; + +import static com.google.common.base.Preconditions.checkArgument; +import static java.util.Objects.requireNonNull; + +public class CassandraInsertTableHandle + implements ConnectorInsertTableHandle +{ + private final String connectorId; + private final String schemaName; + private final String tableName; + private final List columnNames; + private final List columnTypes; + + @JsonCreator + public CassandraInsertTableHandle( + @JsonProperty("connectorId") String connectorId, + @JsonProperty("schemaName") String schemaName, + @JsonProperty("tableName") String tableName, + @JsonProperty("columnNames") List columnNames, + @JsonProperty("columnTypes") List columnTypes) + { + this.connectorId = requireNonNull(connectorId, "clientId is null"); + this.schemaName = requireNonNull(schemaName, "schemaName is null"); + this.tableName = requireNonNull(tableName, "tableName is null"); + + requireNonNull(columnNames, "columnNames is null"); + requireNonNull(columnTypes, "columnTypes is null"); + checkArgument(columnNames.size() == columnTypes.size(), "columnNames and columnTypes sizes don't match"); + this.columnNames = ImmutableList.copyOf(columnNames); + this.columnTypes = ImmutableList.copyOf(columnTypes); + } + + @JsonProperty + public String getConnectorId() + { + return connectorId; + } + + @JsonProperty + public String getSchemaName() + { + return schemaName; + } + + @JsonProperty + public String getTableName() + { + return tableName; + } + + @JsonProperty + public List getColumnNames() + { + return columnNames; + } + + @JsonProperty + public List getColumnTypes() + { + return columnTypes; + } + + @Override + public String toString() + { + return "cassandra:" + schemaName + "." + tableName; + } +} diff --git a/presto-cassandra/src/main/java/com/facebook/presto/cassandra/CassandraMetadata.java b/presto-cassandra/src/main/java/com/facebook/presto/cassandra/CassandraMetadata.java index d68cc32448e5..3307f00f75bc 100644 --- a/presto-cassandra/src/main/java/com/facebook/presto/cassandra/CassandraMetadata.java +++ b/presto-cassandra/src/main/java/com/facebook/presto/cassandra/CassandraMetadata.java @@ -16,6 +16,7 @@ import com.facebook.presto.cassandra.util.CassandraCqlUtils; import com.facebook.presto.spi.ColumnHandle; import com.facebook.presto.spi.ColumnMetadata; +import com.facebook.presto.spi.ConnectorInsertTableHandle; import com.facebook.presto.spi.ConnectorNewTableLayout; import com.facebook.presto.spi.ConnectorOutputTableHandle; import com.facebook.presto.spi.ConnectorSession; @@ -47,8 +48,11 @@ import java.util.Map; import java.util.Optional; import java.util.Set; +import java.util.stream.Collectors; import static com.facebook.presto.cassandra.CassandraType.toCassandraType; +import static com.facebook.presto.cassandra.util.CassandraCqlUtils.validSchemaName; +import static com.facebook.presto.cassandra.util.CassandraCqlUtils.validTableName; import static com.facebook.presto.spi.StandardErrorCode.NOT_SUPPORTED; import static com.facebook.presto.spi.StandardErrorCode.PERMISSION_DENIED; import static com.google.common.base.MoreObjects.toStringHelper; @@ -199,16 +203,16 @@ public List getTableLayouts(ConnectorSession session CassandraTableHandle handle = (CassandraTableHandle) table; CassandraPartitionResult partitionResult = partitionManager.getPartitions(handle, constraint.getSummary()); - List clusteringKeyPredicates; + String clusteringKeyPredicates = ""; TupleDomain unenforcedConstraint; if (partitionResult.isUnpartitioned()) { - clusteringKeyPredicates = ImmutableList.of(); unenforcedConstraint = partitionResult.getUnenforcedConstraint(); } else { CassandraClusteringPredicatesExtractor clusteringPredicatesExtractor = new CassandraClusteringPredicatesExtractor( cassandraSession.getTable(getTableName(handle)).getClusteringKeyColumns(), - partitionResult.getUnenforcedConstraint()); + partitionResult.getUnenforcedConstraint(), + cassandraSession.getCassandraVersion()); clusteringKeyPredicates = clusteringPredicatesExtractor.getClusteringKeyPredicates(); unenforcedConstraint = clusteringPredicatesExtractor.getUnenforcedConstraints(); } @@ -311,4 +315,27 @@ public Optional finishCreateTable(ConnectorSession sess { return Optional.empty(); } + + @Override + public ConnectorInsertTableHandle beginInsert(ConnectorSession session, ConnectorTableHandle tableHandle) + { + CassandraTableHandle table = (CassandraTableHandle) tableHandle; + SchemaTableName schemaTableName = new SchemaTableName(table.getSchemaName(), table.getTableName()); + List columns = cassandraSession.getTable(schemaTableName).getColumns(); + List columnNames = columns.stream().map(CassandraColumnHandle::getName).map(CassandraCqlUtils::validColumnName).collect(Collectors.toList()); + List columnTypes = columns.stream().map(CassandraColumnHandle::getType).collect(Collectors.toList()); + + return new CassandraInsertTableHandle( + connectorId, + validSchemaName(table.getSchemaName()), + validTableName(table.getTableName()), + columnNames, + columnTypes); + } + + @Override + public Optional finishInsert(ConnectorSession session, ConnectorInsertTableHandle insertHandle, Collection fragments) + { + return Optional.empty(); + } } diff --git a/presto-cassandra/src/main/java/com/facebook/presto/cassandra/CassandraSession.java b/presto-cassandra/src/main/java/com/facebook/presto/cassandra/CassandraSession.java index 139bd458fa3b..fa5d3d5274b3 100644 --- a/presto-cassandra/src/main/java/com/facebook/presto/cassandra/CassandraSession.java +++ b/presto-cassandra/src/main/java/com/facebook/presto/cassandra/CassandraSession.java @@ -19,6 +19,7 @@ import com.datastax.driver.core.ResultSet; import com.datastax.driver.core.Statement; import com.datastax.driver.core.TokenRange; +import com.datastax.driver.core.VersionNumber; import com.facebook.presto.spi.SchemaNotFoundException; import com.facebook.presto.spi.SchemaTableName; import com.facebook.presto.spi.TableNotFoundException; @@ -31,6 +32,8 @@ public interface CassandraSession { String PRESTO_COMMENT_METADATA = "Presto Metadata:"; + VersionNumber getCassandraVersion(); + String getPartitioner(); Set getTokenRanges(); diff --git a/presto-cassandra/src/main/java/com/facebook/presto/cassandra/CassandraSplitManager.java b/presto-cassandra/src/main/java/com/facebook/presto/cassandra/CassandraSplitManager.java index 7580a59fa7ba..d0ce4dbbbb13 100644 --- a/presto-cassandra/src/main/java/com/facebook/presto/cassandra/CassandraSplitManager.java +++ b/presto-cassandra/src/main/java/com/facebook/presto/cassandra/CassandraSplitManager.java @@ -105,7 +105,7 @@ private static String buildTokenCondition(String tokenExpression, String startTo return tokenExpression + " > " + startToken + " AND " + tokenExpression + " <= " + endToken; } - private List getSplitsForPartitions(CassandraTableHandle cassTableHandle, List partitions, List clusteringPredicates) + private List getSplitsForPartitions(CassandraTableHandle cassTableHandle, List partitions, String clusteringPredicates) { String schema = cassTableHandle.getSchemaName(); HostAddressFactory hostAddressFactory = new HostAddressFactory(); @@ -148,7 +148,7 @@ private List getSplitsForPartitions(CassandraTableHandle cassTab hostMap.put(hostAddresses, addresses); } else { - builder.addAll(createSplitsForClusteringPredicates(cassTableHandle, cassandraPartition.getPartitionId(), addresses, clusteringPredicates)); + builder.add(createSplitForClusteringPredicates(cassTableHandle, cassandraPartition.getPartitionId(), addresses, clusteringPredicates)); } } if (singlePartitionKeyColumn) { @@ -163,7 +163,7 @@ private List getSplitsForPartitions(CassandraTableHandle cassTab size++; if (size > partitionSizeForBatchSelect) { String partitionId = String.format("%s in (%s)", partitionKeyColumnName, sb.toString()); - builder.addAll(createSplitsForClusteringPredicates(cassTableHandle, partitionId, hostMap.get(entry.getKey()), clusteringPredicates)); + builder.add(createSplitForClusteringPredicates(cassTableHandle, partitionId, hostMap.get(entry.getKey()), clusteringPredicates)); size = 0; sb.setLength(0); sb.trimToSize(); @@ -171,31 +171,27 @@ private List getSplitsForPartitions(CassandraTableHandle cassTab } if (size > 0) { String partitionId = String.format("%s in (%s)", partitionKeyColumnName, sb.toString()); - builder.addAll(createSplitsForClusteringPredicates(cassTableHandle, partitionId, hostMap.get(entry.getKey()), clusteringPredicates)); + builder.add(createSplitForClusteringPredicates(cassTableHandle, partitionId, hostMap.get(entry.getKey()), clusteringPredicates)); } } } return builder.build(); } - private List createSplitsForClusteringPredicates( + private CassandraSplit createSplitForClusteringPredicates( CassandraTableHandle tableHandle, String partitionId, List hosts, - List clusteringPredicates) + String clusteringPredicates) { String schema = tableHandle.getSchemaName(); String table = tableHandle.getTableName(); if (clusteringPredicates.isEmpty()) { - return ImmutableList.of(new CassandraSplit(connectorId, schema, table, partitionId, null, hosts)); + return new CassandraSplit(connectorId, schema, table, partitionId, null, hosts); } - ImmutableList.Builder builder = ImmutableList.builder(); - for (String clusteringPredicate : clusteringPredicates) { - builder.add(new CassandraSplit(connectorId, schema, table, partitionId, clusteringPredicate, hosts)); - } - return builder.build(); + return new CassandraSplit(connectorId, schema, table, partitionId, clusteringPredicates, hosts); } @Override diff --git a/presto-cassandra/src/main/java/com/facebook/presto/cassandra/CassandraTableLayoutHandle.java b/presto-cassandra/src/main/java/com/facebook/presto/cassandra/CassandraTableLayoutHandle.java index dfafda61ef6d..fe2da9a38c01 100644 --- a/presto-cassandra/src/main/java/com/facebook/presto/cassandra/CassandraTableLayoutHandle.java +++ b/presto-cassandra/src/main/java/com/facebook/presto/cassandra/CassandraTableLayoutHandle.java @@ -28,19 +28,19 @@ public final class CassandraTableLayoutHandle { private final CassandraTableHandle table; private final List partitions; - private final List clusteringPredicates; + private final String clusteringPredicates; @JsonCreator public CassandraTableLayoutHandle(@JsonProperty("table") CassandraTableHandle table) { - this(table, ImmutableList.of(), ImmutableList.of()); + this(table, ImmutableList.of(), ""); } - public CassandraTableLayoutHandle(CassandraTableHandle table, List partitions, List clusteringPredicates) + public CassandraTableLayoutHandle(CassandraTableHandle table, List partitions, String clusteringPredicates) { this.table = requireNonNull(table, "table is null"); this.partitions = ImmutableList.copyOf(requireNonNull(partitions, "partition is null")); - this.clusteringPredicates = ImmutableList.copyOf(requireNonNull(clusteringPredicates, "clusteringPredicates is null")); + this.clusteringPredicates = requireNonNull(clusteringPredicates, "clusteringPredicates is null"); } @JsonProperty @@ -56,7 +56,7 @@ public List getPartitions() } @JsonIgnore - public List getClusteringPredicates() + public String getClusteringPredicates() { return clusteringPredicates; } diff --git a/presto-cassandra/src/main/java/com/facebook/presto/cassandra/NativeCassandraSession.java b/presto-cassandra/src/main/java/com/facebook/presto/cassandra/NativeCassandraSession.java index 80c84e487b7f..bd08ff21fdbf 100644 --- a/presto-cassandra/src/main/java/com/facebook/presto/cassandra/NativeCassandraSession.java +++ b/presto-cassandra/src/main/java/com/facebook/presto/cassandra/NativeCassandraSession.java @@ -27,6 +27,7 @@ import com.datastax.driver.core.Statement; import com.datastax.driver.core.TableMetadata; import com.datastax.driver.core.TokenRange; +import com.datastax.driver.core.VersionNumber; import com.datastax.driver.core.exceptions.NoHostAvailableException; import com.datastax.driver.core.policies.ReconnectionPolicy; import com.datastax.driver.core.policies.ReconnectionPolicy.ReconnectionSchedule; @@ -62,6 +63,7 @@ import static com.datastax.driver.core.querybuilder.QueryBuilder.eq; import static com.datastax.driver.core.querybuilder.QueryBuilder.select; import static com.datastax.driver.core.querybuilder.Select.Where; +import static com.facebook.presto.cassandra.CassandraErrorCode.CASSANDRA_VERSION_ERROR; import static com.facebook.presto.cassandra.util.CassandraCqlUtils.validSchemaName; import static com.facebook.presto.spi.StandardErrorCode.NOT_SUPPORTED; import static com.google.common.base.Preconditions.checkState; @@ -100,6 +102,19 @@ public NativeCassandraSession(String connectorId, JsonCodec session.execute("select release_version from system.local")); + Row versionRow = result.one(); + if (versionRow == null) { + throw new PrestoException(CASSANDRA_VERSION_ERROR, "The cluster version is not available. " + + "Please make sure that the Cassandra cluster is up and running, " + + "and that the contact points are specified correctly."); + } + return VersionNumber.parse(versionRow.getString("release_version")); + } + @Override public String getPartitioner() { diff --git a/presto-cassandra/src/test/java/com/facebook/presto/cassandra/CassandraTestingUtils.java b/presto-cassandra/src/test/java/com/facebook/presto/cassandra/CassandraTestingUtils.java index e312b2d3feaa..26334b8434fc 100644 --- a/presto-cassandra/src/test/java/com/facebook/presto/cassandra/CassandraTestingUtils.java +++ b/presto-cassandra/src/test/java/com/facebook/presto/cassandra/CassandraTestingUtils.java @@ -33,21 +33,25 @@ public class CassandraTestingUtils { public static final String TABLE_ALL_TYPES = "table_all_types"; + public static final String TABLE_ALL_TYPES_INSERT = "table_all_types_insert"; public static final String TABLE_ALL_TYPES_PARTITION_KEY = "table_all_types_partition_key"; public static final String TABLE_CLUSTERING_KEYS = "table_clustering_keys"; public static final String TABLE_CLUSTERING_KEYS_LARGE = "table_clustering_keys_large"; public static final String TABLE_MULTI_PARTITION_CLUSTERING_KEYS = "table_multi_partition_clustering_keys"; + public static final String TABLE_CLUSTERING_KEYS_INEQUALITY = "table_clustering_keys_inequality"; private CassandraTestingUtils() {} public static void createTestTables(CassandraSession cassandraSession, String keyspace, Date date) { createKeyspace(cassandraSession, keyspace); - createTableAllTypes(cassandraSession, new SchemaTableName(keyspace, TABLE_ALL_TYPES), date); + createTableAllTypes(cassandraSession, new SchemaTableName(keyspace, TABLE_ALL_TYPES), date, 9); + createTableAllTypes(cassandraSession, new SchemaTableName(keyspace, TABLE_ALL_TYPES_INSERT), date, 0); createTableAllTypesPartitionKey(cassandraSession, new SchemaTableName(keyspace, TABLE_ALL_TYPES_PARTITION_KEY), date); createTableClusteringKeys(cassandraSession, new SchemaTableName(keyspace, TABLE_CLUSTERING_KEYS), 9); createTableClusteringKeys(cassandraSession, new SchemaTableName(keyspace, TABLE_CLUSTERING_KEYS_LARGE), 1000); createTableMultiPartitionClusteringKeys(cassandraSession, new SchemaTableName(keyspace, TABLE_MULTI_PARTITION_CLUSTERING_KEYS)); + createTableClusteringKeysInequality(cassandraSession, new SchemaTableName(keyspace, TABLE_CLUSTERING_KEYS_INEQUALITY), date, 4); } public static void createKeyspace(CassandraSession session, String keyspaceName) @@ -111,7 +115,34 @@ public static void insertIntoTableMultiPartitionClusteringKeys(CassandraSession assertEquals(session.execute("SELECT COUNT(*) FROM " + table).all().get(0).getLong(0), 9); } - public static void createTableAllTypes(CassandraSession session, SchemaTableName table, Date date) + public static void createTableClusteringKeysInequality(CassandraSession session, SchemaTableName table, Date date, int rowsCount) + { + session.execute("DROP TABLE IF EXISTS " + table); + session.execute("CREATE TABLE " + table + " (" + + "key text, " + + "clust_one text, " + + "clust_two int, " + + "clust_three timestamp, " + + "data text, " + + "PRIMARY KEY((key), clust_one, clust_two, clust_three) " + + ")"); + insertIntoTableClusteringKeysInequality(session, table, date, rowsCount); + } + + public static void insertIntoTableClusteringKeysInequality(CassandraSession session, SchemaTableName table, Date date, int rowsCount) + { + for (Integer rowNumber = 1; rowNumber <= rowsCount; rowNumber++) { + Insert insert = QueryBuilder.insertInto(table.getSchemaName(), table.getTableName()) + .value("key", "key_1") + .value("clust_one", "clust_one") + .value("clust_two", rowNumber) + .value("clust_three", date.getTime() + rowNumber * 10); + session.execute(insert); + } + assertEquals(session.execute("SELECT COUNT(*) FROM " + table).all().get(0).getLong(0), rowsCount); + } + + public static void createTableAllTypes(CassandraSession session, SchemaTableName table, Date date, int rowsCount) { session.execute("DROP TABLE IF EXISTS " + table); session.execute("CREATE TABLE " + table + " (" + @@ -134,7 +165,7 @@ public static void createTableAllTypes(CassandraSession session, SchemaTableName " typemap map, " + " typeset set, " + ")"); - insertTestData(session, table, date); + insertTestData(session, table, date, rowsCount); } public static void createTableAllTypesPartitionKey(CassandraSession session, SchemaTableName table, Date date) @@ -186,12 +217,12 @@ public static void createTableAllTypesPartitionKey(CassandraSession session, Sch " ))" + ")"); - insertTestData(session, table, date); + insertTestData(session, table, date, 9); } - private static void insertTestData(CassandraSession session, SchemaTableName table, Date date) + private static void insertTestData(CassandraSession session, SchemaTableName table, Date date, int rowsCount) { - for (Integer rowNumber = 1; rowNumber < 10; rowNumber++) { + for (Integer rowNumber = 1; rowNumber <= rowsCount; rowNumber++) { Insert insert = QueryBuilder.insertInto(table.getSchemaName(), table.getTableName()) .value("key", "key " + rowNumber.toString()) .value("typeuuid", UUID.fromString(String.format("00000000-0000-0000-0000-%012d", rowNumber))) @@ -214,6 +245,6 @@ private static void insertTestData(CassandraSession session, SchemaTableName tab session.execute(insert); } - assertEquals(session.execute("SELECT COUNT(*) FROM " + table).all().get(0).getLong(0), 9); + assertEquals(session.execute("SELECT COUNT(*) FROM " + table).all().get(0).getLong(0), rowsCount); } } diff --git a/presto-cassandra/src/test/java/com/facebook/presto/cassandra/EmbeddedCassandra.java b/presto-cassandra/src/test/java/com/facebook/presto/cassandra/EmbeddedCassandra.java index 1d9df2ae4f27..dd6b0f9e9f61 100644 --- a/presto-cassandra/src/test/java/com/facebook/presto/cassandra/EmbeddedCassandra.java +++ b/presto-cassandra/src/test/java/com/facebook/presto/cassandra/EmbeddedCassandra.java @@ -81,6 +81,7 @@ public static synchronized void start() .withClusterName("TestCluster") .addContactPointsWithPorts(ImmutableList.of( new InetSocketAddress(HOST, PORT))) + .withMaxSchemaAgreementWaitSeconds(30) .build(); CassandraSession session = new NativeCassandraSession( diff --git a/presto-cassandra/src/test/java/com/facebook/presto/cassandra/TestCassandraDistributed.java b/presto-cassandra/src/test/java/com/facebook/presto/cassandra/TestCassandraDistributed.java index 572a2018df1f..54d44a5cfa42 100644 --- a/presto-cassandra/src/test/java/com/facebook/presto/cassandra/TestCassandraDistributed.java +++ b/presto-cassandra/src/test/java/com/facebook/presto/cassandra/TestCassandraDistributed.java @@ -75,6 +75,12 @@ public void testRenameColumn() // Cassandra does not support renaming columns } + @Override + public void testDropColumn() + { + // Cassandra does not support dropping columns + } + @Override public void testInsert() { diff --git a/presto-cassandra/src/test/java/com/facebook/presto/cassandra/TestCassandraIntegrationSmokeTest.java b/presto-cassandra/src/test/java/com/facebook/presto/cassandra/TestCassandraIntegrationSmokeTest.java index 81eb2ebc6664..7506ceffc03e 100644 --- a/presto-cassandra/src/test/java/com/facebook/presto/cassandra/TestCassandraIntegrationSmokeTest.java +++ b/presto-cassandra/src/test/java/com/facebook/presto/cassandra/TestCassandraIntegrationSmokeTest.java @@ -33,8 +33,10 @@ import static com.datastax.driver.core.utils.Bytes.toRawHexString; import static com.facebook.presto.cassandra.CassandraQueryRunner.createCassandraSession; import static com.facebook.presto.cassandra.CassandraTestingUtils.TABLE_ALL_TYPES; +import static com.facebook.presto.cassandra.CassandraTestingUtils.TABLE_ALL_TYPES_INSERT; import static com.facebook.presto.cassandra.CassandraTestingUtils.TABLE_ALL_TYPES_PARTITION_KEY; import static com.facebook.presto.cassandra.CassandraTestingUtils.TABLE_CLUSTERING_KEYS; +import static com.facebook.presto.cassandra.CassandraTestingUtils.TABLE_CLUSTERING_KEYS_INEQUALITY; import static com.facebook.presto.cassandra.CassandraTestingUtils.TABLE_CLUSTERING_KEYS_LARGE; import static com.facebook.presto.cassandra.CassandraTestingUtils.TABLE_MULTI_PARTITION_CLUSTERING_KEYS; import static com.facebook.presto.cassandra.CassandraTestingUtils.createTestTables; @@ -201,6 +203,54 @@ public void testClusteringKeyOnlyPushdown() assertEquals(execute(sql).getRowCount(), 1); sql = "SELECT * FROM " + TABLE_CLUSTERING_KEYS_LARGE + " WHERE clust_one='clust_one' AND clust_two='clust_two_2' AND clust_three='clust_three_2'"; assertEquals(execute(sql).getRowCount(), 1); + sql = "SELECT * FROM " + TABLE_CLUSTERING_KEYS_LARGE + " WHERE clust_one='clust_one' AND clust_two='clust_two_2' AND clust_three IN ('clust_three_1', 'clust_three_2', 'clust_three_3')"; + assertEquals(execute(sql).getRowCount(), 1); + sql = "SELECT * FROM " + TABLE_CLUSTERING_KEYS_LARGE + " WHERE clust_one='clust_one' AND clust_two IN ('clust_two_1','clust_two_2') AND clust_three IN ('clust_three_1', 'clust_three_2', 'clust_three_3')"; + assertEquals(execute(sql).getRowCount(), 2); + sql = "SELECT * FROM " + TABLE_CLUSTERING_KEYS_LARGE + " WHERE clust_one='clust_one' AND clust_two > 'clust_two_998'"; + assertEquals(execute(sql).getRowCount(), 1); + sql = "SELECT * FROM " + TABLE_CLUSTERING_KEYS_LARGE + " WHERE clust_one='clust_one' AND clust_two > 'clust_two_997' AND clust_two < 'clust_two_999'"; + assertEquals(execute(sql).getRowCount(), 1); + sql = "SELECT * FROM " + TABLE_CLUSTERING_KEYS_LARGE + " WHERE clust_one='clust_one' AND clust_two IN ('clust_two_1','clust_two_2') AND clust_three > 'clust_three_998'"; + assertEquals(execute(sql).getRowCount(), 0); + sql = "SELECT * FROM " + TABLE_CLUSTERING_KEYS_LARGE + " WHERE clust_one='clust_one' AND clust_two IN ('clust_two_1','clust_two_2') AND clust_three < 'clust_three_3'"; + assertEquals(execute(sql).getRowCount(), 2); + sql = "SELECT * FROM " + TABLE_CLUSTERING_KEYS_LARGE + " WHERE clust_one='clust_one' AND clust_two IN ('clust_two_1','clust_two_2') AND clust_three > 'clust_three_1' AND clust_three < 'clust_three_3'"; + assertEquals(execute(sql).getRowCount(), 1); + sql = "SELECT * FROM " + TABLE_CLUSTERING_KEYS_LARGE + " WHERE clust_one='clust_one' AND clust_two IN ('clust_two_1','clust_two_2','clust_two_3') AND clust_two < 'clust_two_2'"; + assertEquals(execute(sql).getRowCount(), 1); + sql = "SELECT * FROM " + TABLE_CLUSTERING_KEYS_LARGE + " WHERE clust_one='clust_one' AND clust_two IN ('clust_two_997','clust_two_998','clust_two_999') AND clust_two > 'clust_two_998'"; + assertEquals(execute(sql).getRowCount(), 1); + sql = "SELECT * FROM " + TABLE_CLUSTERING_KEYS_LARGE + " WHERE clust_one='clust_one' AND clust_two IN ('clust_two_1','clust_two_2','clust_two_3') AND clust_two = 'clust_two_2'"; + assertEquals(execute(sql).getRowCount(), 1); + } + + @Test + public void testClusteringKeyPushdownInequality() + throws Exception + { + String sql = "SELECT * FROM " + TABLE_CLUSTERING_KEYS_INEQUALITY + " WHERE key='key_1' AND clust_one='clust_one'"; + assertEquals(execute(sql).getRowCount(), 4); + sql = "SELECT * FROM " + TABLE_CLUSTERING_KEYS_INEQUALITY + " WHERE key='key_1' AND clust_one='clust_one' AND clust_two=2"; + assertEquals(execute(sql).getRowCount(), 1); + sql = "SELECT * FROM " + TABLE_CLUSTERING_KEYS_INEQUALITY + " WHERE key='key_1' AND clust_one='clust_one' AND clust_two=2 AND clust_three = timestamp '1970-01-01 03:04:05.020'"; + assertEquals(execute(sql).getRowCount(), 1); + sql = "SELECT * FROM " + TABLE_CLUSTERING_KEYS_INEQUALITY + " WHERE key='key_1' AND clust_one='clust_one' AND clust_two=2 AND clust_three = timestamp '1970-01-01 03:04:05.010'"; + assertEquals(execute(sql).getRowCount(), 0); + sql = "SELECT * FROM " + TABLE_CLUSTERING_KEYS_INEQUALITY + " WHERE key='key_1' AND clust_one='clust_one' AND clust_two IN (1,2)"; + assertEquals(execute(sql).getRowCount(), 2); + sql = "SELECT * FROM " + TABLE_CLUSTERING_KEYS_INEQUALITY + " WHERE key='key_1' AND clust_one='clust_one' AND clust_two > 1 AND clust_two < 3"; + assertEquals(execute(sql).getRowCount(), 1); + sql = "SELECT * FROM " + TABLE_CLUSTERING_KEYS_INEQUALITY + " WHERE key='key_1' AND clust_one='clust_one' AND clust_two=2 AND clust_three >= timestamp '1970-01-01 03:04:05.010' AND clust_three <= timestamp '1970-01-01 03:04:05.020'"; + assertEquals(execute(sql).getRowCount(), 1); + sql = "SELECT * FROM " + TABLE_CLUSTERING_KEYS_INEQUALITY + " WHERE key='key_1' AND clust_one='clust_one' AND clust_two IN (1,2) AND clust_three >= timestamp '1970-01-01 03:04:05.010' AND clust_three <= timestamp '1970-01-01 03:04:05.020'"; + assertEquals(execute(sql).getRowCount(), 2); + sql = "SELECT * FROM " + TABLE_CLUSTERING_KEYS_INEQUALITY + " WHERE key='key_1' AND clust_one='clust_one' AND clust_two IN (1,2,3) AND clust_two < 2"; + assertEquals(execute(sql).getRowCount(), 1); + sql = "SELECT * FROM " + TABLE_CLUSTERING_KEYS_INEQUALITY + " WHERE key='key_1' AND clust_one='clust_one' AND clust_two IN (1,2,3) AND clust_two > 2"; + assertEquals(execute(sql).getRowCount(), 1); + sql = "SELECT * FROM " + TABLE_CLUSTERING_KEYS_INEQUALITY + " WHERE key='key_1' AND clust_one='clust_one' AND clust_two IN (1,2,3) AND clust_two = 2"; + assertEquals(execute(sql).getRowCount(), 1); } @Test @@ -225,8 +275,7 @@ public void testUpperCaseNameUnescapedInCassandra() .row("column_1", "bigint", "", "") .build()); - // TODO replace with the Presto INSERT INTO once implemented - session.execute("INSERT INTO keyspace_1.table_1 (column_1) VALUES (1)"); + execute("INSERT INTO keyspace_1.table_1 (column_1) VALUES (1)"); assertEquals(execute("SELECT column_1 FROM cassandra.keyspace_1.table_1").getRowCount(), 1); assertUpdate("DROP TABLE cassandra.keyspace_1.table_1"); @@ -257,8 +306,7 @@ public void testUppercaseNameEscaped() .row("column_2", "bigint", "", "") .build()); - // TODO replace with the Presto INSERT INTO once implemented - session.execute("INSERT INTO \"KEYSPACE_2\".\"TABLE_2\" (\"COLUMN_2\") VALUES (1)"); + execute("INSERT INTO \"KEYSPACE_2\".\"TABLE_2\" (\"COLUMN_2\") VALUES (1)"); assertEquals(execute("SELECT column_2 FROM cassandra.keyspace_2.table_2").getRowCount(), 1); assertUpdate("DROP TABLE cassandra.keyspace_2.table_2"); @@ -283,8 +331,10 @@ public void testKeyspaceNameAmbiguity() .build(), new Duration(1, MINUTES)); // There is no way to figure out what the exactly keyspace we want to retrieve tables from - assertQueryFails("SHOW TABLES FROM cassandra.keyspace_3", - "More than one keyspace has been found for the case insensitive schema name: keyspace_3 -> \\(KeYsPaCe_3, kEySpAcE_3\\)"); + assertQueryFailsEventually( + "SHOW TABLES FROM cassandra.keyspace_3", + "More than one keyspace has been found for the case insensitive schema name: keyspace_3 -> \\(KeYsPaCe_3, kEySpAcE_3\\)", + new Duration(1, MINUTES)); session.execute("DROP KEYSPACE \"KeYsPaCe_3\""); session.execute("DROP KEYSPACE \"kEySpAcE_3\""); @@ -311,11 +361,14 @@ public void testTableNameAmbiguity() .build(), new Duration(1, MINUTES)); // There is no way to figure out what the exactly table is being queried - assertQueryFails("SHOW COLUMNS FROM cassandra.keyspace_4.table_4", - "More than one table has been found for the case insensitive table name: table_4 -> \\(TaBlE_4, tAbLe_4\\)"); - assertQueryFails("SELECT * FROM cassandra.keyspace_4.table_4", - "More than one table has been found for the case insensitive table name: table_4 -> \\(TaBlE_4, tAbLe_4\\)"); - + assertQueryFailsEventually( + "SHOW COLUMNS FROM cassandra.keyspace_4.table_4", + "More than one table has been found for the case insensitive table name: table_4 -> \\(TaBlE_4, tAbLe_4\\)", + new Duration(1, MINUTES)); + assertQueryFailsEventually( + "SELECT * FROM cassandra.keyspace_4.table_4", + "More than one table has been found for the case insensitive table name: table_4 -> \\(TaBlE_4, tAbLe_4\\)", + new Duration(1, MINUTES)); session.execute("DROP KEYSPACE keyspace_4"); } @@ -333,14 +386,123 @@ public void testColumnNameAmbiguity() .row("table_5") .build(), new Duration(1, MINUTES)); - assertQueryFails("SHOW COLUMNS FROM cassandra.keyspace_5.table_5", - "More than one column has been found for the case insensitive column name: column_5 -> \\(CoLuMn_5, cOlUmN_5\\)"); - assertQueryFails("SELECT * FROM cassandra.keyspace_5.table_5", - "More than one column has been found for the case insensitive column name: column_5 -> \\(CoLuMn_5, cOlUmN_5\\)"); + assertQueryFailsEventually( + "SHOW COLUMNS FROM cassandra.keyspace_5.table_5", + "More than one column has been found for the case insensitive column name: column_5 -> \\(CoLuMn_5, cOlUmN_5\\)", + new Duration(1, MINUTES)); + assertQueryFailsEventually( + "SELECT * FROM cassandra.keyspace_5.table_5", + "More than one column has been found for the case insensitive column name: column_5 -> \\(CoLuMn_5, cOlUmN_5\\)", + new Duration(1, MINUTES)); session.execute("DROP KEYSPACE keyspace_5"); } + @Test + public void testInsert() + { + String sql = "SELECT key, typeuuid, typeinteger, typelong, typebytes, typetimestamp, typeansi, typeboolean, typedecimal, " + + "typedouble, typefloat, typeinet, typevarchar, typevarint, typetimeuuid, typelist, typemap, typeset" + + " FROM " + TABLE_ALL_TYPES_INSERT; + assertEquals(execute(sql).getRowCount(), 0); + + // TODO Following types are not supported now. We need to change null into the value after fixing it + // blob, frozen>, inet, list, map, set, timeuuid, decimal, uuid, varint + // timestamp can be inserted but the expected and actual values are not same + execute("INSERT INTO " + TABLE_ALL_TYPES_INSERT + " (" + + "key," + + "typeuuid," + + "typeinteger," + + "typelong," + + "typebytes," + + "typetimestamp," + + "typeansi," + + "typeboolean," + + "typedecimal," + + "typedouble," + + "typefloat," + + "typeinet," + + "typevarchar," + + "typevarint," + + "typetimeuuid," + + "typelist," + + "typemap," + + "typeset" + + ") VALUES (" + + "'key1', " + + "null, " + + "1, " + + "1000, " + + "null, " + + "timestamp '1970-01-01 08:34:05.0', " + + "'ansi1', " + + "true, " + + "null, " + + "0.3, " + + "cast('0.4' as real), " + + "null, " + + "'varchar1', " + + "null, " + + "null, " + + "null, " + + "null, " + + "null " + + ")"); + + MaterializedResult result = execute(sql); + int rowCount = result.getRowCount(); + assertEquals(rowCount, 1); + assertEquals(result.getMaterializedRows().get(0), new MaterializedRow(DEFAULT_PRECISION, + "key1", + null, + 1, + 1000L, + null, + Timestamp.valueOf("1970-01-01 14:04:05.0"), + "ansi1", + true, + null, + 0.3, + (float) 0.4, + null, + "varchar1", + null, + null, + null, + null, + null + )); + + // insert null for all datatypes + execute("INSERT INTO " + TABLE_ALL_TYPES_INSERT + " (" + + "key, typeuuid, typeinteger, typelong, typebytes, typetimestamp, typeansi, typeboolean, typedecimal," + + "typedouble, typefloat, typeinet, typevarchar, typevarint, typetimeuuid, typelist, typemap, typeset" + + ") VALUES (" + + "'key2', null, null, null, null, null, null, null, null," + + "null, null, null, null, null, null, null, null, null)"); + sql = "SELECT key, typeuuid, typeinteger, typelong, typebytes, typetimestamp, typeansi, typeboolean, typedecimal, " + + "typedouble, typefloat, typeinet, typevarchar, typevarint, typetimeuuid, typelist, typemap, typeset" + + " FROM " + TABLE_ALL_TYPES_INSERT + " WHERE key = 'key2'"; + result = execute(sql); + rowCount = result.getRowCount(); + assertEquals(rowCount, 1); + assertEquals(result.getMaterializedRows().get(0), new MaterializedRow(DEFAULT_PRECISION, + "key2", null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null)); + + // insert into only a subset of columns + execute("INSERT INTO " + TABLE_ALL_TYPES_INSERT + " (" + + "key, typeinteger, typeansi, typeboolean) VALUES (" + + "'key3', 999, 'ansi', false)"); + sql = "SELECT key, typeuuid, typeinteger, typelong, typebytes, typetimestamp, typeansi, typeboolean, typedecimal, " + + "typedouble, typefloat, typeinet, typevarchar, typevarint, typetimeuuid, typelist, typemap, typeset" + + " FROM " + TABLE_ALL_TYPES_INSERT + " WHERE key = 'key3'"; + result = execute(sql); + rowCount = result.getRowCount(); + assertEquals(rowCount, 1); + assertEquals(result.getMaterializedRows().get(0), new MaterializedRow(DEFAULT_PRECISION, + "key3", null, 999, null, null, null, "ansi", false, null, null, null, null, null, null, null, null, null, null)); + } + private void assertSelect(String tableName, boolean createdByPresto) { Type uuidType = createdByPresto ? createUnboundedVarcharType() : createVarcharType(36); diff --git a/presto-cassandra/src/test/java/com/facebook/presto/cassandra/util/TestCassandraClusteringPredicatesExtractor.java b/presto-cassandra/src/test/java/com/facebook/presto/cassandra/util/TestCassandraClusteringPredicatesExtractor.java index 5a56a88d3d47..78e5c82610b7 100644 --- a/presto-cassandra/src/test/java/com/facebook/presto/cassandra/util/TestCassandraClusteringPredicatesExtractor.java +++ b/presto-cassandra/src/test/java/com/facebook/presto/cassandra/util/TestCassandraClusteringPredicatesExtractor.java @@ -13,6 +13,7 @@ */ package com.facebook.presto.cassandra.util; +import com.datastax.driver.core.VersionNumber; import com.facebook.presto.cassandra.CassandraClusteringPredicatesExtractor; import com.facebook.presto.cassandra.CassandraColumnHandle; import com.facebook.presto.cassandra.CassandraTable; @@ -26,8 +27,6 @@ import org.testng.annotations.BeforeTest; import org.testng.annotations.Test; -import java.util.List; - import static com.facebook.presto.spi.type.BigintType.BIGINT; import static org.testng.Assert.assertEquals; @@ -38,6 +37,7 @@ public class TestCassandraClusteringPredicatesExtractor private static CassandraColumnHandle col3; private static CassandraColumnHandle col4; private static CassandraTable cassandraTable; + private static VersionNumber cassandraVersion; @BeforeTest void setUp() @@ -50,6 +50,8 @@ void setUp() cassandraTable = new CassandraTable( new CassandraTableHandle("cassandra", "test", "records"), ImmutableList.of(col1, col2, col3, col4)); + + cassandraVersion = VersionNumber.parse("2.1.5"); } @Test @@ -60,9 +62,9 @@ public void testBuildClusteringPredicate() col1, Domain.singleValue(BIGINT, 23L), col2, Domain.singleValue(BIGINT, 34L), col4, Domain.singleValue(BIGINT, 26L))); - CassandraClusteringPredicatesExtractor predicatesExtractor = new CassandraClusteringPredicatesExtractor(cassandraTable.getClusteringKeyColumns(), tupleDomain); - List predicate = predicatesExtractor.getClusteringKeyPredicates(); - assertEquals(predicate.get(0), new StringBuilder("\"clusteringKey1\" = 34").toString()); + CassandraClusteringPredicatesExtractor predicatesExtractor = new CassandraClusteringPredicatesExtractor(cassandraTable.getClusteringKeyColumns(), tupleDomain, cassandraVersion); + String predicate = predicatesExtractor.getClusteringKeyPredicates(); + assertEquals(predicate, new StringBuilder("\"clusteringKey1\" = 34").toString()); } @Test @@ -72,7 +74,7 @@ public void testGetUnenforcedPredicates() ImmutableMap.of( col2, Domain.singleValue(BIGINT, 34L), col4, Domain.singleValue(BIGINT, 26L))); - CassandraClusteringPredicatesExtractor predicatesExtractor = new CassandraClusteringPredicatesExtractor(cassandraTable.getClusteringKeyColumns(), tupleDomain); + CassandraClusteringPredicatesExtractor predicatesExtractor = new CassandraClusteringPredicatesExtractor(cassandraTable.getClusteringKeyColumns(), tupleDomain, cassandraVersion); TupleDomain unenforcedPredicates = TupleDomain.withColumnDomains(ImmutableMap.of(col4, Domain.singleValue(BIGINT, 26L))); assertEquals(predicatesExtractor.getUnenforcedConstraints(), unenforcedPredicates); } diff --git a/presto-cli/pom.xml b/presto-cli/pom.xml index a70dfd7fc429..59ff3b03bb86 100644 --- a/presto-cli/pom.xml +++ b/presto-cli/pom.xml @@ -5,7 +5,7 @@ com.facebook.presto presto-root - 0.179-tw-0.36 + 0.181-tw-0.37 presto-cli @@ -14,8 +14,6 @@ ${project.parent.basedir} com.facebook.presto.cli.Presto - false - ${main-class} @@ -39,16 +37,6 @@ concurrent - - io.airlift - http-client - - - - io.airlift - json - - io.airlift log @@ -89,6 +77,11 @@ opencsv + + com.squareup.okhttp3 + okhttp + + org.testng diff --git a/presto-cli/src/main/java/com/facebook/presto/cli/ClientOptions.java b/presto-cli/src/main/java/com/facebook/presto/cli/ClientOptions.java index ce73ad7b7de3..79bf95a058ef 100644 --- a/presto-cli/src/main/java/com/facebook/presto/cli/ClientOptions.java +++ b/presto-cli/src/main/java/com/facebook/presto/cli/ClientOptions.java @@ -19,10 +19,8 @@ import com.google.common.net.HostAndPort; import com.sun.security.auth.module.UnixSystem; import io.airlift.airline.Option; -import io.airlift.http.client.spnego.KerberosConfig; import io.airlift.units.Duration; -import java.io.File; import java.net.URI; import java.net.URISyntaxException; import java.nio.charset.CharsetEncoder; @@ -34,6 +32,7 @@ import java.util.Optional; import java.util.TimeZone; +import static com.facebook.presto.client.KerberosUtil.defaultCredentialCachePath; import static com.google.common.base.Preconditions.checkArgument; import static java.nio.charset.StandardCharsets.US_ASCII; import static java.util.Collections.emptyMap; @@ -59,7 +58,7 @@ public class ClientOptions public String krb5KeytabPath = "/etc/krb5.keytab"; @Option(name = "--krb5-credential-cache-path", title = "krb5 credential cache path", description = "Kerberos credential cache path") - public String krb5CredentialCachePath = defaultCredentialCachePath(); + public String krb5CredentialCachePath = defaultCredentialCachePath().orElse(null); @Option(name = "--krb5-principal", title = "krb5 principal", description = "Kerberos principal to be used") public String krb5Principal; @@ -116,6 +115,9 @@ public class ClientOptions @Option(name = "--socks-proxy", title = "socks-proxy", description = "SOCKS proxy to use for server connections") public HostAndPort socksProxy; + @Option(name = "--http-proxy", title = "http-proxy", description = "HTTP proxy to use for server connections") + public HostAndPort httpProxy; + @Option(name = "--client-request-timeout", title = "client request timeout", description = "Client request timeout (default: 2m)") public Duration clientRequestTimeout = new Duration(2, MINUTES); @@ -148,22 +150,6 @@ public ClientSession toClientSession() clientRequestTimeout); } - public KerberosConfig toKerberosConfig() - { - KerberosConfig config = new KerberosConfig(); - if (krb5ConfigPath != null) { - config.setConfig(new File(krb5ConfigPath)); - } - if (krb5KeytabPath != null) { - config.setKeytab(new File(krb5KeytabPath)); - } - if (krb5CredentialCachePath != null) { - config.setCredentialCache(new File(krb5CredentialCachePath)); - } - config.setUseCanonicalHostname(!krb5DisableRemoteServiceHostnameCanonicalization); - return config; - } - public static URI parseServer(String server) { server = server.toLowerCase(ENGLISH); @@ -173,7 +159,7 @@ public static URI parseServer(String server) HostAndPort host = HostAndPort.fromString(server); try { - return new URI("http", null, host.getHostText(), host.getPortOrDefault(80), null, null, null); + return new URI("http", null, host.getHost(), host.getPortOrDefault(80), null, null, null); } catch (URISyntaxException e) { throw new IllegalArgumentException(e); @@ -193,15 +179,6 @@ public static Map toProperties(List sessi return builder.build(); } - private static String defaultCredentialCachePath() - { - String value = System.getenv("KRB5CCNAME"); - if (value != null && value.startsWith("FILE:")) { - return value.substring("FILE:".length()); - } - return value; - } - public static final class ClientSessionProperty { private static final Splitter NAME_VALUE_SPLITTER = Splitter.on('=').limit(2); diff --git a/presto-cli/src/main/java/com/facebook/presto/cli/Console.java b/presto-cli/src/main/java/com/facebook/presto/cli/Console.java index 03b10369ebd9..ef72bdad3d7a 100644 --- a/presto-cli/src/main/java/com/facebook/presto/cli/Console.java +++ b/presto-cli/src/main/java/com/facebook/presto/cli/Console.java @@ -27,7 +27,6 @@ import com.google.common.io.Files; import io.airlift.airline.Command; import io.airlift.airline.HelpOption; -import io.airlift.http.client.spnego.KerberosConfig; import io.airlift.log.Logging; import io.airlift.log.LoggingConfiguration; import io.airlift.units.Duration; @@ -94,7 +93,6 @@ public class Console public void run() { ClientSession session = clientOptions.toClientSession(); - KerberosConfig kerberosConfig = clientOptions.toKerberosConfig(); boolean hasQuery = !Strings.isNullOrEmpty(clientOptions.execute); boolean isFromFile = !Strings.isNullOrEmpty(clientOptions.file); @@ -125,9 +123,10 @@ public void run() AtomicBoolean exiting = new AtomicBoolean(); interruptThreadOnExit(Thread.currentThread(), exiting); - try (QueryRunner queryRunner = QueryRunner.create( + try (QueryRunner queryRunner = new QueryRunner( session, Optional.ofNullable(clientOptions.socksProxy), + Optional.ofNullable(clientOptions.httpProxy), Optional.ofNullable(clientOptions.keystorePath), Optional.ofNullable(clientOptions.keystorePassword), Optional.ofNullable(clientOptions.truststorePath), @@ -136,8 +135,11 @@ public void run() clientOptions.password ? Optional.of(getPassword()) : Optional.empty(), Optional.ofNullable(clientOptions.krb5Principal), Optional.ofNullable(clientOptions.krb5RemoteServiceName), - clientOptions.authenticationEnabled, - kerberosConfig)) { + Optional.ofNullable(clientOptions.krb5ConfigPath), + Optional.ofNullable(clientOptions.krb5KeytabPath), + Optional.ofNullable(clientOptions.krb5CredentialCachePath), + !clientOptions.krb5DisableRemoteServiceHostnameCanonicalization, + clientOptions.authenticationEnabled)) { if (hasQuery) { executeCommand(queryRunner, query, clientOptions.outputFormat); } diff --git a/presto-cli/src/main/java/com/facebook/presto/cli/LdapRequestFilter.java b/presto-cli/src/main/java/com/facebook/presto/cli/LdapRequestFilter.java deleted file mode 100644 index de40f0e32cc8..000000000000 --- a/presto-cli/src/main/java/com/facebook/presto/cli/LdapRequestFilter.java +++ /dev/null @@ -1,48 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package com.facebook.presto.cli; - -import com.google.common.net.HttpHeaders; -import io.airlift.http.client.HttpRequestFilter; -import io.airlift.http.client.Request; - -import java.util.Base64; - -import static com.google.common.base.Preconditions.checkArgument; -import static io.airlift.http.client.Request.Builder.fromRequest; -import static java.nio.charset.StandardCharsets.ISO_8859_1; -import static java.util.Objects.requireNonNull; - -public class LdapRequestFilter - implements HttpRequestFilter -{ - private final String user; - private final String password; - - public LdapRequestFilter(String user, String password) - { - this.user = requireNonNull(user, "user is null"); - checkArgument(!user.contains(":"), "Illegal character ':' found in username"); - this.password = requireNonNull(password, "password is null"); - } - - @Override - public Request filterRequest(Request request) - { - String value = "Basic " + Base64.getEncoder().encodeToString((user + ":" + password).getBytes(ISO_8859_1)); - return fromRequest(request) - .addHeader(HttpHeaders.AUTHORIZATION, value) - .build(); - } -} diff --git a/presto-cli/src/main/java/com/facebook/presto/cli/QueryRunner.java b/presto-cli/src/main/java/com/facebook/presto/cli/QueryRunner.java index 80fa038eafab..83031ba99574 100644 --- a/presto-cli/src/main/java/com/facebook/presto/cli/QueryRunner.java +++ b/presto-cli/src/main/java/com/facebook/presto/cli/QueryRunner.java @@ -13,39 +13,37 @@ */ package com.facebook.presto.cli; +import com.facebook.presto.client.ClientException; import com.facebook.presto.client.ClientSession; -import com.facebook.presto.client.QueryResults; import com.facebook.presto.client.StatementClient; -import com.google.common.collect.ImmutableList; import com.google.common.net.HostAndPort; -import io.airlift.http.client.HttpClient; -import io.airlift.http.client.HttpClientConfig; -import io.airlift.http.client.HttpRequestFilter; -import io.airlift.http.client.jetty.JettyHttpClient; -import io.airlift.http.client.spnego.KerberosConfig; -import io.airlift.json.JsonCodec; -import io.airlift.units.Duration; +import okhttp3.OkHttpClient; import java.io.Closeable; +import java.io.File; import java.util.Optional; -import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicReference; +import static com.facebook.presto.client.OkHttpUtil.basicAuth; +import static com.facebook.presto.client.OkHttpUtil.setupHttpProxy; +import static com.facebook.presto.client.OkHttpUtil.setupKerberos; +import static com.facebook.presto.client.OkHttpUtil.setupSocksProxy; +import static com.facebook.presto.client.OkHttpUtil.setupSsl; +import static com.facebook.presto.client.OkHttpUtil.setupTimeouts; import static com.google.common.base.Preconditions.checkArgument; -import static io.airlift.json.JsonCodec.jsonCodec; import static java.util.Objects.requireNonNull; +import static java.util.concurrent.TimeUnit.SECONDS; public class QueryRunner implements Closeable { - private final JsonCodec queryResultsCodec; private final AtomicReference session; - private final HttpClient httpClient; + private final OkHttpClient httpClient; public QueryRunner( ClientSession session, - JsonCodec queryResultsCodec, Optional socksProxy, + Optional httpProxy, Optional keystorePath, Optional keystorePassword, Optional truststorePath, @@ -54,24 +52,34 @@ public QueryRunner( Optional password, Optional kerberosPrincipal, Optional kerberosRemoteServiceName, - boolean authenticationEnabled, - KerberosConfig kerberosConfig) + Optional kerberosConfigPath, + Optional kerberosKeytabPath, + Optional kerberosCredentialCachePath, + boolean kerberosUseCanonicalHostname, + boolean kerberosEnabled) { this.session = new AtomicReference<>(requireNonNull(session, "session is null")); - this.queryResultsCodec = requireNonNull(queryResultsCodec, "queryResultsCodec is null"); - this.httpClient = new JettyHttpClient( - getHttpClientConfig( - socksProxy, - keystorePath, - keystorePassword, - truststorePath, - truststorePassword, - kerberosPrincipal, - kerberosRemoteServiceName, - authenticationEnabled), - kerberosConfig, - Optional.empty(), - getRequestFilters(session, user, password)); + + OkHttpClient.Builder builder = new OkHttpClient.Builder(); + + setupTimeouts(builder, 5, SECONDS); + setupSocksProxy(builder, socksProxy); + setupHttpProxy(builder, httpProxy); + setupSsl(builder, keystorePath, keystorePassword, truststorePath, truststorePassword); + setupBasicAuth(builder, session, user, password); + + if (kerberosEnabled) { + setupKerberos( + builder, + kerberosRemoteServiceName.orElseThrow(() -> new ClientException("Kerberos remote service name must be set")), + kerberosUseCanonicalHostname, + kerberosPrincipal, + kerberosConfigPath.map(File::new), + kerberosKeytabPath.map(File::new), + kerberosCredentialCachePath.map(File::new)); + } + + this.httpClient = builder.build(); } public ClientSession getSession() @@ -91,79 +99,26 @@ public Query startQuery(String query) public StatementClient startInternalQuery(String query) { - return new StatementClient(httpClient, queryResultsCodec, session.get(), query); + return new StatementClient(httpClient, session.get(), query); } @Override public void close() { - httpClient.close(); + httpClient.dispatcher().executorService().shutdown(); + httpClient.connectionPool().evictAll(); } - public static QueryRunner create( + private static void setupBasicAuth( + OkHttpClient.Builder clientBuilder, ClientSession session, - Optional socksProxy, - Optional keystorePath, - Optional keystorePassword, - Optional truststorePath, - Optional truststorePassword, Optional user, - Optional password, - Optional kerberosPrincipal, - Optional kerberosRemoteServiceName, - boolean authenticationEnabled, - KerberosConfig kerberosConfig) - { - return new QueryRunner( - session, - jsonCodec(QueryResults.class), - socksProxy, - keystorePath, - keystorePassword, - truststorePath, - truststorePassword, - user, - password, - kerberosPrincipal, - kerberosRemoteServiceName, - authenticationEnabled, - kerberosConfig); - } - - private static HttpClientConfig getHttpClientConfig( - Optional socksProxy, - Optional keystorePath, - Optional keystorePassword, - Optional truststorePath, - Optional truststorePassword, - Optional kerberosPrincipal, - Optional kerberosRemoteServiceName, - boolean authenticationEnabled) - { - HttpClientConfig httpClientConfig = new HttpClientConfig() - .setConnectTimeout(new Duration(5, TimeUnit.SECONDS)) - .setRequestTimeout(new Duration(5, TimeUnit.SECONDS)); - - socksProxy.ifPresent(httpClientConfig::setSocksProxy); - - httpClientConfig.setAuthenticationEnabled(authenticationEnabled); - - keystorePath.ifPresent(httpClientConfig::setKeyStorePath); - keystorePassword.ifPresent(httpClientConfig::setKeyStorePassword); - truststorePath.ifPresent(httpClientConfig::setTrustStorePath); - truststorePassword.ifPresent(httpClientConfig::setTrustStorePassword); - kerberosPrincipal.ifPresent(httpClientConfig::setKerberosPrincipal); - kerberosRemoteServiceName.ifPresent(httpClientConfig::setKerberosRemoteServiceName); - - return httpClientConfig; - } - - private static Iterable getRequestFilters(ClientSession session, Optional user, Optional password) + Optional password) { if (user.isPresent() && password.isPresent()) { - checkArgument(session.getServer().getScheme().equalsIgnoreCase("https"), "Authentication using username/password requires HTTPS to be enabled"); - return ImmutableList.of(new LdapRequestFilter(user.get(), password.get())); + checkArgument(session.getServer().getScheme().equalsIgnoreCase("https"), + "Authentication using username/password requires HTTPS to be enabled"); + clientBuilder.addInterceptor(basicAuth(user.get(), password.get())); } - return ImmutableList.of(); } } diff --git a/presto-cli/src/main/java/com/facebook/presto/cli/StatusPrinter.java b/presto-cli/src/main/java/com/facebook/presto/cli/StatusPrinter.java index 7d45a47cb7ae..80f65a834ec2 100644 --- a/presto-cli/src/main/java/com/facebook/presto/cli/StatusPrinter.java +++ b/presto-cli/src/main/java/com/facebook/presto/cli/StatusPrinter.java @@ -99,7 +99,7 @@ public void printInitialStatusUpdates() // check for keyboard input int key = readKey(); if (key == CTRL_P) { - partialCancel(); + client.cancelLeafStage(); } else if (key == CTRL_C) { updateScreen(); @@ -406,16 +406,6 @@ private void printStageTree(StageStats stage, String indent, AtomicInteger stage } } - private void partialCancel() - { - try { - client.cancelLeafStage(new Duration(1, SECONDS)); - } - catch (RuntimeException e) { - log.debug(e, "error canceling leaf stage"); - } - } - private void reprintLine(String line) { console.reprintLine(line); diff --git a/presto-cli/src/test/java/com/facebook/presto/cli/TestTableNameCompleter.java b/presto-cli/src/test/java/com/facebook/presto/cli/TestTableNameCompleter.java index 3c1bf78c27cd..e9f4275cfd64 100644 --- a/presto-cli/src/test/java/com/facebook/presto/cli/TestTableNameCompleter.java +++ b/presto-cli/src/test/java/com/facebook/presto/cli/TestTableNameCompleter.java @@ -27,7 +27,12 @@ public class TestTableNameCompleter public void testAutoCompleteWithoutSchema() { ClientSession session = new ClientOptions().toClientSession(); - QueryRunner runner = QueryRunner.create(session, + QueryRunner runner = new QueryRunner( + session, + Optional.empty(), + Optional.empty(), + Optional.empty(), + Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), @@ -38,7 +43,7 @@ public void testAutoCompleteWithoutSchema() Optional.empty(), Optional.empty(), false, - null); + false); TableNameCompleter completer = new TableNameCompleter(runner); assertEquals(completer.complete("SELECT is_infi", 14, ImmutableList.of()), 7); } diff --git a/presto-client/pom.xml b/presto-client/pom.xml index cfb1b002babd..564ef54c8b3d 100644 --- a/presto-client/pom.xml +++ b/presto-client/pom.xml @@ -5,7 +5,7 @@ com.facebook.presto presto-root - 0.179-tw-0.36 + 0.181-tw-0.37 presto-client @@ -46,11 +46,6 @@ jackson-databind - - io.airlift - http-client - - io.airlift json @@ -66,6 +61,11 @@ guava + + com.squareup.okhttp3 + okhttp + + org.testng diff --git a/presto-client/src/main/java/com/facebook/presto/client/ClientException.java b/presto-client/src/main/java/com/facebook/presto/client/ClientException.java new file mode 100644 index 000000000000..e9cc75ce70f4 --- /dev/null +++ b/presto-client/src/main/java/com/facebook/presto/client/ClientException.java @@ -0,0 +1,28 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.client; + +public class ClientException + extends RuntimeException +{ + public ClientException(String message) + { + super(message); + } + + public ClientException(String message, Throwable cause) + { + super(message, cause); + } +} diff --git a/presto-client/src/main/java/com/facebook/presto/client/JsonResponse.java b/presto-client/src/main/java/com/facebook/presto/client/JsonResponse.java new file mode 100644 index 000000000000..12609e66296a --- /dev/null +++ b/presto-client/src/main/java/com/facebook/presto/client/JsonResponse.java @@ -0,0 +1,156 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.client; + +import io.airlift.json.JsonCodec; +import okhttp3.Headers; +import okhttp3.MediaType; +import okhttp3.OkHttpClient; +import okhttp3.Request; +import okhttp3.Response; +import okhttp3.ResponseBody; + +import javax.annotation.Nullable; + +import java.io.IOException; +import java.io.UncheckedIOException; + +import static com.google.common.base.MoreObjects.toStringHelper; +import static com.google.common.net.HttpHeaders.LOCATION; +import static java.lang.String.format; +import static java.util.Objects.requireNonNull; + +public final class JsonResponse +{ + private final int statusCode; + private final String statusMessage; + private final Headers headers; + private final String responseBody; + private final boolean hasValue; + private final T value; + private final IllegalArgumentException exception; + + private JsonResponse(int statusCode, String statusMessage, Headers headers, String responseBody) + { + this.statusCode = statusCode; + this.statusMessage = statusMessage; + this.headers = requireNonNull(headers, "headers is null"); + this.responseBody = requireNonNull(responseBody, "responseBody is null"); + + this.hasValue = false; + this.value = null; + this.exception = null; + } + + private JsonResponse(int statusCode, String statusMessage, Headers headers, String responseBody, JsonCodec jsonCodec) + { + this.statusCode = statusCode; + this.statusMessage = statusMessage; + this.headers = requireNonNull(headers, "headers is null"); + this.responseBody = requireNonNull(responseBody, "responseBody is null"); + + T value = null; + IllegalArgumentException exception = null; + try { + value = jsonCodec.fromJson(responseBody); + } + catch (IllegalArgumentException e) { + exception = new IllegalArgumentException(format("Unable to create %s from JSON response:\n[%s]", jsonCodec.getType(), responseBody), e); + } + this.hasValue = (exception == null); + this.value = value; + this.exception = exception; + } + + public int getStatusCode() + { + return statusCode; + } + + public String getStatusMessage() + { + return statusMessage; + } + + public Headers getHeaders() + { + return headers; + } + + public boolean hasValue() + { + return hasValue; + } + + public T getValue() + { + if (!hasValue) { + throw new IllegalStateException("Response does not contain a JSON value", exception); + } + return value; + } + + public String getResponseBody() + { + return responseBody; + } + + @Nullable + public IllegalArgumentException getException() + { + return exception; + } + + @Override + public String toString() + { + return toStringHelper(this) + .add("statusCode", statusCode) + .add("statusMessage", statusMessage) + .add("headers", headers.toMultimap()) + .add("hasValue", hasValue) + .add("value", value) + .omitNullValues() + .toString(); + } + + public static JsonResponse execute(JsonCodec codec, OkHttpClient client, Request request) + { + try (Response response = client.newCall(request).execute()) { + // TODO: fix in OkHttp: https://github.com/square/okhttp/issues/3111 + if ((response.code() == 307) || (response.code() == 308)) { + String location = response.header(LOCATION); + if (location != null) { + request = request.newBuilder().url(location).build(); + return execute(codec, client, request); + } + } + + ResponseBody responseBody = requireNonNull(response.body()); + String body = responseBody.string(); + if (isJson(responseBody.contentType())) { + return new JsonResponse<>(response.code(), response.message(), response.headers(), body, codec); + } + return new JsonResponse<>(response.code(), response.message(), response.headers(), body); + } + catch (IOException e) { + throw new UncheckedIOException(e); + } + } + + private static boolean isJson(MediaType type) + { + return (type != null) && "application".equals(type.type()) && "json".equals(type.subtype()); + } +} diff --git a/presto-client/src/main/java/com/facebook/presto/client/KerberosUtil.java b/presto-client/src/main/java/com/facebook/presto/client/KerberosUtil.java new file mode 100644 index 000000000000..914def58b202 --- /dev/null +++ b/presto-client/src/main/java/com/facebook/presto/client/KerberosUtil.java @@ -0,0 +1,35 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.client; + +import java.util.Optional; + +import static com.google.common.base.Strings.emptyToNull; +import static com.google.common.base.Strings.nullToEmpty; + +public final class KerberosUtil +{ + private static final String FILE_PREFIX = "FILE:"; + + private KerberosUtil() {} + + public static Optional defaultCredentialCachePath() + { + String value = nullToEmpty(System.getenv("KRB5CCNAME")); + if (value.startsWith(FILE_PREFIX)) { + value = value.substring(FILE_PREFIX.length()); + } + return Optional.ofNullable(emptyToNull(value)); + } +} diff --git a/presto-client/src/main/java/com/facebook/presto/client/OkHttpUtil.java b/presto-client/src/main/java/com/facebook/presto/client/OkHttpUtil.java new file mode 100644 index 000000000000..4b788a56535c --- /dev/null +++ b/presto-client/src/main/java/com/facebook/presto/client/OkHttpUtil.java @@ -0,0 +1,192 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.client; + +import com.google.common.net.HostAndPort; +import okhttp3.Call; +import okhttp3.Callback; +import okhttp3.Credentials; +import okhttp3.Interceptor; +import okhttp3.OkHttpClient; +import okhttp3.Response; + +import javax.net.ssl.KeyManager; +import javax.net.ssl.KeyManagerFactory; +import javax.net.ssl.SSLContext; +import javax.net.ssl.TrustManager; +import javax.net.ssl.TrustManagerFactory; +import javax.net.ssl.X509TrustManager; + +import java.io.File; +import java.io.FileInputStream; +import java.io.IOException; +import java.io.InputStream; +import java.net.InetSocketAddress; +import java.net.Proxy; +import java.security.GeneralSecurityException; +import java.security.KeyStore; +import java.util.Arrays; +import java.util.Optional; +import java.util.concurrent.TimeUnit; + +import static com.google.common.net.HttpHeaders.AUTHORIZATION; +import static com.google.common.net.HttpHeaders.USER_AGENT; +import static java.net.Proxy.Type.HTTP; +import static java.net.Proxy.Type.SOCKS; +import static java.util.Objects.requireNonNull; + +public final class OkHttpUtil +{ + private OkHttpUtil() {} + + public static class NullCallback + implements Callback + { + @Override + public void onFailure(Call call, IOException e) {} + + @Override + public void onResponse(Call call, Response response) {} + } + + public static Interceptor userAgent(String userAgent) + { + return chain -> chain.proceed(chain.request().newBuilder() + .header(USER_AGENT, userAgent) + .build()); + } + + public static Interceptor basicAuth(String user, String password) + { + requireNonNull(user, "user is null"); + requireNonNull(password, "password is null"); + if (user.contains(":")) { + throw new ClientException("Illegal character ':' found in username"); + } + + String credential = Credentials.basic(user, password); + return chain -> chain.proceed(chain.request().newBuilder() + .header(AUTHORIZATION, credential) + .build()); + } + + public static void setupTimeouts(OkHttpClient.Builder clientBuilder, int timeout, TimeUnit unit) + { + clientBuilder + .connectTimeout(timeout, unit) + .readTimeout(timeout, unit) + .writeTimeout(timeout, unit); + } + + public static void setupSocksProxy(OkHttpClient.Builder clientBuilder, Optional socksProxy) + { + setupProxy(clientBuilder, socksProxy, SOCKS); + } + + public static void setupHttpProxy(OkHttpClient.Builder clientBuilder, Optional httpProxy) + { + setupProxy(clientBuilder, httpProxy, HTTP); + } + + public static void setupProxy(OkHttpClient.Builder clientBuilder, Optional proxy, Proxy.Type type) + { + proxy.map(OkHttpUtil::toUnresolvedAddress) + .map(address -> new Proxy(type, address)) + .ifPresent(clientBuilder::proxy); + } + + private static InetSocketAddress toUnresolvedAddress(HostAndPort address) + { + return InetSocketAddress.createUnresolved(address.getHost(), address.getPort()); + } + + public static void setupSsl( + OkHttpClient.Builder clientBuilder, + Optional keyStorePath, + Optional keyStorePassword, + Optional trustStorePath, + Optional trustStorePassword) + { + if (!keyStorePath.isPresent() && !trustStorePath.isPresent()) { + return; + } + + try { + // load KeyStore if configured and get KeyManagers + KeyStore keyStore = null; + KeyManager[] keyManagers = null; + if (keyStorePath.isPresent()) { + char[] keyPassword = keyStorePassword.map(String::toCharArray).orElse(null); + + keyStore = KeyStore.getInstance(KeyStore.getDefaultType()); + try (InputStream in = new FileInputStream(keyStorePath.get())) { + keyStore.load(in, keyPassword); + } + + KeyManagerFactory keyManagerFactory = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm()); + keyManagerFactory.init(keyStore, keyPassword); + keyManagers = keyManagerFactory.getKeyManagers(); + } + + // load TrustStore if configured, otherwise use KeyStore + KeyStore trustStore = keyStore; + if (trustStorePath.isPresent()) { + trustStore = KeyStore.getInstance(KeyStore.getDefaultType()); + try (InputStream in = new FileInputStream(trustStorePath.get())) { + trustStore.load(in, trustStorePassword.map(String::toCharArray).orElse(null)); + } + } + + // create TrustManagerFactory + TrustManagerFactory trustManagerFactory = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm()); + trustManagerFactory.init(trustStore); + + // get X509TrustManager + TrustManager[] trustManagers = trustManagerFactory.getTrustManagers(); + if ((trustManagers.length != 1) || !(trustManagers[0] instanceof X509TrustManager)) { + throw new RuntimeException("Unexpected default trust managers:" + Arrays.toString(trustManagers)); + } + X509TrustManager trustManager = (X509TrustManager) trustManagers[0]; + + // create SSLContext + SSLContext sslContext = SSLContext.getInstance("TLS"); + sslContext.init(keyManagers, new TrustManager[] {trustManager}, null); + + clientBuilder.sslSocketFactory(sslContext.getSocketFactory(), trustManager); + } + catch (GeneralSecurityException | IOException e) { + throw new ClientException("Error setting up SSL: " + e.getMessage(), e); + } + } + + public static void setupKerberos( + OkHttpClient.Builder clientBuilder, + String remoteServiceName, + boolean useCanonicalHostname, + Optional principal, + Optional kerberosConfig, + Optional keytab, + Optional credentialCache) + { + SpnegoHandler handler = new SpnegoHandler( + remoteServiceName, + useCanonicalHostname, + principal, + kerberosConfig, + keytab, + credentialCache); + clientBuilder.addInterceptor(handler); + clientBuilder.authenticator(handler); + } +} diff --git a/presto-client/src/main/java/com/facebook/presto/client/SpnegoHandler.java b/presto-client/src/main/java/com/facebook/presto/client/SpnegoHandler.java new file mode 100644 index 000000000000..cb08cb598392 --- /dev/null +++ b/presto-client/src/main/java/com/facebook/presto/client/SpnegoHandler.java @@ -0,0 +1,332 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.client; + +import com.google.common.base.Splitter; +import com.google.common.collect.ImmutableMap; +import com.sun.security.auth.module.Krb5LoginModule; +import io.airlift.units.Duration; +import okhttp3.Authenticator; +import okhttp3.Interceptor; +import okhttp3.Request; +import okhttp3.Response; +import okhttp3.Route; +import org.ietf.jgss.GSSContext; +import org.ietf.jgss.GSSCredential; +import org.ietf.jgss.GSSException; +import org.ietf.jgss.GSSManager; +import org.ietf.jgss.Oid; + +import javax.annotation.concurrent.GuardedBy; +import javax.security.auth.Subject; +import javax.security.auth.login.AppConfigurationEntry; +import javax.security.auth.login.Configuration; +import javax.security.auth.login.LoginContext; +import javax.security.auth.login.LoginException; + +import java.io.File; +import java.io.IOException; +import java.net.InetAddress; +import java.net.UnknownHostException; +import java.security.Principal; +import java.security.PrivilegedActionException; +import java.security.PrivilegedExceptionAction; +import java.util.Base64; +import java.util.Locale; +import java.util.Optional; + +import static com.google.common.base.CharMatcher.whitespace; +import static com.google.common.base.Throwables.throwIfInstanceOf; +import static com.google.common.base.Throwables.throwIfUnchecked; +import static com.google.common.net.HttpHeaders.AUTHORIZATION; +import static com.google.common.net.HttpHeaders.WWW_AUTHENTICATE; +import static java.lang.Boolean.getBoolean; +import static java.lang.String.format; +import static java.util.Objects.requireNonNull; +import static java.util.concurrent.TimeUnit.SECONDS; +import static javax.security.auth.login.AppConfigurationEntry.LoginModuleControlFlag.REQUIRED; +import static org.ietf.jgss.GSSContext.INDEFINITE_LIFETIME; +import static org.ietf.jgss.GSSCredential.DEFAULT_LIFETIME; +import static org.ietf.jgss.GSSCredential.INITIATE_ONLY; +import static org.ietf.jgss.GSSName.NT_HOSTBASED_SERVICE; +import static org.ietf.jgss.GSSName.NT_USER_NAME; + +// TODO: This class is similar to SpnegoAuthentication in Airlift. Consider extracting a library. +public class SpnegoHandler + implements Interceptor, Authenticator +{ + private static final String NEGOTIATE = "Negotiate"; + private static final Duration MIN_CREDENTIAL_LIFETIME = new Duration(60, SECONDS); + + private static final GSSManager GSS_MANAGER = GSSManager.getInstance(); + + private static final Oid SPNEGO_OID = createOid("1.3.6.1.5.5.2"); + private static final Oid KERBEROS_OID = createOid("1.2.840.113554.1.2.2"); + + private final String remoteServiceName; + private final boolean useCanonicalHostname; + private final Optional principal; + private final Optional keytab; + private final Optional credentialCache; + + @GuardedBy("this") + private Session clientSession; + + public SpnegoHandler( + String remoteServiceName, + boolean useCanonicalHostname, + Optional principal, + Optional kerberosConfig, + Optional keytab, + Optional credentialCache) + { + this.remoteServiceName = requireNonNull(remoteServiceName, "remoteServiceName is null"); + this.useCanonicalHostname = useCanonicalHostname; + this.principal = requireNonNull(principal, "principal is null"); + this.keytab = requireNonNull(keytab, "keytab is null"); + this.credentialCache = requireNonNull(credentialCache, "credentialCache is null"); + + kerberosConfig.ifPresent(file -> System.setProperty("java.security.krb5.conf", file.getAbsolutePath())); + } + + @Override + public Response intercept(Chain chain) + throws IOException + { + // eagerly send authentication if possible + try { + return chain.proceed(authenticate(chain.request())); + } + catch (ClientException ignored) { + return chain.proceed(chain.request()); + } + } + + @Override + public Request authenticate(Route route, Response response) + throws IOException + { + // skip if we already tried or were not asked for Kerberos + if (response.request().headers(AUTHORIZATION).stream().anyMatch(SpnegoHandler::isNegotiate) || + response.headers(WWW_AUTHENTICATE).stream().noneMatch(SpnegoHandler::isNegotiate)) { + return null; + } + + return authenticate(response.request()); + } + + private static boolean isNegotiate(String value) + { + return Splitter.on(whitespace()).split(value).iterator().next().equalsIgnoreCase(NEGOTIATE); + } + + private Request authenticate(Request request) + { + String hostName = request.url().host(); + String principal = makeServicePrincipal(remoteServiceName, hostName, useCanonicalHostname); + byte[] token = generateToken(principal); + + String credential = format("%s %s", NEGOTIATE, Base64.getEncoder().encodeToString(token)); + return request.newBuilder() + .header(AUTHORIZATION, credential) + .build(); + } + + private byte[] generateToken(String servicePrincipal) + { + GSSContext context = null; + try { + Session session = getSession(); + context = doAs(session.getLoginContext().getSubject(), () -> { + GSSContext result = GSS_MANAGER.createContext( + GSS_MANAGER.createName(servicePrincipal, NT_HOSTBASED_SERVICE), + SPNEGO_OID, + session.getClientCredential(), + INDEFINITE_LIFETIME); + + result.requestMutualAuth(true); + result.requestConf(true); + result.requestInteg(true); + result.requestCredDeleg(false); + return result; + }); + + byte[] token = context.initSecContext(new byte[0], 0, 0); + if (token == null) { + throw new LoginException("No token generated from GSS context"); + } + return token; + } + catch (GSSException | LoginException e) { + throw new ClientException(format("Kerberos error for [%s]: %s", servicePrincipal, e.getMessage()), e); + } + finally { + try { + if (context != null) { + context.dispose(); + } + } + catch (GSSException ignored) { + } + } + } + + private synchronized Session getSession() + throws LoginException, GSSException + { + if ((clientSession == null) || clientSession.needsRefresh()) { + clientSession = createSession(); + } + return clientSession; + } + + private Session createSession() + throws LoginException, GSSException + { + // TODO: do we need to call logout() on the LoginContext? + + LoginContext loginContext = new LoginContext("", null, null, new Configuration() + { + @Override + public AppConfigurationEntry[] getAppConfigurationEntry(String name) + { + ImmutableMap.Builder options = ImmutableMap.builder(); + options.put("refreshKrb5Config", "true"); + options.put("doNotPrompt", "true"); + options.put("useKeyTab", "true"); + + if (getBoolean("presto.client.debugKerberos")) { + options.put("debug", "true"); + } + + keytab.ifPresent(file -> options.put("keyTab", file.getAbsolutePath())); + + credentialCache.ifPresent(file -> { + options.put("ticketCache", file.getAbsolutePath()); + options.put("useTicketCache", "true"); + options.put("renewTGT", "true"); + }); + + principal.ifPresent(value -> options.put("principal", value)); + + return new AppConfigurationEntry[] { + new AppConfigurationEntry(Krb5LoginModule.class.getName(), REQUIRED, options.build()) + }; + } + }); + + loginContext.login(); + Subject subject = loginContext.getSubject(); + Principal clientPrincipal = subject.getPrincipals().iterator().next(); + GSSCredential clientCredential = doAs(subject, () -> GSS_MANAGER.createCredential( + GSS_MANAGER.createName(clientPrincipal.getName(), NT_USER_NAME), + DEFAULT_LIFETIME, + KERBEROS_OID, + INITIATE_ONLY)); + + return new Session(loginContext, clientCredential); + } + + private static String makeServicePrincipal(String serviceName, String hostName, boolean useCanonicalHostname) + { + String serviceHostName = hostName; + if (useCanonicalHostname) { + serviceHostName = canonicalizeServiceHostName(hostName); + } + return format("%s@%s", serviceName, serviceHostName.toLowerCase(Locale.US)); + } + + private static String canonicalizeServiceHostName(String hostName) + { + try { + InetAddress address = InetAddress.getByName(hostName); + String fullHostName; + if ("localhost".equalsIgnoreCase(address.getHostName())) { + fullHostName = InetAddress.getLocalHost().getCanonicalHostName(); + } + else { + fullHostName = address.getCanonicalHostName(); + } + if (fullHostName.equalsIgnoreCase("localhost")) { + throw new ClientException("Fully qualified name of localhost should not resolve to 'localhost'. System configuration error?"); + } + return fullHostName; + } + catch (UnknownHostException e) { + throw new ClientException("Failed to resolve host: " + hostName, e); + } + } + + private interface GssSupplier + { + T get() + throws GSSException; + } + + private static T doAs(Subject subject, GssSupplier action) + throws GSSException + { + try { + return Subject.doAs(subject, (PrivilegedExceptionAction) action::get); + } + catch (PrivilegedActionException e) { + Throwable t = e.getCause(); + throwIfInstanceOf(t, GSSException.class); + throwIfUnchecked(t); + throw new RuntimeException(t); + } + } + + private static Oid createOid(String value) + { + try { + return new Oid(value); + } + catch (GSSException e) { + throw new AssertionError(e); + } + } + + private static class Session + { + private final LoginContext loginContext; + private final GSSCredential clientCredential; + + public Session(LoginContext loginContext, GSSCredential clientCredential) + throws LoginException + { + requireNonNull(loginContext, "loginContext is null"); + requireNonNull(clientCredential, "gssCredential is null"); + + this.loginContext = loginContext; + this.clientCredential = clientCredential; + } + + public LoginContext getLoginContext() + { + return loginContext; + } + + public GSSCredential getClientCredential() + { + return clientCredential; + } + + public boolean needsRefresh() + throws GSSException + { + return clientCredential.getRemainingLifetime() < MIN_CREDENTIAL_LIFETIME.getValue(SECONDS); + } + } +} diff --git a/presto-client/src/main/java/com/facebook/presto/client/StatementClient.java b/presto-client/src/main/java/com/facebook/presto/client/StatementClient.java index 7373f17338af..785a73b9ae38 100644 --- a/presto-client/src/main/java/com/facebook/presto/client/StatementClient.java +++ b/presto-client/src/main/java/com/facebook/presto/client/StatementClient.java @@ -13,19 +13,19 @@ */ package com.facebook.presto.client; +import com.facebook.presto.client.OkHttpUtil.NullCallback; import com.facebook.presto.spi.type.TimeZoneKey; import com.google.common.base.Splitter; -import com.google.common.base.Throwables; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; import com.google.common.collect.Sets; -import io.airlift.http.client.FullJsonResponseHandler; -import io.airlift.http.client.HttpClient; -import io.airlift.http.client.HttpClient.HttpResponseFuture; -import io.airlift.http.client.HttpStatus; -import io.airlift.http.client.Request; import io.airlift.json.JsonCodec; -import io.airlift.units.Duration; +import okhttp3.Headers; +import okhttp3.HttpUrl; +import okhttp3.MediaType; +import okhttp3.OkHttpClient; +import okhttp3.Request; +import okhttp3.RequestBody; import javax.annotation.concurrent.ThreadSafe; @@ -37,35 +37,36 @@ import java.util.List; import java.util.Map; import java.util.Map.Entry; +import java.util.Optional; import java.util.Set; import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.ExecutionException; -import java.util.concurrent.TimeoutException; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicReference; import static com.facebook.presto.client.PrestoHeaders.PRESTO_ADDED_PREPARE; +import static com.facebook.presto.client.PrestoHeaders.PRESTO_CATALOG; import static com.facebook.presto.client.PrestoHeaders.PRESTO_CLEAR_SESSION; import static com.facebook.presto.client.PrestoHeaders.PRESTO_CLEAR_TRANSACTION_ID; +import static com.facebook.presto.client.PrestoHeaders.PRESTO_CLIENT_INFO; import static com.facebook.presto.client.PrestoHeaders.PRESTO_DEALLOCATED_PREPARE; +import static com.facebook.presto.client.PrestoHeaders.PRESTO_LANGUAGE; +import static com.facebook.presto.client.PrestoHeaders.PRESTO_PREPARED_STATEMENT; +import static com.facebook.presto.client.PrestoHeaders.PRESTO_SCHEMA; +import static com.facebook.presto.client.PrestoHeaders.PRESTO_SESSION; import static com.facebook.presto.client.PrestoHeaders.PRESTO_SET_SESSION; +import static com.facebook.presto.client.PrestoHeaders.PRESTO_SOURCE; import static com.facebook.presto.client.PrestoHeaders.PRESTO_STARTED_TRANSACTION_ID; +import static com.facebook.presto.client.PrestoHeaders.PRESTO_TIME_ZONE; +import static com.facebook.presto.client.PrestoHeaders.PRESTO_TRANSACTION_ID; +import static com.facebook.presto.client.PrestoHeaders.PRESTO_USER; import static com.google.common.base.MoreObjects.firstNonNull; import static com.google.common.base.Preconditions.checkState; import static com.google.common.net.HttpHeaders.USER_AGENT; -import static io.airlift.http.client.FullJsonResponseHandler.JsonResponse; -import static io.airlift.http.client.FullJsonResponseHandler.createFullJsonResponseHandler; -import static io.airlift.http.client.HttpStatus.Family; -import static io.airlift.http.client.HttpStatus.familyForStatusCode; -import static io.airlift.http.client.HttpUriBuilder.uriBuilderFrom; -import static io.airlift.http.client.Request.Builder.prepareDelete; -import static io.airlift.http.client.Request.Builder.prepareGet; -import static io.airlift.http.client.Request.Builder.preparePost; -import static io.airlift.http.client.StaticBodyGenerator.createStaticBodyGenerator; -import static io.airlift.http.client.StatusResponseHandler.StatusResponse; -import static io.airlift.http.client.StatusResponseHandler.createStatusResponseHandler; +import static io.airlift.json.JsonCodec.jsonCodec; import static java.lang.String.format; -import static java.nio.charset.StandardCharsets.UTF_8; +import static java.net.HttpURLConnection.HTTP_OK; +import static java.net.HttpURLConnection.HTTP_UNAUTHORIZED; +import static java.net.HttpURLConnection.HTTP_UNAVAILABLE; import static java.util.Objects.requireNonNull; import static java.util.concurrent.TimeUnit.MILLISECONDS; import static java.util.concurrent.TimeUnit.NANOSECONDS; @@ -74,13 +75,15 @@ public class StatementClient implements Closeable { + private static final MediaType MEDIA_TYPE_JSON = MediaType.parse("application/json; charset=utf-8"); + private static final JsonCodec QUERY_RESULTS_CODEC = jsonCodec(QueryResults.class); + private static final Splitter SESSION_HEADER_SPLITTER = Splitter.on('=').limit(2).trimResults(); private static final String USER_AGENT_VALUE = StatementClient.class.getSimpleName() + "/" + firstNonNull(StatementClient.class.getPackage().getImplementationVersion(), "unknown"); - private final HttpClient httpClient; - private final FullJsonResponseHandler responseHandler; + private final OkHttpClient httpClient; private final boolean debug; private final String query; private final AtomicReference currentResults = new AtomicReference<>(); @@ -97,15 +100,13 @@ public class StatementClient private final long requestTimeoutNanos; private final String user; - public StatementClient(HttpClient httpClient, JsonCodec queryResultsCodec, ClientSession session, String query) + public StatementClient(OkHttpClient httpClient, ClientSession session, String query) { requireNonNull(httpClient, "httpClient is null"); - requireNonNull(queryResultsCodec, "queryResultsCodec is null"); requireNonNull(session, "session is null"); requireNonNull(query, "query is null"); this.httpClient = httpClient; - this.responseHandler = createFullJsonResponseHandler(queryResultsCodec); this.debug = session.isDebug(); this.timeZone = session.getTimeZone(); this.query = query; @@ -113,48 +114,54 @@ public StatementClient(HttpClient httpClient, JsonCodec queryResul this.user = session.getUser(); Request request = buildQueryRequest(session, query); - JsonResponse response = httpClient.execute(request, responseHandler); - if (response.getStatusCode() != HttpStatus.OK.code() || !response.hasValue()) { + JsonResponse response = JsonResponse.execute(QUERY_RESULTS_CODEC, httpClient, request); + if ((response.getStatusCode() != HTTP_OK) || !response.hasValue()) { throw requestFailedException("starting query", request, response); } - processResponse(response); + processResponse(response.getHeaders(), response.getValue()); } private Request buildQueryRequest(ClientSession session, String query) { - Request.Builder builder = prepareRequest(preparePost(), uriBuilderFrom(session.getServer()).replacePath("/v1/statement").build()) - .setBodyGenerator(createStaticBodyGenerator(query, UTF_8)); + HttpUrl url = HttpUrl.get(session.getServer()); + if (url == null) { + throw new ClientException("Invalid server URL: " + session.getServer()); + } + url = url.newBuilder().encodedPath("/v1/statement").build(); + + Request.Builder builder = prepareRequest(url) + .post(RequestBody.create(MEDIA_TYPE_JSON, query)); if (session.getSource() != null) { - builder.setHeader(PrestoHeaders.PRESTO_SOURCE, session.getSource()); + builder.addHeader(PRESTO_SOURCE, session.getSource()); } if (session.getClientInfo() != null) { - builder.setHeader(PrestoHeaders.PRESTO_CLIENT_INFO, session.getClientInfo()); + builder.addHeader(PRESTO_CLIENT_INFO, session.getClientInfo()); } if (session.getCatalog() != null) { - builder.setHeader(PrestoHeaders.PRESTO_CATALOG, session.getCatalog()); + builder.addHeader(PRESTO_CATALOG, session.getCatalog()); } if (session.getSchema() != null) { - builder.setHeader(PrestoHeaders.PRESTO_SCHEMA, session.getSchema()); + builder.addHeader(PRESTO_SCHEMA, session.getSchema()); } - builder.setHeader(PrestoHeaders.PRESTO_TIME_ZONE, session.getTimeZone().getId()); + builder.addHeader(PRESTO_TIME_ZONE, session.getTimeZone().getId()); if (session.getLocale() != null) { - builder.setHeader(PrestoHeaders.PRESTO_LANGUAGE, session.getLocale().toLanguageTag()); + builder.addHeader(PRESTO_LANGUAGE, session.getLocale().toLanguageTag()); } Map property = session.getProperties(); for (Entry entry : property.entrySet()) { - builder.addHeader(PrestoHeaders.PRESTO_SESSION, entry.getKey() + "=" + entry.getValue()); + builder.addHeader(PRESTO_SESSION, entry.getKey() + "=" + entry.getValue()); } Map statements = session.getPreparedStatements(); for (Entry entry : statements.entrySet()) { - builder.addHeader(PrestoHeaders.PRESTO_PREPARED_STATEMENT, urlEncode(entry.getKey()) + "=" + urlEncode(entry.getValue())); + builder.addHeader(PRESTO_PREPARED_STATEMENT, urlEncode(entry.getKey()) + "=" + urlEncode(entry.getValue())); } - builder.setHeader(PrestoHeaders.PRESTO_TRANSACTION_ID, session.getTransactionId() == null ? "NONE" : session.getTransactionId()); + builder.addHeader(PRESTO_TRANSACTION_ID, session.getTransactionId() == null ? "NONE" : session.getTransactionId()); return builder.build(); } @@ -241,13 +248,12 @@ public boolean isValid() return valid.get() && (!isGone()) && (!isClosed()); } - private Request.Builder prepareRequest(Request.Builder builder, URI nextUri) + private Request.Builder prepareRequest(HttpUrl url) { - builder.setHeader(PrestoHeaders.PRESTO_USER, user); - builder.setHeader(USER_AGENT, USER_AGENT_VALUE) - .setUri(nextUri); - - return builder; + return new Request.Builder() + .addHeader(PRESTO_USER, user) + .addHeader(USER_AGENT, USER_AGENT_VALUE) + .url(url); } public boolean advance() @@ -258,7 +264,7 @@ public boolean advance() return false; } - Request request = prepareRequest(prepareGet(), nextUri).build(); + Request request = prepareRequest(HttpUrl.get(nextUri)).build(); Exception cause = null; long start = System.nanoTime(); @@ -284,19 +290,19 @@ public boolean advance() JsonResponse response; try { - response = httpClient.execute(request, responseHandler); + response = JsonResponse.execute(QUERY_RESULTS_CODEC, httpClient, request); } catch (RuntimeException e) { cause = e; continue; } - if (response.getStatusCode() == HttpStatus.OK.code() && response.hasValue()) { - processResponse(response); + if ((response.getStatusCode() == HTTP_OK) && response.hasValue()) { + processResponse(response.getHeaders(), response.getValue()); return true; } - if (response.getStatusCode() != HttpStatus.SERVICE_UNAVAILABLE.code()) { + if (response.getStatusCode() != HTTP_UNAVAILABLE) { throw requestFailedException("fetching next", request, response); } } @@ -306,77 +312,65 @@ public boolean advance() throw new RuntimeException("Error fetching next", cause); } - private void processResponse(JsonResponse response) + private void processResponse(Headers headers, QueryResults results) { - for (String setSession : response.getHeaders(PRESTO_SET_SESSION)) { + for (String setSession : headers.values(PRESTO_SET_SESSION)) { List keyValue = SESSION_HEADER_SPLITTER.splitToList(setSession); if (keyValue.size() != 2) { continue; } setSessionProperties.put(keyValue.get(0), keyValue.size() > 1 ? keyValue.get(1) : ""); } - for (String clearSession : response.getHeaders(PRESTO_CLEAR_SESSION)) { + for (String clearSession : headers.values(PRESTO_CLEAR_SESSION)) { resetSessionProperties.add(clearSession); } - for (String entry : response.getHeaders(PRESTO_ADDED_PREPARE)) { + for (String entry : headers.values(PRESTO_ADDED_PREPARE)) { List keyValue = SESSION_HEADER_SPLITTER.splitToList(entry); if (keyValue.size() != 2) { continue; } addedPreparedStatements.put(urlDecode(keyValue.get(0)), urlDecode(keyValue.get(1))); } - for (String entry : response.getHeaders(PRESTO_DEALLOCATED_PREPARE)) { + for (String entry : headers.values(PRESTO_DEALLOCATED_PREPARE)) { deallocatedPreparedStatements.add(urlDecode(entry)); } - String startedTransactionId = response.getHeader(PRESTO_STARTED_TRANSACTION_ID); + String startedTransactionId = headers.get(PRESTO_STARTED_TRANSACTION_ID); if (startedTransactionId != null) { this.startedtransactionId.set(startedTransactionId); } - if (response.getHeader(PRESTO_CLEAR_TRANSACTION_ID) != null) { + if (headers.values(PRESTO_CLEAR_TRANSACTION_ID) != null) { clearTransactionId.set(true); } - currentResults.set(response.getValue()); + currentResults.set(results); } private RuntimeException requestFailedException(String task, Request request, JsonResponse response) { gone.set(true); if (!response.hasValue()) { + if (response.getStatusCode() == HTTP_UNAUTHORIZED) { + return new ClientException("Authentication failed" + + Optional.ofNullable(response.getStatusMessage()) + .map(message -> ": " + message) + .orElse("")); + } return new RuntimeException( - format("Error %s at %s returned an invalid response: %s [Error: %s]", task, request.getUri(), response, response.getResponseBody()), + format("Error %s at %s returned an invalid response: %s [Error: %s]", task, request.url(), response, response.getResponseBody()), response.getException()); } - return new RuntimeException(format("Error %s at %s returned %s: %s", task, request.getUri(), response.getStatusCode(), response.getStatusMessage())); + return new RuntimeException(format("Error %s at %s returned HTTP %s", task, request.url(), response.getStatusCode())); } - public boolean cancelLeafStage(Duration timeout) + public void cancelLeafStage() { checkState(!isClosed(), "client is closed"); URI uri = current().getPartialCancelUri(); - if (uri == null) { - return false; - } - - Request request = prepareRequest(prepareDelete(), uri).build(); - - HttpResponseFuture response = httpClient.executeAsync(request, createStatusResponseHandler()); - try { - StatusResponse status = response.get(timeout.toMillis(), MILLISECONDS); - return familyForStatusCode(status.getStatusCode()) == Family.SUCCESSFUL; - } - catch (InterruptedException e) { - Thread.currentThread().interrupt(); - throw Throwables.propagate(e); - } - catch (ExecutionException e) { - throw Throwables.propagate(e.getCause()); - } - catch (TimeoutException e) { - return false; + if (uri != null) { + httpDelete(uri); } } @@ -386,12 +380,19 @@ public void close() if (!closed.getAndSet(true)) { URI uri = currentResults.get().getNextUri(); if (uri != null) { - Request request = prepareRequest(prepareDelete(), uri).build(); - httpClient.executeAsync(request, createStatusResponseHandler()); + httpDelete(uri); } } } + private void httpDelete(URI uri) + { + Request request = prepareRequest(HttpUrl.get(uri)) + .delete() + .build(); + httpClient.newCall(request).enqueue(new NullCallback()); + } + private static String urlEncode(String value) { try { diff --git a/presto-docs/pom.xml b/presto-docs/pom.xml index b01a12bf2df9..3382fdd23f25 100644 --- a/presto-docs/pom.xml +++ b/presto-docs/pom.xml @@ -5,7 +5,7 @@ com.facebook.presto presto-root - 0.179-tw-0.36 + 0.181-tw-0.37 presto-docs @@ -44,6 +44,80 @@ + + org.codehaus.mojo + exec-maven-plugin + + + validate-reserved + validate + + java + + + com.facebook.presto.sql.ReservedIdentifiers + + validateDocs + ${project.basedir}/src/main/sphinx/language/reserved.rst + + + + + generate-thrift-idl + validate + + java + + + com.facebook.swift.generator.swift2thrift.Main + + com.facebook.presto.connector.thrift.api.PrestoThriftService + -recursive + -out + ${project.build.directory}/PrestoThriftService.thrift + + + + + validate-thrift-idl + validate + + exec + + + diff + + -b + -c + ${project.basedir}/src/main/sphinx/include/PrestoThriftService.thrift + ${project.build.directory}/PrestoThriftService.thrift + + + + + + false + true + + + + com.facebook.presto + presto-parser + ${project.version} + + + com.facebook.presto + presto-thrift-connector-api + ${project.version} + + + com.facebook.swift + swift2thrift-generator-cli + ${dep.swift.version} + + + + io.airlift.maven.plugins sphinx-maven-plugin diff --git a/presto-docs/src/main/sphinx/admin/properties.rst b/presto-docs/src/main/sphinx/admin/properties.rst index 89c4d087aa95..6284c52f07bf 100644 --- a/presto-docs/src/main/sphinx/admin/properties.rst +++ b/presto-docs/src/main/sphinx/admin/properties.rst @@ -366,3 +366,51 @@ Optimizer Properties ``UNION ALL`` speed when write speed is not yet saturated. However, it may slow down queries in an already heavily loaded system. This can also be specified on a per-query basis using the ``push_table_write_through_union`` session property. + + +Regular Expression Function Properties +-------------------------------------- + +The following properties allow tuning the :doc:`/functions/regexp`. + +``regex-library`` +^^^^^^^^^^^^^^^^^ + + * **Type:** ``string`` + * **Allowed values:** ``JONI``, ``RE2J`` + * **Default value:** ``JONI`` + + Which library to use for regular expression functions. + ``JONI`` is generally faster for common usage, but can require exponential + time for certain expression patterns. ``RE2J`` uses a different algorithm + which guarantees linear time, but is often slower. + +``re2j.dfa-states-limit`` +^^^^^^^^^^^^^^^^^^^^^^^^^ + + * **Type:** ``integer`` + * **Minimum value:** ``2`` + * **Default value:** ``2147483647`` + + The maximum number of states to use when RE2J builds the fast + but potentially memory intensive deterministic finite automaton (DFA) + for regular expression matching. If the limit is reached, RE2J will fall + back to the algorithm that uses the slower, but less memory intensive + non-deterministic finite automaton (NFA). Decreasing this value decreases the + maximum memory footprint of a regular expression search at the cost of speed. + +``re2j.dfa-retries`` +^^^^^^^^^^^^^^^^^^^^ + + * **Type:** ``integer`` + * **Minimum value:** ``0`` + * **Default value:** ``5`` + + The number of times that RE2J will retry the DFA algorithm when + it reaches a states limit before using the slower, but less memory + intensive NFA algorithm for all future inputs for that search. If hitting the + limit for a given input row is likely to be an outlier, you want to be able + to process subsequent rows using the faster DFA algorithm. If you are likely + to hit the limit on matches for subsequent rows as well, you want to use the + correct algorithm from the beginning so as not to waste time and resources. + The more rows you are processing, the larger this value should be. diff --git a/presto-docs/src/main/sphinx/admin/resource-groups.rst b/presto-docs/src/main/sphinx/admin/resource-groups.rst index f17881b387cb..a65bdc131700 100644 --- a/presto-docs/src/main/sphinx/admin/resource-groups.rst +++ b/presto-docs/src/main/sphinx/admin/resource-groups.rst @@ -68,6 +68,15 @@ Selector Properties * ``source`` (optional): regex to match against source string. Defaults to ``.*`` +* ``queryType`` (optional): string to match against the type of the query submitted. The query type can be: + * ``DATA_DEFINITION``: Queries that alter/create/drop the metadata of schemas/tables/views, and that manage + prepared statements, privileges, sessions, and transactions. + * ``DELETE``: ``DELETE`` queries. + * ``DESCRIBE``: ``DESCRIBE``, ``DESCRIBE INPUT``, ``DESCRIBE OUTPUT``, and ``SHOW`` queries. + * ``EXPLAIN``: ``EXPLAIN`` queries. + * ``INSERT``: ``INSERT`` and ``CREATE TABLE AS SELECT`` queries. + * ``SELECT``: ``SELECT`` queries. + * ``group`` (required): the group these queries will run in. Global Properties @@ -89,7 +98,11 @@ There are three selectors that define which queries run in which resource group: * The first selector places queries from ``bob`` into the admin group. - * The second selector states that all queries that come from a source that includes "pipeline" + * The second selector states that all data definition queries that come from a source that includes "pipeline" + should run in the user's personal data definition group, which belongs to the + ``globa.data_definition`` parent group. + + * The third selector states that all queries that come from a source that includes "pipeline" should run in the user's personal pipeline group, which belongs to the ``global.pipeline`` parent group. @@ -116,6 +129,13 @@ all other users are subject to the following limits: "schedulingPolicy": "weighted", "jmxExport": true, "subGroups": [ + { + "name": "data_definition_${USER}", + "softMemoryLimit": "10%", + "maxRunning": 3, + "maxQueued": 10, + "schedulingWeight": 1 + }, { "name": "adhoc_${USER}", "softMemoryLimit": "10%", @@ -157,6 +177,11 @@ all other users are subject to the following limits: "user": "bob", "group": "admin" }, + { + "source": ".*pipeline.*", + "queryType": "DATA_DEFINITION", + "group": "global.data_definition_${USER}" + }, { "source": ".*pipeline.*", "group": "global.pipeline.pipeline_${USER}" diff --git a/presto-docs/src/main/sphinx/connector.rst b/presto-docs/src/main/sphinx/connector.rst index ee10d9c29cf4..651e22ab96a0 100644 --- a/presto-docs/src/main/sphinx/connector.rst +++ b/presto-docs/src/main/sphinx/connector.rst @@ -25,3 +25,4 @@ from different data sources. connector/sqlserver connector/system connector/tpch + connector/thrift diff --git a/presto-docs/src/main/sphinx/connector/cassandra.rst b/presto-docs/src/main/sphinx/connector/cassandra.rst index 3572e068d40c..9c93f5bbb432 100644 --- a/presto-docs/src/main/sphinx/connector/cassandra.rst +++ b/presto-docs/src/main/sphinx/connector/cassandra.rst @@ -230,4 +230,3 @@ Limitations query with a partition key as a filter. * ``IN`` list filters are only allowed on index (that is, partition key or clustering key) columns. * Range (``<`` or ``>`` and ``BETWEEN``) filters can be applied only to the partition keys. -* Non-equality predicates on clustering keys are not pushed down (only ``=`` and ``IN`` are pushed down) . diff --git a/presto-docs/src/main/sphinx/connector/hive-security.rst b/presto-docs/src/main/sphinx/connector/hive-security.rst index cb09f5a7add4..26b1f1edf069 100644 --- a/presto-docs/src/main/sphinx/connector/hive-security.rst +++ b/presto-docs/src/main/sphinx/connector/hive-security.rst @@ -19,7 +19,8 @@ Property Value Description ================================================== ============================================================ ``legacy`` (default value) Few authorization checks are enforced, thus allowing most operations. The config properties ``hive.allow-drop-table``, - ``hive.allow-rename-table``, ``hive.allow-add-column`` and + ``hive.allow-rename-table``, ``hive.allow-add-column``, + ``hive.allow-drop-column`` and ``hive.allow-rename-column`` are used. ``read-only`` Operations that read data or metadata, such as ``SELECT``, diff --git a/presto-docs/src/main/sphinx/connector/thrift.rst b/presto-docs/src/main/sphinx/connector/thrift.rst new file mode 100644 index 000000000000..715f2bec2fec --- /dev/null +++ b/presto-docs/src/main/sphinx/connector/thrift.rst @@ -0,0 +1,96 @@ +================ +Thrift Connector +================ + +The Thrift connector makes it possible to integrate with external storage systems +without a custom Presto connector implementation. + +In order to use the Thrift connector with an external system, you need to implement +the ``PrestoThriftService`` interface, found below. Next, you configure the Thrift connector +to point to a set of machines, called Thrift servers, that implement the interface. +As part of the interface implementation, the Thrift servers will provide metadata, +splits and data. The Thrift server instances are assumed to be stateless and independent +from each other. + +Configuration +------------- + +To configure the Thrift connector, create a catalog properties file +``etc/catalog/thrift.properties`` with the following content, +replacing the properties as appropriate: + +.. code-block:: none + + connector.name=presto-thrift + static-location.hosts=host:port,host:port + +Multiple Thrift Systems +^^^^^^^^^^^^^^^^^^^^^^^ + +You can have as many catalogs as you need, so if you have additional +Thrift systems to connect to, simply add another properties file to ``etc/catalog`` +with a different name (making sure it ends in ``.properties``). + +Configuration Properties +------------------------ + +The following configuration properties are available: + +=========================================== ============================================================== +Property Name Description +=========================================== ============================================================== +``static-location.hosts`` Location of Thrift servers +``presto-thrift.max-response-size`` Maximum size of a response from thrift server +``presto-thrift.metadata-refresh-threads`` Number of refresh threads for metadata cache +=========================================== ============================================================== + +``static-location.hosts`` +^^^^^^^^^^^^^^^^^^^^^^^^^ + +Comma-separated list of thrift servers in the form of ``host:port``. For example: + +.. code-block:: none + + static-location.hosts=192.0.2.3:7777,192.0.2.4:7779 + +This property is required; there is no default. + +``presto-thrift.max-response-size`` +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Maximum size of a data response that the connector accepts. This value is sent +by the connector to the Thrift server when requesting data, allowing it to size +the response appropriately. + +This property is optional; the default is ``16MB``. + +``presto-thrift.metadata-refresh-threads`` +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Number of refresh threads for metadata cache. + +This property is optional; the default is ``1``. + +Thrift Client Properties +^^^^^^^^^^^^^^^^^^^^^^^^ + +The following properties allow configuring the Thrift client used by the connector: + +===================================================== =================== ============= +Property Name Description Default Value +===================================================== =================== ============= +``PrestoThriftService.thrift.client.connect-timeout`` Connect timeout ``500ms`` +``PrestoThriftService.thrift.client.max-frame-size`` Max frame size ``16777216`` +``PrestoThriftService.thrift.client.read-timeout`` Read timeout ``10s`` +``PrestoThriftService.thrift.client.receive-timeout`` Receive timeout ``1m`` +``PrestoThriftService.thrift.client.socks-proxy`` Socks proxy address ``null`` +``PrestoThriftService.thrift.client.write-timeout`` Write timeout ``1m`` +===================================================== =================== ============= + +Thrift IDL File +--------------- + +The following IDL describes the ``PrestoThriftService`` that must be implemented: + +.. literalinclude:: /include/PrestoThriftService.thrift + :language: thrift diff --git a/presto-docs/src/main/sphinx/functions/binary.rst b/presto-docs/src/main/sphinx/functions/binary.rst index e0cb5897f514..a822a3ca0f70 100644 --- a/presto-docs/src/main/sphinx/functions/binary.rst +++ b/presto-docs/src/main/sphinx/functions/binary.rst @@ -1,6 +1,11 @@ -================ -Binary Functions -================ +============================== +Binary Functions and Operators +============================== + +Binary Operators +---------------- + +The ``||`` operator performs concatenation. Binary Functions ---------------- @@ -10,6 +15,13 @@ Binary Functions Returns the length of ``binary`` in bytes. +.. function:: concat(binary1, ..., binaryN) -> varbinary + :noindex: + + Returns the concatenation of ``binary1``, ``binary2``, ``...``, ``binaryN``. + This function provides the same functionality as the + SQL-standard concatenation operator (``||``). + .. function:: to_base64(binary) -> varchar Encodes ``binary`` into a base64 string representation. diff --git a/presto-docs/src/main/sphinx/functions/datetime.rst b/presto-docs/src/main/sphinx/functions/datetime.rst index defac4f571cf..728f5a26a329 100644 --- a/presto-docs/src/main/sphinx/functions/datetime.rst +++ b/presto-docs/src/main/sphinx/functions/datetime.rst @@ -53,6 +53,10 @@ Date and Time Functions Returns the current time zone in the format defined by IANA (e.g., ``America/Los_Angeles``) or as fixed offset from UTC (e.g., ``+08:35``) +.. function:: date(x) -> date + + This is an alias for ``CAST(x AS date)``. + .. function:: from_iso8601_timestamp(string) -> timestamp with time zone Parses the ISO 8601 formatted ``string`` into a ``timestamp with time zone``. diff --git a/presto-docs/src/main/sphinx/include/PrestoThriftService.thrift b/presto-docs/src/main/sphinx/include/PrestoThriftService.thrift new file mode 100644 index 000000000000..9d2052d6c92b --- /dev/null +++ b/presto-docs/src/main/sphinx/include/PrestoThriftService.thrift @@ -0,0 +1,325 @@ +namespace java.swift com.facebook.presto.connector.thrift.api + + +enum PrestoThriftBound { + BELOW=1, EXACTLY=2, ABOVE=3 +} + +exception PrestoThriftServiceException { + 1: string message; + 2: bool retryable; +} + +struct PrestoThriftNullableSchemaName { + 1: optional string schemaName; +} + +struct PrestoThriftSchemaTableName { + 1: string schemaName; + 2: string tableName; +} + +struct PrestoThriftColumnMetadata { + 1: string name; + 2: string type; + 3: optional string comment; + 4: bool hidden; +} + +struct PrestoThriftNullableColumnSet { + 1: optional set columns; +} + + +/** + * Set that either includes all values, or excludes all values. + */ +struct PrestoThriftAllOrNoneValueSet { + 1: bool all; +} + + +/** + * Elements of {@code nulls} array determine if a value for a corresponding row is null. + * Elements of {@code ints} array are values for each row. If row is null then value is ignored. + */ +struct PrestoThriftInteger { + 1: optional list nulls; + 2: optional list ints; +} + + +/** + * Elements of {@code nulls} array determine if a value for a corresponding row is null. + * Elements of {@code longs} array are values for each row. If row is null then value is ignored. + */ +struct PrestoThriftBigint { + 1: optional list nulls; + 2: optional list longs; +} + + +/** + * Elements of {@code nulls} array determine if a value for a corresponding row is null. + * Elements of {@code doubles} array are values for each row. If row is null then value is ignored. + */ +struct PrestoThriftDouble { + 1: optional list nulls; + 2: optional list doubles; +} + + +/** + * Elements of {@code nulls} array determine if a value for a corresponding row is null. + * Each elements of {@code sizes} array contains the length in bytes for the corresponding element. + * If row is null then the corresponding element in {@code sizes} is ignored. + * {@code bytes} array contains uft8 encoded byte values. + * Values for all rows are written to {@code bytes} array one after another. + * The total number of bytes must be equal to the sum of all sizes. + */ +struct PrestoThriftVarchar { + 1: optional list nulls; + 2: optional list sizes; + 3: optional binary bytes; +} + + +/** + * Elements of {@code nulls} array determine if a value for a corresponding row is null. + * Elements of {@code booleans} array are values for each row. If row is null then value is ignored. + */ +struct PrestoThriftBoolean { + 1: optional list nulls; + 2: optional list booleans; +} + + +/** + * Elements of {@code nulls} array determine if a value for a corresponding row is null. + * Elements of {@code dates} array are date values for each row represented as the number + * of days passed since 1970-01-01. + * If row is null then value is ignored. + */ +struct PrestoThriftDate { + 1: optional list nulls; + 2: optional list dates; +} + + +/** + * Elements of {@code nulls} array determine if a value for a corresponding row is null. + * Elements of {@code timestamps} array are values for each row represented as the number + * of milliseconds passed since 1970-01-01T00:00:00 UTC. + * If row is null then value is ignored. + */ +struct PrestoThriftTimestamp { + 1: optional list nulls; + 2: optional list timestamps; +} + + +/** + * Elements of {@code nulls} array determine if a value for a corresponding row is null. + * Each elements of {@code sizes} array contains the length in bytes for the corresponding element. + * If row is null then the corresponding element in {@code sizes} is ignored. + * {@code bytes} array contains uft8 encoded byte values for string representation of json. + * Values for all rows are written to {@code bytes} array one after another. + * The total number of bytes must be equal to the sum of all sizes. + */ +struct PrestoThriftJson { + 1: optional list nulls; + 2: optional list sizes; + 3: optional binary bytes; +} + + +/** + * Elements of {@code nulls} array determine if a value for a corresponding row is null. + * Each elements of {@code sizes} array contains the length in bytes for the corresponding element. + * If row is null then the corresponding element in {@code sizes} is ignored. + * {@code bytes} array contains encoded byte values for HyperLogLog representation as defined in + * Airlift specification: href="https://github.com/airlift/airlift/blob/master/stats/docs/hll.md + * Values for all rows are written to {@code bytes} array one after another. + * The total number of bytes must be equal to the sum of all sizes. + */ +struct PrestoThriftHyperLogLog { + 1: optional list nulls; + 2: optional list sizes; + 3: optional binary bytes; +} + + +/** + * Elements of {@code nulls} array determine if a value for a corresponding row is null. + * Each elements of {@code sizes} array contains the number of elements in the corresponding values array. + * If row is null then the corresponding element in {@code sizes} is ignored. + * {@code values} is a bigint block containing array elements one after another for all rows. + * The total number of elements in bigint block must be equal to the sum of all sizes. + */ +struct PrestoThriftBigintArray { + 1: optional list nulls; + 2: optional list sizes; + 3: optional PrestoThriftBigint values; +} + +struct PrestoThriftId { + 1: binary id; +} + +struct PrestoThriftHostAddress { + 1: string host; + 2: i32 port; +} + +struct PrestoThriftTableMetadata { + 1: PrestoThriftSchemaTableName schemaTableName; + 2: list columns; + 3: optional string comment; +} + +struct PrestoThriftBlock { + 1: optional PrestoThriftInteger integerData; + 2: optional PrestoThriftBigint bigintData; + 3: optional PrestoThriftDouble doubleData; + 4: optional PrestoThriftVarchar varcharData; + 5: optional PrestoThriftBoolean booleanData; + 6: optional PrestoThriftDate dateData; + 7: optional PrestoThriftTimestamp timestampData; + 8: optional PrestoThriftJson jsonData; + 9: optional PrestoThriftHyperLogLog hyperLogLogData; + 10: optional PrestoThriftBigintArray bigintArrayData; +} + + +/** + * LOWER UNBOUNDED is specified with an empty value and an ABOVE bound + * UPPER UNBOUNDED is specified with an empty value and a BELOW bound + */ +struct PrestoThriftMarker { + 1: optional PrestoThriftBlock value; + 2: PrestoThriftBound bound; +} + +struct PrestoThriftNullableToken { + 1: optional PrestoThriftId token; +} + +struct PrestoThriftSplit { + 1: PrestoThriftId splitId; + 2: list hosts; +} + +struct PrestoThriftPageResult { + /** + * Returns data in a columnar format. + * Columns in this list must be in the order they were requested by the engine. + */ + 1: list columnBlocks; + 2: i32 rowCount; + 3: optional PrestoThriftId nextToken; +} + +struct PrestoThriftNullableTableMetadata { + 1: optional PrestoThriftTableMetadata tableMetadata; +} + + +/** + * A set containing values that are uniquely identifiable. + * Assumes an infinite number of possible values. The values may be collectively included (aka whitelist) + * or collectively excluded (aka !whitelist). + * This structure is used with comparable, but not orderable types like "json", "map". + */ +struct PrestoThriftEquatableValueSet { + 1: bool whiteList; + 2: list values; +} + +struct PrestoThriftRange { + 1: PrestoThriftMarker low; + 2: PrestoThriftMarker high; +} + +struct PrestoThriftSplitBatch { + 1: list splits; + 2: optional PrestoThriftId nextToken; +} + + +/** + * A set containing zero or more Ranges of the same type over a continuous space of possible values. + * Ranges are coalesced into the most compact representation of non-overlapping Ranges. + * This structure is used with comparable and orderable types like bigint, integer, double, varchar, etc. + */ +struct PrestoThriftRangeValueSet { + 1: list ranges; +} + +struct PrestoThriftValueSet { + 1: optional PrestoThriftAllOrNoneValueSet allOrNoneValueSet; + 2: optional PrestoThriftEquatableValueSet equatableValueSet; + 3: optional PrestoThriftRangeValueSet rangeValueSet; +} + +struct PrestoThriftDomain { + 1: PrestoThriftValueSet valueSet; + 2: bool nullAllowed; +} + +struct PrestoThriftTupleDomain { + /** + * Return a map of column names to constraints. + */ + 1: optional map domains; +} + +/** + * Presto Thrift service definition. + * This thrift service needs to be implemented in order to be used with Thrift Connector. + */ +service PrestoThriftService { + /** + * Returns available schema names. + */ + list prestoListSchemaNames() throws (1: PrestoThriftServiceException ex1); + + /** + * Returns tables for the given schema name. + * + * @param schemaNameOrNull a structure containing schema name or {@literal null} + * @return a list of table names with corresponding schemas. If schema name is null then returns + * a list of tables for all schemas. Returns an empty list if a schema does not exist + */ + list prestoListTables(1: PrestoThriftNullableSchemaName schemaNameOrNull) throws (1: PrestoThriftServiceException ex1); + + /** + * Returns metadata for a given table. + * + * @param schemaTableName schema and table name + * @return metadata for a given table, or a {@literal null} value inside if it does not exist + */ + PrestoThriftNullableTableMetadata prestoGetTableMetadata(1: PrestoThriftSchemaTableName schemaTableName) throws (1: PrestoThriftServiceException ex1); + + /** + * Returns a batch of splits. + * + * @param schemaTableName schema and table name + * @param desiredColumns a superset of columns to return; empty set means "no columns", {@literal null} set means "all columns" + * @param outputConstraint constraint on the returned data + * @param maxSplitCount maximum number of splits to return + * @param nextToken token from a previous split batch or {@literal null} if it is the first call + * @return a batch of splits + */ + PrestoThriftSplitBatch prestoGetSplits(1: PrestoThriftSchemaTableName schemaTableName, 2: PrestoThriftNullableColumnSet desiredColumns, 3: PrestoThriftTupleDomain outputConstraint, 4: i32 maxSplitCount, 5: PrestoThriftNullableToken nextToken) throws (1: PrestoThriftServiceException ex1); + + /** + * Returns a batch of rows for the given split. + * + * @param splitId split id as returned in split batch + * @param columns a list of column names to return + * @param maxBytes maximum size of returned data in bytes + * @param nextToken token from a previous batch or {@literal null} if it is the first call + * @return a batch of table data + */ + PrestoThriftPageResult prestoGetRows(1: PrestoThriftId splitId, 2: list columns, 3: i64 maxBytes, 4: PrestoThriftNullableToken nextToken) throws (1: PrestoThriftServiceException ex1); +} diff --git a/presto-docs/src/main/sphinx/installation/jdbc.rst b/presto-docs/src/main/sphinx/installation/jdbc.rst index b02a98d3028a..c31202191c14 100644 --- a/presto-docs/src/main/sphinx/installation/jdbc.rst +++ b/presto-docs/src/main/sphinx/installation/jdbc.rst @@ -4,6 +4,20 @@ JDBC Driver Presto can be accessed from Java using the JDBC driver. Download :maven_download:`jdbc` and add it to the class path of your Java application. + +The driver is also available from Maven Central: + +.. parsed-literal:: + + + com.facebook.presto + presto-jdbc + \ |version|\ + + +Connecting +---------- + The following JDBC URL formats are supported: .. code-block:: none @@ -19,3 +33,61 @@ and the schema ``sales``: .. code-block:: none jdbc:presto://example.net:8080/hive/sales + +The above URL can be used as follows to create a connection: + +.. code-block:: java + + String url = "jdbc:presto://example.net:8080/hive/sales"; + Connection connection = DriverManager.getConnection(url, "test", null); + +Connection Parameters +--------------------- + +The driver supports various parameters that may be set as URL parameters +or as properties passed to ``DriverManager``. Both of the following +examples are equivalent: + +.. code-block:: java + + // URL parameters + String url = "jdbc:presto://example.net:8080/hive/sales"; + Properties properties = new Properties(); + properties.setProperty("user", "test"); + properties.setProperty("password", "secret"); + properties.setProperty("SSL", "true"); + Connection connection = DriverManager.getConnection(url, properties); + + // properties + String url = "jdbc:presto://example.net:8080/hive/sales?user=test&password=secret&SSL=true"; + Connection connection = DriverManager.getConnection(url); + +These methods may be mixed; some parameters may be specified in the URL +while others are specified using properties. However, the same parameter +may not be specified using both methods. + +Parameter Reference +------------------- + +================================= ======================================================================= +Name Description +================================= ======================================================================= +``user`` Username to use for authentication and authorization. +``password`` Password to use for LDAP authentication. +``socksProxy`` SOCKS proxy host and port. Example: ``localhost:1080`` +``httpProxy`` HTTP proxy host and port. Example: ``localhost:8888`` +``SSL`` Use HTTPS for connections +``SSLTrustStorePath`` The location of the Java TrustStore file that will be used + to validate HTTPS server certificates. +``SSLTrustStorePassword`` The password for the TrustStore. +``KerberosRemoteServiceName`` Presto coordinator Kerberos service name. This parameter is + required for Kerberos authentiation. +``KerberosPrincipal`` The principal to use when authenticating to the Presto coordinator. +``KerberosUseCanonicalHostname`` Use the canonical hostname of the Presto coordinator for the Kerberos + service principal by first resolving the hostname to an IP address + and then doing a reverse DNS lookup for that IP address. + This is enabled by default. +``KerberosConfigPath`` Kerberos configuration file. +``KerberosKeytabPath`` Kerberos keytab file. +``KerberosCredentialCachePath`` Kerberos credential cache. +================================= ======================================================================= diff --git a/presto-docs/src/main/sphinx/language.rst b/presto-docs/src/main/sphinx/language.rst index 63f9d14fdeaa..a19e4b60733b 100644 --- a/presto-docs/src/main/sphinx/language.rst +++ b/presto-docs/src/main/sphinx/language.rst @@ -6,3 +6,4 @@ SQL Language :maxdepth: 1 language/types + language/reserved diff --git a/presto-docs/src/main/sphinx/language/reserved.rst b/presto-docs/src/main/sphinx/language/reserved.rst new file mode 100644 index 000000000000..9cf05af04d9e --- /dev/null +++ b/presto-docs/src/main/sphinx/language/reserved.rst @@ -0,0 +1,80 @@ +================= +Reserved Keywords +================= + +The following table lists all of the keywords that are reserved in Presto, +along with their status in the SQL standard. These reserved keywords must +be quoted (using double quotes) in order to be used as an identifier. + +============================== ============= ============= +Keyword SQL:2016 SQL-92 +============================== ============= ============= +``ALTER`` reserved reserved +``AND`` reserved reserved +``AS`` reserved reserved +``BETWEEN`` reserved reserved +``BY`` reserved reserved +``CASE`` reserved reserved +``CAST`` reserved reserved +``CONSTRAINT`` reserved reserved +``CREATE`` reserved reserved +``CROSS`` reserved reserved +``CUBE`` reserved +``CURRENT_DATE`` reserved reserved +``CURRENT_TIME`` reserved reserved +``CURRENT_TIMESTAMP`` reserved reserved +``DEALLOCATE`` reserved reserved +``DELETE`` reserved reserved +``DESCRIBE`` reserved reserved +``DISTINCT`` reserved reserved +``DROP`` reserved reserved +``ELSE`` reserved reserved +``END`` reserved reserved +``ESCAPE`` reserved reserved +``EXCEPT`` reserved reserved +``EXECUTE`` reserved reserved +``EXISTS`` reserved reserved +``EXTRACT`` reserved reserved +``FALSE`` reserved reserved +``FOR`` reserved reserved +``FROM`` reserved reserved +``FULL`` reserved reserved +``GROUP`` reserved reserved +``GROUPING`` reserved +``HAVING`` reserved reserved +``IN`` reserved reserved +``INNER`` reserved reserved +``INSERT`` reserved reserved +``INTERSECT`` reserved reserved +``INTO`` reserved reserved +``IS`` reserved reserved +``JOIN`` reserved reserved +``LEFT`` reserved reserved +``LIKE`` reserved reserved +``LOCALTIME`` reserved +``LOCALTIMESTAMP`` reserved +``NATURAL`` reserved reserved +``NORMALIZE`` reserved +``NOT`` reserved reserved +``NULL`` reserved reserved +``ON`` reserved reserved +``OR`` reserved reserved +``ORDER`` reserved reserved +``OUTER`` reserved reserved +``PREPARE`` reserved reserved +``RECURSIVE`` reserved +``RIGHT`` reserved reserved +``ROLLUP`` reserved +``SELECT`` reserved reserved +``TABLE`` reserved reserved +``THEN`` reserved reserved +``TRUE`` reserved reserved +``UESCAPE`` reserved +``UNION`` reserved reserved +``UNNEST`` reserved +``USING`` reserved reserved +``VALUES`` reserved reserved +``WHEN`` reserved reserved +``WHERE`` reserved reserved +``WITH`` reserved reserved +============================== ============= ============= diff --git a/presto-docs/src/main/sphinx/release.rst b/presto-docs/src/main/sphinx/release.rst index 2bde0e3eeab3..a6f2bdfb3b37 100644 --- a/presto-docs/src/main/sphinx/release.rst +++ b/presto-docs/src/main/sphinx/release.rst @@ -5,6 +5,8 @@ Release Notes .. toctree:: :maxdepth: 1 + release/release-0.181 + release/release-0.180 release/release-0.179 release/release-0.178 release/release-0.177 diff --git a/presto-docs/src/main/sphinx/release/release-0.180.rst b/presto-docs/src/main/sphinx/release/release-0.180.rst new file mode 100644 index 000000000000..4befdab62818 --- /dev/null +++ b/presto-docs/src/main/sphinx/release/release-0.180.rst @@ -0,0 +1,63 @@ +============= +Release 0.180 +============= + +General Changes +--------------- + +* Fix a rare bug where rows containing only ``null`` values are not returned + to the client. This only occurs when an entire result page contains only + ``null`` values. The only known case is a query over an ORC encoded Hive table + that does not perform any transformation of the data. +* Fix incorrect results when performing comparisons between values of approximate + data types (``REAL``, ``DOUBLE``) and columns of certain exact numeric types + (``INTEGER``, ``BIGINT``, ``DECIMAL``). +* Fix memory accounting for :func:`min_by` and :func:`max_by` on complex types. +* Fix query failure due to ``NoClassDefFoundError`` when scalar functions declared + in plugins are implemented with instance methods. +* Improve performance of map subscript from O(n) to O(1) in all cases. Previously, only maps + produced by certain functions and readers could take advantage of this improvement. +* Skip unknown costs in ``EXPLAIN`` output. +* Support :doc:`/security/internal-communication` between Presto nodes. +* Add initial support for ``CROSS JOIN`` against ``LATERAL`` derived tables. +* Add support for ``VARBINARY`` concatenation. +* Add :doc:`/connector/thrift` that makes it possible to use Presto with + external systems without the need to implement a custom connector. +* Add experimental ``/v1/resourceGroupState`` REST endpoint on coordinator. + +Hive Changes +------------ + +* Fix skipping short decimal values in the optimized Parquet reader + when they are backed by the ``int32`` or ``int64`` types. +* Ignore partition bucketing if table is not bucketed. This allows dropping + the bucketing from table metadata but leaving it for old partitions. +* Improve error message for Hive partitions dropped during execution. +* The optimized RCFile writer is enabled by default, but can be disabled + with the ``hive.rcfile-optimized-writer.enabled`` config option. + The writer supports validation which reads back the entire file after + writing. Validation is disabled by default, but can be enabled with the + ``hive.rcfile.writer.validate`` config option. + +Cassandra Changes +----------------- + +* Add support for ``INSERT``. +* Add support for pushdown of non-equality predicates on clustering keys. + +JDBC Driver Changes +------------------- + +* Add support for authenticating using Kerberos. +* Allow configuring SSL/TLS and Kerberos properties on a per-connection basis. +* Add support for executing queries using a SOCKS or HTTP proxy. + +CLI Changes +----------- + +* Add support for executing queries using an HTTP proxy. + +SPI Changes +----------- + +* Add running time limit and queued time limit to ``ResourceGroupInfo``. diff --git a/presto-docs/src/main/sphinx/release/release-0.181.rst b/presto-docs/src/main/sphinx/release/release-0.181.rst new file mode 100644 index 000000000000..5e473f888232 --- /dev/null +++ b/presto-docs/src/main/sphinx/release/release-0.181.rst @@ -0,0 +1,66 @@ +============= +Release 0.181 +============= + +General Changes +--------------- + +* Fix query failure and memory usage tracking when query contains + :func:`transform_keys` or :func:`transform_values`. +* Prevent ``CREATE TABLE IF NOT EXISTS`` queries from ever failing with *"Table already exists"*. +* Fix query failure when ``ORDER BY`` expressions reference columns that are used in + the ``GROUP BY`` clause by their fully-qualified name. +* Fix excessive GC overhead caused by large arrays and maps containing ``VARCHAR`` elements. +* Improve error handling when passing too many arguments to various + functions or operators that take a variable number of arguments. +* Improve performance of ``count(*)`` aggregations over subqueries with known + constant cardinality. +* Add ``VERBOSE`` option for :doc:`/sql/explain-analyze` that provides additional + low-level details about query performance. +* Add per-task distribution information to the output of ``EXPLAIN ANALYZE``. +* Add support for ``DROP COLUMN`` in :doc:`/sql/alter-table`. +* Change local scheduler to prevent starvation of long running queries + when the cluster is under constant load from short queries. The new + behavior is disabled by default and can be enabled by setting the + config property ``task.level-absolute-priority=true``. +* Improve the fairness of the local scheduler such that long-running queries + which spend more time on the CPU per scheduling quanta (e.g., due to + slow connectors) do not get a disproportionate share of CPU. The new + behavior is disabled by default and can be enabled by setting the + config property ``task.legacy-scheduling-behavior=false``. +* Add a config option to control the prioritization of queries based on + elapsed scheduled time. The ``task.level-time-multiplier`` property + controls the target scheduled time of a level relative to the next + level. Higher values for this property increase the fraction of CPU + that will be allocated to shorter queries. This config property only + has an effect when ``task.level-absolute-priority=true`` and + ``task.legacy-scheduling-behavior=false``. + +Hive Changes +------------ + +* Fix potential native memory leak when writing tables using RCFile. +* Correctly categorize certain errors when writing tables using RCFile. +* Decrease the number of file system metadata calls when reading tables. +* Add support for dropping columns. + +JDBC Driver Changes +------------------- + +* Add support for query cancellation using ``Statement.cancel()``. + +PostgreSQL Changes +------------------ + +* Add support for operations on external tables. + +Accumulo Changes +---------------- + +* Improve query performance by scanning index ranges in parallel. + +SPI Changes +----------- + +* Fix regression that broke serialization for ``SchemaTableName``. +* Add access control check for ``DROP COLUMN``. diff --git a/presto-docs/src/main/sphinx/security.rst b/presto-docs/src/main/sphinx/security.rst index f75e13ab52bb..27b08492d4d0 100644 --- a/presto-docs/src/main/sphinx/security.rst +++ b/presto-docs/src/main/sphinx/security.rst @@ -10,3 +10,4 @@ Security security/ldap security/tls security/built-in-system-access-control + security/internal-communication diff --git a/presto-docs/src/main/sphinx/security/internal-communication.rst b/presto-docs/src/main/sphinx/security/internal-communication.rst new file mode 100644 index 000000000000..e3186a45cc42 --- /dev/null +++ b/presto-docs/src/main/sphinx/security/internal-communication.rst @@ -0,0 +1,156 @@ +============================= +Secure Internal Communication +============================= + +The Presto cluster can be configured to use secured communication. Communication +between Presto nodes can be secured with SSL/TLS. + +Internal SSL/TLS configuration +------------------------------ + +SSL/TLS is configured in the ``config.properties`` file. The SSL/TLS on the +worker and coordinator nodes are configured using the same set of properties. +Every node in the cluster must be configured. Nodes that have not been +configured, or are configured incorrectly, will not be able to communicate with +other nodes in the cluster. + +To enable SSL/TLS for Presto internal communication, do the following: + +1. Disable HTTP endpoint. + + .. code-block:: none + + http-server.http.enabled=false + + .. warning:: + + You can enable HTTPS while leaving HTTP enabled. In most cases this is a + security hole. If you are certain you want to use this configuration, you + should consider using an firewall to limit access to the HTTP endpoint to + only those hosts that should be allowed to use it. + +2. Configure the cluster to communicate using the fully qualified domain name (fqdn) + of the cluster nodes. This can be done in either of the following ways: + + - If the DNS service is configured properly, we can just let the nodes to + introduce themselves to the coordinator using the hostname taken from + the system configuration (``hostname --fqdn``) + + .. code-block:: none + + node.internal-address-source=FQDN + + - It is also possible to specify each node's fully-qualified hostname manually. + This will be different for every host. Hosts should be in the same domain to + make it easy to create the correct SSL/TLS certificates. + e.g.: ``coordinator.example.com``, ``worker1.example.com``, ``worker2.example.com``. + + .. code-block:: none + + node.internal-address= + + +3. Generate a Java Keystore File. Every Presto node must be able to connect to + any other node within the same cluster. It is possible to create unique + certificates for every node using the fully-qualified hostname of each host, + create a keystore that contains all the public keys for all of the hosts, + and specify it for the client (``http-client.https.keystore.path``). In most + cases it will be simpler to use a wildcard in the certificate as shown + below. + + .. code-block:: none + + keytool -genkeypair -alias example.com -keyalg RSA -keystore keystore.jks + Enter keystore password: + Re-enter new password: + What is your first and last name? + [Unknown]: *.example.com + What is the name of your organizational unit? + [Unknown]: + What is the name of your organization? + [Unknown]: + What is the name of your City or Locality? + [Unknown]: + What is the name of your State or Province? + [Unknown]: + What is the two-letter country code for this unit? + [Unknown]: + Is CN=*.example.com, OU=Unknown, O=Unknown, L=Unknown, ST=Unknown, C=Unknown correct? + [no]: yes + + Enter key password for + (RETURN if same as keystore password): + + .. Note: Replace `example.com` with the appropriate domain. + +4. Distribute the Java Keystore File across the Presto cluster. + +5. Enable the HTTPS endpoint. + + .. code-block:: none + + http-server.https.enabled=true + http-server.https.port= + http-server.https.keystore.path= + http-server.https.keystore.key= + +6. Change the discovery uri to HTTPS. + + .. code-block:: none + + discovery.uri=https://: + +7. Configure the internal communication to require HTTPS. + + .. code-block:: none + + internal-communication.https.required=true + +8. Configure the internal communication to use the Java keystore file. + + .. code-block:: none + + internal-communication.https.keystore.path= + internal-communication.https.keystore.key= + + +Performance with SSL/TLS enabled +-------------------------------- + +Enabling encryption impacts performance. The performance degradation can vary +based on the environment, queries, and concurrency. + +For queries that do not require transferring too much data between the Presto +nodes (e.g. ``SELECT count(*) FROM table``), the performance impact is negligible. + +However, for CPU intensive queries which require a considerable amount of data +to be transferred between the nodes (for example, distributed joins, aggregations and +window functions, which require repartitioning), the performance impact might be +considerable. The slowdown may vary from 10% to even 100%+, depending on the network +traffic and the CPU utilization. + +Advanced Performance Tuning +--------------------------- + +In some cases, changing the source of random numbers will improve performance +significantly. + +By default, TLS encryption uses the ``/dev/urandom`` system device as a source of entropy. +This device has limited throughput, so on environments with high network bandwidth +(e.g. InfiniBand), it may become a bottleneck. In such situations, it is recommended to try +to switch the random number generator algorithm to ``SHA1PRNG``, by setting it via +``http-server.https.secure-random-algorithm`` property in ``config.properties`` on the coordinator +and all of the workers: + + .. code-block:: none + + http-server.https.secure-random-algorithm=SHA1PRNG + +Be aware that this algorithm takes the initial seed from +the blocking ``/dev/random`` device. For environments that do not have enough entropy to seed +the ``SHAPRNG`` algorithm, the source can be changed to ``/dev/urandom`` +by adding the ``java.security.egd`` property to ``jvm.config``: + + .. code-block:: none + + -Djava.security.egd=file:/dev/urandom diff --git a/presto-docs/src/main/sphinx/sql/alter-table.rst b/presto-docs/src/main/sphinx/sql/alter-table.rst index 16c976884096..4564863e0963 100644 --- a/presto-docs/src/main/sphinx/sql/alter-table.rst +++ b/presto-docs/src/main/sphinx/sql/alter-table.rst @@ -9,6 +9,7 @@ Synopsis ALTER TABLE name RENAME TO new_name ALTER TABLE name ADD COLUMN column_name data_type + ALTER TABLE name DROP COLUMN column_name ALTER TABLE name RENAME COLUMN column_name TO new_column_name Description @@ -27,6 +28,10 @@ Add column ``zip`` to the ``users`` table:: ALTER TABLE users ADD COLUMN zip varchar; +Drop column ``zip`` from the ``users`` table:: + + ALTER TABLE users DROP COLUMN zip; + Rename column ``id`` to ``user_id`` in the ``users`` table:: ALTER TABLE users RENAME COLUMN id TO user_id; diff --git a/presto-docs/src/main/sphinx/sql/explain-analyze.rst b/presto-docs/src/main/sphinx/sql/explain-analyze.rst index 910811bc77eb..8ae2fa9171cb 100644 --- a/presto-docs/src/main/sphinx/sql/explain-analyze.rst +++ b/presto-docs/src/main/sphinx/sql/explain-analyze.rst @@ -7,7 +7,7 @@ Synopsis .. code-block:: none - EXPLAIN ANALYZE statement + EXPLAIN ANALYZE [VERBOSE] statement Description ----------- @@ -15,6 +15,9 @@ Description Execute the statement and show the distributed execution plan of the statement along with the cost of each operation. +The ``VERBOSE`` option will give more detailed information and low-level statistics; +understanding these may require knowledge of Presto internals and implementation details. + .. note:: The stats may not be entirely accurate, especially for queries that complete quickly. @@ -69,6 +72,29 @@ relevant plan nodes). Such statistics are useful when one wants to detect data a orderdate := tpch:orderdate clerk := tpch:clerk +When the ``VERBOSE`` option is used, some operators may report additional information. +For example, the window function operator will output the following: + +.. code-block:: none + + EXPLAIN ANALYZE VERBOSE SELECT count(clerk) OVER() FROM orders WHERE orderdate > date '1995-01-01'; + + Query Plan + ----------------------------------------------------------------------------------------------- + ... + - Window[] => [clerk:varchar(15), count:bigint] + Cost: {rows: ?, bytes: ?} + CPU fraction: 75.93%, Output: 8130 rows (230.24kB) + Input avg.: 8130.00 lines, Input std.dev.: 0.00% + Active Drivers: [ 1 / 1 ] + Index size: std.dev.: 0.00 bytes , 0.00 rows + Index count per driver: std.dev.: 0.00 + Rows per driver: std.dev.: 0.00 + Size of partition: std.dev.: 0.00 + count := count("clerk") + ... + + See Also -------- diff --git a/presto-example-http/pom.xml b/presto-example-http/pom.xml index 449aaf26b8a2..601f42469d35 100644 --- a/presto-example-http/pom.xml +++ b/presto-example-http/pom.xml @@ -4,7 +4,7 @@ com.facebook.presto presto-root - 0.179-tw-0.36 + 0.181-tw-0.37 presto-example-http diff --git a/presto-hive-hadoop2/pom.xml b/presto-hive-hadoop2/pom.xml index 3356b2574997..f509c2b7401d 100644 --- a/presto-hive-hadoop2/pom.xml +++ b/presto-hive-hadoop2/pom.xml @@ -5,7 +5,7 @@ com.facebook.presto presto-root - 0.179-tw-0.36 + 0.181-tw-0.37 presto-hive-hadoop2 diff --git a/presto-hive/pom.xml b/presto-hive/pom.xml index 137afd1ee848..4c2783422b4a 100644 --- a/presto-hive/pom.xml +++ b/presto-hive/pom.xml @@ -5,7 +5,7 @@ com.facebook.presto presto-root - 0.179-tw-0.36 + 0.181-tw-0.37 presto-hive diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/BackgroundHiveSplitLoader.java b/presto-hive/src/main/java/com/facebook/presto/hive/BackgroundHiveSplitLoader.java index 10f65fc22e6a..43f6536e28ac 100644 --- a/presto-hive/src/main/java/com/facebook/presto/hive/BackgroundHiveSplitLoader.java +++ b/presto-hive/src/main/java/com/facebook/presto/hive/BackgroundHiveSplitLoader.java @@ -85,6 +85,7 @@ import static com.facebook.presto.hive.HiveUtil.isLzopIndexFile; import static com.facebook.presto.hive.HiveUtil.isSplittable; import static com.facebook.presto.hive.metastore.MetastoreUtil.getHiveSchema; +import static com.facebook.presto.hive.util.ConfigurationUtils.toJobConf; import static com.facebook.presto.spi.StandardErrorCode.NOT_SUPPORTED; import static com.google.common.base.Preconditions.checkState; import static java.lang.Math.toIntExact; @@ -296,6 +297,7 @@ private CompletableFuture loadSplits() file.getBlockLocations(), 0, file.getLen(), + file.getLen(), files.getSchema(), files.getPartitionKeys(), splittable, @@ -338,7 +340,7 @@ private void loadPartition(HivePartitionMetadata partition) TextInputFormat targetInputFormat = new TextInputFormat(); // get the configuration for the target path -- it may be a different hdfs instance Configuration targetConfiguration = hdfsEnvironment.getConfiguration(targetPath); - JobConf targetJob = new JobConf(targetConfiguration); + JobConf targetJob = toJobConf(targetConfiguration); targetJob.setInputFormat(TextInputFormat.class); targetInputFormat.configure(targetJob); FileInputFormat.setInputPaths(targetJob, targetPath); @@ -354,7 +356,7 @@ private void loadPartition(HivePartitionMetadata partition) // To support custom input formats, we want to call getSplits() // on the input format to obtain file splits. if (shouldUseFileSplitsFromInputFormat(inputFormat)) { - JobConf jobConf = new JobConf(configuration); + JobConf jobConf = toJobConf(configuration); FileInputFormat.setInputPaths(jobConf, path); InputSplit[] splits = inputFormat.getSplits(jobConf, 0); @@ -380,6 +382,7 @@ private void loadPartition(HivePartitionMetadata partition) file.getBlockLocations(), 0, file.getLen(), + file.getLen(), iterator.getSchema(), iterator.getPartitionKeys(), splittable, @@ -410,6 +413,7 @@ private void loadPartition(HivePartitionMetadata partition) file.getBlockLocations(), 0, file.getLen(), + file.getLen(), iterator.getSchema(), iterator.getPartitionKeys(), splittable, @@ -445,6 +449,7 @@ private boolean addSplitsToSource( targetFilesystem.getFileBlockLocations(file, split.getStart(), split.getLength()), split.getStart(), split.getLength(), + file.getLen(), schema, partitionKeys, false, @@ -531,6 +536,7 @@ private Iterator createHiveSplitIterator( BlockLocation[] blockLocations, long start, long length, + long fileSize, Properties schema, List partitionKeys, boolean splittable, @@ -604,6 +610,7 @@ protected HiveSplit computeNext() path, blockLocation.getOffset() + chunkOffset, chunkLength, + fileSize, schema, partitionKeys, addresses, @@ -645,6 +652,7 @@ protected HiveSplit computeNext() path, start, length, + fileSize, schema, partitionKeys, addresses, diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/FileFormatDataSourceStats.java b/presto-hive/src/main/java/com/facebook/presto/hive/FileFormatDataSourceStats.java index 29cf25971123..1184b19464ff 100644 --- a/presto-hive/src/main/java/com/facebook/presto/hive/FileFormatDataSourceStats.java +++ b/presto-hive/src/main/java/com/facebook/presto/hive/FileFormatDataSourceStats.java @@ -25,6 +25,7 @@ public class FileFormatDataSourceStats { private final DistributionStat readBytes = new DistributionStat(); private final DistributionStat loadedBlockBytes = new DistributionStat(); + private final DistributionStat maxCombinedBytesPerRow = new DistributionStat(); private final TimeStat time0Bto100KB = new TimeStat(MILLISECONDS); private final TimeStat time100KBto1MB = new TimeStat(MILLISECONDS); private final TimeStat time1MBto10MB = new TimeStat(MILLISECONDS); @@ -44,6 +45,13 @@ public DistributionStat getLoadedBlockBytes() return loadedBlockBytes; } + @Managed + @Nested + public DistributionStat getMaxCombinedBytesPerRow() + { + return maxCombinedBytesPerRow; + } + @Managed @Nested public TimeStat get0Bto100KB() @@ -93,4 +101,9 @@ public void addLoadedBlockSize(long bytes) { loadedBlockBytes.add(bytes); } + + public void addMaxCombinedBytesPerRow(long bytes) + { + maxCombinedBytesPerRow.add(bytes); + } } diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/GenericHiveRecordCursorProvider.java b/presto-hive/src/main/java/com/facebook/presto/hive/GenericHiveRecordCursorProvider.java index 08aeb7575bb1..ae5eaa8d8f7a 100644 --- a/presto-hive/src/main/java/com/facebook/presto/hive/GenericHiveRecordCursorProvider.java +++ b/presto-hive/src/main/java/com/facebook/presto/hive/GenericHiveRecordCursorProvider.java @@ -53,6 +53,7 @@ public Optional createRecordCursor( Path path, long start, long length, + long fileSize, Properties schema, List columns, TupleDomain effectivePredicate, diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/HdfsConfigurationUpdater.java b/presto-hive/src/main/java/com/facebook/presto/hive/HdfsConfigurationUpdater.java index 69f500f77dac..6b277f414129 100644 --- a/presto-hive/src/main/java/com/facebook/presto/hive/HdfsConfigurationUpdater.java +++ b/presto-hive/src/main/java/com/facebook/presto/hive/HdfsConfigurationUpdater.java @@ -133,7 +133,7 @@ private static Configuration readConfiguration(List resourcePaths) return result; } - public void updateConfiguration(PrestoHadoopConfiguration config) + public void updateConfiguration(Configuration config) { copy(resourcesConfiguration, config); diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/HiveClientConfig.java b/presto-hive/src/main/java/com/facebook/presto/hive/HiveClientConfig.java index 5b15b57a8e5f..2a5c7713b4c7 100644 --- a/presto-hive/src/main/java/com/facebook/presto/hive/HiveClientConfig.java +++ b/presto-hive/src/main/java/com/facebook/presto/hive/HiveClientConfig.java @@ -98,8 +98,10 @@ public class HiveClientConfig private DataSize orcMaxMergeDistance = new DataSize(1, MEGABYTE); private DataSize orcMaxBufferSize = new DataSize(8, MEGABYTE); private DataSize orcStreamBufferSize = new DataSize(8, MEGABYTE); + private DataSize orcMaxReadBlockSize = new DataSize(16, MEGABYTE); - private boolean rcfileOptimizedWriterEnabled; + private boolean rcfileOptimizedWriterEnabled = true; + private boolean rcfileWriterValidate; private HiveMetastoreAuthenticationType hiveMetastoreAuthenticationType = HiveMetastoreAuthenticationType.NONE; private HdfsAuthenticationType hdfsAuthenticationType = HdfsAuthenticationType.NONE; @@ -664,6 +666,19 @@ public HiveClientConfig setOrcStreamBufferSize(DataSize orcStreamBufferSize) return this; } + @NotNull + public DataSize getOrcMaxReadBlockSize() + { + return orcMaxReadBlockSize; + } + + @Config("hive.orc.max-read-block-size") + public HiveClientConfig setOrcMaxReadBlockSize(DataSize orcMaxReadBlockSize) + { + this.orcMaxReadBlockSize = orcMaxReadBlockSize; + return this; + } + public boolean isOrcBloomFiltersEnabled() { return orcBloomFiltersEnabled; @@ -690,6 +705,19 @@ public HiveClientConfig setRcfileOptimizedWriterEnabled(boolean rcfileOptimizedW return this; } + public boolean isRcfileWriterValidate() + { + return rcfileWriterValidate; + } + + @Config("hive.rcfile.writer.validate") + @ConfigDescription("Validate RCFile after write by re-reading the whole file") + public HiveClientConfig setRcfileWriterValidate(boolean rcfileWriterValidate) + { + this.rcfileWriterValidate = rcfileWriterValidate; + return this; + } + public boolean isAssumeCanonicalPartitionKeys() { return assumeCanonicalPartitionKeys; diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/HiveClientModule.java b/presto-hive/src/main/java/com/facebook/presto/hive/HiveClientModule.java index 46be3f04f0f4..1d1e6ad853ed 100644 --- a/presto-hive/src/main/java/com/facebook/presto/hive/HiveClientModule.java +++ b/presto-hive/src/main/java/com/facebook/presto/hive/HiveClientModule.java @@ -85,7 +85,7 @@ public void configure(Binder binder) binder.bind(HiveTableProperties.class).in(Scopes.SINGLETON); binder.bind(NamenodeStats.class).in(Scopes.SINGLETON); - newExporter(binder).export(NamenodeStats.class).as(generatedNameOf(NamenodeStats.class)); + newExporter(binder).export(NamenodeStats.class).as(generatedNameOf(NamenodeStats.class, connectorId)); binder.bind(HiveMetastoreClientFactory.class).in(Scopes.SINGLETON); binder.bind(PooledHiveMetastoreClientFactory.class).in(Scopes.SINGLETON); diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/HiveErrorCode.java b/presto-hive/src/main/java/com/facebook/presto/hive/HiveErrorCode.java index 2dd754b01f86..b46090523561 100644 --- a/presto-hive/src/main/java/com/facebook/presto/hive/HiveErrorCode.java +++ b/presto-hive/src/main/java/com/facebook/presto/hive/HiveErrorCode.java @@ -54,7 +54,8 @@ public enum HiveErrorCode HIVE_WRITER_DATA_ERROR(27, EXTERNAL), HIVE_INVALID_BUCKET_FILES(28, EXTERNAL), HIVE_EXCEEDED_PARTITION_LIMIT(29, USER_ERROR), - HIVE_WRITE_VALIDATION_FAILED(30, INTERNAL_ERROR); + HIVE_WRITE_VALIDATION_FAILED(30, INTERNAL_ERROR), + HIVE_PARTITION_DROPPED_DURING_QUERY(31, EXTERNAL); private final ErrorCode errorCode; diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/HiveHdfsConfiguration.java b/presto-hive/src/main/java/com/facebook/presto/hive/HiveHdfsConfiguration.java index e48b429c6c85..23563bf47d0c 100644 --- a/presto-hive/src/main/java/com/facebook/presto/hive/HiveHdfsConfiguration.java +++ b/presto-hive/src/main/java/com/facebook/presto/hive/HiveHdfsConfiguration.java @@ -13,7 +13,6 @@ */ package com.facebook.presto.hive; -import com.google.common.collect.ImmutableClassToInstanceMap; import org.apache.hadoop.conf.Configuration; import javax.inject.Inject; @@ -45,7 +44,7 @@ public class HiveHdfsConfiguration @Override protected Configuration initialValue() { - PrestoHadoopConfiguration configuration = new PrestoHadoopConfiguration(ImmutableClassToInstanceMap.of()); + Configuration configuration = new Configuration(false); copy(INITIAL_CONFIGURATION, configuration); updater.updateConfiguration(configuration); return configuration; diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/HiveMetadata.java b/presto-hive/src/main/java/com/facebook/presto/hive/HiveMetadata.java index 12639da0159c..40ec01e774ca 100644 --- a/presto-hive/src/main/java/com/facebook/presto/hive/HiveMetadata.java +++ b/presto-hive/src/main/java/com/facebook/presto/hive/HiveMetadata.java @@ -133,6 +133,7 @@ import static com.facebook.presto.hive.metastore.SemiTransactionalHiveMetastore.WriteMode.STAGE_AND_MOVE_TO_TARGET_DIRECTORY; import static com.facebook.presto.hive.metastore.StorageFormat.VIEW_STORAGE_FORMAT; import static com.facebook.presto.hive.metastore.StorageFormat.fromHiveStorageFormat; +import static com.facebook.presto.hive.util.ConfigurationUtils.toJobConf; import static com.facebook.presto.spi.StandardErrorCode.INVALID_SCHEMA_PROPERTY; import static com.facebook.presto.spi.StandardErrorCode.INVALID_TABLE_PROPERTY; import static com.facebook.presto.spi.StandardErrorCode.NOT_SUPPORTED; @@ -580,6 +581,15 @@ public void renameColumn(ConnectorSession session, ConnectorTableHandle tableHan metastore.renameColumn(hiveTableHandle.getSchemaName(), hiveTableHandle.getTableName(), sourceHandle.getName(), target); } + @Override + public void dropColumn(ConnectorSession session, ConnectorTableHandle tableHandle, ColumnHandle column) + { + HiveTableHandle hiveTableHandle = (HiveTableHandle) tableHandle; + HiveColumnHandle columnHandle = (HiveColumnHandle) column; + + metastore.dropColumn(hiveTableHandle.getSchemaName(), hiveTableHandle.getTableName(), columnHandle.getName()); + } + @Override public void renameTable(ConnectorSession session, ConnectorTableHandle tableHandle, SchemaTableName newTableName) { @@ -738,10 +748,9 @@ private List computeFileNamesForMissingBuckets(HiveStorageFormat storage // fast path for common case return ImmutableList.of(); } - JobConf conf = new JobConf(hdfsEnvironment.getConfiguration(targetPath)); + JobConf conf = toJobConf(hdfsEnvironment.getConfiguration(targetPath)); String fileExtension = HiveWriterFactory.getFileExtension(conf, fromHiveStorageFormat(storageFormat)); - Set fileNames = partitionUpdate.getFileNames().stream() - .collect(Collectors.toSet()); + Set fileNames = ImmutableSet.copyOf(partitionUpdate.getFileNames()); ImmutableList.Builder missingFileNamesBuilder = ImmutableList.builder(); for (int i = 0; i < bucketCount; i++) { String fileName = HiveWriterFactory.computeBucketedFileName(filePrefix, i) + fileExtension; @@ -756,7 +765,7 @@ private List computeFileNamesForMissingBuckets(HiveStorageFormat storage private void createEmptyFile(Path path, Table table, Optional partition, List fileNames) { - JobConf conf = new JobConf(hdfsEnvironment.getConfiguration(path)); + JobConf conf = toJobConf(hdfsEnvironment.getConfiguration(path)); Properties schema; StorageFormat format; diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/HivePageSourceFactory.java b/presto-hive/src/main/java/com/facebook/presto/hive/HivePageSourceFactory.java index efa076474a33..56062ab38642 100644 --- a/presto-hive/src/main/java/com/facebook/presto/hive/HivePageSourceFactory.java +++ b/presto-hive/src/main/java/com/facebook/presto/hive/HivePageSourceFactory.java @@ -32,6 +32,7 @@ Optional createPageSource( Path path, long start, long length, + long fileSize, Properties schema, List columns, TupleDomain effectivePredicate, diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/HivePageSourceProvider.java b/presto-hive/src/main/java/com/facebook/presto/hive/HivePageSourceProvider.java index e6987890b3d7..ae4ffb3df5f1 100644 --- a/presto-hive/src/main/java/com/facebook/presto/hive/HivePageSourceProvider.java +++ b/presto-hive/src/main/java/com/facebook/presto/hive/HivePageSourceProvider.java @@ -116,6 +116,7 @@ private ConnectorPageSource doCreatePageSource(ConnectorSession session, Connect hiveSplit.getBucketNumber(), hiveSplit.getStart(), hiveSplit.getLength(), + hiveSplit.getFileSize(), hiveSplit.getSchema(), hiveSplit.getEffectivePredicate(), hiveColumns, @@ -139,6 +140,7 @@ public static Optional createHivePageSource( OptionalInt bucketNumber, long start, long length, + long fileSize, Properties schema, TupleDomain effectivePredicate, List hiveColumns, @@ -157,6 +159,7 @@ public static Optional createHivePageSource( path, start, length, + fileSize, schema, extractRegularColumnHandles(regularColumnMappings, true), effectivePredicate, @@ -183,6 +186,7 @@ public static Optional createHivePageSource( path, start, length, + fileSize, schema, extractRegularColumnHandles(regularColumnMappings, doCoercion), effectivePredicate, diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/HiveRecordCursorProvider.java b/presto-hive/src/main/java/com/facebook/presto/hive/HiveRecordCursorProvider.java index d5fc472e4c9b..68e3cd709b48 100644 --- a/presto-hive/src/main/java/com/facebook/presto/hive/HiveRecordCursorProvider.java +++ b/presto-hive/src/main/java/com/facebook/presto/hive/HiveRecordCursorProvider.java @@ -34,6 +34,7 @@ Optional createRecordCursor( Path path, long start, long length, + long fileSize, Properties schema, List columns, TupleDomain effectivePredicate, diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/HiveSessionProperties.java b/presto-hive/src/main/java/com/facebook/presto/hive/HiveSessionProperties.java index 9c8f6383dd31..e0968dedecbf 100644 --- a/presto-hive/src/main/java/com/facebook/presto/hive/HiveSessionProperties.java +++ b/presto-hive/src/main/java/com/facebook/presto/hive/HiveSessionProperties.java @@ -33,6 +33,7 @@ public final class HiveSessionProperties private static final String ORC_MAX_MERGE_DISTANCE = "orc_max_merge_distance"; private static final String ORC_MAX_BUFFER_SIZE = "orc_max_buffer_size"; private static final String ORC_STREAM_BUFFER_SIZE = "orc_stream_buffer_size"; + private static final String ORC_MAX_READ_BLOCK_SIZE = "orc_max_read_block_size"; private static final String PARQUET_PREDICATE_PUSHDOWN_ENABLED = "parquet_predicate_pushdown_enabled"; private static final String PARQUET_OPTIMIZED_READER_ENABLED = "parquet_optimized_reader_enabled"; private static final String READ_AS_QUERY_USER = "read_as_query_user"; @@ -83,6 +84,11 @@ public HiveSessionProperties(HiveClientConfig config) "ORC: Size of buffer for streaming reads", config.getOrcStreamBufferSize(), false), + dataSizeSessionProperty( + ORC_MAX_READ_BLOCK_SIZE, + "ORC: Maximum size of a block to read", + config.getOrcMaxReadBlockSize(), + false), booleanSessionProperty( PARQUET_OPTIMIZED_READER_ENABLED, "Experimental: Parquet: Enable optimized reader", @@ -111,7 +117,7 @@ public HiveSessionProperties(HiveClientConfig config) booleanSessionProperty( RCFILE_OPTIMIZED_WRITER_VALIDATE, "Experimental: RCFile: Validate writer files", - true, + config.isRcfileWriterValidate(), false), booleanSessionProperty( STATISTICS_ENABLED, @@ -160,6 +166,11 @@ public static DataSize getOrcStreamBufferSize(ConnectorSession session) return session.getProperty(ORC_STREAM_BUFFER_SIZE, DataSize.class); } + public static DataSize getOrcMaxReadBlockSize(ConnectorSession session) + { + return session.getProperty(ORC_MAX_READ_BLOCK_SIZE, DataSize.class); + } + public static boolean isParquetPredicatePushdownEnabled(ConnectorSession session) { return session.getProperty(PARQUET_PREDICATE_PUSHDOWN_ENABLED, Boolean.class); diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/HiveSplit.java b/presto-hive/src/main/java/com/facebook/presto/hive/HiveSplit.java index ff1c203f0443..ba297877a573 100644 --- a/presto-hive/src/main/java/com/facebook/presto/hive/HiveSplit.java +++ b/presto-hive/src/main/java/com/facebook/presto/hive/HiveSplit.java @@ -37,6 +37,7 @@ public class HiveSplit private final String path; private final long start; private final long length; + private final long fileSize; private final Properties schema; private final List partitionKeys; private final List addresses; @@ -57,6 +58,7 @@ public HiveSplit( @JsonProperty("path") String path, @JsonProperty("start") long start, @JsonProperty("length") long length, + @JsonProperty("fileSize") long fileSize, @JsonProperty("schema") Properties schema, @JsonProperty("partitionKeys") List partitionKeys, @JsonProperty("addresses") List addresses, @@ -68,6 +70,7 @@ public HiveSplit( requireNonNull(clientId, "clientId is null"); checkArgument(start >= 0, "start must be positive"); checkArgument(length >= 0, "length must be positive"); + checkArgument(fileSize >= 0, "fileSize must be positive"); requireNonNull(database, "database is null"); requireNonNull(table, "table is null"); requireNonNull(partitionName, "partitionName is null"); @@ -86,6 +89,7 @@ public HiveSplit( this.path = path; this.start = start; this.length = length; + this.fileSize = fileSize; this.schema = schema; this.partitionKeys = ImmutableList.copyOf(partitionKeys); this.addresses = ImmutableList.copyOf(addresses); @@ -137,6 +141,12 @@ public long getLength() return length; } + @JsonProperty + public long getFileSize() + { + return fileSize; + } + @JsonProperty public Properties getSchema() { @@ -193,6 +203,7 @@ public Object getInfo() .put("path", path) .put("start", start) .put("length", length) + .put("fileSize", fileSize) .put("hosts", addresses) .put("database", database) .put("table", table) @@ -208,6 +219,7 @@ public String toString() .addValue(path) .addValue(start) .addValue(length) + .addValue(fileSize) .addValue(effectivePredicate) .toString(); } diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/HiveSplitManager.java b/presto-hive/src/main/java/com/facebook/presto/hive/HiveSplitManager.java index d6bf69f143c8..50abdf6cfa85 100644 --- a/presto-hive/src/main/java/com/facebook/presto/hive/HiveSplitManager.java +++ b/presto-hive/src/main/java/com/facebook/presto/hive/HiveSplitManager.java @@ -26,7 +26,6 @@ import com.facebook.presto.spi.TableNotFoundException; import com.facebook.presto.spi.connector.ConnectorSplitManager; import com.facebook.presto.spi.connector.ConnectorTransactionHandle; -import com.google.common.base.Verify; import com.google.common.collect.AbstractIterator; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; @@ -48,10 +47,9 @@ import java.util.function.Function; import static com.facebook.presto.hive.HiveErrorCode.HIVE_INVALID_METADATA; -import static com.facebook.presto.hive.HiveErrorCode.HIVE_METASTORE_ERROR; +import static com.facebook.presto.hive.HiveErrorCode.HIVE_PARTITION_DROPPED_DURING_QUERY; import static com.facebook.presto.hive.HiveErrorCode.HIVE_PARTITION_SCHEMA_MISMATCH; import static com.facebook.presto.hive.HivePartition.UNPARTITIONED_ID; -import static com.facebook.presto.hive.HiveUtil.checkCondition; import static com.facebook.presto.hive.metastore.MetastoreUtil.makePartName; import static com.facebook.presto.spi.StandardErrorCode.GENERIC_INTERNAL_ERROR; import static com.facebook.presto.spi.StandardErrorCode.SERVER_SHUTTING_DOWN; @@ -206,12 +204,11 @@ private Iterable getPartitionMetadata(SemiTransactionalHi ImmutableMap.Builder partitionBuilder = ImmutableMap.builder(); for (Map.Entry> entry : batch.entrySet()) { if (!entry.getValue().isPresent()) { - throw new PrestoException(HIVE_METASTORE_ERROR, "Partition metadata not available"); + throw new PrestoException(HIVE_PARTITION_DROPPED_DURING_QUERY, "Partition no longer exists: " + entry.getKey()); } partitionBuilder.put(entry.getKey(), entry.getValue().get()); } Map partitions = partitionBuilder.build(); - Verify.verify(partitions.size() == partitionBatch.size()); if (partitionBatch.size() != partitions.size()) { throw new PrestoException(GENERIC_INTERNAL_ERROR, format("Expected %s partitions but found %s", partitionBatch.size(), partitions.size())); } @@ -266,15 +263,25 @@ private Iterable getPartitionMetadata(SemiTransactionalHi } } - Optional partitionBucketProperty = partition.getStorage().getBucketProperty(); - checkCondition( - partitionBucketProperty.equals(bucketProperty), - HiveErrorCode.HIVE_PARTITION_SCHEMA_MISMATCH, - "Hive table (%s) bucketing property (%s) does not match partition (%s) bucketing property (%s)", - hivePartition.getTableName(), - bucketProperty, - hivePartition.getPartitionId(), - partitionBucketProperty); + if (bucketProperty.isPresent()) { + Optional partitionBucketProperty = partition.getStorage().getBucketProperty(); + if (!partitionBucketProperty.isPresent()) { + throw new PrestoException(HIVE_PARTITION_SCHEMA_MISMATCH, format( + "Hive table (%s) is bucketed but partition (%s) is not bucketed", + hivePartition.getTableName(), + hivePartition.getPartitionId())); + } + if (!bucketProperty.equals(partitionBucketProperty)) { + throw new PrestoException(HIVE_PARTITION_SCHEMA_MISMATCH, format( + "Hive table (%s) bucketing (columns=%s, buckets=%s) does not match partition (%s) bucketing (columns=%s, buckets=%s)", + hivePartition.getTableName(), + bucketProperty.get().getBucketedBy(), + bucketProperty.get().getBucketCount(), + hivePartition.getPartitionId(), + partitionBucketProperty.get().getBucketedBy(), + partitionBucketProperty.get().getBucketCount())); + } + } results.add(new HivePartitionMetadata(hivePartition, Optional.of(partition), columnCoercions.build())); } diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/HiveUtil.java b/presto-hive/src/main/java/com/facebook/presto/hive/HiveUtil.java index f082ebdc1a45..ecc57a741325 100644 --- a/presto-hive/src/main/java/com/facebook/presto/hive/HiveUtil.java +++ b/presto-hive/src/main/java/com/facebook/presto/hive/HiveUtil.java @@ -94,6 +94,7 @@ import static com.facebook.presto.hive.HivePartitionKey.HIVE_DEFAULT_DYNAMIC_PARTITION; import static com.facebook.presto.hive.RetryDriver.retry; import static com.facebook.presto.hive.metastore.MetastoreUtil.getHiveSchema; +import static com.facebook.presto.hive.util.ConfigurationUtils.toJobConf; import static com.facebook.presto.spi.StandardErrorCode.NOT_SUPPORTED; import static com.facebook.presto.spi.type.BigintType.BIGINT; import static com.facebook.presto.spi.type.BooleanType.BOOLEAN; @@ -191,7 +192,7 @@ private HiveUtil() setReadColumns(configuration, readHiveColumnIndexes); InputFormat inputFormat = getInputFormat(configuration, schema, true); - JobConf jobConf = new JobConf(configuration); + JobConf jobConf = toJobConf(configuration); FileSplit fileSplit = new FileSplit(path, start, length, (String[]) null); // propagate serialization configuration to getRecordReader @@ -235,7 +236,7 @@ public static void setReadColumns(Configuration configuration, List rea { String inputFormatName = getInputFormatName(schema); try { - JobConf jobConf = new JobConf(configuration); + JobConf jobConf = toJobConf(configuration); Class> inputFormatClass = getInputFormatClass(jobConf, inputFormatName); if (symlinkTarget && (inputFormatClass == SymlinkTextInputFormat.class)) { diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/HiveWriterFactory.java b/presto-hive/src/main/java/com/facebook/presto/hive/HiveWriterFactory.java index 910ea7333e94..8e60fc63a6ea 100644 --- a/presto-hive/src/main/java/com/facebook/presto/hive/HiveWriterFactory.java +++ b/presto-hive/src/main/java/com/facebook/presto/hive/HiveWriterFactory.java @@ -63,6 +63,7 @@ import static com.facebook.presto.hive.HiveWriteUtils.getField; import static com.facebook.presto.hive.metastore.MetastoreUtil.getHiveSchema; import static com.facebook.presto.hive.metastore.StorageFormat.fromHiveStorageFormat; +import static com.facebook.presto.hive.util.ConfigurationUtils.toJobConf; import static com.facebook.presto.spi.StandardErrorCode.NOT_FOUND; import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.collect.ImmutableMap.toImmutableMap; @@ -201,7 +202,7 @@ public HiveWriterFactory( entry -> session.getProperty(entry.getName(), entry.getJavaType()).toString())); Configuration conf = hdfsEnvironment.getConfiguration(writePath); - this.conf = new JobConf(conf); + this.conf = toJobConf(conf); // make sure the FileSystem is created with the correct Configuration object try { diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/PrestoHadoopConfiguration.java b/presto-hive/src/main/java/com/facebook/presto/hive/PrestoHadoopConfiguration.java deleted file mode 100644 index f65d619b34c7..000000000000 --- a/presto-hive/src/main/java/com/facebook/presto/hive/PrestoHadoopConfiguration.java +++ /dev/null @@ -1,40 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package com.facebook.presto.hive; - -import com.google.common.collect.ClassToInstanceMap; -import com.google.common.collect.ImmutableClassToInstanceMap; -import org.apache.hadoop.conf.Configuration; - -import static com.google.common.base.Preconditions.checkArgument; -import static java.util.Objects.requireNonNull; - -public final class PrestoHadoopConfiguration - extends Configuration -{ - private final ClassToInstanceMap services; - - public PrestoHadoopConfiguration(ClassToInstanceMap services) - { - super(false); - this.services = ImmutableClassToInstanceMap.copyOf(requireNonNull(services, "services is null")); - } - - public T getService(Class type) - { - T service = services.getInstance(type); - checkArgument(service != null, "service not found: %s", type.getName()); - return service; - } -} diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/PrestoS3FileSystem.java b/presto-hive/src/main/java/com/facebook/presto/hive/PrestoS3FileSystem.java index dd252ef75746..d0948d19de86 100644 --- a/presto-hive/src/main/java/com/facebook/presto/hive/PrestoS3FileSystem.java +++ b/presto-hive/src/main/java/com/facebook/presto/hive/PrestoS3FileSystem.java @@ -93,6 +93,7 @@ import static com.google.common.base.Preconditions.checkState; import static com.google.common.base.Strings.isNullOrEmpty; import static com.google.common.base.Strings.nullToEmpty; +import static com.google.common.base.Throwables.throwIfInstanceOf; import static com.google.common.collect.Iterables.toArray; import static io.airlift.units.DataSize.Unit.MEGABYTE; import static java.lang.Math.max; @@ -579,7 +580,7 @@ ObjectMetadata getS3ObjectMetadata(Path path) throw Throwables.propagate(e); } catch (Exception e) { - Throwables.propagateIfInstanceOf(e, IOException.class); + throwIfInstanceOf(e, IOException.class); throw Throwables.propagate(e); } } @@ -830,7 +831,7 @@ public int read(byte[] buffer, int offset, int length) throw Throwables.propagate(e); } catch (Exception e) { - Throwables.propagateIfInstanceOf(e, IOException.class); + throwIfInstanceOf(e, IOException.class); throw Throwables.propagate(e); } } @@ -918,7 +919,7 @@ private InputStream openStream(Path path, long start) throw Throwables.propagate(e); } catch (Exception e) { - Throwables.propagateIfInstanceOf(e, IOException.class); + throwIfInstanceOf(e, IOException.class); throw Throwables.propagate(e); } } diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/RcFileFileWriter.java b/presto-hive/src/main/java/com/facebook/presto/hive/RcFileFileWriter.java index f68e80b73ecf..4b7f23f36ac2 100644 --- a/presto-hive/src/main/java/com/facebook/presto/hive/RcFileFileWriter.java +++ b/presto-hive/src/main/java/com/facebook/presto/hive/RcFileFileWriter.java @@ -31,6 +31,7 @@ import java.io.IOException; import java.io.OutputStream; +import java.io.UncheckedIOException; import java.util.List; import java.util.Map; import java.util.Optional; @@ -108,7 +109,7 @@ public void appendRows(Page dataPage) try { rcFileWriter.write(page); } - catch (IOException e) { + catch (IOException | UncheckedIOException e) { throw new PrestoException(HIVE_WRITER_DATA_ERROR, e); } } @@ -119,7 +120,7 @@ public void commit() try { rcFileWriter.close(); } - catch (IOException e) { + catch (IOException | UncheckedIOException e) { try { rollbackAction.call(); } @@ -135,7 +136,7 @@ public void commit() rcFileWriter.validate(input); } } - catch (IOException e) { + catch (IOException | UncheckedIOException e) { throw new PrestoException(HIVE_WRITE_VALIDATION_FAILED, e); } } diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/StaticHiveCluster.java b/presto-hive/src/main/java/com/facebook/presto/hive/StaticHiveCluster.java index 3a37c831cbe4..4facd9771f38 100644 --- a/presto-hive/src/main/java/com/facebook/presto/hive/StaticHiveCluster.java +++ b/presto-hive/src/main/java/com/facebook/presto/hive/StaticHiveCluster.java @@ -71,7 +71,7 @@ public HiveMetastoreClient createMetastoreClient() TTransportException lastException = null; for (HostAndPort metastore : metastores) { try { - return clientFactory.create(metastore.getHostText(), metastore.getPort()); + return clientFactory.create(metastore.getHost(), metastore.getPort()); } catch (TTransportException e) { lastException = e; diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/metastore/BridgingHiveMetastore.java b/presto-hive/src/main/java/com/facebook/presto/hive/metastore/BridgingHiveMetastore.java index c43c2f0b85bb..b48163d6abd9 100644 --- a/presto-hive/src/main/java/com/facebook/presto/hive/metastore/BridgingHiveMetastore.java +++ b/presto-hive/src/main/java/com/facebook/presto/hive/metastore/BridgingHiveMetastore.java @@ -37,6 +37,7 @@ import static com.facebook.presto.hive.metastore.MetastoreUtil.toMetastoreApiPartition; import static com.facebook.presto.hive.metastore.MetastoreUtil.toMetastoreApiPrivilegeGrantInfo; import static com.facebook.presto.hive.metastore.MetastoreUtil.toMetastoreApiTable; +import static com.facebook.presto.hive.metastore.MetastoreUtil.verifyCanDropColumn; import static com.facebook.presto.spi.StandardErrorCode.NOT_SUPPORTED; import static java.util.Objects.requireNonNull; import static java.util.function.UnaryOperator.identity; @@ -205,6 +206,20 @@ public void renameColumn(String databaseName, String tableName, String oldColumn alterTable(databaseName, tableName, table); } + @Override + public void dropColumn(String databaseName, String tableName, String columnName) + { + verifyCanDropColumn(this, databaseName, tableName, columnName); + org.apache.hadoop.hive.metastore.api.Table table = delegate.getTable(databaseName, tableName) + .orElseThrow(() -> new TableNotFoundException(new SchemaTableName(databaseName, tableName))); + for (FieldSchema fieldSchema : table.getSd().getCols()) { + if (fieldSchema.getName().equals(columnName)) { + table.getSd().getCols().remove(fieldSchema); + } + } + alterTable(databaseName, tableName, table); + } + private void alterTable(String databaseName, String tableName, org.apache.hadoop.hive.metastore.api.Table table) { delegate.alterTable(databaseName, tableName, table); diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/metastore/CachingHiveMetastore.java b/presto-hive/src/main/java/com/facebook/presto/hive/metastore/CachingHiveMetastore.java index 02d95e57afca..13d2b3255a97 100644 --- a/presto-hive/src/main/java/com/facebook/presto/hive/metastore/CachingHiveMetastore.java +++ b/presto-hive/src/main/java/com/facebook/presto/hive/metastore/CachingHiveMetastore.java @@ -42,14 +42,13 @@ import java.util.concurrent.ExecutionException; import java.util.concurrent.ExecutorService; import java.util.stream.Collectors; -import java.util.stream.Stream; -import java.util.stream.StreamSupport; import static com.facebook.presto.hive.HiveUtil.toPartitionValues; import static com.google.common.base.MoreObjects.toStringHelper; import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.cache.CacheLoader.asyncReloading; import static com.google.common.collect.Iterables.transform; +import static com.google.common.collect.Streams.stream; import static com.google.common.util.concurrent.MoreExecutors.newDirectExecutorService; import static java.util.Objects.requireNonNull; import static java.util.concurrent.TimeUnit.MILLISECONDS; @@ -434,11 +433,6 @@ private Map> l return resultMap.build(); } - private Stream stream(Iterable keys) - { - return StreamSupport.stream(keys.spliterator(), false); - } - @Override public Optional> getAllTables(String databaseName) { @@ -571,6 +565,17 @@ public void renameColumn(String databaseName, String tableName, String oldColumn } } + @Override + public void dropColumn(String databaseName, String tableName, String columnName) + { + try { + delegate.dropColumn(databaseName, tableName, columnName); + } + finally { + invalidateTable(databaseName, tableName); + } + } + protected void invalidateTable(String databaseName, String tableName) { tableCache.invalidate(new HiveTableName(databaseName, tableName)); diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/metastore/ExtendedHiveMetastore.java b/presto-hive/src/main/java/com/facebook/presto/hive/metastore/ExtendedHiveMetastore.java index c130938b0524..98d50c69ad97 100644 --- a/presto-hive/src/main/java/com/facebook/presto/hive/metastore/ExtendedHiveMetastore.java +++ b/presto-hive/src/main/java/com/facebook/presto/hive/metastore/ExtendedHiveMetastore.java @@ -59,6 +59,8 @@ public interface ExtendedHiveMetastore void renameColumn(String databaseName, String tableName, String oldColumnName, String newColumnName); + void dropColumn(String databaseName, String tableName, String columnName); + Optional getPartition(String databaseName, String tableName, List partitionValues); Optional> getPartitionNames(String databaseName, String tableName); diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/metastore/MetastoreUtil.java b/presto-hive/src/main/java/com/facebook/presto/hive/metastore/MetastoreUtil.java index 7490414009c1..e523f4ca4aef 100644 --- a/presto-hive/src/main/java/com/facebook/presto/hive/metastore/MetastoreUtil.java +++ b/presto-hive/src/main/java/com/facebook/presto/hive/metastore/MetastoreUtil.java @@ -19,6 +19,7 @@ import com.facebook.presto.hive.TableOfflineException; import com.facebook.presto.spi.PrestoException; import com.facebook.presto.spi.SchemaTableName; +import com.facebook.presto.spi.TableNotFoundException; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; import org.apache.hadoop.hive.common.FileUtils; @@ -57,6 +58,7 @@ import static com.facebook.presto.hive.HiveErrorCode.HIVE_INVALID_METADATA; import static com.facebook.presto.hive.HiveSplitManager.PRESTO_OFFLINE; import static com.facebook.presto.hive.metastore.HivePrivilegeInfo.parsePrivilege; +import static com.facebook.presto.spi.StandardErrorCode.NOT_SUPPORTED; import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Strings.emptyToNull; import static com.google.common.base.Strings.isNullOrEmpty; @@ -602,4 +604,17 @@ public static void verifyOnline(SchemaTableName tableName, Optional part throw new TableOfflineException(tableName, true, prestoOffline); } } + + public static void verifyCanDropColumn(ExtendedHiveMetastore metastore, String databaseName, String tableName, String columnName) + { + Table table = metastore.getTable(databaseName, tableName) + .orElseThrow(() -> new TableNotFoundException(new SchemaTableName(databaseName, tableName))); + + if (table.getPartitionColumns().stream().anyMatch(column -> column.getName().equals(columnName))) { + throw new PrestoException(NOT_SUPPORTED, "Cannot drop partition columns"); + } + if (table.getDataColumns().size() <= 1) { + throw new PrestoException(NOT_SUPPORTED, "Cannot drop the only column in a table"); + } + } } diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/metastore/SemiTransactionalHiveMetastore.java b/presto-hive/src/main/java/com/facebook/presto/hive/metastore/SemiTransactionalHiveMetastore.java index 94ef07774db8..344addcf895d 100644 --- a/presto-hive/src/main/java/com/facebook/presto/hive/metastore/SemiTransactionalHiveMetastore.java +++ b/presto-hive/src/main/java/com/facebook/presto/hive/metastore/SemiTransactionalHiveMetastore.java @@ -349,6 +349,11 @@ public synchronized void renameColumn(String databaseName, String tableName, Str setExclusive((delegate, hdfsEnvironment) -> delegate.renameColumn(databaseName, tableName, oldColumnName, newColumnName)); } + public synchronized void dropColumn(String databaseName, String tableName, String columnName) + { + setExclusive((delegate, hdfsEnvironment) -> delegate.dropColumn(databaseName, tableName, columnName)); + } + public synchronized void finishInsertIntoExistingTable(ConnectorSession session, String databaseName, String tableName, Path currentLocation, List fileNames) { // Data can only be inserted into partitions and unpartitioned tables. They can never be inserted into a partitioned table. diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/metastore/file/FileHiveMetastore.java b/presto-hive/src/main/java/com/facebook/presto/hive/metastore/file/FileHiveMetastore.java index b6b7e3f95dde..74d208bb95be 100644 --- a/presto-hive/src/main/java/com/facebook/presto/hive/metastore/file/FileHiveMetastore.java +++ b/presto-hive/src/main/java/com/facebook/presto/hive/metastore/file/FileHiveMetastore.java @@ -67,6 +67,7 @@ import static com.facebook.presto.hive.metastore.Database.DEFAULT_DATABASE_NAME; import static com.facebook.presto.hive.metastore.HivePrivilegeInfo.HivePrivilege.OWNERSHIP; import static com.facebook.presto.hive.metastore.MetastoreUtil.makePartName; +import static com.facebook.presto.hive.metastore.MetastoreUtil.verifyCanDropColumn; import static com.facebook.presto.hive.metastore.PrincipalType.ROLE; import static com.facebook.presto.hive.metastore.PrincipalType.USER; import static com.facebook.presto.spi.StandardErrorCode.ALREADY_EXISTS; @@ -427,6 +428,27 @@ public synchronized void renameColumn(String databaseName, String tableName, Str }); } + @Override + public synchronized void dropColumn(String databaseName, String tableName, String columnName) + { + alterTable(databaseName, tableName, oldTable -> { + verifyCanDropColumn(this, databaseName, tableName, columnName); + if (!oldTable.getColumn(columnName).isPresent()) { + SchemaTableName name = new SchemaTableName(databaseName, tableName); + throw new ColumnNotFoundException(name, columnName); + } + + ImmutableList.Builder newDataColumns = ImmutableList.builder(); + for (Column fieldSchema : oldTable.getDataColumns()) { + if (!fieldSchema.getName().equals(columnName)) { + newDataColumns.add(fieldSchema); + } + } + + return oldTable.withDataColumns(newDataColumns.build()); + }); + } + private void alterTable(String databaseName, String tableName, Function alterFunction) { requireNonNull(databaseName, "databaseName is null"); diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/orc/DwrfPageSourceFactory.java b/presto-hive/src/main/java/com/facebook/presto/hive/orc/DwrfPageSourceFactory.java index 80ee0f3de20e..d52ead0fda89 100644 --- a/presto-hive/src/main/java/com/facebook/presto/hive/orc/DwrfPageSourceFactory.java +++ b/presto-hive/src/main/java/com/facebook/presto/hive/orc/DwrfPageSourceFactory.java @@ -35,6 +35,7 @@ import static com.facebook.presto.hive.HiveSessionProperties.getOrcMaxBufferSize; import static com.facebook.presto.hive.HiveSessionProperties.getOrcMaxMergeDistance; +import static com.facebook.presto.hive.HiveSessionProperties.getOrcMaxReadBlockSize; import static com.facebook.presto.hive.HiveSessionProperties.getOrcStreamBufferSize; import static com.facebook.presto.hive.HiveUtil.isDeserializerClass; import static com.facebook.presto.hive.orc.OrcPageSourceFactory.createOrcPageSource; @@ -61,6 +62,7 @@ public Optional createPageSource(Configuration co Path path, long start, long length, + long fileSize, Properties schema, List columns, TupleDomain effectivePredicate, @@ -78,6 +80,7 @@ public Optional createPageSource(Configuration co path, start, length, + fileSize, columns, false, effectivePredicate, @@ -86,6 +89,7 @@ public Optional createPageSource(Configuration co getOrcMaxMergeDistance(session), getOrcMaxBufferSize(session), getOrcStreamBufferSize(session), + getOrcMaxReadBlockSize(session), false, stats)); } diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/orc/OrcPageSource.java b/presto-hive/src/main/java/com/facebook/presto/hive/orc/OrcPageSource.java index 65c97e73fbf0..a09794c54cb4 100644 --- a/presto-hive/src/main/java/com/facebook/presto/hive/orc/OrcPageSource.java +++ b/presto-hive/src/main/java/com/facebook/presto/hive/orc/OrcPageSource.java @@ -176,6 +176,7 @@ public void close() closed = true; try { + stats.addMaxCombinedBytesPerRow(recordReader.getMaxCombinedBytesPerRow()); recordReader.close(); } catch (IOException e) { diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/orc/OrcPageSourceFactory.java b/presto-hive/src/main/java/com/facebook/presto/hive/orc/OrcPageSourceFactory.java index 55a4cc7502c2..eef13bb68293 100644 --- a/presto-hive/src/main/java/com/facebook/presto/hive/orc/OrcPageSourceFactory.java +++ b/presto-hive/src/main/java/com/facebook/presto/hive/orc/OrcPageSourceFactory.java @@ -60,6 +60,7 @@ import static com.facebook.presto.hive.HiveErrorCode.HIVE_MISSING_DATA; import static com.facebook.presto.hive.HiveSessionProperties.getOrcMaxBufferSize; import static com.facebook.presto.hive.HiveSessionProperties.getOrcMaxMergeDistance; +import static com.facebook.presto.hive.HiveSessionProperties.getOrcMaxReadBlockSize; import static com.facebook.presto.hive.HiveSessionProperties.getOrcStreamBufferSize; import static com.facebook.presto.hive.HiveSessionProperties.isOrcBloomFiltersEnabled; import static com.facebook.presto.hive.HiveUtil.isDeserializerClass; @@ -97,6 +98,7 @@ public Optional createPageSource( Path path, long start, long length, + long fileSize, Properties schema, List columns, TupleDomain effectivePredicate, @@ -114,6 +116,7 @@ public Optional createPageSource( path, start, length, + fileSize, columns, useOrcColumnNames, effectivePredicate, @@ -122,6 +125,7 @@ public Optional createPageSource( getOrcMaxMergeDistance(session), getOrcMaxBufferSize(session), getOrcStreamBufferSize(session), + getOrcMaxReadBlockSize(session), isOrcBloomFiltersEnabled(session), stats)); } @@ -134,6 +138,7 @@ public static OrcPageSource createOrcPageSource( Path path, long start, long length, + long fileSize, List columns, boolean useOrcColumnNames, TupleDomain effectivePredicate, @@ -142,15 +147,15 @@ public static OrcPageSource createOrcPageSource( DataSize maxMergeDistance, DataSize maxBufferSize, DataSize streamBufferSize, + DataSize maxReadBlockSize, boolean orcBloomFiltersEnabled, FileFormatDataSourceStats stats) { OrcDataSource orcDataSource; try { FileSystem fileSystem = hdfsEnvironment.getFileSystem(sessionUser, path, configuration); - long size = fileSystem.getFileStatus(path).getLen(); FSDataInputStream inputStream = fileSystem.open(path); - orcDataSource = new HdfsOrcDataSource(new OrcDataSourceId(path.toString()), size, maxMergeDistance, maxBufferSize, streamBufferSize, inputStream, stats); + orcDataSource = new HdfsOrcDataSource(new OrcDataSourceId(path.toString()), fileSize, maxMergeDistance, maxBufferSize, streamBufferSize, inputStream, stats); } catch (Exception e) { if (nullToEmpty(e.getMessage()).trim().equals("Filesystem closed") || @@ -162,7 +167,7 @@ public static OrcPageSource createOrcPageSource( AggregatedMemoryContext systemMemoryUsage = new AggregatedMemoryContext(); try { - OrcReader reader = new OrcReader(orcDataSource, metadataReader, maxMergeDistance, maxBufferSize); + OrcReader reader = new OrcReader(orcDataSource, metadataReader, maxMergeDistance, maxBufferSize, maxReadBlockSize); List physicalColumns = getPhysicalHiveColumnHandles(columns, useOrcColumnNames, reader, path); ImmutableMap.Builder includedColumns = ImmutableMap.builder(); diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/parquet/HdfsParquetDataSource.java b/presto-hive/src/main/java/com/facebook/presto/hive/parquet/HdfsParquetDataSource.java index 6fbb272d66f4..1b35375d0250 100644 --- a/presto-hive/src/main/java/com/facebook/presto/hive/parquet/HdfsParquetDataSource.java +++ b/presto-hive/src/main/java/com/facebook/presto/hive/parquet/HdfsParquetDataSource.java @@ -90,12 +90,11 @@ private void readInternal(long position, byte[] buffer, int bufferOffset, int bu } } - public static HdfsParquetDataSource buildHdfsParquetDataSource(FileSystem fileSystem, Path path, long start, long length) + public static HdfsParquetDataSource buildHdfsParquetDataSource(FileSystem fileSystem, Path path, long start, long length, long fileSize) { try { - long size = fileSystem.getFileStatus(path).getLen(); FSDataInputStream inputStream = fileSystem.open(path); - return new HdfsParquetDataSource(path, size, inputStream); + return new HdfsParquetDataSource(path, fileSize, inputStream); } catch (Exception e) { if (nullToEmpty(e.getMessage()).trim().equals("Filesystem closed") || diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/parquet/ParquetHiveRecordCursor.java b/presto-hive/src/main/java/com/facebook/presto/hive/parquet/ParquetHiveRecordCursor.java index a2c4ae8801b1..b7d6de8b219d 100644 --- a/presto-hive/src/main/java/com/facebook/presto/hive/parquet/ParquetHiveRecordCursor.java +++ b/presto-hive/src/main/java/com/facebook/presto/hive/parquet/ParquetHiveRecordCursor.java @@ -87,6 +87,7 @@ import static com.facebook.presto.spi.type.Varchars.truncateToLength; import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkState; +import static com.google.common.base.Throwables.throwIfInstanceOf; import static io.airlift.slice.Slices.wrappedBuffer; import static java.lang.Float.floatToRawIntBits; import static java.lang.Math.max; @@ -123,6 +124,7 @@ public ParquetHiveRecordCursor( Path path, long start, long length, + long fileSize, Properties splitSchema, List columns, boolean useParquetColumnNames, @@ -162,6 +164,7 @@ public ParquetHiveRecordCursor( path, start, length, + fileSize, columns, useParquetColumnNames, predicatePushdownEnabled, @@ -319,6 +322,7 @@ private ParquetRecordReader createParquetRecordReader( Path path, long start, long length, + long fileSize, List columns, boolean useParquetColumnNames, boolean predicatePushdownEnabled, @@ -327,7 +331,7 @@ private ParquetRecordReader createParquetRecordReader( ParquetDataSource dataSource = null; try { FileSystem fileSystem = hdfsEnvironment.getFileSystem(sessionUser, path, configuration); - dataSource = buildHdfsParquetDataSource(fileSystem, path, start, length); + dataSource = buildHdfsParquetDataSource(fileSystem, path, start, length, fileSize); ParquetMetadata parquetMetadata = hdfsEnvironment.doAs(sessionUser, () -> ParquetFileReader.readFooter(configuration, path, NO_FILTER)); List blocks = parquetMetadata.getBlocks(); FileMetaData fileMetaData = parquetMetadata.getFileMetaData(); @@ -371,7 +375,7 @@ private ParquetRecordReader createParquetRecordReader( }); } catch (Exception e) { - Throwables.propagateIfInstanceOf(e, PrestoException.class); + throwIfInstanceOf(e, PrestoException.class); if (e instanceof InterruptedException) { Thread.currentThread().interrupt(); throw Throwables.propagate(e); diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/parquet/ParquetPageSourceFactory.java b/presto-hive/src/main/java/com/facebook/presto/hive/parquet/ParquetPageSourceFactory.java index 9990227d2ef8..edbdf37dd04b 100644 --- a/presto-hive/src/main/java/com/facebook/presto/hive/parquet/ParquetPageSourceFactory.java +++ b/presto-hive/src/main/java/com/facebook/presto/hive/parquet/ParquetPageSourceFactory.java @@ -94,6 +94,7 @@ public Optional createPageSource( Path path, long start, long length, + long fileSize, Properties schema, List columns, TupleDomain effectivePredicate, @@ -114,6 +115,7 @@ public Optional createPageSource( path, start, length, + fileSize, schema, columns, useParquetColumnNames, @@ -129,6 +131,7 @@ public static ParquetPageSource createParquetPageSource( Path path, long start, long length, + long fileSize, Properties schema, List columns, boolean useParquetColumnNames, @@ -141,8 +144,8 @@ public static ParquetPageSource createParquetPageSource( ParquetDataSource dataSource = null; try { FileSystem fileSystem = hdfsEnvironment.getFileSystem(user, path, configuration); - dataSource = buildHdfsParquetDataSource(fileSystem, path, start, length); - ParquetMetadata parquetMetadata = ParquetMetadataReader.readFooter(fileSystem, path); + dataSource = buildHdfsParquetDataSource(fileSystem, path, start, length, fileSize); + ParquetMetadata parquetMetadata = ParquetMetadataReader.readFooter(fileSystem, path, fileSize); FileMetaData fileMetaData = parquetMetadata.getFileMetaData(); MessageType fileSchema = fileMetaData.getSchema(); diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/parquet/ParquetRecordCursorProvider.java b/presto-hive/src/main/java/com/facebook/presto/hive/parquet/ParquetRecordCursorProvider.java index 14f45946b392..056ff1269e2b 100644 --- a/presto-hive/src/main/java/com/facebook/presto/hive/parquet/ParquetRecordCursorProvider.java +++ b/presto-hive/src/main/java/com/facebook/presto/hive/parquet/ParquetRecordCursorProvider.java @@ -68,6 +68,7 @@ public Optional createRecordCursor( Path path, long start, long length, + long fileSize, Properties schema, List columns, TupleDomain effectivePredicate, @@ -85,6 +86,7 @@ public Optional createRecordCursor( path, start, length, + fileSize, schema, columns, useParquetColumnNames, diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/parquet/reader/ParquetMetadataReader.java b/presto-hive/src/main/java/com/facebook/presto/hive/parquet/reader/ParquetMetadataReader.java index 1b9a249f0f9a..c080f30fda1d 100644 --- a/presto-hive/src/main/java/com/facebook/presto/hive/parquet/reader/ParquetMetadataReader.java +++ b/presto-hive/src/main/java/com/facebook/presto/hive/parquet/reader/ParquetMetadataReader.java @@ -14,7 +14,6 @@ package com.facebook.presto.hive.parquet.reader; import org.apache.hadoop.fs.FSDataInputStream; -import org.apache.hadoop.fs.FileStatus; import org.apache.hadoop.fs.FileSystem; import org.apache.hadoop.fs.Path; import parquet.format.ColumnChunk; @@ -62,10 +61,9 @@ public final class ParquetMetadataReader private ParquetMetadataReader() {} - public static ParquetMetadata readFooter(FileSystem fileSystem, Path file) + public static ParquetMetadata readFooter(FileSystem fileSystem, Path file, long fileSize) throws IOException { - FileStatus fileStatus = fileSystem.getFileStatus(file); try (FSDataInputStream inputStream = fileSystem.open(file)) { // Parquet File Layout: // @@ -75,9 +73,8 @@ public static ParquetMetadata readFooter(FileSystem fileSystem, Path file) // 4 bytes: MetadataLength // MAGIC - long length = fileStatus.getLen(); - validateParquet(length >= MAGIC.length + PARQUET_METADATA_LENGTH + MAGIC.length, "%s is not a valid Parquet File", file); - long metadataLengthIndex = length - PARQUET_METADATA_LENGTH - MAGIC.length; + validateParquet(fileSize >= MAGIC.length + PARQUET_METADATA_LENGTH + MAGIC.length, "%s is not a valid Parquet File", file); + long metadataLengthIndex = fileSize - PARQUET_METADATA_LENGTH - MAGIC.length; inputStream.seek(metadataLengthIndex); int metadataLength = readIntLittleEndian(inputStream); diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/parquet/reader/ParquetReader.java b/presto-hive/src/main/java/com/facebook/presto/hive/parquet/reader/ParquetReader.java index d0199f546cf9..7ac5883045e6 100644 --- a/presto-hive/src/main/java/com/facebook/presto/hive/parquet/reader/ParquetReader.java +++ b/presto-hive/src/main/java/com/facebook/presto/hive/parquet/reader/ParquetReader.java @@ -22,6 +22,7 @@ import com.facebook.presto.spi.block.Block; import com.facebook.presto.spi.block.InterleavedBlock; import com.facebook.presto.spi.block.RunLengthEncodedBlock; +import com.facebook.presto.spi.type.MapType; import com.facebook.presto.spi.type.NamedTypeSignature; import com.facebook.presto.spi.type.Type; import com.facebook.presto.spi.type.TypeManager; @@ -203,14 +204,13 @@ private Block readMap(Type type, List path, IntList elementOffsets) } return RunLengthEncodedBlock.create(parameters.get(0), null, batchSize); } - InterleavedBlock interleavedBlock = new InterleavedBlock(new Block[] {blocks[0], blocks[1]}); int[] offsets = new int[batchSize + 1]; for (int i = 1; i < offsets.length; i++) { - int elementPositionCount = keyOffsets.getInt(i - 1) * 2; - elementOffsets.add(elementPositionCount); + int elementPositionCount = keyOffsets.getInt(i - 1); + elementOffsets.add(elementPositionCount * 2); offsets[i] = offsets[i - 1] + elementPositionCount; } - return new ArrayBlock(batchSize, new boolean[batchSize], offsets, interleavedBlock); + return ((MapType) type).createBlockFromKeyValue(new boolean[batchSize], offsets, blocks[0], blocks[1]); } public Block readStruct(Type type, List path) diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/parquet/reader/ParquetShortDecimalColumnReader.java b/presto-hive/src/main/java/com/facebook/presto/hive/parquet/reader/ParquetShortDecimalColumnReader.java index b96dd1efe59e..b47d42115baa 100644 --- a/presto-hive/src/main/java/com/facebook/presto/hive/parquet/reader/ParquetShortDecimalColumnReader.java +++ b/presto-hive/src/main/java/com/facebook/presto/hive/parquet/reader/ParquetShortDecimalColumnReader.java @@ -59,7 +59,15 @@ else if (columnDescriptor.getType().equals(INT64)) { protected void skipValue() { if (definitionLevel == columnDescriptor.getMaxDefinitionLevel()) { - valuesReader.readBytes(); + if (columnDescriptor.getType().equals(INT32)) { + valuesReader.readInteger(); + } + else if (columnDescriptor.getType().equals(INT64)) { + valuesReader.readLong(); + } + else { + valuesReader.readBytes(); + } } } } diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/rcfile/RcFilePageSourceFactory.java b/presto-hive/src/main/java/com/facebook/presto/hive/rcfile/RcFilePageSourceFactory.java index ec845329b0f7..05b3adbd2459 100644 --- a/presto-hive/src/main/java/com/facebook/presto/hive/rcfile/RcFilePageSourceFactory.java +++ b/presto-hive/src/main/java/com/facebook/presto/hive/rcfile/RcFilePageSourceFactory.java @@ -94,6 +94,7 @@ public Optional createPageSource( Path path, long start, long length, + long fileSize, Properties schema, List columns, TupleDomain effectivePredicate, @@ -111,11 +112,9 @@ else if (deserializerClassName.equals(ColumnarSerDe.class.getName())) { return Optional.empty(); } - long size; FSDataInputStream inputStream; try { FileSystem fileSystem = hdfsEnvironment.getFileSystem(session.getUser(), path, configuration); - size = fileSystem.getFileStatus(path).getLen(); inputStream = fileSystem.open(path); } catch (Exception e) { @@ -133,7 +132,7 @@ else if (deserializerClassName.equals(ColumnarSerDe.class.getName())) { } RcFileReader rcFileReader = new RcFileReader( - new HdfsRcFileDataSource(path.toString(), inputStream, size, stats), + new HdfsRcFileDataSource(path.toString(), inputStream, fileSize, stats), rcFileEncoding, readColumns.build(), new AircompressorCodecFactory(new HadoopCodecFactory(configuration.getClassLoader())), diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/security/LegacyAccessControl.java b/presto-hive/src/main/java/com/facebook/presto/hive/security/LegacyAccessControl.java index 731b96e01e63..788ec92c740b 100644 --- a/presto-hive/src/main/java/com/facebook/presto/hive/security/LegacyAccessControl.java +++ b/presto-hive/src/main/java/com/facebook/presto/hive/security/LegacyAccessControl.java @@ -29,6 +29,7 @@ import java.util.function.Function; import static com.facebook.presto.spi.security.AccessDeniedException.denyAddColumn; +import static com.facebook.presto.spi.security.AccessDeniedException.denyDropColumn; import static com.facebook.presto.spi.security.AccessDeniedException.denyDropTable; import static com.facebook.presto.spi.security.AccessDeniedException.denyRenameColumn; import static com.facebook.presto.spi.security.AccessDeniedException.denyRenameTable; @@ -41,6 +42,7 @@ public class LegacyAccessControl private final boolean allowDropTable; private final boolean allowRenameTable; private final boolean allowAddColumn; + private final boolean allowDropColumn; private final boolean allowRenameColumn; @Inject @@ -54,6 +56,7 @@ public LegacyAccessControl( allowDropTable = securityConfig.getAllowDropTable(); allowRenameTable = securityConfig.getAllowRenameTable(); allowAddColumn = securityConfig.getAllowAddColumn(); + allowDropColumn = securityConfig.getAllowDropColumn(); allowRenameColumn = securityConfig.getAllowRenameColumn(); } @@ -133,6 +136,14 @@ public void checkCanAddColumn(ConnectorTransactionHandle transaction, Identity i } } + @Override + public void checkCanDropColumn(ConnectorTransactionHandle transactionHandle, Identity identity, SchemaTableName tableName) + { + if (!allowDropColumn) { + denyDropColumn(tableName.toString()); + } + } + @Override public void checkCanRenameColumn(ConnectorTransactionHandle transaction, Identity identity, SchemaTableName tableName) { diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/security/LegacySecurityConfig.java b/presto-hive/src/main/java/com/facebook/presto/hive/security/LegacySecurityConfig.java index 72c3ac414073..c1573809deda 100644 --- a/presto-hive/src/main/java/com/facebook/presto/hive/security/LegacySecurityConfig.java +++ b/presto-hive/src/main/java/com/facebook/presto/hive/security/LegacySecurityConfig.java @@ -19,6 +19,7 @@ public class LegacySecurityConfig { private boolean allowAddColumn; + private boolean allowDropColumn; private boolean allowDropTable; private boolean allowRenameTable; private boolean allowRenameColumn; @@ -36,6 +37,19 @@ public LegacySecurityConfig setAllowAddColumn(boolean allowAddColumn) return this; } + public boolean getAllowDropColumn() + { + return this.allowDropColumn; + } + + @Config("hive.allow-drop-column") + @ConfigDescription("Allow Hive connector to drop column") + public LegacySecurityConfig setAllowDropColumn(boolean allowDropColumn) + { + this.allowDropColumn = allowDropColumn; + return this; + } + public boolean getAllowDropTable() { return this.allowDropTable; diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/security/SqlStandardAccessControl.java b/presto-hive/src/main/java/com/facebook/presto/hive/security/SqlStandardAccessControl.java index 8b0fa940109c..fee66951bd9d 100644 --- a/presto-hive/src/main/java/com/facebook/presto/hive/security/SqlStandardAccessControl.java +++ b/presto-hive/src/main/java/com/facebook/presto/hive/security/SqlStandardAccessControl.java @@ -43,6 +43,7 @@ import static com.facebook.presto.spi.security.AccessDeniedException.denyCreateView; import static com.facebook.presto.spi.security.AccessDeniedException.denyCreateViewWithSelect; import static com.facebook.presto.spi.security.AccessDeniedException.denyDeleteTable; +import static com.facebook.presto.spi.security.AccessDeniedException.denyDropColumn; import static com.facebook.presto.spi.security.AccessDeniedException.denyDropSchema; import static com.facebook.presto.spi.security.AccessDeniedException.denyDropTable; import static com.facebook.presto.spi.security.AccessDeniedException.denyDropView; @@ -159,6 +160,14 @@ public void checkCanAddColumn(ConnectorTransactionHandle transaction, Identity i } } + @Override + public void checkCanDropColumn(ConnectorTransactionHandle transaction, Identity identity, SchemaTableName tableName) + { + if (!checkTablePermission(transaction, identity, tableName, OWNERSHIP)) { + denyDropColumn(tableName.toString()); + } + } + @Override public void checkCanRenameColumn(ConnectorTransactionHandle transaction, Identity identity, SchemaTableName tableName) { diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/thrift/Transport.java b/presto-hive/src/main/java/com/facebook/presto/hive/thrift/Transport.java index b9bdd521bdae..a8ffe3df7019 100644 --- a/presto-hive/src/main/java/com/facebook/presto/hive/thrift/Transport.java +++ b/presto-hive/src/main/java/com/facebook/presto/hive/thrift/Transport.java @@ -82,7 +82,7 @@ private static void closeQuietly(Closeable closeable) private static Socket createSocksSocket(HostAndPort proxy) { - SocketAddress address = InetSocketAddress.createUnresolved(proxy.getHostText(), proxy.getPort()); + SocketAddress address = InetSocketAddress.createUnresolved(proxy.getHost(), proxy.getPort()); return new Socket(new Proxy(Proxy.Type.SOCKS, address)); } diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/util/ConfigurationUtils.java b/presto-hive/src/main/java/com/facebook/presto/hive/util/ConfigurationUtils.java index 27d9d0854c3e..b647ab89791c 100644 --- a/presto-hive/src/main/java/com/facebook/presto/hive/util/ConfigurationUtils.java +++ b/presto-hive/src/main/java/com/facebook/presto/hive/util/ConfigurationUtils.java @@ -14,6 +14,7 @@ package com.facebook.presto.hive.util; import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.mapred.JobConf; import java.util.Map; @@ -27,4 +28,12 @@ public static void copy(Configuration from, Configuration to) to.set(entry.getKey(), entry.getValue()); } } + + public static JobConf toJobConf(Configuration conf) + { + if (conf instanceof JobConf) { + return (JobConf) conf; + } + return new JobConf(conf); + } } diff --git a/presto-hive/src/main/java/com/facebook/presto/twitter/hive/thrift/ThriftHiveRecordCursorProvider.java b/presto-hive/src/main/java/com/facebook/presto/twitter/hive/thrift/ThriftHiveRecordCursorProvider.java index ed2cf1b9defd..9373a4c5232c 100644 --- a/presto-hive/src/main/java/com/facebook/presto/twitter/hive/thrift/ThriftHiveRecordCursorProvider.java +++ b/presto-hive/src/main/java/com/facebook/presto/twitter/hive/thrift/ThriftHiveRecordCursorProvider.java @@ -68,6 +68,7 @@ public Optional createRecordCursor( Path path, long start, long length, + long fileSize, Properties schema, List columns, TupleDomain effectivePredicate, diff --git a/presto-hive/src/test/java/com/facebook/presto/hive/AbstractTestHiveClient.java b/presto-hive/src/test/java/com/facebook/presto/hive/AbstractTestHiveClient.java index 4526d4e215ee..3cc69ce64418 100644 --- a/presto-hive/src/test/java/com/facebook/presto/hive/AbstractTestHiveClient.java +++ b/presto-hive/src/test/java/com/facebook/presto/hive/AbstractTestHiveClient.java @@ -66,6 +66,8 @@ import com.facebook.presto.spi.predicate.Range; import com.facebook.presto.spi.predicate.TupleDomain; import com.facebook.presto.spi.predicate.ValueSet; +import com.facebook.presto.spi.type.ArrayType; +import com.facebook.presto.spi.type.MapType; import com.facebook.presto.spi.type.NamedTypeSignature; import com.facebook.presto.spi.type.SqlDate; import com.facebook.presto.spi.type.SqlTimestamp; @@ -78,8 +80,6 @@ import com.facebook.presto.testing.MaterializedRow; import com.facebook.presto.testing.TestingConnectorSession; import com.facebook.presto.testing.TestingNodeManager; -import com.facebook.presto.type.ArrayType; -import com.facebook.presto.type.MapType; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableMultimap; @@ -124,7 +124,6 @@ import static com.facebook.presto.hive.HiveColumnHandle.ColumnType.PARTITION_KEY; import static com.facebook.presto.hive.HiveColumnHandle.ColumnType.REGULAR; import static com.facebook.presto.hive.HiveErrorCode.HIVE_INVALID_PARTITION_VALUE; -import static com.facebook.presto.hive.HiveErrorCode.HIVE_METASTORE_ERROR; import static com.facebook.presto.hive.HiveErrorCode.HIVE_PARTITION_SCHEMA_MISMATCH; import static com.facebook.presto.hive.HiveMetadata.PRESTO_QUERY_ID_NAME; import static com.facebook.presto.hive.HiveMetadata.PRESTO_VERSION_NAME; @@ -2155,7 +2154,7 @@ private void doInsertIntoNewPartition(HiveStorageFormat storageFormat, SchemaTab try (Transaction transaction = newTransaction()) { // verify partitions were created List partitionNames = transaction.getMetastore(tableName.getSchemaName()).getPartitionNames(tableName.getSchemaName(), tableName.getTableName()) - .orElseThrow(() -> new PrestoException(HIVE_METASTORE_ERROR, "Partition metadata not available")); + .orElseThrow(() -> new AssertionError("Table does not exist: " + tableName)); assertEqualsIgnoreOrder(partitionNames, CREATE_TABLE_PARTITIONED_DATA.getMaterializedRows().stream() .map(row -> "ds=" + row.getField(CREATE_TABLE_PARTITIONED_DATA.getTypes().size() - 1)) .collect(toList())); @@ -2272,7 +2271,7 @@ private void doInsertIntoExistingPartition(HiveStorageFormat storageFormat, Sche // verify partitions were created List partitionNames = transaction.getMetastore(tableName.getSchemaName()).getPartitionNames(tableName.getSchemaName(), tableName.getTableName()) - .orElseThrow(() -> new PrestoException(HIVE_METASTORE_ERROR, "Partition metadata not available")); + .orElseThrow(() -> new AssertionError("Table does not exist: " + tableName)); assertEqualsIgnoreOrder(partitionNames, CREATE_TABLE_PARTITIONED_DATA.getMaterializedRows().stream() .map(row -> "ds=" + row.getField(CREATE_TABLE_PARTITIONED_DATA.getTypes().size() - 1)) .collect(toList())); @@ -2395,7 +2394,7 @@ private void doTestMetadataDelete(HiveStorageFormat storageFormat, SchemaTableNa // verify partitions were created List partitionNames = transaction.getMetastore(tableName.getSchemaName()).getPartitionNames(tableName.getSchemaName(), tableName.getTableName()) - .orElseThrow(() -> new PrestoException(HIVE_METASTORE_ERROR, "Partition metadata not available")); + .orElseThrow(() -> new AssertionError("Table does not exist: " + tableName)); assertEqualsIgnoreOrder(partitionNames, CREATE_TABLE_PARTITIONED_DATA.getMaterializedRows().stream() .map(row -> "ds=" + row.getField(CREATE_TABLE_PARTITIONED_DATA.getTypes().size() - 1)) .collect(toList())); @@ -3253,7 +3252,7 @@ private void doTestTransactionDeleteInsert( // verify partitions List partitionNames = transaction.getMetastore(tableName.getSchemaName()) .getPartitionNames(tableName.getSchemaName(), tableName.getTableName()) - .orElseThrow(() -> new PrestoException(HIVE_METASTORE_ERROR, "Partition metadata not available")); + .orElseThrow(() -> new AssertionError("Table does not exist: " + tableName)); assertEqualsIgnoreOrder( partitionNames, expectedData.getMaterializedRows().stream() diff --git a/presto-hive/src/test/java/com/facebook/presto/hive/AbstractTestHiveFileFormats.java b/presto-hive/src/test/java/com/facebook/presto/hive/AbstractTestHiveFileFormats.java index 2b132c688354..ea56b623f34d 100644 --- a/presto-hive/src/test/java/com/facebook/presto/hive/AbstractTestHiveFileFormats.java +++ b/presto-hive/src/test/java/com/facebook/presto/hive/AbstractTestHiveFileFormats.java @@ -23,10 +23,12 @@ import com.facebook.presto.spi.block.Block; import com.facebook.presto.spi.block.BlockBuilder; import com.facebook.presto.spi.block.BlockBuilderStatus; +import com.facebook.presto.spi.type.ArrayType; import com.facebook.presto.spi.type.CharType; import com.facebook.presto.spi.type.DateType; import com.facebook.presto.spi.type.DecimalType; import com.facebook.presto.spi.type.Decimals; +import com.facebook.presto.spi.type.RowType; import com.facebook.presto.spi.type.SqlDate; import com.facebook.presto.spi.type.SqlDecimal; import com.facebook.presto.spi.type.SqlTimestamp; @@ -36,8 +38,6 @@ import com.facebook.presto.testing.MaterializedResult; import com.facebook.presto.testing.MaterializedRow; import com.facebook.presto.tests.StructuralTestUtil; -import com.facebook.presto.type.ArrayType; -import com.facebook.presto.type.RowType; import com.google.common.base.Joiner; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; diff --git a/presto-hive/src/test/java/com/facebook/presto/hive/HiveTestUtils.java b/presto-hive/src/test/java/com/facebook/presto/hive/HiveTestUtils.java index 8fa624c8e28e..0c507cfd6835 100644 --- a/presto-hive/src/test/java/com/facebook/presto/hive/HiveTestUtils.java +++ b/presto-hive/src/test/java/com/facebook/presto/hive/HiveTestUtils.java @@ -23,12 +23,12 @@ import com.facebook.presto.metadata.FunctionRegistry; import com.facebook.presto.spi.ColumnHandle; import com.facebook.presto.spi.ConnectorSession; +import com.facebook.presto.spi.type.MapType; import com.facebook.presto.spi.type.StandardTypes; import com.facebook.presto.spi.type.Type; import com.facebook.presto.spi.type.TypeSignatureParameter; import com.facebook.presto.sql.analyzer.FeaturesConfig; import com.facebook.presto.testing.TestingConnectorSession; -import com.facebook.presto.type.MapType; import com.facebook.presto.type.TypeRegistry; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; diff --git a/presto-hive/src/test/java/com/facebook/presto/hive/TestHiveClientConfig.java b/presto-hive/src/test/java/com/facebook/presto/hive/TestHiveClientConfig.java index 37bca3e51553..b7e90a43c3c7 100644 --- a/presto-hive/src/test/java/com/facebook/presto/hive/TestHiveClientConfig.java +++ b/presto-hive/src/test/java/com/facebook/presto/hive/TestHiveClientConfig.java @@ -80,7 +80,9 @@ public void testDefaults() .setOrcMaxMergeDistance(new DataSize(1, Unit.MEGABYTE)) .setOrcMaxBufferSize(new DataSize(8, Unit.MEGABYTE)) .setOrcStreamBufferSize(new DataSize(8, Unit.MEGABYTE)) - .setRcfileOptimizedWriterEnabled(false) + .setOrcMaxReadBlockSize(new DataSize(16, Unit.MEGABYTE)) + .setRcfileOptimizedWriterEnabled(true) + .setRcfileWriterValidate(false) .setHiveMetastoreAuthenticationType(HiveMetastoreAuthenticationType.NONE) .setHdfsAuthenticationType(HdfsAuthenticationType.NONE) .setHdfsImpersonationEnabled(false) @@ -139,7 +141,9 @@ public void testExplicitPropertyMappings() .put("hive.orc.max-merge-distance", "22kB") .put("hive.orc.max-buffer-size", "44kB") .put("hive.orc.stream-buffer-size", "55kB") - .put("hive.rcfile-optimized-writer.enabled", "true") + .put("hive.orc.max-read-block-size", "66kB") + .put("hive.rcfile-optimized-writer.enabled", "false") + .put("hive.rcfile.writer.validate", "true") .put("hive.metastore.authentication.type", "KERBEROS") .put("hive.hdfs.authentication.type", "KERBEROS") .put("hive.hdfs.impersonation.enabled", "true") @@ -195,7 +199,9 @@ public void testExplicitPropertyMappings() .setOrcMaxMergeDistance(new DataSize(22, Unit.KILOBYTE)) .setOrcMaxBufferSize(new DataSize(44, Unit.KILOBYTE)) .setOrcStreamBufferSize(new DataSize(55, Unit.KILOBYTE)) - .setRcfileOptimizedWriterEnabled(true) + .setOrcMaxReadBlockSize(new DataSize(66, Unit.KILOBYTE)) + .setRcfileOptimizedWriterEnabled(false) + .setRcfileWriterValidate(true) .setHiveMetastoreAuthenticationType(HiveMetastoreAuthenticationType.KERBEROS) .setHdfsAuthenticationType(HdfsAuthenticationType.KERBEROS) .setHdfsImpersonationEnabled(true) diff --git a/presto-hive/src/test/java/com/facebook/presto/hive/TestHiveFileFormats.java b/presto-hive/src/test/java/com/facebook/presto/hive/TestHiveFileFormats.java index 8aaac9350b64..ba9633c73d0a 100644 --- a/presto-hive/src/test/java/com/facebook/presto/hive/TestHiveFileFormats.java +++ b/presto-hive/src/test/java/com/facebook/presto/hive/TestHiveFileFormats.java @@ -24,12 +24,12 @@ import com.facebook.presto.spi.RecordCursor; import com.facebook.presto.spi.RecordPageSource; import com.facebook.presto.spi.predicate.TupleDomain; +import com.facebook.presto.spi.type.ArrayType; +import com.facebook.presto.spi.type.RowType; import com.facebook.presto.testing.TestingConnectorSession; import com.facebook.presto.twitter.hive.thrift.HiveThriftFieldIdResolverFactory; import com.facebook.presto.twitter.hive.thrift.ThriftGenericRow; import com.facebook.presto.twitter.hive.thrift.ThriftHiveRecordCursorProvider; -import com.facebook.presto.type.ArrayType; -import com.facebook.presto.type.RowType; import com.google.common.base.Joiner; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; @@ -698,6 +698,7 @@ private void testCursorProvider(HiveRecordCursorProvider cursorProvider, OptionalInt.empty(), split.getStart(), split.getLength(), + split.getLength(), splitProperties, TupleDomain.all(), getColumnHandles(testColumns), @@ -742,6 +743,7 @@ private void testPageSourceFactory(HivePageSourceFactory sourceFactory, OptionalInt.empty(), split.getStart(), split.getLength(), + split.getLength(), splitProperties, TupleDomain.all(), columnHandles, diff --git a/presto-hive/src/test/java/com/facebook/presto/hive/TestHiveIntegrationSmokeTest.java b/presto-hive/src/test/java/com/facebook/presto/hive/TestHiveIntegrationSmokeTest.java index a6eb49489530..1e18af8d72ac 100644 --- a/presto-hive/src/test/java/com/facebook/presto/hive/TestHiveIntegrationSmokeTest.java +++ b/presto-hive/src/test/java/com/facebook/presto/hive/TestHiveIntegrationSmokeTest.java @@ -1861,6 +1861,29 @@ public void testRenameColumn() assertUpdate("DROP TABLE test_rename_column"); } + @Test + public void testDropColumn() + throws Exception + { + @Language("SQL") String createTable = "" + + "CREATE TABLE test_drop_column\n" + + "WITH (\n" + + " partitioned_by = ARRAY ['orderstatus']\n" + + ")\n" + + "AS\n" + + "SELECT custkey, orderkey, orderstatus FROM orders"; + + assertUpdate(createTable, "SELECT count(*) FROM orders"); + assertQuery("SELECT orderkey, orderstatus FROM test_drop_column", "SELECT orderkey, orderstatus FROM orders"); + + assertQueryFails("ALTER TABLE test_drop_column DROP COLUMN orderstatus", "Cannot drop partition columns"); + assertUpdate("ALTER TABLE test_drop_column DROP COLUMN orderkey"); + assertQueryFails("ALTER TABLE test_drop_column DROP COLUMN custkey", "Cannot drop the only column in a table"); + assertQuery("SELECT * FROM test_drop_column", "SELECT custkey, orderstatus FROM orders"); + + assertUpdate("DROP TABLE test_drop_column"); + } + @Test public void testAvroTypeValidation() { diff --git a/presto-hive/src/test/java/com/facebook/presto/hive/TestHivePageSink.java b/presto-hive/src/test/java/com/facebook/presto/hive/TestHivePageSink.java index 3b43955a88d2..508b4e3ac4be 100644 --- a/presto-hive/src/test/java/com/facebook/presto/hive/TestHivePageSink.java +++ b/presto-hive/src/test/java/com/facebook/presto/hive/TestHivePageSink.java @@ -209,7 +209,7 @@ private static ConnectorPageSource createPageSource(HiveTransactionHandle transa splitProperties.setProperty(SERIALIZATION_LIB, config.getHiveStorageFormat().getSerDe()); splitProperties.setProperty("columns", Joiner.on(',').join(getColumnHandles().stream().map(HiveColumnHandle::getName).collect(toList()))); splitProperties.setProperty("columns.types", Joiner.on(',').join(getColumnHandles().stream().map(HiveColumnHandle::getHiveType).map(HiveType::getHiveTypeName).collect(toList()))); - HiveSplit split = new HiveSplit(CLIENT_ID, SCHEMA_NAME, TABLE_NAME, "", "file:///" + outputFile.getAbsolutePath(), 0, outputFile.length(), splitProperties, ImmutableList.of(), ImmutableList.of(), OptionalInt.empty(), false, TupleDomain.all(), ImmutableMap.of()); + HiveSplit split = new HiveSplit(CLIENT_ID, SCHEMA_NAME, TABLE_NAME, "", "file:///" + outputFile.getAbsolutePath(), 0, outputFile.length(), outputFile.length(), splitProperties, ImmutableList.of(), ImmutableList.of(), OptionalInt.empty(), false, TupleDomain.all(), ImmutableMap.of()); HivePageSourceProvider provider = new HivePageSourceProvider(config, createTestHdfsEnvironment(config), getDefaultHiveRecordCursorProvider(config), getDefaultHiveDataStreamFactories(config), TYPE_MANAGER); return provider.createPageSource(transaction, getSession(config), split, ImmutableList.copyOf(getColumnHandles())); } diff --git a/presto-hive/src/test/java/com/facebook/presto/hive/TestHiveSplit.java b/presto-hive/src/test/java/com/facebook/presto/hive/TestHiveSplit.java index 9b9ec00dbdfb..fcae1dee939c 100644 --- a/presto-hive/src/test/java/com/facebook/presto/hive/TestHiveSplit.java +++ b/presto-hive/src/test/java/com/facebook/presto/hive/TestHiveSplit.java @@ -47,6 +47,7 @@ public void testJsonRoundTrip() "path", 42, 88, + 88, schema, partitionKeys, addresses, @@ -65,6 +66,7 @@ public void testJsonRoundTrip() assertEquals(actual.getPath(), expected.getPath()); assertEquals(actual.getStart(), expected.getStart()); assertEquals(actual.getLength(), expected.getLength()); + assertEquals(actual.getFileSize(), expected.getFileSize()); assertEquals(actual.getSchema(), expected.getSchema()); assertEquals(actual.getPartitionKeys(), expected.getPartitionKeys()); assertEquals(actual.getAddresses(), expected.getAddresses()); diff --git a/presto-hive/src/test/java/com/facebook/presto/hive/TestOrcPageSourceMemoryTracking.java b/presto-hive/src/test/java/com/facebook/presto/hive/TestOrcPageSourceMemoryTracking.java index 771f07b0f6a7..f2fbd7b518cb 100644 --- a/presto-hive/src/test/java/com/facebook/presto/hive/TestOrcPageSourceMemoryTracking.java +++ b/presto-hive/src/test/java/com/facebook/presto/hive/TestOrcPageSourceMemoryTracking.java @@ -25,6 +25,7 @@ import com.facebook.presto.operator.project.PageProcessor; import com.facebook.presto.spi.ColumnHandle; import com.facebook.presto.spi.ConnectorPageSource; +import com.facebook.presto.spi.ConnectorSession; import com.facebook.presto.spi.Page; import com.facebook.presto.spi.block.Block; import com.facebook.presto.spi.classloader.ThreadContextClassLoader; @@ -33,6 +34,7 @@ import com.facebook.presto.sql.gen.ExpressionCompiler; import com.facebook.presto.sql.planner.plan.PlanNodeId; import com.facebook.presto.sql.relational.RowExpression; +import com.facebook.presto.testing.TestingConnectorSession; import com.facebook.presto.testing.TestingSplit; import com.facebook.presto.testing.TestingTransactionHandle; import com.google.common.base.Joiner; @@ -41,6 +43,8 @@ import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; import io.airlift.slice.Slice; +import io.airlift.stats.Distribution; +import io.airlift.units.DataSize; import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.fs.FileSystem; import org.apache.hadoop.fs.Path; @@ -67,6 +71,7 @@ import org.joda.time.DateTimeZone; import org.testng.annotations.AfterClass; import org.testng.annotations.BeforeClass; +import org.testng.annotations.DataProvider; import org.testng.annotations.Test; import java.io.File; @@ -74,6 +79,7 @@ import java.lang.reflect.Constructor; import java.lang.reflect.Field; import java.lang.reflect.Method; +import java.util.Arrays; import java.util.List; import java.util.Optional; import java.util.OptionalInt; @@ -89,6 +95,7 @@ import static com.facebook.presto.hive.HiveTestUtils.SESSION; import static com.facebook.presto.hive.HiveTestUtils.TYPE_MANAGER; import static com.facebook.presto.metadata.MetadataManager.createTestMetadataManager; +import static com.facebook.presto.orc.OrcReader.MAX_BATCH_SIZE; import static com.facebook.presto.spi.type.VarcharType.createUnboundedVarcharType; import static com.facebook.presto.sql.relational.Expressions.field; import static com.facebook.presto.testing.TestingSession.testSessionBuilder; @@ -98,6 +105,7 @@ import static com.google.common.collect.Iterables.transform; import static io.airlift.concurrent.Threads.daemonThreadsNamed; import static io.airlift.testing.Assertions.assertBetweenInclusive; +import static io.airlift.units.DataSize.Unit.BYTE; import static java.util.Objects.requireNonNull; import static java.util.concurrent.Executors.newCachedThreadPool; import static java.util.stream.Collectors.toList; @@ -132,6 +140,12 @@ public class TestOrcPageSourceMemoryTracking private File tempFile; private TestPreparer testPreparer; + @DataProvider(name = "rowCount") + public static Object[][] rowCount() + { + return new Object[][] { { 50_000 }, { 10_000 }, { 5_000 } }; + } + @BeforeClass public void setUp() throws Exception @@ -225,6 +239,63 @@ public void testPageSource() assertEquals((int) stats.getLoadedBlockBytes().getAllTime().getCount(), 50); } + @Test(dataProvider = "rowCount") + public void testMaxReadBytes(int rowCount) + throws Exception + { + int maxReadBytes = 1_000; + HiveClientConfig config = new HiveClientConfig(); + config.setOrcMaxReadBlockSize(new DataSize(maxReadBytes, BYTE)); + ConnectorSession session = new TestingConnectorSession(new HiveSessionProperties(config).getSessionProperties()); + FileFormatDataSourceStats stats = new FileFormatDataSourceStats(); + + // Build a table where every row gets larger, so we can test that the "batchSize" reduces + int numColumns = 5; + int step = 250; + ImmutableList.Builder columnBuilder = ImmutableList.builder() + .add(new TestColumn("p_empty_string", javaStringObjectInspector, () -> "", true)); + GrowingTestColumn[] dataColumns = new GrowingTestColumn[numColumns]; + for (int i = 0; i < numColumns; i++) { + dataColumns[i] = new GrowingTestColumn("p_string", javaStringObjectInspector, () -> Long.toHexString(random.nextLong()), false, step * (i + 1)); + columnBuilder.add(dataColumns[i]); + } + List testColumns = columnBuilder.build(); + File tempFile = File.createTempFile("presto_test_orc_page_source_max_read_bytes", "orc"); + tempFile.delete(); + + TestPreparer testPreparer = new TestPreparer(tempFile.getAbsolutePath(), testColumns, rowCount, rowCount); + ConnectorPageSource pageSource = testPreparer.newPageSource(stats, session); + + try { + int positionCount = 0; + while (true) { + Page page = pageSource.getNextPage(); + if (pageSource.isFinished()) { + break; + } + assertNotNull(page); + page.assureLoaded(); + positionCount += page.getPositionCount(); + // assert upper bound is tight + // ignore the first MAX_BATCH_SIZE rows given the sizes are set when loading the blocks + if (positionCount > MAX_BATCH_SIZE) { + // either the block is bounded by maxReadBytes or we just load one single large block + // an error margin MAX_BATCH_SIZE / step is needed given the block sizes are increasing + assertTrue(page.getSizeInBytes() < maxReadBytes * (MAX_BATCH_SIZE / step) || 1 == page.getPositionCount()); + } + } + + // verify the stats are correctly recorded + Distribution distribution = stats.getMaxCombinedBytesPerRow().getAllTime(); + assertEquals((int) distribution.getCount(), 1); + assertEquals((int) distribution.getMax(), Arrays.stream(dataColumns).mapToInt(GrowingTestColumn::getMaxSize).sum()); + pageSource.close(); + } + finally { + tempFile.delete(); + } + } + @Test public void testTableScanOperator() throws Exception @@ -323,6 +394,12 @@ private class TestPreparer public TestPreparer(String tempFilePath) throws Exception + { + this(tempFilePath, testColumns, NUM_ROWS, STRIPE_ROWS); + } + + public TestPreparer(String tempFilePath, List testColumns, int numRows, int stripeRows) + throws Exception { OrcSerde serde = new OrcSerde(); schema = new Properties(); @@ -359,10 +436,20 @@ public TestPreparer(String tempFilePath) columns = columnsBuilder.build(); types = typesBuilder.build(); - fileSplit = createTestFile(tempFilePath, new OrcOutputFormat(), serde, null, testColumns, NUM_ROWS); + fileSplit = createTestFile(tempFilePath, new OrcOutputFormat(), serde, null, testColumns, numRows, stripeRows); + } + + public ConnectorPageSource newPageSource() + { + return newPageSource(new FileFormatDataSourceStats(), SESSION); } public ConnectorPageSource newPageSource(FileFormatDataSourceStats stats) + { + return newPageSource(stats, SESSION); + } + + public ConnectorPageSource newPageSource(FileFormatDataSourceStats stats, ConnectorSession session) { OrcPageSourceFactory orcPageSourceFactory = new OrcPageSourceFactory(TYPE_MANAGER, false, HDFS_ENVIRONMENT, stats); return HivePageSourceProvider.createHivePageSource( @@ -370,11 +457,12 @@ public ConnectorPageSource newPageSource(FileFormatDataSourceStats stats) ImmutableSet.of(orcPageSourceFactory), "test", new Configuration(), - SESSION, + session, fileSplit.getPath(), OptionalInt.empty(), fileSplit.getStart(), fileSplit.getLength(), + fileSplit.getLength(), schema, TupleDomain.all(), columns, @@ -387,7 +475,7 @@ public ConnectorPageSource newPageSource(FileFormatDataSourceStats stats) public SourceOperator newTableScanOperator(DriverContext driverContext) { - ConnectorPageSource pageSource = newPageSource(new FileFormatDataSourceStats()); + ConnectorPageSource pageSource = newPageSource(); SourceOperatorFactory sourceOperatorFactory = new TableScanOperatorFactory( 0, new PlanNodeId("0"), @@ -402,7 +490,7 @@ public SourceOperator newTableScanOperator(DriverContext driverContext) public SourceOperator newScanFilterAndProjectOperator(DriverContext driverContext) { - ConnectorPageSource pageSource = newPageSource(new FileFormatDataSourceStats()); + ConnectorPageSource pageSource = newPageSource(); ImmutableList.Builder projectionsBuilder = ImmutableList.builder(); for (int i = 0; i < types.size(); i++) { projectionsBuilder.add(field(i, types.get(i))); @@ -437,7 +525,8 @@ public static FileSplit createTestFile(String filePath, @SuppressWarnings("deprecation") SerDe serDe, String compressionCodec, List testColumns, - int numRows) + int numRows, + int stripeRows) throws Exception { // filter out partition keys, which are not written to the file @@ -477,7 +566,7 @@ public static FileSplit createTestFile(String filePath, Writable record = serDe.serialize(row, objectInspector); recordWriter.write(record); - if (rowNumber % STRIPE_ROWS == STRIPE_ROWS - 1) { + if (rowNumber % stripeRows == stripeRows - 1) { flushStripe(recordWriter); } } @@ -541,7 +630,7 @@ private static Constructor getOrcWriterConstructor() } } - public static final class TestColumn + public static class TestColumn { private final String name; private final ObjectInspector objectInspector; @@ -592,4 +681,41 @@ public String toString() return sb.toString(); } } + + public static final class GrowingTestColumn + extends TestColumn + { + private final Supplier writeValue; + private int counter; + private int step; + private int maxSize; + + public GrowingTestColumn(String name, ObjectInspector objectInspector, Supplier writeValue, boolean partitionKey, int step) + { + super(name, objectInspector, writeValue, partitionKey); + this.writeValue = writeValue; + this.counter = step; + this.step = step; + } + + @Override + public Object getWriteValue() + { + StringBuilder builder = new StringBuilder(); + String source = writeValue.get(); + for (int i = 0; i < counter / step; i++) { + builder.append(source); + } + counter++; + if (builder.length() > maxSize) { + maxSize = builder.length(); + } + return builder.toString(); + } + + public int getMaxSize() + { + return maxSize; + } + } } diff --git a/presto-hive/src/test/java/com/facebook/presto/hive/benchmark/FileFormat.java b/presto-hive/src/test/java/com/facebook/presto/hive/benchmark/FileFormat.java index 703a6a376ca7..fb8e2eb6d594 100644 --- a/presto-hive/src/test/java/com/facebook/presto/hive/benchmark/FileFormat.java +++ b/presto-hive/src/test/java/com/facebook/presto/hive/benchmark/FileFormat.java @@ -360,6 +360,7 @@ private static ConnectorPageSource createPageSource( new Path(targetFile.getAbsolutePath()), 0, targetFile.length(), + targetFile.length(), createSchema(format, columnNames, columnTypes), columnHandles, TupleDomain.all(), @@ -392,6 +393,7 @@ private static ConnectorPageSource createPageSource( new Path(targetFile.getAbsolutePath()), 0, targetFile.length(), + targetFile.length(), createSchema(format, columnNames, columnTypes), columnHandles, TupleDomain.all(), diff --git a/presto-hive/src/test/java/com/facebook/presto/hive/benchmark/HiveFileFormatBenchmark.java b/presto-hive/src/test/java/com/facebook/presto/hive/benchmark/HiveFileFormatBenchmark.java index 6fd6a89eab3b..fb32e0b9aa15 100644 --- a/presto-hive/src/test/java/com/facebook/presto/hive/benchmark/HiveFileFormatBenchmark.java +++ b/presto-hive/src/test/java/com/facebook/presto/hive/benchmark/HiveFileFormatBenchmark.java @@ -23,9 +23,9 @@ import com.facebook.presto.spi.Page; import com.facebook.presto.spi.PageBuilder; import com.facebook.presto.spi.block.BlockBuilder; +import com.facebook.presto.spi.type.ArrayType; import com.facebook.presto.spi.type.Type; import com.facebook.presto.testing.TestingConnectorSession; -import com.facebook.presto.type.ArrayType; import com.google.common.collect.ImmutableList; import io.airlift.slice.Slices; import io.airlift.tpch.OrderColumn; diff --git a/presto-hive/src/test/java/com/facebook/presto/hive/metastore/TestingHiveMetastore.java b/presto-hive/src/test/java/com/facebook/presto/hive/metastore/TestingHiveMetastore.java index f1a2503f50a6..d92737c86b2f 100644 --- a/presto-hive/src/test/java/com/facebook/presto/hive/metastore/TestingHiveMetastore.java +++ b/presto-hive/src/test/java/com/facebook/presto/hive/metastore/TestingHiveMetastore.java @@ -47,6 +47,7 @@ import static com.facebook.presto.hive.metastore.Database.DEFAULT_DATABASE_NAME; import static com.facebook.presto.hive.metastore.HivePrivilegeInfo.HivePrivilege.OWNERSHIP; import static com.facebook.presto.hive.metastore.MetastoreUtil.makePartName; +import static com.facebook.presto.hive.metastore.MetastoreUtil.verifyCanDropColumn; import static com.facebook.presto.hive.metastore.PrincipalType.ROLE; import static com.facebook.presto.hive.metastore.PrincipalType.USER; import static com.facebook.presto.spi.StandardErrorCode.ALREADY_EXISTS; @@ -353,6 +354,30 @@ public synchronized void renameColumn(String databaseName, String tableName, Str relations.put(name, newTable); } + @Override + public synchronized void dropColumn(String databaseName, String tableName, String columnName) + { + SchemaTableName name = new SchemaTableName(databaseName, tableName); + Table oldTable = getRequiredTable(name); + + verifyCanDropColumn(this, databaseName, tableName, columnName); + if (!oldTable.getColumn(columnName).isPresent()) { + throw new ColumnNotFoundException(name, columnName); + } + + ImmutableList.Builder newDataColumns = ImmutableList.builder(); + for (Column fieldSchema : oldTable.getDataColumns()) { + if (!fieldSchema.getName().equals(columnName)) { + newDataColumns.add(fieldSchema); + } + } + + Table newTable = Table.builder(oldTable) + .setDataColumns(newDataColumns.build()) + .build(); + relations.put(name, newTable); + } + @Override public synchronized Optional> getAllTables(String databaseName) { diff --git a/presto-hive/src/test/java/com/facebook/presto/hive/parquet/ParquetTester.java b/presto-hive/src/test/java/com/facebook/presto/hive/parquet/ParquetTester.java index 939986159ed2..8d6943dd9b23 100644 --- a/presto-hive/src/test/java/com/facebook/presto/hive/parquet/ParquetTester.java +++ b/presto-hive/src/test/java/com/facebook/presto/hive/parquet/ParquetTester.java @@ -164,7 +164,7 @@ private static void assertFileContents(JobConf jobConf, { Path path = new Path(tempFile.getFile().toURI()); FileSystem fileSystem = path.getFileSystem(jobConf); - ParquetMetadata parquetMetadata = ParquetMetadataReader.readFooter(fileSystem, path); + ParquetMetadata parquetMetadata = ParquetMetadataReader.readFooter(fileSystem, path, fileSystem.getFileStatus(path).getLen()); FileMetaData fileMetaData = parquetMetadata.getFileMetaData(); MessageType fileSchema = fileMetaData.getSchema(); diff --git a/presto-hive/src/test/java/com/facebook/presto/hive/security/TestLegacySecurityConfig.java b/presto-hive/src/test/java/com/facebook/presto/hive/security/TestLegacySecurityConfig.java index 3bc2bc8f11f1..0ed1a6896f90 100644 --- a/presto-hive/src/test/java/com/facebook/presto/hive/security/TestLegacySecurityConfig.java +++ b/presto-hive/src/test/java/com/facebook/presto/hive/security/TestLegacySecurityConfig.java @@ -29,6 +29,7 @@ public void testDefaults() { assertRecordedDefaults(recordDefaults(LegacySecurityConfig.class) .setAllowAddColumn(false) + .setAllowDropColumn(false) .setAllowDropTable(false) .setAllowRenameTable(false) .setAllowRenameColumn(false)); @@ -39,6 +40,7 @@ public void testExplicitPropertyMappings() { Map properties = new ImmutableMap.Builder() .put("hive.allow-add-column", "true") + .put("hive.allow-drop-column", "true") .put("hive.allow-drop-table", "true") .put("hive.allow-rename-table", "true") .put("hive.allow-rename-column", "true") @@ -46,6 +48,7 @@ public void testExplicitPropertyMappings() LegacySecurityConfig expected = new LegacySecurityConfig() .setAllowAddColumn(true) + .setAllowDropColumn(true) .setAllowDropTable(true) .setAllowRenameTable(true) .setAllowRenameColumn(true); diff --git a/presto-hive/src/test/java/com/facebook/presto/hive/util/TestSerDeUtils.java b/presto-hive/src/test/java/com/facebook/presto/hive/util/TestSerDeUtils.java index 58d0fe7ea699..401bc6264610 100644 --- a/presto-hive/src/test/java/com/facebook/presto/hive/util/TestSerDeUtils.java +++ b/presto-hive/src/test/java/com/facebook/presto/hive/util/TestSerDeUtils.java @@ -18,8 +18,8 @@ import com.facebook.presto.spi.block.BlockBuilder; import com.facebook.presto.spi.block.BlockBuilderStatus; import com.facebook.presto.spi.block.InterleavedBlockBuilder; -import com.facebook.presto.type.ArrayType; -import com.facebook.presto.type.RowType; +import com.facebook.presto.spi.type.ArrayType; +import com.facebook.presto.spi.type.RowType; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.reflect.TypeToken; diff --git a/presto-jdbc/pom.xml b/presto-jdbc/pom.xml index 7717cf7aa355..3012eb3044c1 100644 --- a/presto-jdbc/pom.xml +++ b/presto-jdbc/pom.xml @@ -5,7 +5,7 @@ com.facebook.presto presto-root - 0.179-tw-0.36 + 0.181-tw-0.37 presto-jdbc @@ -20,6 +20,12 @@ com.facebook.presto presto-client + + + javax.validation + validation-api + + @@ -28,39 +34,8 @@ - io.airlift - http-client - - - - io.airlift - configuration - - - - io.airlift - trace-token - - - - com.google.inject - guice - - - com.google.inject.extensions - guice-multibindings - - - - org.weakref - jmxutils - - - - org.eclipse.jetty - jetty-servlet - - + com.squareup.okhttp3 + okhttp @@ -72,10 +47,6 @@ io.airlift json - - javax.inject - javax.inject - com.google.inject guice @@ -97,11 +68,6 @@ guava - - com.google.code.findbugs - annotations - - com.facebook.presto @@ -144,6 +110,12 @@ concurrent test + + + com.squareup.okhttp3 + mockwebserver + test + @@ -199,12 +171,12 @@ ${shadeBase}.joda.time - org.eclipse.jetty - ${shadeBase}.jetty + okhttp3 + ${shadeBase}.okhttp3 - org.HdrHistogram - ${shadeBase}.HdrHistogram + okio + ${shadeBase}.okio @@ -212,12 +184,8 @@ *:* META-INF/maven/** - META-INF/*.xml - META-INF/services/org.eclipse.** META-INF/services/com.fasterxml.** LICENSE - *.css - *.html @@ -226,10 +194,11 @@ ** + - javax.validation:validation-api + com.squareup.okhttp3:okhttp - ** + publicsuffixes.gz diff --git a/presto-jdbc/src/main/java/com/facebook/presto/jdbc/AbstractConnectionProperty.java b/presto-jdbc/src/main/java/com/facebook/presto/jdbc/AbstractConnectionProperty.java new file mode 100644 index 000000000000..814a27e14f8d --- /dev/null +++ b/presto-jdbc/src/main/java/com/facebook/presto/jdbc/AbstractConnectionProperty.java @@ -0,0 +1,166 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.jdbc; + +import java.io.File; +import java.sql.DriverPropertyInfo; +import java.sql.SQLException; +import java.util.Optional; +import java.util.Properties; +import java.util.function.Predicate; + +import static java.lang.String.format; +import static java.util.Locale.ENGLISH; +import static java.util.Objects.requireNonNull; + +abstract class AbstractConnectionProperty + implements ConnectionProperty +{ + private final String key; + private final Optional defaultValue; + private final Predicate isRequired; + private final Predicate isAllowed; + private final Converter converter; + + protected AbstractConnectionProperty( + String key, + Optional defaultValue, + Predicate isRequired, + Predicate isAllowed, + Converter converter) + { + this.key = requireNonNull(key, "key is null"); + this.defaultValue = requireNonNull(defaultValue, "defaultValue is null"); + this.isRequired = requireNonNull(isRequired, "isRequired is null"); + this.isAllowed = requireNonNull(isAllowed, "isAllowed is null"); + this.converter = requireNonNull(converter, "converter is null"); + } + + protected AbstractConnectionProperty( + String key, + Predicate required, + Predicate allowed, + Converter converter) + { + this(key, Optional.empty(), required, allowed, converter); + } + + @Override + public String getKey() + { + return key; + } + + @Override + public Optional getDefault() + { + return defaultValue; + } + + @Override + public DriverPropertyInfo getDriverPropertyInfo(Properties mergedProperties) + { + String currentValue = mergedProperties.getProperty(key); + DriverPropertyInfo result = new DriverPropertyInfo(key, currentValue); + result.required = isRequired.test(mergedProperties); + return result; + } + + @Override + public boolean isRequired(Properties properties) + { + return isRequired.test(properties); + } + + @Override + public boolean isAllowed(Properties properties) + { + return !properties.containsKey(key) || isAllowed.test(properties); + } + + @Override + public Optional getValue(Properties properties) + throws SQLException + { + String value = properties.getProperty(key); + if (value == null) { + if (isRequired(properties)) { + throw new SQLException(format("Connection property '%s' is required", key)); + } + return Optional.empty(); + } + if (value.isEmpty()) { + throw new SQLException(format("Connection property '%s' value is empty", key)); + } + + try { + return Optional.of(converter.convert(value)); + } + catch (RuntimeException e) { + throw new SQLException(format("Connection property '%s' value is invalid: %s", key, value), e); + } + } + + @Override + public void validate(Properties properties) + throws SQLException + { + if (!isAllowed(properties)) { + throw new SQLException(format("Connection property '%s' is not allowed", key)); + } + + getValue(properties); + } + + protected static final Predicate REQUIRED = properties -> true; + protected static final Predicate NOT_REQUIRED = properties -> false; + + protected static final Predicate ALLOWED = properties -> true; + + interface Converter + { + T convert(String value); + } + + protected static final Converter STRING_CONVERTER = value -> value; + protected static final Converter FILE_CONVERTER = File::new; + + protected static final Converter BOOLEAN_CONVERTER = value -> { + switch (value.toLowerCase(ENGLISH)) { + case "true": + return true; + case "false": + return false; + } + throw new IllegalArgumentException("value must be 'true' or 'false'"); + }; + + protected interface CheckedPredicate + { + boolean test(T t) + throws SQLException; + } + + protected static Predicate checkedPredicate(CheckedPredicate predicate) + { + return t -> { + try { + return predicate.test(t); + } + catch (SQLException e) { + return false; + } + }; + } +} diff --git a/presto-jdbc/src/main/java/com/facebook/presto/jdbc/ConnectionProperties.java b/presto-jdbc/src/main/java/com/facebook/presto/jdbc/ConnectionProperties.java new file mode 100644 index 000000000000..8ed6f7a63ecc --- /dev/null +++ b/presto-jdbc/src/main/java/com/facebook/presto/jdbc/ConnectionProperties.java @@ -0,0 +1,227 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.jdbc; + +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; +import com.google.common.net.HostAndPort; + +import java.io.File; +import java.util.Map; +import java.util.Optional; +import java.util.Properties; +import java.util.Set; +import java.util.function.Predicate; + +import static com.facebook.presto.jdbc.AbstractConnectionProperty.checkedPredicate; +import static java.util.Collections.unmodifiableMap; +import static java.util.function.Function.identity; +import static java.util.stream.Collectors.toMap; + +final class ConnectionProperties +{ + public static final ConnectionProperty USER = new User(); + public static final ConnectionProperty PASSWORD = new Password(); + public static final ConnectionProperty SOCKS_PROXY = new SocksProxy(); + public static final ConnectionProperty HTTP_PROXY = new HttpProxy(); + public static final ConnectionProperty SSL = new Ssl(); + public static final ConnectionProperty SSL_TRUST_STORE_PATH = new SslTrustStorePath(); + public static final ConnectionProperty SSL_TRUST_STORE_PASSWORD = new SslTrustStorePassword(); + public static final ConnectionProperty KERBEROS_REMOTE_SERICE_NAME = new KerberosRemoteServiceName(); + public static final ConnectionProperty KERBEROS_USE_CANONICAL_HOSTNAME = new KerberosUseCanonicalHostname(); + public static final ConnectionProperty KERBEROS_PRINCIPAL = new KerberosPrincipal(); + public static final ConnectionProperty KERBEROS_CONFIG_PATH = new KerberosConfigPath(); + public static final ConnectionProperty KERBEROS_KEYTAB_PATH = new KerberosKeytabPath(); + public static final ConnectionProperty KERBEROS_CREDENTIAL_CACHE_PATH = new KerberosCredentialCachePath(); + + private static final Set> ALL_PROPERTIES = ImmutableSet.>builder() + .add(USER) + .add(PASSWORD) + .add(SOCKS_PROXY) + .add(HTTP_PROXY) + .add(SSL) + .add(SSL_TRUST_STORE_PATH) + .add(SSL_TRUST_STORE_PASSWORD) + .add(KERBEROS_REMOTE_SERICE_NAME) + .add(KERBEROS_USE_CANONICAL_HOSTNAME) + .add(KERBEROS_PRINCIPAL) + .add(KERBEROS_CONFIG_PATH) + .add(KERBEROS_KEYTAB_PATH) + .add(KERBEROS_CREDENTIAL_CACHE_PATH) + .build(); + + private static final Map> KEY_LOOKUP = unmodifiableMap(ALL_PROPERTIES.stream() + .collect(toMap(ConnectionProperty::getKey, identity()))); + + private static final Map DEFAULTS; + + static { + ImmutableMap.Builder defaults = ImmutableMap.builder(); + for (ConnectionProperty property : ALL_PROPERTIES) { + property.getDefault().ifPresent(value -> defaults.put(property.getKey(), value)); + } + DEFAULTS = defaults.build(); + } + + private ConnectionProperties() {} + + public static ConnectionProperty forKey(String propertiesKey) + { + return KEY_LOOKUP.get(propertiesKey); + } + + public static Set> allProperties() + { + return ALL_PROPERTIES; + } + + public static Map getDefaults() + { + return DEFAULTS; + } + + private static class User + extends AbstractConnectionProperty + { + public User() + { + super("user", REQUIRED, ALLOWED, STRING_CONVERTER); + } + } + + private static class Password + extends AbstractConnectionProperty + { + public Password() + { + super("password", NOT_REQUIRED, ALLOWED, STRING_CONVERTER); + } + } + + private static class SocksProxy + extends AbstractConnectionProperty + { + private static final Predicate NO_HTTP_PROXY = + checkedPredicate(properties -> !HTTP_PROXY.getValue(properties).isPresent()); + + public SocksProxy() + { + super("socksProxy", NOT_REQUIRED, NO_HTTP_PROXY, HostAndPort::fromString); + } + } + + private static class HttpProxy + extends AbstractConnectionProperty + { + private static final Predicate NO_SOCKS_PROXY = + checkedPredicate(properties -> !SOCKS_PROXY.getValue(properties).isPresent()); + + public HttpProxy() + { + super("httpProxy", NOT_REQUIRED, NO_SOCKS_PROXY, HostAndPort::fromString); + } + } + + private static class Ssl + extends AbstractConnectionProperty + { + public Ssl() + { + super("SSL", Optional.of("false"), NOT_REQUIRED, ALLOWED, BOOLEAN_CONVERTER); + } + } + + private static class SslTrustStorePath + extends AbstractConnectionProperty + { + private static final Predicate IF_SSL_ENABLED = + checkedPredicate(properties -> SSL.getValue(properties).orElse(false)); + + public SslTrustStorePath() + { + super("SSLTrustStorePath", NOT_REQUIRED, IF_SSL_ENABLED, STRING_CONVERTER); + } + } + + private static class SslTrustStorePassword + extends AbstractConnectionProperty + { + private static final Predicate IF_TRUST_STORE = + checkedPredicate(properties -> SSL_TRUST_STORE_PATH.getValue(properties).isPresent()); + + public SslTrustStorePassword() + { + super("SSLTrustStorePassword", NOT_REQUIRED, IF_TRUST_STORE, STRING_CONVERTER); + } + } + + private static class KerberosRemoteServiceName + extends AbstractConnectionProperty + { + public KerberosRemoteServiceName() + { + super("KerberosRemoteServiceName", NOT_REQUIRED, ALLOWED, STRING_CONVERTER); + } + } + + private static Predicate isKerberosEnabled() + { + return checkedPredicate(properties -> KERBEROS_REMOTE_SERICE_NAME.getValue(properties).isPresent()); + } + + private static class KerberosPrincipal + extends AbstractConnectionProperty + { + public KerberosPrincipal() + { + super("KerberosPrincipal", NOT_REQUIRED, isKerberosEnabled(), STRING_CONVERTER); + } + } + + private static class KerberosUseCanonicalHostname + extends AbstractConnectionProperty + { + public KerberosUseCanonicalHostname() + { + super("KerberosUseCanonicalHostname", Optional.of("true"), isKerberosEnabled(), ALLOWED, BOOLEAN_CONVERTER); + } + } + + private static class KerberosConfigPath + extends AbstractConnectionProperty + { + public KerberosConfigPath() + { + super("KerberosConfigPath", NOT_REQUIRED, isKerberosEnabled(), FILE_CONVERTER); + } + } + + private static class KerberosKeytabPath + extends AbstractConnectionProperty + { + public KerberosKeytabPath() + { + super("KerberosKeytabPath", NOT_REQUIRED, isKerberosEnabled(), FILE_CONVERTER); + } + } + + private static class KerberosCredentialCachePath + extends AbstractConnectionProperty + { + public KerberosCredentialCachePath() + { + super("KerberosCredentialCachePath", NOT_REQUIRED, isKerberosEnabled(), FILE_CONVERTER); + } + } +} diff --git a/presto-jdbc/src/main/java/com/facebook/presto/jdbc/ConnectionProperty.java b/presto-jdbc/src/main/java/com/facebook/presto/jdbc/ConnectionProperty.java new file mode 100644 index 000000000000..ac2e90897803 --- /dev/null +++ b/presto-jdbc/src/main/java/com/facebook/presto/jdbc/ConnectionProperty.java @@ -0,0 +1,47 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.jdbc; + +import java.sql.DriverPropertyInfo; +import java.sql.SQLException; +import java.util.Optional; +import java.util.Properties; + +import static java.lang.String.format; + +interface ConnectionProperty +{ + String getKey(); + + Optional getDefault(); + + DriverPropertyInfo getDriverPropertyInfo(Properties properties); + + boolean isRequired(Properties properties); + + boolean isAllowed(Properties properties); + + Optional getValue(Properties properties) + throws SQLException; + + default T getRequiredValue(Properties properties) + throws SQLException + { + return getValue(properties).orElseThrow(() -> + new SQLException(format("Connection property '%s' is required", getKey()))); + } + + void validate(Properties properties) + throws SQLException; +} diff --git a/presto-jdbc/src/main/java/com/facebook/presto/jdbc/PrestoConnection.java b/presto-jdbc/src/main/java/com/facebook/presto/jdbc/PrestoConnection.java index 5f62ed93c7be..4e54ac4f3b39 100644 --- a/presto-jdbc/src/main/java/com/facebook/presto/jdbc/PrestoConnection.java +++ b/presto-jdbc/src/main/java/com/facebook/presto/jdbc/PrestoConnection.java @@ -73,7 +73,7 @@ public class PrestoConnection private final AtomicReference transactionId = new AtomicReference<>(); private final QueryExecutor queryExecutor; - PrestoConnection(PrestoDriverUri uri, String user, QueryExecutor queryExecutor) + PrestoConnection(PrestoDriverUri uri, QueryExecutor queryExecutor) throws SQLException { requireNonNull(uri, "uri is null"); @@ -81,9 +81,10 @@ public class PrestoConnection this.httpUri = uri.getHttpUri(); this.schema.set(uri.getSchema()); this.catalog.set(uri.getCatalog()); + this.user = uri.getUser(); - this.user = requireNonNull(user, "user is null"); this.queryExecutor = requireNonNull(queryExecutor, "queryExecutor is null"); + timeZoneId.set(TimeZone.getDefault().getID()); locale.set(Locale.getDefault()); } diff --git a/presto-jdbc/src/main/java/com/facebook/presto/jdbc/PrestoDatabaseMetaData.java b/presto-jdbc/src/main/java/com/facebook/presto/jdbc/PrestoDatabaseMetaData.java index febb90e4b552..cc420899d678 100644 --- a/presto-jdbc/src/main/java/com/facebook/presto/jdbc/PrestoDatabaseMetaData.java +++ b/presto-jdbc/src/main/java/com/facebook/presto/jdbc/PrestoDatabaseMetaData.java @@ -970,7 +970,7 @@ public ResultSet getColumns(String catalog, String schemaPattern, String tableNa " CHAR_OCTET_LENGTH, ORDINAL_POSITION, IS_NULLABLE,\n" + " SCOPE_CATALOG, SCOPE_SCHEMA, SCOPE_TABLE,\n" + " SOURCE_DATA_TYPE, IS_AUTOINCREMENT, IS_GENERATEDCOLUMN\n" + - "FROM system.jdbc.columns\n"); + "FROM system.jdbc.columns"); List filters = new ArrayList<>(); emptyStringEqualsFilter(filters, "TABLE_CAT", catalog); diff --git a/presto-jdbc/src/main/java/com/facebook/presto/jdbc/PrestoDriver.java b/presto-jdbc/src/main/java/com/facebook/presto/jdbc/PrestoDriver.java index bfbfe35ac628..9ee216947c5a 100644 --- a/presto-jdbc/src/main/java/com/facebook/presto/jdbc/PrestoDriver.java +++ b/presto-jdbc/src/main/java/com/facebook/presto/jdbc/PrestoDriver.java @@ -14,6 +14,7 @@ package com.facebook.presto.jdbc; import com.google.common.base.Throwables; +import okhttp3.OkHttpClient; import java.io.Closeable; import java.sql.Connection; @@ -27,10 +28,9 @@ import java.util.regex.Matcher; import java.util.regex.Pattern; -import static com.google.common.base.Strings.isNullOrEmpty; +import static com.facebook.presto.client.OkHttpUtil.userAgent; import static com.google.common.base.Strings.nullToEmpty; import static java.lang.Integer.parseInt; -import static java.lang.String.format; public class PrestoDriver implements Driver, Closeable @@ -40,13 +40,11 @@ public class PrestoDriver static final int DRIVER_VERSION_MAJOR; static final int DRIVER_VERSION_MINOR; - private static final DriverPropertyInfo[] DRIVER_PROPERTY_INFOS = {}; - private static final String DRIVER_URL_START = "jdbc:presto:"; - private static final String USER_PROPERTY = "user"; - - private final QueryExecutor queryExecutor; + private final OkHttpClient httpClient = new OkHttpClient().newBuilder() + .addInterceptor(userAgent(DRIVER_NAME + "/" + DRIVER_VERSION)) + .build(); static { String version = nullToEmpty(PrestoDriver.class.getPackage().getImplementationVersion()); @@ -70,15 +68,11 @@ public class PrestoDriver } } - public PrestoDriver() - { - this.queryExecutor = QueryExecutor.create(DRIVER_NAME + "/" + DRIVER_VERSION); - } - @Override public void close() { - queryExecutor.close(); + httpClient.dispatcher().executorService().shutdown(); + httpClient.connectionPool().evictAll(); } @Override @@ -89,12 +83,13 @@ public Connection connect(String url, Properties info) return null; } - String user = info.getProperty(USER_PROPERTY); - if (isNullOrEmpty(user)) { - throw new SQLException(format("Username property (%s) must be set", USER_PROPERTY)); - } + PrestoDriverUri uri = new PrestoDriverUri(url, info); - return new PrestoConnection(new PrestoDriverUri(url), user, queryExecutor); + OkHttpClient.Builder builder = httpClient.newBuilder(); + uri.setupClient(builder); + QueryExecutor executor = new QueryExecutor(builder.build()); + + return new PrestoConnection(uri, executor); } @Override @@ -108,7 +103,11 @@ public boolean acceptsURL(String url) public DriverPropertyInfo[] getPropertyInfo(String url, Properties info) throws SQLException { - return DRIVER_PROPERTY_INFOS; + Properties properties = new PrestoDriverUri(url, info).getProperties(); + + return ConnectionProperties.allProperties().stream() + .map(property -> property.getDriverPropertyInfo(properties)) + .toArray(DriverPropertyInfo[]::new); } @Override diff --git a/presto-jdbc/src/main/java/com/facebook/presto/jdbc/PrestoDriverUri.java b/presto-jdbc/src/main/java/com/facebook/presto/jdbc/PrestoDriverUri.java index 35908043eff7..d366d41e4aa8 100644 --- a/presto-jdbc/src/main/java/com/facebook/presto/jdbc/PrestoDriverUri.java +++ b/presto-jdbc/src/main/java/com/facebook/presto/jdbc/PrestoDriverUri.java @@ -13,18 +13,44 @@ */ package com.facebook.presto.jdbc; +import com.facebook.presto.client.ClientException; import com.google.common.base.Splitter; +import com.google.common.collect.Maps; import com.google.common.net.HostAndPort; +import okhttp3.OkHttpClient; +import java.io.File; import java.net.URI; import java.net.URISyntaxException; import java.sql.SQLException; import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.Map.Entry; +import java.util.Optional; +import java.util.Properties; +import static com.facebook.presto.client.KerberosUtil.defaultCredentialCachePath; +import static com.facebook.presto.client.OkHttpUtil.basicAuth; +import static com.facebook.presto.client.OkHttpUtil.setupHttpProxy; +import static com.facebook.presto.client.OkHttpUtil.setupKerberos; +import static com.facebook.presto.client.OkHttpUtil.setupSocksProxy; +import static com.facebook.presto.client.OkHttpUtil.setupSsl; +import static com.facebook.presto.jdbc.ConnectionProperties.HTTP_PROXY; +import static com.facebook.presto.jdbc.ConnectionProperties.KERBEROS_CONFIG_PATH; +import static com.facebook.presto.jdbc.ConnectionProperties.KERBEROS_CREDENTIAL_CACHE_PATH; +import static com.facebook.presto.jdbc.ConnectionProperties.KERBEROS_KEYTAB_PATH; +import static com.facebook.presto.jdbc.ConnectionProperties.KERBEROS_PRINCIPAL; +import static com.facebook.presto.jdbc.ConnectionProperties.KERBEROS_REMOTE_SERICE_NAME; +import static com.facebook.presto.jdbc.ConnectionProperties.KERBEROS_USE_CANONICAL_HOSTNAME; +import static com.facebook.presto.jdbc.ConnectionProperties.PASSWORD; +import static com.facebook.presto.jdbc.ConnectionProperties.SOCKS_PROXY; +import static com.facebook.presto.jdbc.ConnectionProperties.SSL; +import static com.facebook.presto.jdbc.ConnectionProperties.SSL_TRUST_STORE_PASSWORD; +import static com.facebook.presto.jdbc.ConnectionProperties.SSL_TRUST_STORE_PATH; +import static com.facebook.presto.jdbc.ConnectionProperties.USER; import static com.google.common.base.Strings.isNullOrEmpty; -import static io.airlift.http.client.HttpUriBuilder.uriBuilder; +import static java.lang.String.format; import static java.util.Objects.requireNonNull; /** @@ -40,25 +66,29 @@ final class PrestoDriverUri private final HostAndPort address; private final URI uri; + private final Properties properties; + private String catalog; private String schema; private final boolean useSecureConnection; - public PrestoDriverUri(String url) + public PrestoDriverUri(String url, Properties driverProperties) throws SQLException { - this(parseDriverUrl(url)); + this(parseDriverUrl(url), driverProperties); } - private PrestoDriverUri(URI uri) + private PrestoDriverUri(URI uri, Properties driverProperties) throws SQLException { this.uri = requireNonNull(uri, "uri is null"); - this.address = HostAndPort.fromParts(uri.getHost(), uri.getPort()); + address = HostAndPort.fromParts(uri.getHost(), uri.getPort()); + properties = mergeConnectionProperties(uri, driverProperties); + + validateConnectionProperties(properties); - Map params = parseParameters(uri.getQuery()); - useSecureConnection = Boolean.parseBoolean(params.get("secure")); + useSecureConnection = SSL.getRequiredValue(properties); initCatalogAndSchema(); } @@ -83,7 +113,64 @@ public URI getHttpUri() return buildHttpUri(); } + public String getUser() + throws SQLException + { + return USER.getRequiredValue(properties); + } + + public Properties getProperties() + { + return properties; + } + + public void setupClient(OkHttpClient.Builder builder) + throws SQLException + { + try { + setupSocksProxy(builder, SOCKS_PROXY.getValue(properties)); + setupHttpProxy(builder, HTTP_PROXY.getValue(properties)); + + // TODO: fix Tempto to allow empty passwords + String password = PASSWORD.getValue(properties).orElse(""); + if (!password.isEmpty() && !password.equals("***empty***")) { + if (!useSecureConnection) { + throw new SQLException("Authentication using username/password requires SSL to be enabled"); + } + builder.addInterceptor(basicAuth(getUser(), password)); + } + + if (useSecureConnection) { + Optional trustStorePath = SSL_TRUST_STORE_PATH.getValue(properties); + Optional trustStorePassword = SSL_TRUST_STORE_PASSWORD.getValue(properties); + setupSsl(builder, Optional.empty(), Optional.empty(), trustStorePath, trustStorePassword); + } + + if (KERBEROS_REMOTE_SERICE_NAME.getValue(properties).isPresent()) { + if (!useSecureConnection) { + throw new SQLException("Authentication using Kerberos requires SSL to be enabled"); + } + setupKerberos( + builder, + KERBEROS_REMOTE_SERICE_NAME.getRequiredValue(properties), + KERBEROS_USE_CANONICAL_HOSTNAME.getRequiredValue(properties), + KERBEROS_PRINCIPAL.getValue(properties), + KERBEROS_CONFIG_PATH.getValue(properties), + KERBEROS_KEYTAB_PATH.getValue(properties), + Optional.ofNullable(KERBEROS_CREDENTIAL_CACHE_PATH.getValue(properties) + .orElseGet(() -> defaultCredentialCachePath().map(File::new).orElse(null)))); + } + } + catch (ClientException e) { + throw new SQLException(e.getMessage(), e); + } + catch (RuntimeException e) { + throw new SQLException("Error setting up connection", e); + } + } + private static Map parseParameters(String query) + throws SQLException { Map result = new HashMap<>(); @@ -91,7 +178,9 @@ private static Map parseParameters(String query) Iterable queryArgs = QUERY_SPLITTER.split(query); for (String queryArg : queryArgs) { List parts = ARG_SPLITTER.splitToList(queryArg); - result.put(parts.get(0), parts.get(1)); + if (result.put(parts.get(0), parts.get(1)) != null) { + throw new SQLException(format("Connection property '%s' is in URL multiple times", parts.get(0))); + } } } @@ -122,12 +211,13 @@ private static URI parseDriverUrl(String url) private URI buildHttpUri() { - String scheme = (address.getPort() == 443 || useSecureConnection) ? "https" : "http"; - - return uriBuilder() - .scheme(scheme) - .host(address.getHostText()).port(address.getPort()) - .build(); + String scheme = useSecureConnection ? "https" : "http"; + try { + return new URI(scheme, null, address.getHost(), address.getPort(), null, null, null); + } + catch (URISyntaxException e) { + throw new RuntimeException(e); + } } private void initCatalogAndSchema() @@ -166,4 +256,45 @@ private void initCatalogAndSchema() schema = parts.get(1); } } + + private static Properties mergeConnectionProperties(URI uri, Properties driverProperties) + throws SQLException + { + Map defaults = ConnectionProperties.getDefaults(); + Map urlProperties = parseParameters(uri.getQuery()); + Map suppliedProperties = Maps.fromProperties(driverProperties); + + for (String key : urlProperties.keySet()) { + if (suppliedProperties.containsKey(key)) { + throw new SQLException(format("Connection property '%s' is both in the URL and an argument", key)); + } + } + + Properties result = new Properties(); + setProperties(result, defaults); + setProperties(result, urlProperties); + setProperties(result, suppliedProperties); + return result; + } + + private static void setProperties(Properties properties, Map values) + { + for (Entry entry : values.entrySet()) { + properties.setProperty(entry.getKey(), entry.getValue()); + } + } + + private static void validateConnectionProperties(Properties connectionProperties) + throws SQLException + { + for (String propertyName : connectionProperties.stringPropertyNames()) { + if (ConnectionProperties.forKey(propertyName) == null) { + throw new SQLException(format("Unrecognized connection property '%s'", propertyName)); + } + } + + for (ConnectionProperty property : ConnectionProperties.allProperties()) { + property.validate(connectionProperties); + } + } } diff --git a/presto-jdbc/src/main/java/com/facebook/presto/jdbc/PrestoStatement.java b/presto-jdbc/src/main/java/com/facebook/presto/jdbc/PrestoStatement.java index 28fb75805998..4fd53921883a 100644 --- a/presto-jdbc/src/main/java/com/facebook/presto/jdbc/PrestoStatement.java +++ b/presto-jdbc/src/main/java/com/facebook/presto/jdbc/PrestoStatement.java @@ -13,6 +13,7 @@ */ package com.facebook.presto.jdbc; +import com.facebook.presto.client.ClientException; import com.facebook.presto.client.StatementClient; import com.google.common.collect.ImmutableMap; import com.google.common.primitives.Ints; @@ -170,7 +171,12 @@ public void setQueryTimeout(int seconds) public void cancel() throws SQLException { - throw new SQLFeatureNotSupportedException("cancel"); + // TODO: handle non-query statements + checkOpen(); + ResultSet resultSet = currentResult.get(); + if (resultSet != null) { + resultSet.close(); + } } @Override @@ -239,6 +245,9 @@ public boolean execute(String sql) return false; } + catch (ClientException e) { + throw new SQLException(e.getMessage(), e); + } catch (RuntimeException e) { throw new SQLException("Error executing query", e); } diff --git a/presto-jdbc/src/main/java/com/facebook/presto/jdbc/QueryExecutor.java b/presto-jdbc/src/main/java/com/facebook/presto/jdbc/QueryExecutor.java index e1964b1bde68..e9a587df3824 100644 --- a/presto-jdbc/src/main/java/com/facebook/presto/jdbc/QueryExecutor.java +++ b/presto-jdbc/src/main/java/com/facebook/presto/jdbc/QueryExecutor.java @@ -13,103 +13,52 @@ */ package com.facebook.presto.jdbc; +import com.facebook.presto.client.ClientException; import com.facebook.presto.client.ClientSession; -import com.facebook.presto.client.QueryResults; +import com.facebook.presto.client.JsonResponse; import com.facebook.presto.client.ServerInfo; import com.facebook.presto.client.StatementClient; -import com.google.common.collect.ImmutableSet; -import com.google.common.net.HostAndPort; -import io.airlift.http.client.HttpClient; -import io.airlift.http.client.HttpClientConfig; -import io.airlift.http.client.Request; -import io.airlift.http.client.jetty.JettyHttpClient; -import io.airlift.http.client.jetty.JettyIoPool; -import io.airlift.http.client.jetty.JettyIoPoolConfig; import io.airlift.json.JsonCodec; -import io.airlift.units.Duration; +import okhttp3.HttpUrl; +import okhttp3.OkHttpClient; +import okhttp3.Request; -import javax.annotation.Nullable; - -import java.io.Closeable; -import java.net.InetSocketAddress; -import java.net.Proxy; -import java.net.ProxySelector; import java.net.URI; -import java.util.concurrent.TimeUnit; -import static io.airlift.http.client.HttpUriBuilder.uriBuilderFrom; -import static io.airlift.http.client.JsonResponseHandler.createJsonResponseHandler; -import static io.airlift.http.client.Request.Builder.prepareGet; import static io.airlift.json.JsonCodec.jsonCodec; +import static java.lang.String.format; import static java.util.Objects.requireNonNull; class QueryExecutor - implements Closeable { - private final JsonCodec queryInfoCodec; - private final JsonCodec serverInfoCodec; - private final HttpClient httpClient; + private static final JsonCodec SERVER_INFO_CODEC = jsonCodec(ServerInfo.class); + + private final OkHttpClient httpClient; - private QueryExecutor(JsonCodec queryResultsCodec, JsonCodec serverInfoCodec, HttpClient httpClient) + public QueryExecutor(OkHttpClient httpClient) { - this.queryInfoCodec = requireNonNull(queryResultsCodec, "queryResultsCodec is null"); - this.serverInfoCodec = requireNonNull(serverInfoCodec, "serverInfoCodec is null"); this.httpClient = requireNonNull(httpClient, "httpClient is null"); } public StatementClient startQuery(ClientSession session, String query) { - return new StatementClient(httpClient, queryInfoCodec, session, query); - } - - @Override - public void close() - { - httpClient.close(); + return new StatementClient(httpClient, session, query); } public ServerInfo getServerInfo(URI server) { - URI uri = uriBuilderFrom(server).replacePath("/v1/info").build(); - Request request = prepareGet().setUri(uri).build(); - return httpClient.execute(request, createJsonResponseHandler(serverInfoCodec)); - } - - // TODO: replace this with a phantom reference - @SuppressWarnings("FinalizeDeclaration") - @Override - protected void finalize() - { - close(); - } - - static QueryExecutor create(String userAgent) - { - return create(new JettyHttpClient( - new HttpClientConfig() - .setConnectTimeout(new Duration(10, TimeUnit.SECONDS)) - .setSocksProxy(getSystemSocksProxy()), - new JettyIoPool("presto-jdbc", new JettyIoPoolConfig()), - ImmutableSet.of(new UserAgentRequestFilter(userAgent)))); - } + HttpUrl url = HttpUrl.get(server); + if (url == null) { + throw new ClientException("Invalid server URL: " + server); + } + url = url.newBuilder().encodedPath("/v1/info").build(); - static QueryExecutor create(HttpClient httpClient) - { - return new QueryExecutor(jsonCodec(QueryResults.class), jsonCodec(ServerInfo.class), httpClient); - } + Request request = new Request.Builder().url(url).build(); - @Nullable - private static HostAndPort getSystemSocksProxy() - { - URI uri = URI.create("socket://0.0.0.0:80"); - for (Proxy proxy : ProxySelector.getDefault().select(uri)) { - if (proxy.type() == Proxy.Type.SOCKS) { - if (proxy.address() instanceof InetSocketAddress) { - InetSocketAddress address = (InetSocketAddress) proxy.address(); - return HostAndPort.fromParts(address.getHostString(), address.getPort()); - } - } + JsonResponse response = JsonResponse.execute(SERVER_INFO_CODEC, httpClient, request); + if (!response.hasValue()) { + throw new RuntimeException(format("Request to %s failed: %s [Error: %s]", server, response, response.getResponseBody())); } - return null; + return response.getValue(); } } diff --git a/presto-jdbc/src/main/java/com/facebook/presto/jdbc/UserAgentRequestFilter.java b/presto-jdbc/src/main/java/com/facebook/presto/jdbc/UserAgentRequestFilter.java deleted file mode 100644 index 7b0be8c475c1..000000000000 --- a/presto-jdbc/src/main/java/com/facebook/presto/jdbc/UserAgentRequestFilter.java +++ /dev/null @@ -1,40 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package com.facebook.presto.jdbc; - -import com.google.common.net.HttpHeaders; -import io.airlift.http.client.HttpRequestFilter; -import io.airlift.http.client.Request; - -import static io.airlift.http.client.Request.Builder.fromRequest; -import static java.util.Objects.requireNonNull; - -class UserAgentRequestFilter - implements HttpRequestFilter -{ - private final String userAgent; - - public UserAgentRequestFilter(String userAgent) - { - this.userAgent = requireNonNull(userAgent, "userAgent is null"); - } - - @Override - public Request filterRequest(Request request) - { - return fromRequest(request) - .addHeader(HttpHeaders.USER_AGENT, userAgent) - .build(); - } -} diff --git a/presto-jdbc/src/test/java/com/facebook/presto/jdbc/TestPrestoDriver.java b/presto-jdbc/src/test/java/com/facebook/presto/jdbc/TestPrestoDriver.java index e12add0eb779..20c10a57c7d6 100644 --- a/presto-jdbc/src/test/java/com/facebook/presto/jdbc/TestPrestoDriver.java +++ b/presto-jdbc/src/test/java/com/facebook/presto/jdbc/TestPrestoDriver.java @@ -16,6 +16,7 @@ import com.facebook.presto.execution.QueryState; import com.facebook.presto.plugin.blackhole.BlackHolePlugin; import com.facebook.presto.server.testing.TestingPrestoServer; +import com.facebook.presto.spi.type.ArrayType; import com.facebook.presto.spi.type.BigintType; import com.facebook.presto.spi.type.BooleanType; import com.facebook.presto.spi.type.DateType; @@ -33,7 +34,6 @@ import com.facebook.presto.spi.type.VarbinaryType; import com.facebook.presto.tpch.TpchMetadata; import com.facebook.presto.tpch.TpchPlugin; -import com.facebook.presto.type.ArrayType; import com.facebook.presto.type.ColorType; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; @@ -135,6 +135,14 @@ private void setupTestTables() try (Connection connection = createConnection("blackhole", "blackhole"); Statement statement = connection.createStatement()) { assertEquals(statement.executeUpdate("CREATE TABLE test_table (x bigint)"), 0); + + assertEquals(statement.executeUpdate("CREATE TABLE slow_test_table (x bigint) " + + "WITH (" + + " split_count = 1, " + + " pages_per_split = 1, " + + " rows_per_page = 1, " + + " page_processing_delay = '1m'" + + ")"), 0); } } @@ -1359,30 +1367,19 @@ public void testBadQuery() } } - @Test(expectedExceptions = SQLException.class, expectedExceptionsMessageRegExp = "Username property \\(user\\) must be set") + @Test(expectedExceptions = SQLException.class, expectedExceptionsMessageRegExp = "Connection property 'user' is required") public void testUserIsRequired() throws Exception { - try (Connection ignored = DriverManager.getConnection("jdbc:presto://test.invalid/")) { + try (Connection ignored = DriverManager.getConnection(format("jdbc:presto://%s", server.getAddress()))) { fail("expected exception"); } } @Test(timeOut = 10000) - public void testQueryCancellation() + public void testQueryCancelByInterrupt() throws Exception { - try (Connection connection = createConnection("blackhole", "blackhole"); - Statement statement = connection.createStatement()) { - statement.executeUpdate("CREATE TABLE test_cancellation (key BIGINT) " + - "WITH (" + - " split_count = 1, " + - " pages_per_split = 1, " + - " rows_per_page = 1, " + - " page_processing_delay = '1m'" + - ")"); - } - CountDownLatch queryStarted = new CountDownLatch(1); CountDownLatch queryFinished = new CountDownLatch(1); AtomicReference queryId = new AtomicReference<>(); @@ -1391,7 +1388,7 @@ public void testQueryCancellation() Future queryFuture = executorService.submit(() -> { try (Connection connection = createConnection("blackhole", "default"); Statement statement = connection.createStatement(); - ResultSet resultSet = statement.executeQuery("SELECT * FROM test_cancellation")) { + ResultSet resultSet = statement.executeQuery("SELECT * FROM slow_test_table")) { queryId.set(resultSet.unwrap(PrestoResultSet.class).getQueryId()); queryStarted.countDown(); try { @@ -1419,10 +1416,46 @@ public void testQueryCancellation() assertTrue(queryFinished.await(10, SECONDS)); assertNotNull(queryFailure.get()); assertEquals(getQueryState(queryId.get()), FAILED); + } - try (Connection connection = createConnection("blackhole", "blackhole"); + @Test(timeOut = 10000) + public void testQueryCancelExplicit() + throws Exception + { + CountDownLatch queryStarted = new CountDownLatch(1); + CountDownLatch queryFinished = new CountDownLatch(1); + AtomicReference queryId = new AtomicReference<>(); + AtomicReference queryFailure = new AtomicReference<>(); + + try (Connection connection = createConnection("blackhole", "default"); Statement statement = connection.createStatement()) { - statement.executeUpdate("DROP TABLE test_cancellation"); + // execute the slow query on another thread + executorService.execute(() -> { + try (ResultSet resultSet = statement.executeQuery("SELECT * FROM slow_test_table")) { + queryId.set(resultSet.unwrap(PrestoResultSet.class).getQueryId()); + queryStarted.countDown(); + resultSet.next(); + } + catch (SQLException t) { + queryFailure.set(t); + } + finally { + queryFinished.countDown(); + } + }); + + // start query and make sure it is not finished + queryStarted.await(10, SECONDS); + assertNotNull(queryId.get()); + assertFalse(getQueryState(queryId.get()).isDone()); + + // cancel the query from this test thread + statement.cancel(); + + // make sure the query was aborted + queryFinished.await(10, SECONDS); + assertNotNull(queryFailure.get()); + assertEquals(getQueryState(queryId.get()), FAILED); } } diff --git a/presto-jdbc/src/test/java/com/facebook/presto/jdbc/TestPrestoDriverUri.java b/presto-jdbc/src/test/java/com/facebook/presto/jdbc/TestPrestoDriverUri.java index bb84f9132dc6..20f2d62fb300 100644 --- a/presto-jdbc/src/test/java/com/facebook/presto/jdbc/TestPrestoDriverUri.java +++ b/presto-jdbc/src/test/java/com/facebook/presto/jdbc/TestPrestoDriverUri.java @@ -17,87 +17,189 @@ import java.net.URI; import java.sql.SQLException; +import java.util.Properties; +import static com.facebook.presto.jdbc.ConnectionProperties.HTTP_PROXY; +import static com.facebook.presto.jdbc.ConnectionProperties.SOCKS_PROXY; +import static com.facebook.presto.jdbc.ConnectionProperties.SSL_TRUST_STORE_PASSWORD; +import static com.facebook.presto.jdbc.ConnectionProperties.SSL_TRUST_STORE_PATH; import static java.lang.String.format; import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertNotNull; +import static org.testng.Assert.assertNull; +import static org.testng.Assert.fail; public class TestPrestoDriverUri { - private static final String SERVER = "127.0.0.1:60429"; + @Test + public void testInvalidUrls() + { + // missing port + assertInvalid("jdbc:presto://localhost/", "No port number specified:"); + + // extra path segments + assertInvalid("jdbc:presto://localhost:8080/hive/default/abc", "Invalid path segments in URL:"); + + // extra slash + assertInvalid("jdbc:presto://localhost:8080//", "Catalog name is empty:"); + + // has schema but is missing catalog + assertInvalid("jdbc:presto://localhost:8080//default", "Catalog name is empty:"); + + // has catalog but schema is missing + assertInvalid("jdbc:presto://localhost:8080/a//", "Schema name is empty:"); + + // unrecognized property + assertInvalid("jdbc:presto://localhost:8080/hive/default?ShoeSize=13", "Unrecognized connection property 'ShoeSize'"); + + // empty property + assertInvalid("jdbc:presto://localhost:8080/hive/default?password=", "Connection property 'password' value is empty"); + + // property in url multiple times + assertInvalid("presto://localhost:8080/blackhole?password=a&password=b", "Connection property 'password' is in URL multiple times"); + + // property in both url and arguments + assertInvalid("presto://localhost:8080/blackhole?user=test123", "Connection property 'user' is both in the URL and an argument"); + + // setting both socks and http proxy + assertInvalid("presto://localhost:8080?socksProxy=localhost:1080&httpProxy=localhost:8888", "Connection property 'socksProxy' is not allowed"); + assertInvalid("presto://localhost:8080?httpProxy=localhost:8888&socksProxy=localhost:1080", "Connection property 'socksProxy' is not allowed"); + + // invalid ssl flag + assertInvalid("jdbc:presto://localhost:8080?SSL=0", "Connection property 'SSL' value is invalid: 0"); + assertInvalid("jdbc:presto://localhost:8080?SSL=1", "Connection property 'SSL' value is invalid: 1"); + assertInvalid("jdbc:presto://localhost:8080?SSL=2", "Connection property 'SSL' value is invalid: 2"); + assertInvalid("jdbc:presto://localhost:8080?SSL=abc", "Connection property 'SSL' value is invalid: abc"); - @Test(expectedExceptions = SQLException.class, expectedExceptionsMessageRegExp = "Invalid path segments in URL: .*") - public void testBadUrlExtraPathSegments() + // ssl trust store password without path + assertInvalid("jdbc:presto://localhost:8080?SSL=true&SSLTrustStorePassword=password", "Connection property 'SSLTrustStorePassword' is not allowed"); + + // trust store path without ssl + assertInvalid("jdbc:presto://localhost:8080?SSLTrustStorePath=truststore.jks", "Connection property 'SSLTrustStorePath' is not allowed"); + + // trust store password without ssl + assertInvalid("jdbc:presto://localhost:8080?SSLTrustStorePassword=password", "Connection property 'SSLTrustStorePassword' is not allowed"); + + // kerberos config without service name + assertInvalid("jdbc:presto://localhost:8080?KerberosCredentialCachePath=/test", "Connection property 'KerberosCredentialCachePath' is not allowed"); + } + + @Test(expectedExceptions = SQLException.class, expectedExceptionsMessageRegExp = "Connection property 'user' is required") + public void testRequireUser() throws Exception { - String url = format("jdbc:presto://%s/hive/default/bad_string", SERVER); - new PrestoDriverUri(url); + new PrestoDriverUri("jdbc:presto://localhost:8080", new Properties()); } - @Test(expectedExceptions = SQLException.class, expectedExceptionsMessageRegExp = "Catalog name is empty: .*") - public void testBadUrlMissingCatalog() - throws Exception + @Test + void testUriWithSocksProxy() + throws SQLException { - String url = format("jdbc:presto://%s//default", SERVER); - new PrestoDriverUri(url); + PrestoDriverUri parameters = createDriverUri("presto://localhost:8080?socksProxy=localhost:1234"); + assertUriPortScheme(parameters, 8080, "http"); + + Properties properties = parameters.getProperties(); + assertEquals(properties.getProperty(SOCKS_PROXY.getKey()), "localhost:1234"); } - @Test(expectedExceptions = SQLException.class, expectedExceptionsMessageRegExp = "Catalog name is empty: .*") - public void testBadUrlEndsInSlashes() - throws Exception + @Test + void testUriWithHttpProxy() + throws SQLException { - String url = format("jdbc:presto://%s//", SERVER); - new PrestoDriverUri(url); + PrestoDriverUri parameters = createDriverUri("presto://localhost:8080?httpProxy=localhost:5678"); + assertUriPortScheme(parameters, 8080, "http"); + + Properties properties = parameters.getProperties(); + assertEquals(properties.getProperty(HTTP_PROXY.getKey()), "localhost:5678"); } - @Test(expectedExceptions = SQLException.class, expectedExceptionsMessageRegExp = "Schema name is empty: .*") - public void testBadUrlMissingSchema() - throws Exception + @Test + public void testUriWithoutSsl() + throws SQLException { - String url = format("jdbc:presto://%s/a//", SERVER); - new PrestoDriverUri(url); + PrestoDriverUri parameters = createDriverUri("presto://localhost:8080/blackhole"); + assertUriPortScheme(parameters, 8080, "http"); } @Test - public void testUrlWithSsl() + public void testUriWithSslPortDoesNotUseSsl() throws SQLException { - PrestoDriverUri parameters = new PrestoDriverUri("presto://some-ssl-server:443/blackhole"); + PrestoDriverUri parameters = createDriverUri("presto://somelocalhost:443/blackhole"); + assertUriPortScheme(parameters, 443, "http"); + } - URI uri = parameters.getHttpUri(); - assertEquals(uri.getPort(), 443); - assertEquals(uri.getScheme(), "https"); + @Test + public void testUriWithSslDisabled() + throws SQLException + { + PrestoDriverUri parameters = createDriverUri("presto://localhost:8080/blackhole?SSL=false"); + assertUriPortScheme(parameters, 8080, "http"); } @Test - public void testUriWithSecureMissing() + public void testUriWithSslEnabled() throws SQLException { - PrestoDriverUri parameters = new PrestoDriverUri("presto://localhost:8080/blackhole"); + PrestoDriverUri parameters = createDriverUri("presto://localhost:8080/blackhole?SSL=true"); + assertUriPortScheme(parameters, 8080, "https"); - URI uri = parameters.getHttpUri(); - assertEquals(uri.getPort(), 8080); - assertEquals(uri.getScheme(), "http"); + Properties properties = parameters.getProperties(); + assertNull(properties.getProperty(SSL_TRUST_STORE_PATH.getKey())); + assertNull(properties.getProperty(SSL_TRUST_STORE_PASSWORD.getKey())); } @Test - public void testUriWithSecureTrue() + public void testUriWithSslEnabledPathOnly() throws SQLException { - PrestoDriverUri parameters = new PrestoDriverUri("presto://localhost:8080/blackhole?secure=true"); + PrestoDriverUri parameters = createDriverUri("presto://localhost:8080/blackhole?SSL=true&SSLTrustStorePath=truststore.jks"); + assertUriPortScheme(parameters, 8080, "https"); - URI uri = parameters.getHttpUri(); - assertEquals(uri.getPort(), 8080); - assertEquals(uri.getScheme(), "https"); + Properties properties = parameters.getProperties(); + assertEquals(properties.getProperty(SSL_TRUST_STORE_PATH.getKey()), "truststore.jks"); + assertNull(properties.getProperty(SSL_TRUST_STORE_PASSWORD.getKey())); } @Test - public void testUriWithSecureFalse() + public void testUriWithSslEnabledPassword() throws SQLException { - PrestoDriverUri parameters = new PrestoDriverUri("presto://localhost:8080/blackhole?secure=false"); + PrestoDriverUri parameters = createDriverUri("presto://localhost:8080/blackhole?SSL=true&SSLTrustStorePath=truststore.jks&SSLTrustStorePassword=password"); + assertUriPortScheme(parameters, 8080, "https"); + Properties properties = parameters.getProperties(); + assertEquals(properties.getProperty(SSL_TRUST_STORE_PATH.getKey()), "truststore.jks"); + assertEquals(properties.getProperty(SSL_TRUST_STORE_PASSWORD.getKey()), "password"); + } + + private static void assertUriPortScheme(PrestoDriverUri parameters, int port, String scheme) + { URI uri = parameters.getHttpUri(); - assertEquals(uri.getPort(), 8080); - assertEquals(uri.getScheme(), "http"); + assertEquals(uri.getPort(), port); + assertEquals(uri.getScheme(), scheme); + } + + private static PrestoDriverUri createDriverUri(String url) + throws SQLException + { + Properties properties = new Properties(); + properties.setProperty("user", "test"); + + return new PrestoDriverUri(url, properties); + } + + private static void assertInvalid(String url, String prefix) + { + try { + createDriverUri(url); + fail("expected exception"); + } + catch (SQLException e) { + assertNotNull(e.getMessage()); + if (!e.getMessage().startsWith(prefix)) { + fail(format("expected:<%s> to start with <%s>", e.getMessage(), prefix)); + } + } } } diff --git a/presto-jdbc/src/test/java/com/facebook/presto/jdbc/TestProgressMonitor.java b/presto-jdbc/src/test/java/com/facebook/presto/jdbc/TestProgressMonitor.java index 66c21737a46f..f937aafd1eef 100644 --- a/presto-jdbc/src/test/java/com/facebook/presto/jdbc/TestProgressMonitor.java +++ b/presto-jdbc/src/test/java/com/facebook/presto/jdbc/TestProgressMonitor.java @@ -18,32 +18,28 @@ import com.facebook.presto.client.QueryResults; import com.facebook.presto.client.StatementStats; import com.google.common.collect.ImmutableList; -import com.google.common.collect.ImmutableListMultimap; -import com.google.common.net.HttpHeaders; -import io.airlift.http.client.HttpClient; -import io.airlift.http.client.HttpStatus; -import io.airlift.http.client.Request; -import io.airlift.http.client.Response; -import io.airlift.http.client.testing.TestingHttpClient; -import io.airlift.http.client.testing.TestingResponse; import io.airlift.json.JsonCodec; +import okhttp3.mockwebserver.MockResponse; +import okhttp3.mockwebserver.MockWebServer; +import org.testng.annotations.AfterMethod; +import org.testng.annotations.BeforeMethod; import org.testng.annotations.Test; -import java.net.URI; +import java.io.IOException; import java.sql.Connection; +import java.sql.DriverManager; import java.sql.ResultSet; import java.sql.ResultSetMetaData; import java.sql.SQLException; import java.sql.Statement; -import java.util.Iterator; import java.util.List; import java.util.function.Consumer; import static com.google.common.base.Preconditions.checkState; +import static com.google.common.net.HttpHeaders.CONTENT_TYPE; import static io.airlift.json.JsonCodec.jsonCodec; import static io.airlift.testing.Assertions.assertGreaterThanOrEqual; import static java.lang.String.format; -import static java.util.Objects.requireNonNull; import static org.testng.Assert.assertEquals; import static org.testng.Assert.assertFalse; import static org.testng.Assert.assertTrue; @@ -51,28 +47,46 @@ @Test(singleThreaded = true) public class TestProgressMonitor { - private static final String SERVER_ADDRESS = "127.0.0.1:8080"; private static final JsonCodec QUERY_RESULTS_CODEC = jsonCodec(QueryResults.class); - private static final String QUERY_ID = "20160128_214710_00012_rk68b"; - private static final String INFO_URI = "http://" + SERVER_ADDRESS + "/query.html?" + QUERY_ID; - private static final String PARTIAL_CANCEL_URI = "http://" + SERVER_ADDRESS + "/v1/stage/" + QUERY_ID + ".%d"; - private static final String NEXT_URI = "http://" + SERVER_ADDRESS + "/v1/statement/" + QUERY_ID + "/%d"; - private static final List RESPONSE_COLUMNS = ImmutableList.of(new Column("_col0", "bigint", new ClientTypeSignature("bigint", ImmutableList.of()))); - private static final List RESPONSES = ImmutableList.of( - newQueryResults(null, 1, null, null, "QUEUED"), - newQueryResults(1, 2, RESPONSE_COLUMNS, null, "RUNNING"), - newQueryResults(1, 3, RESPONSE_COLUMNS, null, "RUNNING"), - newQueryResults(0, 4, RESPONSE_COLUMNS, ImmutableList.of(ImmutableList.of(253161)), "RUNNING"), - newQueryResults(null, null, RESPONSE_COLUMNS, null, "FINISHED")); - - private static String newQueryResults(Integer partialCancelId, Integer nextUriId, List responseColumns, List> data, String state) + private MockWebServer server; + + @BeforeMethod + public void setup() + throws IOException + { + server = new MockWebServer(); + server.start(); + } + + @AfterMethod + public void teardown() + throws IOException + { + server.close(); + } + + private List createResults() + { + List columns = ImmutableList.of(new Column("_col0", "bigint", new ClientTypeSignature("bigint", ImmutableList.of()))); + return ImmutableList.builder() + .add(newQueryResults(null, 1, null, null, "QUEUED")) + .add(newQueryResults(1, 2, columns, null, "RUNNING")) + .add(newQueryResults(1, 3, columns, null, "RUNNING")) + .add(newQueryResults(0, 4, columns, ImmutableList.of(ImmutableList.of(253161)), "RUNNING")) + .add(newQueryResults(null, null, columns, null, "FINISHED")) + .build(); + } + + private String newQueryResults(Integer partialCancelId, Integer nextUriId, List responseColumns, List> data, String state) { + String queryId = "20160128_214710_00012_rk68b"; + QueryResults queryResults = new QueryResults( - QUERY_ID, - URI.create(INFO_URI), - partialCancelId == null ? null : URI.create(format(PARTIAL_CANCEL_URI, partialCancelId)), - nextUriId == null ? null : URI.create(format(NEXT_URI, nextUriId)), + queryId, + server.url("/query.html?" + queryId).uri(), + partialCancelId == null ? null : server.url(format("/v1/stage/%s.%s", queryId, partialCancelId)).uri(), + nextUriId == null ? null : server.url(format("/v1/statement/%s/%s", queryId, nextUriId)).uri(), responseColumns, data, new StatementStats(state, state.equals("QUEUED"), true, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, null), @@ -87,6 +101,12 @@ private static String newQueryResults(Integer partialCancelId, Integer nextUriId public void test() throws SQLException { + for (String result : createResults()) { + server.enqueue(new MockResponse() + .addHeader(CONTENT_TYPE, "application/json") + .setBody(result)); + } + try (Connection connection = createConnection()) { try (Statement statement = connection.createStatement()) { PrestoStatement prestoStatement = statement.unwrap(PrestoStatement.class); @@ -116,40 +136,15 @@ public void test() private Connection createConnection() throws SQLException { - HttpClient client = new TestingHttpClient(new TestingHttpClientProcessor(RESPONSES)); - QueryExecutor testQueryExecutor = QueryExecutor.create(client); - String uri = format("prestotest://%s", SERVER_ADDRESS); - return new PrestoConnection(new PrestoDriverUri(uri), "test", testQueryExecutor); - } - - private static class TestingHttpClientProcessor - implements TestingHttpClient.Processor - { - private final Iterator responses; - - public TestingHttpClientProcessor(List responses) - { - this.responses = ImmutableList.copyOf(requireNonNull(responses, "responses is null")).iterator(); - } - - @Override - public synchronized Response handle(Request request) - throws Exception - { - checkState(responses.hasNext(), "too many requests (ran out of test responses)"); - Response response = new TestingResponse( - HttpStatus.OK, - ImmutableListMultimap.of(HttpHeaders.CONTENT_TYPE, "application/json"), - responses.next().getBytes()); - return response; - } + String url = format("jdbc:presto://%s", server.url("/").uri().getAuthority()); + return DriverManager.getConnection(url, "test", null); } private static class RecordingProgressMonitor implements Consumer { private final ImmutableList.Builder builder = ImmutableList.builder(); - private boolean finished = false; + private boolean finished; @Override public synchronized void accept(QueryStats queryStats) diff --git a/presto-jdbc/src/test/java/com/facebook/presto/jdbc/TestQueryExecutor.java b/presto-jdbc/src/test/java/com/facebook/presto/jdbc/TestQueryExecutor.java index 1e369a14f05c..6693748e55ad 100644 --- a/presto-jdbc/src/test/java/com/facebook/presto/jdbc/TestQueryExecutor.java +++ b/presto-jdbc/src/test/java/com/facebook/presto/jdbc/TestQueryExecutor.java @@ -14,37 +14,62 @@ package com.facebook.presto.jdbc; import com.facebook.presto.client.ServerInfo; -import com.google.common.collect.ImmutableListMultimap; -import io.airlift.http.client.testing.TestingHttpClient; -import io.airlift.http.client.testing.TestingResponse; import io.airlift.json.JsonCodec; import io.airlift.units.Duration; +import okhttp3.OkHttpClient; +import okhttp3.mockwebserver.MockResponse; +import okhttp3.mockwebserver.MockWebServer; +import org.testng.annotations.AfterMethod; +import org.testng.annotations.BeforeMethod; import org.testng.annotations.Test; -import java.net.URI; +import java.io.IOException; import java.util.Optional; import static com.facebook.presto.client.NodeVersion.UNKNOWN; import static com.google.common.net.HttpHeaders.CONTENT_TYPE; -import static io.airlift.http.client.HttpStatus.OK; import static io.airlift.json.JsonCodec.jsonCodec; import static org.testng.Assert.assertEquals; +@Test(singleThreaded = true) public class TestQueryExecutor { private static final JsonCodec SERVER_INFO_CODEC = jsonCodec(ServerInfo.class); + private MockWebServer server; + + @BeforeMethod + public void setup() + throws IOException + { + server = new MockWebServer(); + server.start(); + } + + @AfterMethod + public void teardown() + throws IOException + { + server.close(); + } + @Test public void testGetServerInfo() throws Exception { - ServerInfo serverInfo = new ServerInfo(UNKNOWN, "test", true, Optional.of(Duration.valueOf("2m"))); + ServerInfo expected = new ServerInfo(UNKNOWN, "test", true, Optional.of(Duration.valueOf("2m"))); + + server.enqueue(new MockResponse() + .addHeader(CONTENT_TYPE, "application/json") + .setBody(SERVER_INFO_CODEC.toJson(expected))); + + QueryExecutor executor = new QueryExecutor(new OkHttpClient()); - QueryExecutor executor = QueryExecutor.create(new TestingHttpClient(input -> new TestingResponse( - OK, - ImmutableListMultimap.of(CONTENT_TYPE, "application/json"), - SERVER_INFO_CODEC.toJsonBytes(serverInfo)))); + ServerInfo actual = executor.getServerInfo(server.url("/v1/info").uri()); + assertEquals(actual.getEnvironment(), "test"); + assertEquals(actual.getUptime(), Optional.of(Duration.valueOf("2m"))); - assertEquals(executor.getServerInfo(new URI("http://example.com")).getUptime().get(), Duration.valueOf("2m")); + assertEquals(server.getRequestCount(), 1); + assertEquals(server.takeRequest().getPath(), "/v1/info"); } } diff --git a/presto-jmx/pom.xml b/presto-jmx/pom.xml index a6bb4234cfe7..94ef3bc69c7a 100644 --- a/presto-jmx/pom.xml +++ b/presto-jmx/pom.xml @@ -4,7 +4,7 @@ com.facebook.presto presto-root - 0.179-tw-0.36 + 0.181-tw-0.37 presto-jmx diff --git a/presto-kafka/pom.xml b/presto-kafka/pom.xml index d59c0ea4bca2..38ad83f0bf10 100644 --- a/presto-kafka/pom.xml +++ b/presto-kafka/pom.xml @@ -5,7 +5,7 @@ com.facebook.presto presto-root - 0.179-tw-0.36 + 0.181-tw-0.37 presto-kafka diff --git a/presto-local-file/pom.xml b/presto-local-file/pom.xml index d4f488fd21e9..483e0b2246da 100644 --- a/presto-local-file/pom.xml +++ b/presto-local-file/pom.xml @@ -5,7 +5,7 @@ com.facebook.presto presto-root - 0.179-tw-0.36 + 0.181-tw-0.37 presto-local-file diff --git a/presto-main/pom.xml b/presto-main/pom.xml index 50d62b627dd5..357dc08934ff 100644 --- a/presto-main/pom.xml +++ b/presto-main/pom.xml @@ -5,7 +5,7 @@ com.facebook.presto presto-root - 0.179-tw-0.36 + 0.181-tw-0.37 presto-main diff --git a/presto-main/src/main/java/com/facebook/presto/connector/system/jdbc/ColumnJdbcTable.java b/presto-main/src/main/java/com/facebook/presto/connector/system/jdbc/ColumnJdbcTable.java index 3786624e651a..a4363f46b26f 100644 --- a/presto-main/src/main/java/com/facebook/presto/connector/system/jdbc/ColumnJdbcTable.java +++ b/presto-main/src/main/java/com/facebook/presto/connector/system/jdbc/ColumnJdbcTable.java @@ -26,11 +26,11 @@ import com.facebook.presto.spi.SchemaTableName; import com.facebook.presto.spi.connector.ConnectorTransactionHandle; import com.facebook.presto.spi.predicate.TupleDomain; +import com.facebook.presto.spi.type.ArrayType; import com.facebook.presto.spi.type.CharType; import com.facebook.presto.spi.type.DecimalType; import com.facebook.presto.spi.type.Type; import com.facebook.presto.spi.type.VarcharType; -import com.facebook.presto.type.ArrayType; import javax.inject.Inject; diff --git a/presto-main/src/main/java/com/facebook/presto/execution/CallTask.java b/presto-main/src/main/java/com/facebook/presto/execution/CallTask.java index b1e6a867c347..4abbdc9f69d5 100644 --- a/presto-main/src/main/java/com/facebook/presto/execution/CallTask.java +++ b/presto-main/src/main/java/com/facebook/presto/execution/CallTask.java @@ -54,7 +54,7 @@ import static com.facebook.presto.sql.analyzer.SemanticErrorCode.MISSING_CATALOG; import static com.facebook.presto.sql.planner.ExpressionInterpreter.evaluateConstantExpression; import static com.facebook.presto.util.Failures.checkCondition; -import static com.google.common.base.Throwables.propagateIfInstanceOf; +import static com.google.common.base.Throwables.throwIfInstanceOf; import static com.google.common.util.concurrent.Futures.immediateFuture; import static java.util.Arrays.asList; @@ -164,7 +164,7 @@ else if (i < procedure.getArguments().size()) { if (t instanceof InterruptedException) { Thread.currentThread().interrupt(); } - propagateIfInstanceOf(t, PrestoException.class); + throwIfInstanceOf(t, PrestoException.class); throw new PrestoException(PROCEDURE_CALL_FAILED, t); } diff --git a/presto-main/src/main/java/com/facebook/presto/execution/CreateTableTask.java b/presto-main/src/main/java/com/facebook/presto/execution/CreateTableTask.java index 733e54ad1484..43b7644d9df2 100644 --- a/presto-main/src/main/java/com/facebook/presto/execution/CreateTableTask.java +++ b/presto-main/src/main/java/com/facebook/presto/execution/CreateTableTask.java @@ -31,6 +31,7 @@ import com.facebook.presto.sql.tree.LikeClause; import com.facebook.presto.sql.tree.TableElement; import com.facebook.presto.transaction.TransactionManager; +import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.util.concurrent.ListenableFuture; @@ -43,6 +44,7 @@ import java.util.Set; import static com.facebook.presto.metadata.MetadataUtil.createQualifiedObjectName; +import static com.facebook.presto.spi.StandardErrorCode.ALREADY_EXISTS; import static com.facebook.presto.spi.StandardErrorCode.GENERIC_INTERNAL_ERROR; import static com.facebook.presto.spi.StandardErrorCode.NOT_FOUND; import static com.facebook.presto.spi.type.TypeSignature.parseTypeSignature; @@ -73,10 +75,15 @@ public String explain(CreateTable statement, List parameters) @Override public ListenableFuture execute(CreateTable statement, TransactionManager transactionManager, Metadata metadata, AccessControl accessControl, QueryStateMachine stateMachine, List parameters) + { + return internalExecute(statement, metadata, accessControl, stateMachine.getSession(), parameters); + } + + @VisibleForTesting + public ListenableFuture internalExecute(CreateTable statement, Metadata metadata, AccessControl accessControl, Session session, List parameters) { checkArgument(!statement.getElements().isEmpty(), "no columns for table"); - Session session = stateMachine.getSession(); QualifiedObjectName tableName = createQualifiedObjectName(session, statement, statement.getName()); Optional tableHandle = metadata.getTableHandle(session, tableName); if (tableHandle.isPresent()) { @@ -155,7 +162,14 @@ else if (element instanceof LikeClause) { ConnectorTableMetadata tableMetadata = new ConnectorTableMetadata(tableName.asSchemaTableName(), ImmutableList.copyOf(columns.values()), finalProperties, statement.getComment()); - metadata.createTable(session, tableName.getCatalogName(), tableMetadata); + try { + metadata.createTable(session, tableName.getCatalogName(), tableMetadata); + } + catch (PrestoException e) { + if (!e.getErrorCode().equals(ALREADY_EXISTS.toErrorCode()) || !statement.isNotExists()) { + throw e; + } + } return immediateFuture(null); } diff --git a/presto-main/src/main/java/com/facebook/presto/execution/DataDefinitionExecution.java b/presto-main/src/main/java/com/facebook/presto/execution/DataDefinitionExecution.java index 02288e3b90c6..50ae0d6c14eb 100644 --- a/presto-main/src/main/java/com/facebook/presto/execution/DataDefinitionExecution.java +++ b/presto-main/src/main/java/com/facebook/presto/execution/DataDefinitionExecution.java @@ -20,6 +20,7 @@ import com.facebook.presto.metadata.MetadataManager; import com.facebook.presto.security.AccessControl; import com.facebook.presto.spi.QueryId; +import com.facebook.presto.spi.resourceGroups.QueryType; import com.facebook.presto.spi.resourceGroups.ResourceGroupId; import com.facebook.presto.sql.planner.Plan; import com.facebook.presto.sql.tree.Expression; @@ -41,6 +42,7 @@ import java.util.concurrent.ExecutorService; import java.util.concurrent.TimeUnit; +import static com.facebook.presto.spi.resourceGroups.QueryType.DATA_DEFINITION; import static com.google.common.base.Preconditions.checkArgument; import static java.util.Objects.requireNonNull; @@ -155,6 +157,12 @@ public void addFinalQueryInfoListener(StateChangeListener stateChange stateMachine.addQueryInfoStateChangeListener(stateChangeListener); } + @Override + public Optional getQueryType() + { + return Optional.of(DATA_DEFINITION); + } + @Override public void fail(Throwable cause) { diff --git a/presto-main/src/main/java/com/facebook/presto/execution/DropColumnTask.java b/presto-main/src/main/java/com/facebook/presto/execution/DropColumnTask.java new file mode 100644 index 000000000000..ebf4608627d4 --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/execution/DropColumnTask.java @@ -0,0 +1,73 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.execution; + +import com.facebook.presto.Session; +import com.facebook.presto.metadata.Metadata; +import com.facebook.presto.metadata.QualifiedObjectName; +import com.facebook.presto.metadata.TableHandle; +import com.facebook.presto.security.AccessControl; +import com.facebook.presto.spi.ColumnHandle; +import com.facebook.presto.sql.analyzer.SemanticException; +import com.facebook.presto.sql.tree.DropColumn; +import com.facebook.presto.sql.tree.Expression; +import com.facebook.presto.transaction.TransactionManager; +import com.google.common.util.concurrent.ListenableFuture; + +import java.util.List; +import java.util.Map; +import java.util.Optional; + +import static com.facebook.presto.metadata.MetadataUtil.createQualifiedObjectName; +import static com.facebook.presto.sql.analyzer.SemanticErrorCode.MISSING_COLUMN; +import static com.facebook.presto.sql.analyzer.SemanticErrorCode.MISSING_TABLE; +import static com.facebook.presto.sql.analyzer.SemanticErrorCode.NOT_SUPPORTED; +import static com.google.common.util.concurrent.Futures.immediateFuture; +import static java.util.Locale.ENGLISH; + +public class DropColumnTask + implements DataDefinitionTask +{ + @Override + public String getName() + { + return "DROP COLUMN"; + } + + @Override + public ListenableFuture execute(DropColumn statement, TransactionManager transactionManager, Metadata metadata, AccessControl accessControl, QueryStateMachine stateMachine, List parameters) + { + Session session = stateMachine.getSession(); + QualifiedObjectName tableName = createQualifiedObjectName(session, statement, statement.getTable()); + Optional tableHandle = metadata.getTableHandle(session, tableName); + + String column = statement.getColumn().toLowerCase(ENGLISH); + + if (!tableHandle.isPresent()) { + throw new SemanticException(MISSING_TABLE, statement, "Table '%s' does not exist", tableName); + } + accessControl.checkCanDropColumn(session.getRequiredTransactionId(), session.getIdentity(), tableName); + + Map columnHandles = metadata.getColumnHandles(session, tableHandle.get()); + if (!columnHandles.containsKey(column)) { + throw new SemanticException(MISSING_COLUMN, statement, "Column '%s' does not exist", column); + } + if (columnHandles.size() == 1) { + throw new SemanticException(NOT_SUPPORTED, statement, "Cannot drop column from a table with only one column"); + } + metadata.dropColumn(session, tableHandle.get(), columnHandles.get(column)); + + return immediateFuture(null); + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/execution/FailedQueryExecution.java b/presto-main/src/main/java/com/facebook/presto/execution/FailedQueryExecution.java index 20ee5867e80d..1c739af50a02 100644 --- a/presto-main/src/main/java/com/facebook/presto/execution/FailedQueryExecution.java +++ b/presto-main/src/main/java/com/facebook/presto/execution/FailedQueryExecution.java @@ -18,6 +18,7 @@ import com.facebook.presto.memory.VersionedMemoryPoolId; import com.facebook.presto.metadata.Metadata; import com.facebook.presto.spi.QueryId; +import com.facebook.presto.spi.resourceGroups.QueryType; import com.facebook.presto.spi.resourceGroups.ResourceGroupId; import com.facebook.presto.sql.planner.Plan; import com.facebook.presto.transaction.TransactionManager; @@ -128,6 +129,12 @@ public void addFinalQueryInfoListener(StateChangeListener stateChange executor.execute(() -> stateChangeListener.stateChanged(queryInfo)); } + @Override + public Optional getQueryType() + { + return Optional.empty(); + } + @Override public void fail(Throwable cause) { diff --git a/presto-main/src/main/java/com/facebook/presto/execution/QueryExecution.java b/presto-main/src/main/java/com/facebook/presto/execution/QueryExecution.java index 098ed4dbecc7..1b5aba35396a 100644 --- a/presto-main/src/main/java/com/facebook/presto/execution/QueryExecution.java +++ b/presto-main/src/main/java/com/facebook/presto/execution/QueryExecution.java @@ -17,6 +17,7 @@ import com.facebook.presto.execution.StateMachine.StateChangeListener; import com.facebook.presto.memory.VersionedMemoryPoolId; import com.facebook.presto.spi.QueryId; +import com.facebook.presto.spi.resourceGroups.QueryType; import com.facebook.presto.spi.resourceGroups.ResourceGroupId; import com.facebook.presto.sql.planner.Plan; import com.facebook.presto.sql.tree.Expression; @@ -74,4 +75,6 @@ interface QueryExecutionFactory { T createQueryExecution(QueryId queryId, String query, Session session, Statement statement, List parameters); } + + Optional getQueryType(); } diff --git a/presto-main/src/main/java/com/facebook/presto/execution/SqlQueryExecution.java b/presto-main/src/main/java/com/facebook/presto/execution/SqlQueryExecution.java index 87067b1fdb60..92fbe996d05e 100644 --- a/presto-main/src/main/java/com/facebook/presto/execution/SqlQueryExecution.java +++ b/presto-main/src/main/java/com/facebook/presto/execution/SqlQueryExecution.java @@ -31,6 +31,7 @@ import com.facebook.presto.security.AccessControl; import com.facebook.presto.spi.PrestoException; import com.facebook.presto.spi.QueryId; +import com.facebook.presto.spi.resourceGroups.QueryType; import com.facebook.presto.spi.resourceGroups.ResourceGroupId; import com.facebook.presto.split.SplitManager; import com.facebook.presto.split.SplitSource; @@ -52,11 +53,26 @@ import com.facebook.presto.sql.planner.StageExecutionPlan; import com.facebook.presto.sql.planner.SubPlan; import com.facebook.presto.sql.planner.optimizations.PlanOptimizer; +import com.facebook.presto.sql.tree.CreateTableAsSelect; +import com.facebook.presto.sql.tree.Delete; +import com.facebook.presto.sql.tree.DescribeInput; +import com.facebook.presto.sql.tree.DescribeOutput; import com.facebook.presto.sql.tree.Explain; import com.facebook.presto.sql.tree.Expression; +import com.facebook.presto.sql.tree.Insert; +import com.facebook.presto.sql.tree.Query; +import com.facebook.presto.sql.tree.ShowCatalogs; +import com.facebook.presto.sql.tree.ShowColumns; +import com.facebook.presto.sql.tree.ShowCreate; +import com.facebook.presto.sql.tree.ShowFunctions; +import com.facebook.presto.sql.tree.ShowGrants; +import com.facebook.presto.sql.tree.ShowPartitions; +import com.facebook.presto.sql.tree.ShowSchemas; +import com.facebook.presto.sql.tree.ShowSession; +import com.facebook.presto.sql.tree.ShowStats; +import com.facebook.presto.sql.tree.ShowTables; import com.facebook.presto.sql.tree.Statement; import com.facebook.presto.transaction.TransactionManager; -import com.google.common.base.Throwables; import com.google.common.collect.ImmutableSet; import io.airlift.concurrent.SetThreadName; import io.airlift.log.Logger; @@ -76,7 +92,13 @@ import static com.facebook.presto.OutputBuffers.BROADCAST_PARTITION_ID; import static com.facebook.presto.OutputBuffers.createInitialEmptyOutputBuffers; import static com.facebook.presto.spi.StandardErrorCode.NOT_SUPPORTED; +import static com.facebook.presto.spi.resourceGroups.QueryType.DELETE; +import static com.facebook.presto.spi.resourceGroups.QueryType.DESCRIBE; +import static com.facebook.presto.spi.resourceGroups.QueryType.EXPLAIN; +import static com.facebook.presto.spi.resourceGroups.QueryType.INSERT; +import static com.facebook.presto.spi.resourceGroups.QueryType.SELECT; import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Throwables.throwIfInstanceOf; import static java.util.Objects.requireNonNull; import static java.util.concurrent.TimeUnit.SECONDS; @@ -266,7 +288,7 @@ public void start() } catch (Throwable e) { fail(e); - Throwables.propagateIfInstanceOf(e, Error.class); + throwIfInstanceOf(e, Error.class); } } } @@ -285,6 +307,30 @@ public void addFinalQueryInfoListener(StateChangeListener stateChange stateMachine.addQueryInfoStateChangeListener(stateChangeListener); } + @Override + public Optional getQueryType() + { + if (statement instanceof Query) { + return Optional.of(SELECT); + } + else if (statement instanceof Explain) { + return Optional.of(EXPLAIN); + } + else if (statement instanceof ShowCatalogs || statement instanceof ShowCreate || statement instanceof ShowFunctions || + statement instanceof ShowGrants || statement instanceof ShowPartitions || statement instanceof ShowSchemas || + statement instanceof ShowSession || statement instanceof ShowStats || statement instanceof ShowTables || + statement instanceof ShowColumns || statement instanceof DescribeInput || statement instanceof DescribeOutput) { + return Optional.of(DESCRIBE); + } + else if (statement instanceof CreateTableAsSelect || statement instanceof Insert) { + return Optional.of(INSERT); + } + else if (statement instanceof Delete) { + return Optional.of(DELETE); + } + return Optional.empty(); + } + private PlanRoot analyzeQuery() { try { diff --git a/presto-main/src/main/java/com/facebook/presto/execution/TaskManagerConfig.java b/presto-main/src/main/java/com/facebook/presto/execution/TaskManagerConfig.java index 9c551abd2009..e809b8edfd3d 100644 --- a/presto-main/src/main/java/com/facebook/presto/execution/TaskManagerConfig.java +++ b/presto-main/src/main/java/com/facebook/presto/execution/TaskManagerConfig.java @@ -27,6 +27,7 @@ import javax.validation.constraints.Min; import javax.validation.constraints.NotNull; +import java.math.BigDecimal; import java.util.concurrent.TimeUnit; @DefunctConfig({ @@ -64,6 +65,11 @@ public class TaskManagerConfig private int taskNotificationThreads = 5; + private boolean levelAbsolutePriority = true; + private BigDecimal levelTimeMultiplier = new BigDecimal(2.0); + + private boolean legacySchedulingBehavior = true; + @MinDuration("1ms") @MaxDuration("10s") @NotNull @@ -158,6 +164,35 @@ public TaskManagerConfig setShareIndexLoading(boolean shareIndexLoading) return this; } + @Deprecated + @NotNull + public boolean isLevelAbsolutePriority() + { + return levelAbsolutePriority; + } + + @Deprecated + @Config("task.level-absolute-priority") + public TaskManagerConfig setLevelAbsolutePriority(boolean levelAbsolutePriority) + { + this.levelAbsolutePriority = levelAbsolutePriority; + return this; + } + + public BigDecimal getLevelTimeMultiplier() + { + return levelTimeMultiplier; + } + + @Config("task.level-time-multiplier") + @ConfigDescription("Factor that determines the target scheduled time for a level relative to the next") + @Min(0) + public TaskManagerConfig setLevelTimeMultiplier(BigDecimal levelTimeMultiplier) + { + this.levelTimeMultiplier = levelTimeMultiplier; + return this; + } + @Min(1) public int getMaxWorkerThreads() { @@ -339,4 +374,18 @@ public TaskManagerConfig setTaskNotificationThreads(int taskNotificationThreads) this.taskNotificationThreads = taskNotificationThreads; return this; } + + @Deprecated + public boolean isLegacySchedulingBehavior() + { + return legacySchedulingBehavior; + } + + @Deprecated + @Config("task.legacy-scheduling-behavior") + public TaskManagerConfig setLegacySchedulingBehavior(boolean legacySchedulingBehavior) + { + this.legacySchedulingBehavior = legacySchedulingBehavior; + return this; + } } diff --git a/presto-main/src/main/java/com/facebook/presto/execution/executor/LegacyPrioritizedSplitRunner.java b/presto-main/src/main/java/com/facebook/presto/execution/executor/LegacyPrioritizedSplitRunner.java new file mode 100644 index 000000000000..6500564c4503 --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/execution/executor/LegacyPrioritizedSplitRunner.java @@ -0,0 +1,47 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.execution.executor; + +import com.facebook.presto.execution.SplitRunner; +import com.google.common.base.Ticker; +import io.airlift.stats.CounterStat; +import io.airlift.stats.TimeStat; + +import static com.facebook.presto.execution.executor.MultilevelSplitQueue.LEVEL_THRESHOLD_SECONDS; + +public class LegacyPrioritizedSplitRunner + extends PrioritizedSplitRunner +{ + public LegacyPrioritizedSplitRunner(TaskHandle taskHandle, SplitRunner split, Ticker ticker, CounterStat globalCpuTimeMicros, CounterStat globalScheduledTimeMicros, TimeStat blockedQuantaWallTime, TimeStat unblockedQuantaWallTime) + { + super(taskHandle, split, ticker, globalCpuTimeMicros, globalScheduledTimeMicros, blockedQuantaWallTime, unblockedQuantaWallTime); + } + + @Override + public int compareTo(PrioritizedSplitRunner o) + { + int level = priority.get().getLevel(); + + int result = 0; + if (level == LEVEL_THRESHOLD_SECONDS.length - 1) { + result = Long.compare(lastRun.get(), o.lastRun.get()); + } + + if (result != 0) { + return result; + } + + return super.compareTo(o); + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/execution/executor/LegacyTaskHandle.java b/presto-main/src/main/java/com/facebook/presto/execution/executor/LegacyTaskHandle.java new file mode 100644 index 000000000000..8de3cca8291e --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/execution/executor/LegacyTaskHandle.java @@ -0,0 +1,59 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.execution.executor; + +import com.facebook.presto.execution.TaskId; +import io.airlift.units.Duration; + +import java.util.function.DoubleSupplier; + +import static com.facebook.presto.execution.executor.MultilevelSplitQueue.LEVEL_THRESHOLD_SECONDS; +import static com.facebook.presto.execution.executor.MultilevelSplitQueue.computeLevel; +import static java.util.concurrent.TimeUnit.SECONDS; + +public class LegacyTaskHandle + extends TaskHandle +{ + public LegacyTaskHandle(TaskId taskId, MultilevelSplitQueue splitQueue, DoubleSupplier utilizationSupplier, int initialSplitConcurrency, Duration splitConcurrencyAdjustFrequency) + { + super(taskId, splitQueue, utilizationSupplier, initialSplitConcurrency, splitConcurrencyAdjustFrequency); + } + + @Override + public synchronized Priority addScheduledNanos(long durationNanos) + { + concurrencyController.update(durationNanos, utilizationSupplier.getAsDouble(), runningLeafSplits.size()); + scheduledNanos += durationNanos; + + Priority oldPriority = priority.get(); + Priority newPriority; + + if (oldPriority.getLevel() < (LEVEL_THRESHOLD_SECONDS.length - 1) && scheduledNanos >= SECONDS.toNanos(LEVEL_THRESHOLD_SECONDS[oldPriority.getLevel() + 1])) { + int newLevel = computeLevel(scheduledNanos); + newPriority = new Priority(newLevel, scheduledNanos); + } + else { + newPriority = new Priority(oldPriority.getLevel(), scheduledNanos); + } + + priority.set(newPriority); + return newPriority; + } + + @Override + public synchronized Priority resetLevelPriority() + { + return priority.get(); + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/execution/executor/MultilevelSplitQueue.java b/presto-main/src/main/java/com/facebook/presto/execution/executor/MultilevelSplitQueue.java new file mode 100644 index 000000000000..44a72eb2d515 --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/execution/executor/MultilevelSplitQueue.java @@ -0,0 +1,339 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.execution.executor; + +import com.google.common.annotations.VisibleForTesting; +import com.google.common.collect.ImmutableList; +import io.airlift.stats.CounterStat; + +import javax.annotation.concurrent.GuardedBy; +import javax.annotation.concurrent.ThreadSafe; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.List; +import java.util.PriorityQueue; +import java.util.concurrent.atomic.AtomicLong; +import java.util.concurrent.locks.Condition; +import java.util.concurrent.locks.ReentrantLock; + +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Preconditions.checkState; +import static java.util.concurrent.TimeUnit.NANOSECONDS; +import static java.util.concurrent.TimeUnit.SECONDS; + +@ThreadSafe +public class MultilevelSplitQueue +{ + static final int[] LEVEL_THRESHOLD_SECONDS = {0, 1, 10, 60, 300}; + static final long LEVEL_CONTRIBUTION_CAP = SECONDS.toNanos(30); + + @GuardedBy("lock") + private final List> levelWaitingSplits; + @GuardedBy("lock") + private final long[] levelScheduledTime = new long[LEVEL_THRESHOLD_SECONDS.length]; + + private final AtomicLong[] levelMinPriority; + private final List selectedLevelCounters; + + private final ReentrantLock lock = new ReentrantLock(); + private final Condition notEmpty = lock.newCondition(); + + private final boolean levelAbsolutePriority; + private final double levelTimeMultiplier; + + public MultilevelSplitQueue(boolean levelAbsolutePriority, double levelTimeMultiplier) + { + this.levelMinPriority = new AtomicLong[LEVEL_THRESHOLD_SECONDS.length]; + this.levelWaitingSplits = new ArrayList<>(LEVEL_THRESHOLD_SECONDS.length); + ImmutableList.Builder counters = ImmutableList.builder(); + + for (int i = 0; i < LEVEL_THRESHOLD_SECONDS.length; i++) { + levelMinPriority[i] = new AtomicLong(-1); + levelWaitingSplits.add(new PriorityQueue<>()); + counters.add(new CounterStat()); + } + + this.selectedLevelCounters = counters.build(); + + this.levelAbsolutePriority = levelAbsolutePriority; + this.levelTimeMultiplier = levelTimeMultiplier; + } + + private void addLevelTime(int level, long nanos) + { + lock.lock(); + try { + levelScheduledTime[level] += nanos; + } + finally { + lock.unlock(); + } + } + + public void offer(PrioritizedSplitRunner split) + { + checkArgument(split != null, "split is null"); + + split.setReady(); + lock.lock(); + try { + levelWaitingSplits.get(split.getPriority().getLevel()).offer(split); + notEmpty.signal(); + } + finally { + lock.unlock(); + } + } + + public PrioritizedSplitRunner take() + throws InterruptedException + { + while (true) { + lock.lockInterruptibly(); + try { + PrioritizedSplitRunner result; + while ((result = pollSplit()) == null) { + notEmpty.await(); + } + + if (result.updateLevelPriority()) { + offer(result); + continue; + } + + int selectedLevel = result.getPriority().getLevel(); + levelMinPriority[selectedLevel].set(result.getPriority().getLevelPriority()); + selectedLevelCounters.get(selectedLevel).update(1); + + return result; + } + finally { + lock.unlock(); + } + } + } + + /** + * Presto attempts to give each level a target amount of scheduled time, which is configurable + * using levelTimeMultiplier. + * + * This function selects the level that has the the lowest ratio of actual to the target time + * with the objective of minimizing deviation from the target scheduled time. From this level, + * we pick the split with the lowest priority. + */ + @GuardedBy("lock") + private PrioritizedSplitRunner pollSplit() + { + if (levelAbsolutePriority) { + return pollFirstSplit(); + } + + long targetScheduledTime = updateLevelTimes(); + double worstRatio = 1; + int selectedLevel = -1; + for (int level = 0; level < LEVEL_THRESHOLD_SECONDS.length; level++) { + if (!levelWaitingSplits.get(level).isEmpty()) { + double ratio = levelScheduledTime[level] == 0 ? 0 : targetScheduledTime / (1.0 * levelScheduledTime[level]); + if (selectedLevel == -1 || ratio > worstRatio) { + worstRatio = ratio; + selectedLevel = level; + } + } + + targetScheduledTime /= levelTimeMultiplier; + } + + if (selectedLevel == -1) { + return null; + } + + PrioritizedSplitRunner result = levelWaitingSplits.get(selectedLevel).poll(); + checkState(result != null, "pollSplit cannot return null"); + + return result; + } + + /** + * During periods of time when a level has no waiting splits, it will not accumulate + * accumulate scheduled time and will fall behind relative to other levels. + * + * This can cause temporary starvation for other levels when splits do reach the + * previously-empty level. + * + * To prevent this we set the scheduled time for levels which are empty to the expected + * scheduled time. + * + * @return target scheduled time for level 0 + */ + @GuardedBy("lock") + private long updateLevelTimes() + { + long level0ExpectedTime = levelScheduledTime[0]; + boolean updated; + do { + double currentMultiplier = levelTimeMultiplier; + updated = false; + for (int level = 0; level < LEVEL_THRESHOLD_SECONDS.length; level++) { + currentMultiplier /= levelTimeMultiplier; + long levelExpectedTime = (long) (level0ExpectedTime * currentMultiplier); + + if (levelWaitingSplits.get(level).isEmpty()) { + levelScheduledTime[level] = levelExpectedTime; + continue; + } + + if (levelScheduledTime[level] > levelExpectedTime) { + level0ExpectedTime = (long) (levelScheduledTime[level] / currentMultiplier); + updated = true; + break; + } + } + } while (updated && level0ExpectedTime != 0); + + return level0ExpectedTime; + } + + @GuardedBy("lock") + private PrioritizedSplitRunner pollFirstSplit() + { + for (PriorityQueue level : levelWaitingSplits) { + PrioritizedSplitRunner split = level.poll(); + if (split != null) { + return split; + } + } + + return null; + } + + /** + * Presto 'charges' the quanta run time to the task and the level it belongs to in + * an effort to maintain the target thread utilization ratios between levels and to + * maintain fairness within a level. + * + * Consider an example split where a read hung for several minutes. This is either a bug + * or a failing dependency. In either case we do not want to charge the task too much, + * and we especially do not want to charge the level too much - i.e. cause other queries + * in this level to starve. + * + * @return the new priority for the task + */ + public Priority updatePriority(Priority oldPriority, long quantaNanos, long scheduledNanos) + { + int oldLevel = oldPriority.getLevel(); + int newLevel = computeLevel(scheduledNanos); + + long levelContribution = Math.min(quantaNanos, LEVEL_CONTRIBUTION_CAP); + + if (oldLevel == newLevel) { + addLevelTime(oldLevel, levelContribution); + return new Priority(oldLevel, oldPriority.getLevelPriority() + quantaNanos); + } + + long remainingLevelContribution = levelContribution; + long remainingTaskTime = quantaNanos; + + // a task normally slowly accrues scheduled time in a level and then moves to the next, but + // if the split had a particularly long quanta, accrue time to each level as if it had run + // in that level up to the level limit. + for (int currentLevel = oldLevel; currentLevel < newLevel; currentLevel++) { + long timeAccruedToLevel = Math.min(SECONDS.toNanos(LEVEL_THRESHOLD_SECONDS[currentLevel + 1] - LEVEL_THRESHOLD_SECONDS[currentLevel]), remainingLevelContribution); + addLevelTime(currentLevel, timeAccruedToLevel); + remainingLevelContribution -= timeAccruedToLevel; + remainingTaskTime -= timeAccruedToLevel; + } + + addLevelTime(newLevel, remainingLevelContribution); + long newLevelMinPriority = getLevelMinPriority(newLevel, scheduledNanos); + return new Priority(newLevel, newLevelMinPriority + remainingTaskTime); + } + + public void remove(PrioritizedSplitRunner split) + { + checkArgument(split != null, "split is null"); + lock.lock(); + try { + for (PriorityQueue level : levelWaitingSplits) { + level.remove(split); + } + } + finally { + lock.unlock(); + } + } + + public void removeAll(Collection splits) + { + lock.lock(); + try { + for (PriorityQueue level : levelWaitingSplits) { + level.removeAll(splits); + } + } + finally { + lock.unlock(); + } + } + + public long getLevelMinPriority(int level, long taskThreadUsageNanos) + { + levelMinPriority[level].compareAndSet(-1, taskThreadUsageNanos); + return levelMinPriority[level].get(); + } + + public int size() + { + lock.lock(); + try { + int total = 0; + for (PriorityQueue level : levelWaitingSplits) { + total += level.size(); + } + return total; + } + finally { + lock.unlock(); + } + } + + public List getSelectedLevelCounters() + { + return selectedLevelCounters; + } + + public static int computeLevel(long threadUsageNanos) + { + long seconds = NANOSECONDS.toSeconds(threadUsageNanos); + for (int i = 0; i < (LEVEL_THRESHOLD_SECONDS.length - 1); i++) { + if (seconds < LEVEL_THRESHOLD_SECONDS[i + 1]) { + return i; + } + } + + return LEVEL_THRESHOLD_SECONDS.length - 1; + } + + @VisibleForTesting + long[] getLevelScheduledTime() + { + lock.lock(); + try { + return levelScheduledTime; + } + finally { + lock.unlock(); + } + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/execution/executor/PrioritizedSplitRunner.java b/presto-main/src/main/java/com/facebook/presto/execution/executor/PrioritizedSplitRunner.java index 34ff5ac57288..ba0676b4f954 100644 --- a/presto-main/src/main/java/com/facebook/presto/execution/executor/PrioritizedSplitRunner.java +++ b/presto-main/src/main/java/com/facebook/presto/execution/executor/PrioritizedSplitRunner.java @@ -25,8 +25,8 @@ import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; -import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicLong; +import java.util.concurrent.atomic.AtomicReference; import static com.facebook.presto.operator.Operator.NOT_BLOCKED; import static java.util.concurrent.TimeUnit.NANOSECONDS; @@ -54,10 +54,9 @@ class PrioritizedSplitRunner private final AtomicBoolean destroyed = new AtomicBoolean(); - private final AtomicInteger priorityLevel = new AtomicInteger(); - private final AtomicLong taskScheduledNanos = new AtomicLong(); + protected final AtomicReference priority = new AtomicReference<>(new Priority(0, 0)); - private final AtomicLong lastRun = new AtomicLong(); + protected final AtomicLong lastRun = new AtomicLong(); private final AtomicLong lastReady = new AtomicLong(); private final AtomicLong start = new AtomicLong(); @@ -90,6 +89,8 @@ class PrioritizedSplitRunner this.globalScheduledTimeMicros = globalScheduledTimeMicros; this.blockedQuantaWallTime = blockedQuantaWallTime; this.unblockedQuantaWallTime = unblockedQuantaWallTime; + + this.updateLevelPriority(); } public TaskHandle getTaskHandle() @@ -162,21 +163,14 @@ public ListenableFuture process() ListenableFuture blocked = split.processFor(SPLIT_RUN_QUANTA); CpuTimer.CpuDuration elapsed = timer.elapsedTime(); - // update priority level base on total thread usage of task - long quantaScheduledNanos = elapsed.getWall().roundTo(NANOSECONDS); + long quantaScheduledNanos = ticker.read() - startNanos; scheduledNanos.addAndGet(quantaScheduledNanos); - long taskScheduledTimeNanos = taskHandle.addThreadUsageNanos(quantaScheduledNanos); - taskScheduledNanos.set(taskScheduledTimeNanos); - - priorityLevel.set(calculatePriorityLevel(taskScheduledTimeNanos)); - - // record last run for prioritization within a level + priority.set(taskHandle.addScheduledNanos(quantaScheduledNanos)); lastRun.set(ticker.read()); if (blocked == NOT_BLOCKED) { unblockedQuantaWallTime.add(elapsed.getWall()); - setReady(); } else { blockedQuantaWallTime.add(elapsed.getWall()); @@ -196,34 +190,39 @@ public ListenableFuture process() } } - public boolean updatePriorityLevel() + public void setReady() { - int newPriority = calculatePriorityLevel(taskHandle.getThreadUsageNanos()); - if (newPriority == priorityLevel.getAndSet(newPriority)) { - return false; - } + lastReady.set(ticker.read()); + } + + /** + * Updates the (potentially stale) priority value cached in this object. + * This should be called when this object is outside the queue. + * + * @return true if the level changed. + */ + public boolean updateLevelPriority() + { + Priority newPriority = taskHandle.getPriority(); + Priority oldPriority = priority.getAndSet(newPriority); + return newPriority.getLevel() != oldPriority.getLevel(); + } - // update thread usage while if level changed - taskScheduledNanos.set(taskHandle.getThreadUsageNanos()); - return true; + /** + * Updates the task level priority to be greater than or equal to the minimum + * priority within that level. This ensures that tasks that spend time blocked do + * not return and starve already-running tasks. Also updates the cached priority + * object. + */ + public void resetLevelPriority() + { + priority.set(taskHandle.resetLevelPriority()); } @Override public int compareTo(PrioritizedSplitRunner o) { - int level = priorityLevel.get(); - - int result = Integer.compare(level, o.priorityLevel.get()); - if (result != 0) { - return result; - } - - if (level < 4) { - result = Long.compare(taskScheduledNanos.get(), o.taskScheduledNanos.get()); - } - else { - result = Long.compare(lastRun.get(), o.lastRun.get()); - } + int result = Long.compare(priority.get().getLevelPriority(), o.getPriority().getLevelPriority()); if (result != 0) { return result; } @@ -236,37 +235,9 @@ public int getSplitId() return splitId; } - public void setReady() - { - lastReady.set(ticker.read()); - } - - public AtomicInteger getPriorityLevel() - { - return priorityLevel; - } - - public static int calculatePriorityLevel(long threadUsageNanos) + public Priority getPriority() { - long millis = NANOSECONDS.toMillis(threadUsageNanos); - - int priorityLevel; - if (millis < 1000) { - priorityLevel = 0; - } - else if (millis < 10_000) { - priorityLevel = 1; - } - else if (millis < 60_000) { - priorityLevel = 2; - } - else if (millis < 300_000) { - priorityLevel = 3; - } - else { - priorityLevel = 4; - } - return priorityLevel; + return priority.get(); } public String getInfo() diff --git a/presto-main/src/main/java/com/facebook/presto/execution/executor/Priority.java b/presto-main/src/main/java/com/facebook/presto/execution/executor/Priority.java new file mode 100644 index 000000000000..6bb3609a492f --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/execution/executor/Priority.java @@ -0,0 +1,67 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.execution.executor; + +import javax.annotation.concurrent.Immutable; + +import static com.google.common.base.MoreObjects.toStringHelper; + +/** + * Task (and split) priority is composed of a level and a within-level + * priority. Level decides which queue the split is placed in, while + * within-level priority decides which split is executed next in that level. + *

+ * Tasks move from a lower to higher level as they exceed level thresholds + * of total scheduled time accrued to a task. + *

+ * The priority within a level increases with the scheduled time accumulated + * in that level. This is necessary to achieve fairness when tasks acquire + * scheduled time at varying rates. + *

+ * However, this priority is not equal to the task total accrued + * scheduled time. When a task graduates to a higher level, the level + * priority is set to the minimum current priority in the new level. This + * allows us to maintain instantaneous fairness in terms of scheduled time. + */ +@Immutable +public final class Priority +{ + private final int level; + private final long levelPriority; + + public Priority(int level, long levelPriority) + { + this.level = level; + this.levelPriority = levelPriority; + } + + public int getLevel() + { + return level; + } + + public long getLevelPriority() + { + return levelPriority; + } + + @Override + public String toString() + { + return toStringHelper(this) + .add("level", level) + .add("levelPriority", levelPriority) + .toString(); + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/execution/executor/TaskExecutor.java b/presto-main/src/main/java/com/facebook/presto/execution/executor/TaskExecutor.java index 6652791540a8..d1c3396b60aa 100644 --- a/presto-main/src/main/java/com/facebook/presto/execution/executor/TaskExecutor.java +++ b/presto-main/src/main/java/com/facebook/presto/execution/executor/TaskExecutor.java @@ -49,7 +49,6 @@ import java.util.concurrent.ConcurrentSkipListSet; import java.util.concurrent.ExecutorService; import java.util.concurrent.Future; -import java.util.concurrent.PriorityBlockingQueue; import java.util.concurrent.RejectedExecutionException; import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.ThreadPoolExecutor; @@ -58,7 +57,7 @@ import java.util.concurrent.atomic.AtomicLongArray; import java.util.function.DoubleSupplier; -import static com.facebook.presto.execution.executor.PrioritizedSplitRunner.calculatePriorityLevel; +import static com.facebook.presto.execution.executor.MultilevelSplitQueue.computeLevel; import static com.google.common.base.MoreObjects.toStringHelper; import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkState; @@ -77,7 +76,7 @@ public class TaskExecutor private static final Logger log = Logger.get(TaskExecutor.class); // each task is guaranteed a minimum number of splits - private static final int GUARANTEED_SPLITS_PER_TASK = 3; + static final int GUARANTEED_SPLITS_PER_TASK = 3; // print out split call stack if it has been running for a certain amount of time private static final Duration LONG_SPLIT_WARNING_THRESHOLD = new Duration(1000, TimeUnit.SECONDS); @@ -113,7 +112,7 @@ public class TaskExecutor /** * Splits waiting for a runner thread. */ - private final PriorityBlockingQueue waitingSplits; + private final MultilevelSplitQueue waitingSplits; /** * Splits running on a thread. @@ -128,8 +127,6 @@ public class TaskExecutor private final AtomicLongArray completedTasksPerLevel = new AtomicLongArray(5); private final AtomicLongArray completedSplitsPerLevel = new AtomicLongArray(5); - private final CounterStat[] selectedLevelCounters = new CounterStat[5]; - private final TimeStat splitQueuedTime = new TimeStat(NANOSECONDS); private final TimeStat splitWallTime = new TimeStat(NANOSECONDS); @@ -149,12 +146,14 @@ public class TaskExecutor private final TimeStat blockedQuantaWallTime = new TimeStat(MICROSECONDS); private final TimeStat unblockedQuantaWallTime = new TimeStat(MICROSECONDS); + private final boolean legacySchedulingBehavior; + private volatile boolean closed; @Inject public TaskExecutor(TaskManagerConfig config) { - this(requireNonNull(config, "config is null").getMaxWorkerThreads(), config.getMinDrivers()); + this(requireNonNull(config, "config is null").getMaxWorkerThreads(), config.getMinDrivers(), config.getLevelTimeMultiplier().doubleValue(), config.isLevelAbsolutePriority(), config.isLegacySchedulingBehavior(), Ticker.systemTicker()); } public TaskExecutor(int runnerThreads, int minDrivers) @@ -162,8 +161,13 @@ public TaskExecutor(int runnerThreads, int minDrivers) this(runnerThreads, minDrivers, Ticker.systemTicker()); } - @VisibleForTesting public TaskExecutor(int runnerThreads, int minDrivers, Ticker ticker) + { + this(runnerThreads, minDrivers, 2, false, true, ticker); + } + + @VisibleForTesting + public TaskExecutor(int runnerThreads, int minDrivers, double levelTimeMultiplier, boolean levelAbsolutePriority, boolean legacySchedulingBehavior, Ticker ticker) { checkArgument(runnerThreads > 0, "runnerThreads must be at least 1"); @@ -175,12 +179,9 @@ public TaskExecutor(int runnerThreads, int minDrivers, Ticker ticker) this.ticker = requireNonNull(ticker, "ticker is null"); this.minimumNumberOfDrivers = minDrivers; - this.waitingSplits = new PriorityBlockingQueue<>(Runtime.getRuntime().availableProcessors() * 10); + this.waitingSplits = new MultilevelSplitQueue(levelAbsolutePriority, levelTimeMultiplier); this.tasks = new LinkedList<>(); - - for (int i = 0; i < 5; i++) { - selectedLevelCounters[i] = new CounterStat(); - } + this.legacySchedulingBehavior = legacySchedulingBehavior; } @PostConstruct @@ -230,7 +231,15 @@ public synchronized TaskHandle addTask(TaskId taskId, DoubleSupplier utilization log.debug("Task scheduled " + taskId); - TaskHandle taskHandle = new TaskHandle(taskId, utilizationSupplier, initialSplitConcurrency, splitConcurrencyAdjustFrequency); + TaskHandle taskHandle; + + if (legacySchedulingBehavior) { + taskHandle = new LegacyTaskHandle(taskId, waitingSplits, utilizationSupplier, initialSplitConcurrency, splitConcurrencyAdjustFrequency); + } + else { + taskHandle = new TaskHandle(taskId, waitingSplits, utilizationSupplier, initialSplitConcurrency, splitConcurrencyAdjustFrequency); + } + tasks.add(taskHandle); return taskHandle; } @@ -255,9 +264,8 @@ public void removeTask(TaskHandle taskHandle) } // record completed stats - long threadUsageNanos = taskHandle.getThreadUsageNanos(); - int priorityLevel = calculatePriorityLevel(threadUsageNanos); - completedTasksPerLevel.incrementAndGet(priorityLevel); + long threadUsageNanos = taskHandle.getScheduledNanos(); + completedTasksPerLevel.incrementAndGet(computeLevel(threadUsageNanos)); log.debug("Task finished or failed " + taskHandle.getTaskId()); @@ -271,14 +279,27 @@ public List> enqueueSplits(TaskHandle taskHandle, boolean in List> finishedFutures = new ArrayList<>(taskSplits.size()); synchronized (this) { for (SplitRunner taskSplit : taskSplits) { - PrioritizedSplitRunner prioritizedSplitRunner = new PrioritizedSplitRunner( - taskHandle, - taskSplit, - ticker, - globalCpuTimeMicros, - globalScheduledTimeMicros, - blockedQuantaWallTime, - unblockedQuantaWallTime); + PrioritizedSplitRunner prioritizedSplitRunner; + if (legacySchedulingBehavior) { + prioritizedSplitRunner = new LegacyPrioritizedSplitRunner( + taskHandle, + taskSplit, + ticker, + globalCpuTimeMicros, + globalScheduledTimeMicros, + blockedQuantaWallTime, + unblockedQuantaWallTime); + } + else { + prioritizedSplitRunner = new PrioritizedSplitRunner( + taskHandle, + taskSplit, + ticker, + globalCpuTimeMicros, + globalScheduledTimeMicros, + blockedQuantaWallTime, + unblockedQuantaWallTime); + } if (taskHandle.isDestroyed()) { // If the handle is destroyed, we destroy the task splits to complete the future @@ -310,7 +331,7 @@ else if (intermediate) { private void splitFinished(PrioritizedSplitRunner split) { - completedSplitsPerLevel.incrementAndGet(split.getPriorityLevel().get()); + completedSplitsPerLevel.incrementAndGet(split.getPriority().getLevel()); synchronized (this) { allSplits.remove(split); @@ -383,7 +404,7 @@ private synchronized void startIntermediateSplit(PrioritizedSplitRunner split) private synchronized void startSplit(PrioritizedSplitRunner split) { allSplits.add(split); - waitingSplits.put(split); + waitingSplits.offer(split); } private synchronized PrioritizedSplitRunner pollNextSplitWorker() @@ -440,18 +461,12 @@ public void run() final PrioritizedSplitRunner split; try { split = waitingSplits.take(); - if (split.updatePriorityLevel()) { - // priority level changed, return split to queue for re-prioritization - waitingSplits.put(split); - continue; - } } catch (InterruptedException e) { Thread.currentThread().interrupt(); return; } - selectedLevelCounters[split.getPriorityLevel().get()].update(1); String threadId = split.getTaskHandle().getTaskId() + "-" + split.getSplitId(); try (SetThreadName splitName = new SetThreadName(threadId)) { RunningSplitInfo splitInfo = new RunningSplitInfo(ticker.read(), threadId, Thread.currentThread()); @@ -473,15 +488,15 @@ public void run() } else { if (blocked.isDone()) { - waitingSplits.put(split); + waitingSplits.offer(split); } else { blockedSplits.put(split, blocked); blocked.addListener(() -> { blockedSplits.remove(split); - split.updatePriorityLevel(); - split.setReady(); - waitingSplits.put(split); + // reset the level priority to prevent previously-blocked splits from starving existing splits + split.resetLevelPriority(); + waitingSplits.offer(split); }, executor); } } @@ -625,66 +640,66 @@ public long getCompletedSplitsLevel4() @Managed public long getRunningTasksLevel0() { - return calculateRunningTasksForLevel(0); + return getRunningTasksForLevel(0); } @Managed public long getRunningTasksLevel1() { - return calculateRunningTasksForLevel(1); + return getRunningTasksForLevel(1); } @Managed public long getRunningTasksLevel2() { - return calculateRunningTasksForLevel(2); + return getRunningTasksForLevel(2); } @Managed public long getRunningTasksLevel3() { - return calculateRunningTasksForLevel(3); + return getRunningTasksForLevel(3); } @Managed public long getRunningTasksLevel4() { - return calculateRunningTasksForLevel(4); + return getRunningTasksForLevel(4); } @Managed @Nested public CounterStat getSelectedCountLevel0() { - return selectedLevelCounters[0]; + return waitingSplits.getSelectedLevelCounters().get(0); } @Managed @Nested public CounterStat getSelectedCountLevel1() { - return selectedLevelCounters[1]; + return waitingSplits.getSelectedLevelCounters().get(1); } @Managed @Nested public CounterStat getSelectedCountLevel2() { - return selectedLevelCounters[2]; + return waitingSplits.getSelectedLevelCounters().get(2); } @Managed @Nested public CounterStat getSelectedCountLevel3() { - return selectedLevelCounters[3]; + return waitingSplits.getSelectedLevelCounters().get(3); } @Managed @Nested public CounterStat getSelectedCountLevel4() { - return selectedLevelCounters[4]; + return waitingSplits.getSelectedLevelCounters().get(4); } @Managed @@ -743,6 +758,20 @@ public TimeDistribution getIntermediateSplitWallTime() return intermediateSplitWallTime; } + @Managed + @Nested + public TimeDistribution getLeafSplitWaitTime() + { + return leafSplitWaitTime; + } + + @Managed + @Nested + public TimeDistribution getIntermediateSplitWaitTime() + { + return intermediateSplitWaitTime; + } + @Managed @Nested public CounterStat getGlobalScheduledTimeMicros() @@ -757,11 +786,11 @@ public CounterStat getGlobalCpuTimeMicros() return globalCpuTimeMicros; } - private synchronized int calculateRunningTasksForLevel(int level) + private synchronized int getRunningTasksForLevel(int level) { int count = 0; for (TaskHandle task : tasks) { - if (calculatePriorityLevel(task.getThreadUsageNanos()) == level) { + if (task.getPriority().getLevel() == level) { count++; } } diff --git a/presto-main/src/main/java/com/facebook/presto/execution/executor/TaskHandle.java b/presto-main/src/main/java/com/facebook/presto/execution/executor/TaskHandle.java index baf5d291a830..adeaf6452cce 100644 --- a/presto-main/src/main/java/com/facebook/presto/execution/executor/TaskHandle.java +++ b/presto-main/src/main/java/com/facebook/presto/execution/executor/TaskHandle.java @@ -15,7 +15,6 @@ import com.facebook.presto.execution.SplitConcurrencyController; import com.facebook.presto.execution.TaskId; -import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.ImmutableList; import io.airlift.units.Duration; @@ -27,49 +26,68 @@ import java.util.List; import java.util.Queue; import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicReference; import java.util.function.DoubleSupplier; import static com.google.common.base.MoreObjects.toStringHelper; import static com.google.common.base.Preconditions.checkState; +import static java.util.Objects.requireNonNull; @ThreadSafe public class TaskHandle { private final TaskId taskId; - private final DoubleSupplier utilizationSupplier; + protected final DoubleSupplier utilizationSupplier; @GuardedBy("this") - private final Queue queuedLeafSplits = new ArrayDeque<>(10); + protected final Queue queuedLeafSplits = new ArrayDeque<>(10); @GuardedBy("this") - private final List runningLeafSplits = new ArrayList<>(10); + protected final List runningLeafSplits = new ArrayList<>(10); @GuardedBy("this") - private final List runningIntermediateSplits = new ArrayList<>(10); + protected final List runningIntermediateSplits = new ArrayList<>(10); @GuardedBy("this") - private long taskThreadUsageNanos; + protected long scheduledNanos; @GuardedBy("this") private boolean destroyed; @GuardedBy("this") - private final SplitConcurrencyController concurrencyController; + protected final SplitConcurrencyController concurrencyController; private final AtomicInteger nextSplitId = new AtomicInteger(); - public TaskHandle(TaskId taskId, DoubleSupplier utilizationSupplier, int initialSplitConcurrency, Duration splitConcurrencyAdjustFrequency) + protected final AtomicReference priority = new AtomicReference<>(new Priority(0, 0)); + private final MultilevelSplitQueue splitQueue; + + public TaskHandle(TaskId taskId, MultilevelSplitQueue splitQueue, DoubleSupplier utilizationSupplier, int initialSplitConcurrency, Duration splitConcurrencyAdjustFrequency) { - this.taskId = taskId; - this.utilizationSupplier = utilizationSupplier; - this.concurrencyController = new SplitConcurrencyController(initialSplitConcurrency, splitConcurrencyAdjustFrequency); + this.taskId = requireNonNull(taskId, "taskId is null"); + this.splitQueue = requireNonNull(splitQueue, "splitQueue is null"); + this.utilizationSupplier = requireNonNull(utilizationSupplier, "utilizationSupplier is null"); + this.concurrencyController = new SplitConcurrencyController( + initialSplitConcurrency, + requireNonNull(splitConcurrencyAdjustFrequency, "splitConcurrencyAdjustFrequency is null")); } - public synchronized long addThreadUsageNanos(long durationNanos) + public synchronized Priority addScheduledNanos(long durationNanos) { concurrencyController.update(durationNanos, utilizationSupplier.getAsDouble(), runningLeafSplits.size()); - taskThreadUsageNanos += durationNanos; - return taskThreadUsageNanos; + scheduledNanos += durationNanos; + + Priority newPriority = splitQueue.updatePriority(priority.get(), durationNanos, scheduledNanos); + + priority.set(newPriority); + return newPriority; } - public TaskId getTaskId() + public synchronized Priority resetLevelPriority() { - return taskId; + long levelMinPriority = splitQueue.getLevelMinPriority(priority.get().getLevel(), scheduledNanos); + if (priority.get().getLevelPriority() < levelMinPriority) { + Priority newPriority = new Priority(priority.get().getLevel(), levelMinPriority); + priority.set(newPriority); + return newPriority; + } + + return priority.get(); } public synchronized boolean isDestroyed() @@ -77,6 +95,16 @@ public synchronized boolean isDestroyed() return destroyed; } + public Priority getPriority() + { + return priority.get(); + } + + public TaskId getTaskId() + { + return taskId; + } + // Returns any remaining splits. The caller must destroy these. public synchronized List destroy() { @@ -104,15 +132,14 @@ public synchronized void recordIntermediateSplit(PrioritizedSplitRunner split) runningIntermediateSplits.add(split); } - @VisibleForTesting - public synchronized int getRunningLeafSplits() + synchronized int getRunningLeafSplits() { return runningLeafSplits.size(); } - public synchronized long getThreadUsageNanos() + public synchronized long getScheduledNanos() { - return taskThreadUsageNanos; + return scheduledNanos; } public synchronized PrioritizedSplitRunner pollNextSplit() diff --git a/presto-main/src/main/java/com/facebook/presto/execution/resourceGroups/InternalResourceGroup.java b/presto-main/src/main/java/com/facebook/presto/execution/resourceGroups/InternalResourceGroup.java index aa5a1873bd46..7e4416d3740a 100644 --- a/presto-main/src/main/java/com/facebook/presto/execution/resourceGroups/InternalResourceGroup.java +++ b/presto-main/src/main/java/com/facebook/presto/execution/resourceGroups/InternalResourceGroup.java @@ -15,6 +15,8 @@ import com.facebook.presto.execution.QueryExecution; import com.facebook.presto.execution.QueryState; +import com.facebook.presto.server.QueryStateInfo; +import com.facebook.presto.server.ResourceGroupStateInfo; import com.facebook.presto.spi.PrestoException; import com.facebook.presto.spi.resourceGroups.ResourceGroup; import com.facebook.presto.spi.resourceGroups.ResourceGroupId; @@ -40,6 +42,7 @@ import java.util.function.BiConsumer; import static com.facebook.presto.SystemSessionProperties.getQueryPriority; +import static com.facebook.presto.server.QueryStateInfo.createQueryStateInfo; import static com.facebook.presto.spi.ErrorType.USER_ERROR; import static com.facebook.presto.spi.StandardErrorCode.EXCEEDED_TIME_LIMIT; import static com.facebook.presto.spi.resourceGroups.ResourceGroupState.CAN_QUEUE; @@ -145,31 +148,84 @@ public ResourceGroupInfo getInfo() .map(InternalResourceGroup::getInfo) .collect(toImmutableList()); - ResourceGroupState resourceGroupState; - if (canRunMore()) { - resourceGroupState = CAN_RUN; - } - else if (canQueueMore()) { - resourceGroupState = CAN_QUEUE; - } - else { - resourceGroupState = FULL; - } - return new ResourceGroupInfo( id, - new DataSize(softMemoryLimitBytes, BYTE), + DataSize.succinctBytes(softMemoryLimitBytes), maxRunningQueries, + runningTimeLimit, maxQueuedQueries, - resourceGroupState, + queuedTimeLimit, + getState(), eligibleSubGroups.size(), - new DataSize(cachedMemoryUsageBytes, BYTE), + DataSize.succinctBytes(cachedMemoryUsageBytes), runningQueries.size() + descendantRunningQueries, queuedQueries.size() + descendantQueuedQueries, infos); } } + public ResourceGroupStateInfo getStateInfo() + { + synchronized (root) { + return new ResourceGroupStateInfo( + id, + getState(), + DataSize.succinctBytes(softMemoryLimitBytes), + DataSize.succinctBytes(cachedMemoryUsageBytes), + maxRunningQueries, + maxQueuedQueries, + runningTimeLimit, + queuedTimeLimit, + getAggregatedRunningQueriesInfo(), + queuedQueries.size() + descendantQueuedQueries, + subGroups.values().stream() + .map(subGroup -> new ResourceGroupInfo( + subGroup.getId(), + DataSize.succinctBytes(softMemoryLimitBytes), + maxRunningQueries, + runningTimeLimit, + maxQueuedQueries, + queuedTimeLimit, + getState(), + eligibleSubGroups.size(), + DataSize.succinctBytes(cachedMemoryUsageBytes), + runningQueries.size() + descendantRunningQueries, + queuedQueries.size() + descendantQueuedQueries)) + .collect(toImmutableList())); + } + } + + private ResourceGroupState getState() + { + synchronized (root) { + if (canRunMore()) { + return CAN_RUN; + } + else if (canQueueMore()) { + return CAN_QUEUE; + } + else { + return FULL; + } + } + } + + private List getAggregatedRunningQueriesInfo() + { + synchronized (root) { + if (subGroups.isEmpty()) { + return runningQueries.stream() + .map(QueryExecution::getQueryInfo) + .map(queryInfo -> createQueryStateInfo(queryInfo, Optional.of(id), Optional.empty())) + .collect(toImmutableList()); + } + return subGroups.values().stream() + .map(InternalResourceGroup::getAggregatedRunningQueriesInfo) + .flatMap(List::stream) + .collect(toImmutableList()); + } + } + @Override public ResourceGroupId getId() { diff --git a/presto-main/src/main/java/com/facebook/presto/execution/resourceGroups/InternalResourceGroupManager.java b/presto-main/src/main/java/com/facebook/presto/execution/resourceGroups/InternalResourceGroupManager.java index aa21ac9d9b47..1b2276ee0b0b 100644 --- a/presto-main/src/main/java/com/facebook/presto/execution/resourceGroups/InternalResourceGroupManager.java +++ b/presto-main/src/main/java/com/facebook/presto/execution/resourceGroups/InternalResourceGroupManager.java @@ -16,6 +16,7 @@ import com.facebook.presto.Session; import com.facebook.presto.execution.QueryExecution; import com.facebook.presto.execution.resourceGroups.InternalResourceGroup.RootInternalResourceGroup; +import com.facebook.presto.server.ResourceGroupStateInfo; import com.facebook.presto.spi.PrestoException; import com.facebook.presto.spi.memory.ClusterMemoryPoolManager; import com.facebook.presto.spi.resourceGroups.ResourceGroupConfigurationManager; @@ -44,6 +45,7 @@ import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.NoSuchElementException; import java.util.Optional; import java.util.Properties; import java.util.concurrent.ConcurrentHashMap; @@ -103,19 +105,28 @@ public ResourceGroupInfo getResourceGroupInfo(ResourceGroupId id) return groups.get(id).getInfo(); } + @Override + public ResourceGroupStateInfo getResourceGroupStateInfo(ResourceGroupId id) + { + if (!groups.containsKey(id)) { + throw new NoSuchElementException(); + } + return groups.get(id).getStateInfo(); + } + @Override public void submit(Statement statement, QueryExecution queryExecution, Executor executor) { checkState(configurationManager.get() != null, "configurationManager not set"); ResourceGroupId group; try { - group = selectGroup(queryExecution.getSession()); + group = selectGroup(queryExecution); } catch (PrestoException e) { queryExecution.fail(e); return; } - createGroupIfNecessary(group, queryExecution.getSession(), executor); + createGroupIfNecessary(group, queryExecution, executor); groups.get(group).run(queryExecution); } @@ -224,13 +235,19 @@ else if (elapsedSeconds < 0) { } } - private synchronized void createGroupIfNecessary(ResourceGroupId id, Session session, Executor executor) + private synchronized void createGroupIfNecessary(ResourceGroupId id, QueryExecution queryExecution, Executor executor) { - SelectionContext context = new SelectionContext(session.getIdentity().getPrincipal().isPresent(), session.getUser(), session.getSource(), getQueryPriority(session)); + Session session = queryExecution.getSession(); + SelectionContext context = new SelectionContext( + session.getIdentity().getPrincipal().isPresent(), + session.getUser(), + session.getSource(), + getQueryPriority(session), + determineQueryType(queryExecution)); if (!groups.containsKey(id)) { InternalResourceGroup group; if (id.getParent().isPresent()) { - createGroupIfNecessary(id.getParent().get(), session, executor); + createGroupIfNecessary(id.getParent().get(), queryExecution, executor); InternalResourceGroup parent = groups.get(id.getParent().get()); requireNonNull(parent, "parent is null"); group = parent.getOrCreateSubGroup(id.getLastSegment()); @@ -261,9 +278,15 @@ private void exportGroup(InternalResourceGroup group, Boolean export) } } - private ResourceGroupId selectGroup(Session session) + private ResourceGroupId selectGroup(QueryExecution queryExecution) { - SelectionContext context = new SelectionContext(session.getIdentity().getPrincipal().isPresent(), session.getUser(), session.getSource(), getQueryPriority(session)); + Session session = queryExecution.getSession(); + SelectionContext context = new SelectionContext( + session.getIdentity().getPrincipal().isPresent(), + session.getUser(), + session.getSource(), + getQueryPriority(session), + determineQueryType(queryExecution)); for (ResourceGroupSelector selector : configurationManager.get().getSelectors()) { Optional group = selector.match(context); if (group.isPresent()) { @@ -272,4 +295,9 @@ private ResourceGroupId selectGroup(Session session) } throw new PrestoException(QUERY_REJECTED, "Query did not match any selection rule"); } + + private Optional determineQueryType(QueryExecution queryExecution) + { + return queryExecution.getQueryType().map(Enum::toString); + } } diff --git a/presto-main/src/main/java/com/facebook/presto/execution/resourceGroups/NoOpResourceGroupManager.java b/presto-main/src/main/java/com/facebook/presto/execution/resourceGroups/NoOpResourceGroupManager.java index b51b129e8bb0..9fadfda21f9e 100644 --- a/presto-main/src/main/java/com/facebook/presto/execution/resourceGroups/NoOpResourceGroupManager.java +++ b/presto-main/src/main/java/com/facebook/presto/execution/resourceGroups/NoOpResourceGroupManager.java @@ -14,6 +14,7 @@ package com.facebook.presto.execution.resourceGroups; import com.facebook.presto.execution.QueryExecution; +import com.facebook.presto.server.ResourceGroupStateInfo; import com.facebook.presto.spi.resourceGroups.ResourceGroupConfigurationManagerFactory; import com.facebook.presto.spi.resourceGroups.ResourceGroupId; import com.facebook.presto.spi.resourceGroups.ResourceGroupInfo; @@ -39,6 +40,12 @@ public ResourceGroupInfo getResourceGroupInfo(ResourceGroupId id) throw new UnsupportedOperationException(); } + @Override + public ResourceGroupStateInfo getResourceGroupStateInfo(ResourceGroupId id) + { + throw new UnsupportedOperationException(); + } + @Override public void addConfigurationManagerFactory(ResourceGroupConfigurationManagerFactory factory) { diff --git a/presto-main/src/main/java/com/facebook/presto/execution/resourceGroups/ResourceGroupManager.java b/presto-main/src/main/java/com/facebook/presto/execution/resourceGroups/ResourceGroupManager.java index 8749f7ab876b..b0e17c36d511 100644 --- a/presto-main/src/main/java/com/facebook/presto/execution/resourceGroups/ResourceGroupManager.java +++ b/presto-main/src/main/java/com/facebook/presto/execution/resourceGroups/ResourceGroupManager.java @@ -14,6 +14,7 @@ package com.facebook.presto.execution.resourceGroups; import com.facebook.presto.execution.QueryQueueManager; +import com.facebook.presto.server.ResourceGroupStateInfo; import com.facebook.presto.spi.resourceGroups.ResourceGroupConfigurationManagerFactory; import com.facebook.presto.spi.resourceGroups.ResourceGroupId; import com.facebook.presto.spi.resourceGroups.ResourceGroupInfo; @@ -23,6 +24,8 @@ public interface ResourceGroupManager { ResourceGroupInfo getResourceGroupInfo(ResourceGroupId id); + ResourceGroupStateInfo getResourceGroupStateInfo(ResourceGroupId id); + void addConfigurationManagerFactory(ResourceGroupConfigurationManagerFactory factory); void loadConfigurationManager() throws Exception; diff --git a/presto-main/src/main/java/com/facebook/presto/failureDetector/HeartbeatFailureDetector.java b/presto-main/src/main/java/com/facebook/presto/failureDetector/HeartbeatFailureDetector.java index 5d4407d64121..018b29e3397e 100644 --- a/presto-main/src/main/java/com/facebook/presto/failureDetector/HeartbeatFailureDetector.java +++ b/presto-main/src/main/java/com/facebook/presto/failureDetector/HeartbeatFailureDetector.java @@ -13,7 +13,11 @@ */ package com.facebook.presto.failureDetector; +import com.facebook.presto.client.FailureInfo; +import com.facebook.presto.server.InternalCommunicationConfig; import com.facebook.presto.spi.HostAddress; +import com.facebook.presto.util.Failures; +import com.fasterxml.jackson.annotation.JsonIgnore; import com.fasterxml.jackson.annotation.JsonProperty; import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.ImmutableMap; @@ -32,6 +36,7 @@ import org.joda.time.DateTime; import org.weakref.jmx.Managed; +import javax.annotation.Nullable; import javax.annotation.PostConstruct; import javax.annotation.PreDestroy; import javax.annotation.concurrent.GuardedBy; @@ -87,6 +92,7 @@ public class HeartbeatFailureDetector private final boolean isEnabled; private final Duration warmupInterval; private final Duration gcGraceInterval; + private final boolean httpsRequired; private final AtomicBoolean started = new AtomicBoolean(); @@ -94,25 +100,28 @@ public class HeartbeatFailureDetector public HeartbeatFailureDetector( @ServiceType("presto") ServiceSelector selector, @ForFailureDetector HttpClient httpClient, - FailureDetectorConfig config, - NodeInfo nodeInfo) + FailureDetectorConfig failureDetectorConfig, + NodeInfo nodeInfo, + InternalCommunicationConfig internalCommunicationConfig) { requireNonNull(selector, "selector is null"); requireNonNull(httpClient, "httpClient is null"); requireNonNull(nodeInfo, "nodeInfo is null"); - requireNonNull(config, "config is null"); - checkArgument(config.getHeartbeatInterval().toMillis() >= 1, "heartbeat interval must be >= 1ms"); + requireNonNull(failureDetectorConfig, "config is null"); + checkArgument(failureDetectorConfig.getHeartbeatInterval().toMillis() >= 1, "heartbeat interval must be >= 1ms"); this.selector = selector; this.httpClient = httpClient; this.nodeInfo = nodeInfo; - this.failureRatioThreshold = config.getFailureRatioThreshold(); - this.heartbeat = config.getHeartbeatInterval(); - this.warmupInterval = config.getWarmupInterval(); - this.gcGraceInterval = config.getExpirationGraceInterval(); + this.failureRatioThreshold = failureDetectorConfig.getFailureRatioThreshold(); + this.heartbeat = failureDetectorConfig.getHeartbeatInterval(); + this.warmupInterval = failureDetectorConfig.getWarmupInterval(); + this.gcGraceInterval = failureDetectorConfig.getExpirationGraceInterval(); - this.isEnabled = config.isEnabled(); + this.isEnabled = failureDetectorConfig.isEnabled(); + + this.httpsRequired = internalCommunicationConfig.isHttpsRequired(); } @PostConstruct @@ -164,7 +173,6 @@ public State getState(HostAddress hostAddress) if (lastFailureException instanceof ConnectException) { return GONE; } - if (lastFailureException instanceof SocketTimeoutException) { // TODO: distinguish between process unresponsiveness (e.g GC pause) and host reboot return UNRESPONSIVE; @@ -251,18 +259,16 @@ void updateMonitoredServices() } } - private static URI getHttpUri(ServiceDescriptor service) + private URI getHttpUri(ServiceDescriptor descriptor) { - try { - String uri = service.getProperties().get("http"); - if (uri != null) { - return new URI(uri); + String url = descriptor.getProperties().get(httpsRequired ? "https" : "http"); + if (url != null) { + try { + return new URI(url); + } + catch (URISyntaxException ignored) { } } - catch (URISyntaxException e) { - // ignore, not a valid http uri - } - return null; } @@ -486,12 +492,23 @@ public DateTime getLastResponseTime() return lastResponseTime.get(); } - @JsonProperty + @JsonIgnore public Exception getLastFailureException() { return lastFailureException.get(); } + @Nullable + @JsonProperty + public FailureInfo getLastFailureInfo() + { + Exception lastFailureException = getLastFailureException(); + if (lastFailureException == null) { + return null; + } + return Failures.toFailure(lastFailureException).toFailureInfo(); + } + @JsonProperty public synchronized Map getRecentFailuresByType() { diff --git a/presto-main/src/main/java/com/facebook/presto/operator/Mergeable.java b/presto-main/src/main/java/com/facebook/presto/matching/Matchable.java similarity index 85% rename from presto-main/src/main/java/com/facebook/presto/operator/Mergeable.java rename to presto-main/src/main/java/com/facebook/presto/matching/Matchable.java index 8cbe5c3e5729..a25c0b7374b3 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/Mergeable.java +++ b/presto-main/src/main/java/com/facebook/presto/matching/Matchable.java @@ -11,9 +11,10 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.facebook.presto.operator; -public interface Mergeable +package com.facebook.presto.matching; + +public interface Matchable { - T mergeWith(T other); + Pattern getPattern(); } diff --git a/presto-main/src/main/java/com/facebook/presto/matching/MatchingEngine.java b/presto-main/src/main/java/com/facebook/presto/matching/MatchingEngine.java new file mode 100644 index 000000000000..aa2ab6b30c10 --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/matching/MatchingEngine.java @@ -0,0 +1,77 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.facebook.presto.matching; + +import com.google.common.collect.ImmutableListMultimap; +import com.google.common.collect.ListMultimap; +import com.google.common.reflect.TypeToken; + +import java.util.Set; +import java.util.stream.Stream; + +public class MatchingEngine +{ + private final ListMultimap matchablesByClass; + + private MatchingEngine(ListMultimap matchablesByClass) + { + this.matchablesByClass = ImmutableListMultimap.copyOf(matchablesByClass); + } + + public Stream getCandidates(Object object) + { + return supertypes(object.getClass()) + .flatMap(clazz -> matchablesByClass.get(clazz).stream()); + } + + private static Stream> supertypes(Class type) + { + return TypeToken.of(type).getTypes().stream() + .map(TypeToken::getRawType); + } + + public static Builder builder() + { + return new Builder(); + } + + public static class Builder + { + private final ImmutableListMultimap.Builder matchablesByClass = ImmutableListMultimap.builder(); + + public Builder register(Set matchables) + { + matchables.forEach(this::register); + return this; + } + + public Builder register(T matchable) + { + Pattern pattern = matchable.getPattern(); + if (pattern instanceof Pattern.TypeOf) { + matchablesByClass.put(((Pattern.TypeOf) pattern).getType(), matchable); + } + else { + throw new IllegalArgumentException("Unexpected Pattern: " + pattern); + } + return this; + } + + public MatchingEngine build() + { + return new MatchingEngine(matchablesByClass.build()); + } + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/Pattern.java b/presto-main/src/main/java/com/facebook/presto/matching/Pattern.java similarity index 57% rename from presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/Pattern.java rename to presto-main/src/main/java/com/facebook/presto/matching/Pattern.java index a9ceb1316872..034fc5ad8e77 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/Pattern.java +++ b/presto-main/src/main/java/com/facebook/presto/matching/Pattern.java @@ -12,57 +12,55 @@ * limitations under the License. */ -package com.facebook.presto.sql.planner.iterative; - -import com.facebook.presto.sql.planner.plan.PlanNode; +package com.facebook.presto.matching; import static com.google.common.base.MoreObjects.toStringHelper; import static java.util.Objects.requireNonNull; public abstract class Pattern { - private static final Pattern ANY_NODE = new MatchNodeClass(PlanNode.class); + private static final Pattern ANY = new TypeOf(Object.class); private Pattern() {} - public abstract boolean matches(PlanNode node); + public abstract boolean matches(Object object); public static Pattern any() { - return ANY_NODE; + return ANY; } - public static Pattern node(Class nodeClass) + public static Pattern typeOf(Class objectClass) { - return new MatchNodeClass(nodeClass); + return new TypeOf(objectClass); } - static class MatchNodeClass + static class TypeOf extends Pattern { - private final Class nodeClass; + private final Class type; - MatchNodeClass(Class nodeClass) + TypeOf(Class type) { - this.nodeClass = requireNonNull(nodeClass, "nodeClass is null"); + this.type = requireNonNull(type, "type is null"); } - Class getNodeClass() + Class getType() { - return nodeClass; + return type; } @Override - public boolean matches(PlanNode node) + public boolean matches(Object object) { - return nodeClass.isInstance(node); + return type.isInstance(object); } @Override public String toString() { return toStringHelper(this) - .add("nodeClass", nodeClass) + .add("type", type) .toString(); } } diff --git a/presto-main/src/main/java/com/facebook/presto/metadata/AbstractTypedJacksonModule.java b/presto-main/src/main/java/com/facebook/presto/metadata/AbstractTypedJacksonModule.java index 9af98a9f8d64..b492df349a4c 100644 --- a/presto-main/src/main/java/com/facebook/presto/metadata/AbstractTypedJacksonModule.java +++ b/presto-main/src/main/java/com/facebook/presto/metadata/AbstractTypedJacksonModule.java @@ -34,7 +34,6 @@ import com.fasterxml.jackson.databind.ser.BeanSerializerFactory; import com.fasterxml.jackson.databind.ser.std.StdSerializer; import com.fasterxml.jackson.databind.type.SimpleType; -import com.google.common.base.Throwables; import com.google.common.cache.Cache; import com.google.common.cache.CacheBuilder; @@ -43,7 +42,7 @@ import java.util.function.Function; import static com.google.common.base.Preconditions.checkArgument; -import static com.google.common.base.Throwables.propagateIfInstanceOf; +import static com.google.common.base.Throwables.throwIfInstanceOf; import static java.util.Objects.requireNonNull; public abstract class AbstractTypedJacksonModule @@ -111,8 +110,11 @@ public void serialize(T value, JsonGenerator generator, SerializerProvider provi serializer.serializeWithType(value, generator, provider, typeSerializer); } catch (ExecutionException e) { - propagateIfInstanceOf(e.getCause(), IOException.class); - throw Throwables.propagate(e.getCause()); + Throwable cause = e.getCause(); + if (cause != null) { + throwIfInstanceOf(cause, IOException.class); + } + throw new RuntimeException(e); } } diff --git a/presto-main/src/main/java/com/facebook/presto/metadata/DiscoveryNodeManager.java b/presto-main/src/main/java/com/facebook/presto/metadata/DiscoveryNodeManager.java index cf8c6edf6edd..42ac7b711d61 100644 --- a/presto-main/src/main/java/com/facebook/presto/metadata/DiscoveryNodeManager.java +++ b/presto-main/src/main/java/com/facebook/presto/metadata/DiscoveryNodeManager.java @@ -17,6 +17,7 @@ import com.facebook.presto.connector.ConnectorId; import com.facebook.presto.connector.system.GlobalSystemConnector; import com.facebook.presto.failureDetector.FailureDetector; +import com.facebook.presto.server.InternalCommunicationConfig; import com.facebook.presto.spi.Node; import com.facebook.presto.spi.NodeState; import com.google.common.base.Splitter; @@ -56,7 +57,6 @@ import static com.google.common.collect.Sets.difference; import static io.airlift.concurrent.Threads.threadsNamed; import static io.airlift.http.client.HttpUriBuilder.uriBuilderFrom; -import static java.util.Arrays.asList; import static java.util.Locale.ENGLISH; import static java.util.Objects.requireNonNull; import static java.util.concurrent.Executors.newSingleThreadScheduledExecutor; @@ -76,6 +76,7 @@ public final class DiscoveryNodeManager private final ConcurrentHashMap nodeStates = new ConcurrentHashMap<>(); private final HttpClient httpClient; private final ScheduledExecutorService nodeStateUpdateExecutor; + private final boolean httpsRequired; @GuardedBy("this") private SetMultimap activeNodesByConnectorId; @@ -97,7 +98,8 @@ public DiscoveryNodeManager( NodeInfo nodeInfo, FailureDetector failureDetector, NodeVersion expectedNodeVersion, - @ForNodeManager HttpClient httpClient) + @ForNodeManager HttpClient httpClient, + InternalCommunicationConfig internalCommunicationConfig) { this.serviceSelector = requireNonNull(serviceSelector, "serviceSelector is null"); this.nodeInfo = requireNonNull(nodeInfo, "nodeInfo is null"); @@ -105,6 +107,7 @@ public DiscoveryNodeManager( this.expectedNodeVersion = requireNonNull(expectedNodeVersion, "expectedNodeVersion is null"); this.httpClient = requireNonNull(httpClient, "httpClient is null"); this.nodeStateUpdateExecutor = newSingleThreadScheduledExecutor(threadsNamed("node-state-poller-%s")); + this.httpsRequired = internalCommunicationConfig.isHttpsRequired(); this.currentNode = refreshNodesInternal(); } @@ -329,16 +332,14 @@ public synchronized Set getCoordinators() return coordinators; } - private static URI getHttpUri(ServiceDescriptor descriptor) + private URI getHttpUri(ServiceDescriptor descriptor) { - for (String type : asList("http", "https")) { - String url = descriptor.getProperties().get(type); - if (url != null) { - try { - return new URI(url); - } - catch (URISyntaxException ignored) { - } + String url = descriptor.getProperties().get(httpsRequired ? "https" : "http"); + if (url != null) { + try { + return new URI(url); + } + catch (URISyntaxException ignored) { } } return null; diff --git a/presto-main/src/main/java/com/facebook/presto/metadata/FunctionRegistry.java b/presto-main/src/main/java/com/facebook/presto/metadata/FunctionRegistry.java index 2bbb59aee778..ad6d26d482fc 100644 --- a/presto-main/src/main/java/com/facebook/presto/metadata/FunctionRegistry.java +++ b/presto-main/src/main/java/com/facebook/presto/metadata/FunctionRegistry.java @@ -221,7 +221,8 @@ import static com.facebook.presto.operator.scalar.ArrayToJsonCast.ARRAY_TO_JSON; import static com.facebook.presto.operator.scalar.ArrayTransformFunction.ARRAY_TRANSFORM_FUNCTION; import static com.facebook.presto.operator.scalar.CastFromUnknownOperator.CAST_FROM_UNKNOWN; -import static com.facebook.presto.operator.scalar.ConcatFunction.CONCAT; +import static com.facebook.presto.operator.scalar.ConcatFunction.VARBINARY_CONCAT; +import static com.facebook.presto.operator.scalar.ConcatFunction.VARCHAR_CONCAT; import static com.facebook.presto.operator.scalar.ElementToArrayConcatFunction.ELEMENT_TO_ARRAY_CONCAT_FUNCTION; import static com.facebook.presto.operator.scalar.Greatest.GREATEST; import static com.facebook.presto.operator.scalar.IdentityCast.IDENTITY_CAST; @@ -300,9 +301,7 @@ import static com.facebook.presto.type.DecimalSaturatedFloorCasts.DECIMAL_TO_INTEGER_SATURATED_FLOOR_CAST; import static com.facebook.presto.type.DecimalSaturatedFloorCasts.DECIMAL_TO_SMALLINT_SATURATED_FLOOR_CAST; import static com.facebook.presto.type.DecimalSaturatedFloorCasts.DECIMAL_TO_TINYINT_SATURATED_FLOOR_CAST; -import static com.facebook.presto.type.DecimalSaturatedFloorCasts.DOUBLE_TO_DECIMAL_SATURATED_FLOOR_CAST; import static com.facebook.presto.type.DecimalSaturatedFloorCasts.INTEGER_TO_DECIMAL_SATURATED_FLOOR_CAST; -import static com.facebook.presto.type.DecimalSaturatedFloorCasts.REAL_TO_DECIMAL_SATURATED_FLOOR_CAST; import static com.facebook.presto.type.DecimalSaturatedFloorCasts.SMALLINT_TO_DECIMAL_SATURATED_FLOOR_CAST; import static com.facebook.presto.type.DecimalSaturatedFloorCasts.TINYINT_TO_DECIMAL_SATURATED_FLOOR_CAST; import static com.facebook.presto.type.DecimalToDecimalCasts.DECIMAL_TO_DECIMAL_CAST; @@ -541,7 +540,7 @@ public WindowFunctionSupplier load(SpecializedFunctionKey key) .function(ARRAY_FLATTEN_FUNCTION) .function(ARRAY_CONCAT_FUNCTION) .functions(ARRAY_CONSTRUCTOR, ARRAY_SUBSCRIPT, ARRAY_TO_JSON, JSON_TO_ARRAY) - .functions(new MapSubscriptOperator(featuresConfig.isLegacyMapSubscript())) + .functions(new MapSubscriptOperator(featuresConfig.isLegacyMapSubscript(), featuresConfig.isNewMapBlock())) .functions(MAP_CONSTRUCTOR, MAP_TO_JSON, JSON_TO_MAP) .functions(MAP_AGG, MULTIMAP_AGG, MAP_UNION) .functions(DECIMAL_TO_VARCHAR_CAST, DECIMAL_TO_INTEGER_CAST, DECIMAL_TO_BIGINT_CAST, DECIMAL_TO_DOUBLE_CAST, DECIMAL_TO_REAL_CAST, DECIMAL_TO_BOOLEAN_CAST, DECIMAL_TO_TINYINT_CAST, DECIMAL_TO_SMALLINT_CAST) @@ -552,7 +551,6 @@ public WindowFunctionSupplier load(SpecializedFunctionKey key) .functions(DECIMAL_LESS_THAN_OPERATOR, DECIMAL_LESS_THAN_OR_EQUAL_OPERATOR) .functions(DECIMAL_GREATER_THAN_OPERATOR, DECIMAL_GREATER_THAN_OR_EQUAL_OPERATOR) .function(DECIMAL_TO_DECIMAL_SATURATED_FLOOR_CAST) - .functions(DOUBLE_TO_DECIMAL_SATURATED_FLOOR_CAST, REAL_TO_DECIMAL_SATURATED_FLOOR_CAST) .functions(DECIMAL_TO_BIGINT_SATURATED_FLOOR_CAST, BIGINT_TO_DECIMAL_SATURATED_FLOOR_CAST) .functions(DECIMAL_TO_INTEGER_SATURATED_FLOOR_CAST, INTEGER_TO_DECIMAL_SATURATED_FLOOR_CAST) .functions(DECIMAL_TO_SMALLINT_SATURATED_FLOOR_CAST, SMALLINT_TO_DECIMAL_SATURATED_FLOOR_CAST) @@ -568,7 +566,7 @@ public WindowFunctionSupplier load(SpecializedFunctionKey key) .functions(MAX_AGGREGATION, MIN_AGGREGATION, MAX_N_AGGREGATION, MIN_N_AGGREGATION) .function(COUNT_COLUMN) .functions(ROW_HASH_CODE, ROW_TO_JSON, ROW_DISTINCT_FROM, ROW_EQUAL, ROW_GREATER_THAN, ROW_GREATER_THAN_OR_EQUAL, ROW_LESS_THAN, ROW_LESS_THAN_OR_EQUAL, ROW_NOT_EQUAL, ROW_TO_ROW_CAST) - .function(CONCAT) + .functions(VARCHAR_CONCAT, VARBINARY_CONCAT) .function(DECIMAL_TO_DECIMAL_CAST) .function(castVarcharToRe2JRegexp(featuresConfig.getRe2JDfaStatesLimit(), featuresConfig.getRe2JDfaRetries())) .function(castCharToRe2JRegexp(featuresConfig.getRe2JDfaStatesLimit(), featuresConfig.getRe2JDfaRetries())) diff --git a/presto-main/src/main/java/com/facebook/presto/metadata/Metadata.java b/presto-main/src/main/java/com/facebook/presto/metadata/Metadata.java index 427350a48298..6efd2c5ae9e7 100644 --- a/presto-main/src/main/java/com/facebook/presto/metadata/Metadata.java +++ b/presto-main/src/main/java/com/facebook/presto/metadata/Metadata.java @@ -139,6 +139,11 @@ public interface Metadata */ void addColumn(Session session, TableHandle tableHandle, ColumnMetadata column); + /** + * Drop the specified column. + */ + void dropColumn(Session session, TableHandle tableHandle, ColumnHandle column); + /** * Drops the specified table * diff --git a/presto-main/src/main/java/com/facebook/presto/metadata/MetadataManager.java b/presto-main/src/main/java/com/facebook/presto/metadata/MetadataManager.java index 7e878d1ed9db..d79eb60de2d4 100644 --- a/presto-main/src/main/java/com/facebook/presto/metadata/MetadataManager.java +++ b/presto-main/src/main/java/com/facebook/presto/metadata/MetadataManager.java @@ -531,6 +531,14 @@ public void addColumn(Session session, TableHandle tableHandle, ColumnMetadata c metadata.addColumn(session.toConnectorSession(connectorId), tableHandle.getConnectorHandle(), column); } + @Override + public void dropColumn(Session session, TableHandle tableHandle, ColumnHandle column) + { + ConnectorId connectorId = tableHandle.getConnectorId(); + ConnectorMetadata metadata = getMetadataForWrite(session, connectorId); + metadata.dropColumn(session.toConnectorSession(connectorId), tableHandle.getConnectorHandle(), column); + } + @Override public void dropTable(Session session, TableHandle tableHandle) { diff --git a/presto-main/src/main/java/com/facebook/presto/metadata/SessionPropertyManager.java b/presto-main/src/main/java/com/facebook/presto/metadata/SessionPropertyManager.java index a2fe45843ab6..34d5bc986085 100644 --- a/presto-main/src/main/java/com/facebook/presto/metadata/SessionPropertyManager.java +++ b/presto-main/src/main/java/com/facebook/presto/metadata/SessionPropertyManager.java @@ -20,17 +20,17 @@ import com.facebook.presto.spi.block.BlockBuilder; import com.facebook.presto.spi.block.BlockBuilderStatus; import com.facebook.presto.spi.session.PropertyMetadata; +import com.facebook.presto.spi.type.ArrayType; import com.facebook.presto.spi.type.BigintType; import com.facebook.presto.spi.type.BooleanType; import com.facebook.presto.spi.type.DoubleType; import com.facebook.presto.spi.type.IntegerType; +import com.facebook.presto.spi.type.MapType; import com.facebook.presto.spi.type.Type; import com.facebook.presto.spi.type.VarcharType; import com.facebook.presto.sql.planner.ParameterRewriter; import com.facebook.presto.sql.tree.Expression; import com.facebook.presto.sql.tree.ExpressionTreeRewriter; -import com.facebook.presto.type.ArrayType; -import com.facebook.presto.type.MapType; import com.google.common.collect.ImmutableList; import com.google.common.collect.Maps; import io.airlift.json.JsonCodec; diff --git a/presto-main/src/main/java/com/facebook/presto/operator/ArrayUnnester.java b/presto-main/src/main/java/com/facebook/presto/operator/ArrayUnnester.java index c15b8a35eef3..cb0e87a2f833 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/ArrayUnnester.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/ArrayUnnester.java @@ -16,8 +16,8 @@ import com.facebook.presto.spi.PageBuilder; import com.facebook.presto.spi.block.Block; import com.facebook.presto.spi.block.BlockBuilder; +import com.facebook.presto.spi.type.ArrayType; import com.facebook.presto.spi.type.Type; -import com.facebook.presto.type.ArrayType; import javax.annotation.Nullable; diff --git a/presto-main/src/main/java/com/facebook/presto/operator/Driver.java b/presto-main/src/main/java/com/facebook/presto/operator/Driver.java index d8840fea9ae5..ac89987740a2 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/Driver.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/Driver.java @@ -288,16 +288,13 @@ private ListenableFuture processInternal() Operator current = operators.get(i); Operator next = operators.get(i + 1); - // skip blocked operators + // skip blocked operator if (getBlockedFuture(current).isPresent()) { continue; } - if (getBlockedFuture(next).isPresent()) { - continue; - } - // if the current operator is not finished and next operator needs input... - if (!current.isFinished() && next.needsInput()) { + // if the current operator is not finished and next operator isn't blocked and needs input... + if (!current.isFinished() && !getBlockedFuture(next).isPresent() && next.needsInput()) { // get an output page from current operator current.getOperatorContext().startIntervalTimer(); Page page = current.getOutput(); diff --git a/presto-main/src/main/java/com/facebook/presto/operator/ExchangeClientStatus.java b/presto-main/src/main/java/com/facebook/presto/operator/ExchangeClientStatus.java index 70655d5c79b9..de3bab3a828a 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/ExchangeClientStatus.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/ExchangeClientStatus.java @@ -13,6 +13,7 @@ */ package com.facebook.presto.operator; +import com.facebook.presto.util.Mergeable; import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; import com.google.common.collect.ImmutableList; diff --git a/presto-main/src/main/java/com/facebook/presto/operator/ExplainAnalyzeOperator.java b/presto-main/src/main/java/com/facebook/presto/operator/ExplainAnalyzeOperator.java index f53679de0a47..e39863fc65f2 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/ExplainAnalyzeOperator.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/ExplainAnalyzeOperator.java @@ -46,15 +46,17 @@ public static class ExplainAnalyzeOperatorFactory private final QueryPerformanceFetcher queryPerformanceFetcher; private final Metadata metadata; private final CostCalculator costCalculator; + private final boolean verbose; private boolean closed; - public ExplainAnalyzeOperatorFactory(int operatorId, PlanNodeId planNodeId, QueryPerformanceFetcher queryPerformanceFetcher, Metadata metadata, CostCalculator costCalculator) + public ExplainAnalyzeOperatorFactory(int operatorId, PlanNodeId planNodeId, QueryPerformanceFetcher queryPerformanceFetcher, Metadata metadata, CostCalculator costCalculator, boolean verbose) { this.operatorId = operatorId; this.planNodeId = requireNonNull(planNodeId, "planNodeId is null"); this.queryPerformanceFetcher = requireNonNull(queryPerformanceFetcher, "queryPerformanceFetcher is null"); this.metadata = requireNonNull(metadata, "metadata is null"); this.costCalculator = requireNonNull(costCalculator, "costCalculator is null"); + this.verbose = verbose; } @Override @@ -68,7 +70,7 @@ public Operator createOperator(DriverContext driverContext) { checkState(!closed, "Factory is already closed"); OperatorContext operatorContext = driverContext.addOperatorContext(operatorId, planNodeId, ExplainAnalyzeOperator.class.getSimpleName()); - return new ExplainAnalyzeOperator(operatorContext, queryPerformanceFetcher, metadata, costCalculator); + return new ExplainAnalyzeOperator(operatorContext, queryPerformanceFetcher, metadata, costCalculator, verbose); } @Override @@ -80,7 +82,7 @@ public void close() @Override public OperatorFactory duplicate() { - return new ExplainAnalyzeOperatorFactory(operatorId, planNodeId, queryPerformanceFetcher, metadata, costCalculator); + return new ExplainAnalyzeOperatorFactory(operatorId, planNodeId, queryPerformanceFetcher, metadata, costCalculator, verbose); } } @@ -88,15 +90,17 @@ public OperatorFactory duplicate() private final QueryPerformanceFetcher queryPerformanceFetcher; private final Metadata metadata; private final CostCalculator costCalculator; + private final boolean verbose; private boolean finishing; private boolean outputConsumed; - public ExplainAnalyzeOperator(OperatorContext operatorContext, QueryPerformanceFetcher queryPerformanceFetcher, Metadata metadata, CostCalculator costCalculator) + public ExplainAnalyzeOperator(OperatorContext operatorContext, QueryPerformanceFetcher queryPerformanceFetcher, Metadata metadata, CostCalculator costCalculator, boolean verbose) { this.operatorContext = requireNonNull(operatorContext, "operatorContext is null"); this.queryPerformanceFetcher = requireNonNull(queryPerformanceFetcher, "queryPerformanceFetcher is null"); this.metadata = requireNonNull(metadata, "metadata is null"); this.costCalculator = requireNonNull(costCalculator, "costCalculator is null"); + this.verbose = verbose; } @Override @@ -151,7 +155,7 @@ public Page getOutput() return null; } - String plan = textDistributedPlan(queryInfo.getOutputStage().get(), metadata, costCalculator, operatorContext.getSession()); + String plan = textDistributedPlan(queryInfo.getOutputStage().get(), metadata, costCalculator, operatorContext.getSession(), verbose); BlockBuilder builder = VARCHAR.createBlockBuilder(new BlockBuilderStatus(), 1); VARCHAR.writeString(builder, plan); diff --git a/presto-main/src/main/java/com/facebook/presto/operator/GroupByIdBlock.java b/presto-main/src/main/java/com/facebook/presto/operator/GroupByIdBlock.java index 6e0da9db77ed..be2c6588e0b0 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/GroupByIdBlock.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/GroupByIdBlock.java @@ -20,6 +20,7 @@ import org.openjdk.jol.info.ClassLayout; import java.util.List; +import java.util.function.BiConsumer; import static com.facebook.presto.spi.type.BigintType.BIGINT; import static com.google.common.base.MoreObjects.toStringHelper; @@ -57,7 +58,7 @@ public Block getRegion(int positionOffset, int length) } @Override - public int getRegionSizeInBytes(int positionOffset, int length) + public long getRegionSizeInBytes(int positionOffset, int length) { return block.getRegionSizeInBytes(positionOffset, length); } @@ -171,17 +172,24 @@ public int getPositionCount() } @Override - public int getSizeInBytes() + public long getSizeInBytes() { return block.getSizeInBytes(); } @Override - public int getRetainedSizeInBytes() + public long getRetainedSizeInBytes() { return INSTANCE_SIZE + block.getRetainedSizeInBytes(); } + @Override + public void retainedBytesForEachPart(BiConsumer consumer) + { + consumer.accept(block, block.getRetainedSizeInBytes()); + consumer.accept(this, (long) INSTANCE_SIZE); + } + @Override public BlockEncoding getEncoding() { diff --git a/presto-main/src/main/java/com/facebook/presto/operator/HashCollisionsInfo.java b/presto-main/src/main/java/com/facebook/presto/operator/HashCollisionsInfo.java index 391e1c6717cd..573f235965cd 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/HashCollisionsInfo.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/HashCollisionsInfo.java @@ -13,6 +13,7 @@ */ package com.facebook.presto.operator; +import com.facebook.presto.util.Mergeable; import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; diff --git a/presto-main/src/main/java/com/facebook/presto/operator/JoinOperatorInfo.java b/presto-main/src/main/java/com/facebook/presto/operator/JoinOperatorInfo.java index c82fb0e3da94..9f627234e975 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/JoinOperatorInfo.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/JoinOperatorInfo.java @@ -14,6 +14,7 @@ package com.facebook.presto.operator; import com.facebook.presto.operator.LookupJoinOperators.JoinType; +import com.facebook.presto.util.Mergeable; import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; diff --git a/presto-main/src/main/java/com/facebook/presto/operator/LookupJoinOperatorFactory.java b/presto-main/src/main/java/com/facebook/presto/operator/LookupJoinOperatorFactory.java index 9bfe19e9952a..e299571daf68 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/LookupJoinOperatorFactory.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/LookupJoinOperatorFactory.java @@ -15,7 +15,6 @@ import com.facebook.presto.operator.LookupJoinOperators.JoinType; import com.facebook.presto.operator.LookupOuterOperator.LookupOuterOperatorFactory; -import com.facebook.presto.operator.LookupSource.OuterPositionIterator; import com.facebook.presto.spi.type.Type; import com.facebook.presto.sql.planner.plan.PlanNodeId; import com.google.common.collect.ImmutableList; @@ -83,9 +82,8 @@ public LookupJoinOperatorFactory(int operatorId, // when all join operators finish (and lookup source is ready), set the outer position future to start the outer operator ListenableFuture lookupSourceAfterProbeFinished = transformAsync(probeReferenceCount.getFreeFuture(), ignored -> lookupSourceFactory.createLookupSource()); ListenableFuture outerPositionsFuture = transform(lookupSourceAfterProbeFinished, lookupSource -> { - try (LookupSource ignore = lookupSource) { - return lookupSource.getOuterPositionIterator(); - } + lookupSource.close(); + return lookupSourceFactory.getOuterPositionIterator(); }); lookupSourceFactoryUsersCount.retain(); diff --git a/presto-main/src/main/java/com/facebook/presto/operator/LookupOuterOperator.java b/presto-main/src/main/java/com/facebook/presto/operator/LookupOuterOperator.java index 78e675223fc4..4ce1881ba079 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/LookupOuterOperator.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/LookupOuterOperator.java @@ -13,7 +13,6 @@ */ package com.facebook.presto.operator; -import com.facebook.presto.operator.LookupSource.OuterPositionIterator; import com.facebook.presto.spi.Page; import com.facebook.presto.spi.PageBuilder; import com.facebook.presto.spi.type.Type; diff --git a/presto-main/src/main/java/com/facebook/presto/operator/LookupSource.java b/presto-main/src/main/java/com/facebook/presto/operator/LookupSource.java index 65dafedd9b3f..4d482b5d1c01 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/LookupSource.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/LookupSource.java @@ -38,18 +38,8 @@ public interface LookupSource void appendTo(long position, PageBuilder pageBuilder, int outputChannelOffset); - default OuterPositionIterator getOuterPositionIterator() - { - return (pageBuilder, outputChannelOffset) -> false; - } - boolean isJoinPositionEligible(long currentJoinPosition, int probePosition, Page allProbeChannelsPage); @Override void close(); - - interface OuterPositionIterator - { - boolean appendToNext(PageBuilder pageBuilder, int outputChannelOffset); - } } diff --git a/presto-main/src/main/java/com/facebook/presto/operator/LookupSourceFactory.java b/presto-main/src/main/java/com/facebook/presto/operator/LookupSourceFactory.java index c487e013b4f9..d9f9c8449452 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/LookupSourceFactory.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/LookupSourceFactory.java @@ -28,6 +28,11 @@ public interface LookupSourceFactory ListenableFuture createLookupSource(); + /** + * Can be called only after {@link #createLookupSource()} is done and all users of {@link LookupSource}-s finished. + */ + OuterPositionIterator getOuterPositionIterator(); + Map getLayout(); // this is only here for the index lookup source diff --git a/presto-main/src/main/java/com/facebook/presto/operator/MapUnnester.java b/presto-main/src/main/java/com/facebook/presto/operator/MapUnnester.java index aa9ab6776bb9..1342f26ee156 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/MapUnnester.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/MapUnnester.java @@ -16,8 +16,8 @@ import com.facebook.presto.spi.PageBuilder; import com.facebook.presto.spi.block.Block; import com.facebook.presto.spi.block.BlockBuilder; +import com.facebook.presto.spi.type.MapType; import com.facebook.presto.spi.type.Type; -import com.facebook.presto.type.MapType; import javax.annotation.Nullable; diff --git a/presto-main/src/main/java/com/facebook/presto/operator/OperatorContext.java b/presto-main/src/main/java/com/facebook/presto/operator/OperatorContext.java index 20dd58527247..4732b01b6598 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/OperatorContext.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/OperatorContext.java @@ -512,8 +512,7 @@ private class OperatorSpillContext implements SpillContext { private final DriverContext driverContext; - - private long reservedBytes; + private final AtomicLong reservedBytes = new AtomicLong(); public OperatorSpillContext(DriverContext driverContext) { @@ -523,21 +522,35 @@ public OperatorSpillContext(DriverContext driverContext) @Override public void updateBytes(long bytes) { - if (bytes > 0) { + if (bytes >= 0) { + reservedBytes.addAndGet(bytes); driverContext.reserveSpill(bytes); } else { - checkArgument(reservedBytes + bytes >= 0, "tried to free %s spilled bytes from %s bytes reserved", -bytes, reservedBytes); + reservedBytes.accumulateAndGet(-bytes, this::decrementSpilledReservation); driverContext.freeSpill(-bytes); } - reservedBytes += bytes; + } + + private long decrementSpilledReservation(long reservedBytes, long bytesBeingFreed) + { + checkArgument(bytesBeingFreed >= 0); + checkArgument(bytesBeingFreed <= reservedBytes, "tried to free %s spilled bytes from %s bytes reserved", bytesBeingFreed, reservedBytes); + return reservedBytes - bytesBeingFreed; + } + + @Override + public void close() + { + // Only products of SpillContext.newLocalSpillContext() should be closed. + throw new UnsupportedOperationException(format("%s should not be closed directly", getClass())); } @Override public String toString() { return toStringHelper(this) - .add("usedBytes", reservedBytes) + .add("usedBytes", reservedBytes.get()) .toString(); } } diff --git a/presto-main/src/main/java/com/facebook/presto/operator/OperatorInfo.java b/presto-main/src/main/java/com/facebook/presto/operator/OperatorInfo.java index 11e476e8d047..14985009cbaf 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/OperatorInfo.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/OperatorInfo.java @@ -29,7 +29,8 @@ @JsonSubTypes.Type(value = SplitOperatorInfo.class, name = "splitOperator"), @JsonSubTypes.Type(value = HashCollisionsInfo.class, name = "hashCollisionsInfo"), @JsonSubTypes.Type(value = PartitionedOutputInfo.class, name = "partitionedOutput"), - @JsonSubTypes.Type(value = JoinOperatorInfo.class, name = "joinOperatorInfo") + @JsonSubTypes.Type(value = JoinOperatorInfo.class, name = "joinOperatorInfo"), + @JsonSubTypes.Type(value = WindowInfo.class, name = "windowInfo") }) public interface OperatorInfo { diff --git a/presto-main/src/main/java/com/facebook/presto/operator/OperatorStats.java b/presto-main/src/main/java/com/facebook/presto/operator/OperatorStats.java index b065d1b8f717..e39cf8041024 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/OperatorStats.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/OperatorStats.java @@ -14,6 +14,7 @@ package com.facebook.presto.operator; import com.facebook.presto.sql.planner.plan.PlanNodeId; +import com.facebook.presto.util.Mergeable; import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; import com.google.common.collect.ImmutableList; diff --git a/presto-main/src/main/java/com/facebook/presto/operator/OuterLookupSource.java b/presto-main/src/main/java/com/facebook/presto/operator/OuterLookupSource.java index 98794ffbf670..2c53982cb678 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/OuterLookupSource.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/OuterLookupSource.java @@ -31,7 +31,7 @@ public final class OuterLookupSource implements LookupSource { - public static Supplier createOuterLookupSourceSupplier(Supplier lookupSourceSupplier) + public static TrackingLookupSourceSupplier createOuterLookupSourceSupplier(Supplier lookupSourceSupplier) { return new OuterLookupSourceSupplier(lookupSourceSupplier); } @@ -94,12 +94,6 @@ public void appendTo(long position, PageBuilder pageBuilder, int outputChannelOf outerPositionTracker.positionVisited(position); } - @Override - public OuterPositionIterator getOuterPositionIterator() - { - return outerPositionTracker.getOuterPositionIterator(); - } - @Override public void close() { @@ -140,7 +134,7 @@ public synchronized boolean appendToNext(PageBuilder pageBuilder, int outputChan @ThreadSafe private static class OuterLookupSourceSupplier - implements Supplier + implements TrackingLookupSourceSupplier { private final Supplier lookupSourceSupplier; private final OuterPositionTracker outerPositionTracker; @@ -152,10 +146,15 @@ public OuterLookupSourceSupplier(Supplier lookupSourceSupplier) } @Override - public LookupSource get() + public LookupSource getLookupSource() { return new OuterLookupSource(lookupSourceSupplier.get(), outerPositionTracker); } + + public OuterPositionIterator getOuterPositionIterator() + { + return outerPositionTracker.getOuterPositionIterator(); + } } @ThreadSafe diff --git a/presto-main/src/main/java/com/facebook/presto/operator/OuterPositionIterator.java b/presto-main/src/main/java/com/facebook/presto/operator/OuterPositionIterator.java new file mode 100644 index 000000000000..6a945f44d309 --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/operator/OuterPositionIterator.java @@ -0,0 +1,21 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.operator; + +import com.facebook.presto.spi.PageBuilder; + +public interface OuterPositionIterator +{ + boolean appendToNext(PageBuilder pageBuilder, int outputChannelOffset); +} diff --git a/presto-main/src/main/java/com/facebook/presto/operator/PagesIndex.java b/presto-main/src/main/java/com/facebook/presto/operator/PagesIndex.java index a2f0e3935ce3..90b33534c13c 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/PagesIndex.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/PagesIndex.java @@ -153,9 +153,12 @@ public void clear() { for (ObjectArrayList channel : channels) { channel.clear(); + channel.trim(); } valueAddresses.clear(); + valueAddresses.trim(); positionCount = 0; + nextBlockToCompact = 0; pagesMemorySize = 0; estimatedSize = calculateEstimatedSize(); diff --git a/presto-main/src/main/java/com/facebook/presto/operator/PartitionedLookupSource.java b/presto-main/src/main/java/com/facebook/presto/operator/PartitionedLookupSource.java index 64f2cd8958b6..2d719ffe04d8 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/PartitionedLookupSource.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/PartitionedLookupSource.java @@ -25,10 +25,10 @@ import java.util.Arrays; import java.util.List; import java.util.Optional; +import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicLong; import java.util.function.Supplier; -import static com.google.common.base.Preconditions.checkState; import static com.google.common.base.Verify.verify; import static com.google.common.collect.ImmutableList.toImmutableList; import static java.lang.Integer.numberOfTrailingZeros; @@ -38,21 +38,40 @@ public class PartitionedLookupSource implements LookupSource { - public static Supplier createPartitionedLookupSourceSupplier(List> partitions, List hashChannelTypes, boolean outer) + public static TrackingLookupSourceSupplier createPartitionedLookupSourceSupplier(List> partitions, List hashChannelTypes, boolean outer) { - Optional outerPositionTrackerFactory = outer ? - Optional.of(new OuterPositionTracker.Factory( - partitions.stream() - .map(partition -> partition.get().getJoinPositionCount()) - .collect(toImmutableList()))) - : Optional.empty(); - - return () -> new PartitionedLookupSource( - partitions.stream() - .map(Supplier::get) - .collect(toImmutableList()), - hashChannelTypes, - outerPositionTrackerFactory.map(OuterPositionTracker.Factory::create)); + if (outer) { + OuterPositionTracker.Factory outerPositionTrackerFactory = new OuterPositionTracker.Factory(partitions); + + return new TrackingLookupSourceSupplier() + { + @Override + public LookupSource getLookupSource() + { + return new PartitionedLookupSource( + partitions.stream() + .map(Supplier::get) + .collect(toImmutableList()), + hashChannelTypes, + Optional.of(outerPositionTrackerFactory.create())); + } + + @Override + public OuterPositionIterator getOuterPositionIterator() + { + return outerPositionTrackerFactory.getOuterPositionIterator(); + } + }; + } + else { + return TrackingLookupSourceSupplier.nonTracking( + () -> new PartitionedLookupSource( + partitions.stream() + .map(Supplier::get) + .collect(toImmutableList()), + hashChannelTypes, + Optional.empty())); + } } private final LookupSource[] lookupSources; @@ -148,13 +167,6 @@ public void appendTo(long partitionedJoinPosition, PageBuilder pageBuilder, int } } - @Override - public OuterPositionIterator getOuterPositionIterator() - { - checkState(outerPositionTracker != null, "This is not an outer lookup source"); - return new PartitionedLookupOuterPositionIterator(lookupSources, outerPositionTracker.getVisitedPositions()); - } - @Override public void close() { @@ -233,30 +245,46 @@ private static class OuterPositionTracker { public static class Factory { + private final LookupSource[] lookupSources; private final boolean[][] visitedPositions; + private final AtomicBoolean finished = new AtomicBoolean(); private final AtomicLong referenceCount = new AtomicLong(); - public Factory(List positionCounts) + public Factory(List> partitions) { - visitedPositions = new boolean[positionCounts.size()][]; - for (int partition = 0; partition < visitedPositions.length; partition++) { - visitedPositions[partition] = new boolean[positionCounts.get(partition)]; - } + this.lookupSources = partitions.stream() + .map(Supplier::get) + .toArray(LookupSource[]::new); + + visitedPositions = Arrays.stream(this.lookupSources) + .map(LookupSource::getJoinPositionCount) + .map(boolean[]::new) + .toArray(boolean[][]::new); } public OuterPositionTracker create() { - return new OuterPositionTracker(visitedPositions, referenceCount); + return new OuterPositionTracker(visitedPositions, finished, referenceCount); + } + + public OuterPositionIterator getOuterPositionIterator() + { + // touching atomic values ensures memory visibility between commit and getVisitedPositions + verify(referenceCount.get() == 0); + finished.set(true); + return new PartitionedLookupOuterPositionIterator(lookupSources, visitedPositions); } } private final boolean[][] visitedPositions; // shared across multiple operators/drivers + private final AtomicBoolean finished; // shared across multiple operators/drivers private final AtomicLong referenceCount; // shared across multiple operators/drivers private boolean written; // unique per each operator/driver - private OuterPositionTracker(boolean[][] visitedPositions, AtomicLong referenceCount) + private OuterPositionTracker(boolean[][] visitedPositions, AtomicBoolean finished, AtomicLong referenceCount) { this.visitedPositions = visitedPositions; + this.finished = finished; this.referenceCount = referenceCount; } @@ -267,7 +295,8 @@ public void positionVisited(int partition, int position) { if (!written) { written = true; - incrementReferenceCount(); + verify(!finished.get()); + referenceCount.incrementAndGet(); } visitedPositions[partition][position] = true; } @@ -279,17 +308,5 @@ public void commit() referenceCount.decrementAndGet(); } } - - public boolean[][] getVisitedPositions() - { - // touching atomic values ensures memory visibility between commit and getVisitedPositions - verify(referenceCount.get() == 0); - return visitedPositions; - } - - private void incrementReferenceCount() - { - referenceCount.incrementAndGet(); - } } } diff --git a/presto-main/src/main/java/com/facebook/presto/operator/PartitionedLookupSourceFactory.java b/presto-main/src/main/java/com/facebook/presto/operator/PartitionedLookupSourceFactory.java index e814fa0e579a..e0f7f30351e5 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/PartitionedLookupSourceFactory.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/PartitionedLookupSourceFactory.java @@ -50,7 +50,7 @@ public final class PartitionedLookupSourceFactory private int partitionsSet; @GuardedBy("this") - private Supplier lookupSourceSupplier; + private TrackingLookupSourceSupplier lookupSourceSupplier; @GuardedBy("this") private final List> lookupSourceFutures = new ArrayList<>(); @@ -90,7 +90,7 @@ public Map getLayout() public synchronized ListenableFuture createLookupSource() { if (lookupSourceSupplier != null) { - return Futures.immediateFuture(lookupSourceSupplier.get()); + return Futures.immediateFuture(lookupSourceSupplier.getLookupSource()); } SettableFuture lookupSourceFuture = SettableFuture.create(); @@ -102,7 +102,7 @@ public void setPartitionLookupSourceSupplier(int partitionIndex, Supplier lookupSourceSupplier = null; + TrackingLookupSourceSupplier lookupSourceSupplier = null; List> lookupSourceFutures = null; synchronized (this) { if (destroyed.isDone()) { @@ -122,7 +122,7 @@ else if (outer) { this.lookupSourceSupplier = createOuterLookupSourceSupplier(partitionLookupSource); } else { - this.lookupSourceSupplier = partitionLookupSource; + this.lookupSourceSupplier = TrackingLookupSourceSupplier.nonTracking(partitionLookupSource); } // store lookup source supplier and futures into local variables so they can be used outside of the lock @@ -133,11 +133,22 @@ else if (outer) { if (lookupSourceSupplier != null) { for (SettableFuture lookupSourceFuture : lookupSourceFutures) { - lookupSourceFuture.set(lookupSourceSupplier.get()); + lookupSourceFuture.set(lookupSourceSupplier.getLookupSource()); } } } + @Override + public OuterPositionIterator getOuterPositionIterator() + { + TrackingLookupSourceSupplier lookupSourceSupplier; + synchronized (this) { + checkState(this.lookupSourceSupplier != null, "lookup source not ready yet"); + lookupSourceSupplier = this.lookupSourceSupplier; + } + return lookupSourceSupplier.getOuterPositionIterator(); + } + @Override public void destroy() { diff --git a/presto-main/src/main/java/com/facebook/presto/operator/PartitionedOutputOperator.java b/presto-main/src/main/java/com/facebook/presto/operator/PartitionedOutputOperator.java index beaa7ce7e756..c2da434a773d 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/PartitionedOutputOperator.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/PartitionedOutputOperator.java @@ -25,6 +25,7 @@ import com.facebook.presto.spi.predicate.NullableValue; import com.facebook.presto.spi.type.Type; import com.facebook.presto.sql.planner.plan.PlanNodeId; +import com.facebook.presto.util.Mergeable; import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; import com.google.common.collect.ImmutableList; diff --git a/presto-main/src/main/java/com/facebook/presto/operator/PipelineContext.java b/presto-main/src/main/java/com/facebook/presto/operator/PipelineContext.java index a8b39e6688db..30f26ecea503 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/PipelineContext.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/PipelineContext.java @@ -455,7 +455,7 @@ else if (driverStats.isFullyBlocked()) { new Duration(totalCpuTime, NANOSECONDS).convertToMostSuccinctTimeUnit(), new Duration(totalUserTime, NANOSECONDS).convertToMostSuccinctTimeUnit(), new Duration(totalBlockedTime, NANOSECONDS).convertToMostSuccinctTimeUnit(), - fullyBlocked && (runningDrivers > 0 || runningPartitionedDrivers > 0), + fullyBlocked, blockedReasons, succinctBytes(rawInputDataSize), diff --git a/presto-main/src/main/java/com/facebook/presto/operator/TrackingLookupSourceSupplier.java b/presto-main/src/main/java/com/facebook/presto/operator/TrackingLookupSourceSupplier.java new file mode 100644 index 000000000000..174be9337aa8 --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/operator/TrackingLookupSourceSupplier.java @@ -0,0 +1,44 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.operator; + +import java.util.function.Supplier; + +import static java.util.Objects.requireNonNull; + +public interface TrackingLookupSourceSupplier +{ + LookupSource getLookupSource(); + + OuterPositionIterator getOuterPositionIterator(); + + static TrackingLookupSourceSupplier nonTracking(Supplier lookupSourceSupplier) + { + requireNonNull(lookupSourceSupplier, "lookupSourceSupplier is null"); + return new TrackingLookupSourceSupplier() + { + @Override + public LookupSource getLookupSource() + { + return lookupSourceSupplier.get(); + } + + @Override + public OuterPositionIterator getOuterPositionIterator() + { + throw new UnsupportedOperationException(); + } + }; + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/operator/UnnestOperator.java b/presto-main/src/main/java/com/facebook/presto/operator/UnnestOperator.java index cda3a083f195..fd7eb17072b3 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/UnnestOperator.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/UnnestOperator.java @@ -16,10 +16,10 @@ import com.facebook.presto.spi.Page; import com.facebook.presto.spi.PageBuilder; import com.facebook.presto.spi.block.Block; +import com.facebook.presto.spi.type.ArrayType; +import com.facebook.presto.spi.type.MapType; import com.facebook.presto.spi.type.Type; import com.facebook.presto.sql.planner.plan.PlanNodeId; -import com.facebook.presto.type.ArrayType; -import com.facebook.presto.type.MapType; import com.google.common.collect.ImmutableList; import java.util.ArrayList; diff --git a/presto-main/src/main/java/com/facebook/presto/operator/WindowInfo.java b/presto-main/src/main/java/com/facebook/presto/operator/WindowInfo.java new file mode 100644 index 000000000000..ddd9493ab1b8 --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/operator/WindowInfo.java @@ -0,0 +1,247 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.operator; + +import com.facebook.presto.operator.window.WindowPartition; +import com.facebook.presto.util.Mergeable; +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.google.common.collect.ImmutableList; + +import javax.annotation.concurrent.Immutable; + +import java.util.List; +import java.util.Optional; + +import static com.google.common.base.Preconditions.checkState; +import static com.google.common.collect.Iterables.concat; + +public class WindowInfo + implements Mergeable, OperatorInfo +{ + private final List windowInfos; + + @JsonCreator + public WindowInfo(@JsonProperty("windowInfos") List windowInfos) + { + this.windowInfos = ImmutableList.copyOf(windowInfos); + } + + @JsonProperty + public List getWindowInfos() + { + return windowInfos; + } + + @Override + public WindowInfo mergeWith(WindowInfo other) + { + return new WindowInfo(ImmutableList.copyOf(concat(this.windowInfos, other.windowInfos))); + } + + static class DriverWindowInfoBuilder + { + private ImmutableList.Builder indexInfosBuilder = ImmutableList.builder(); + private IndexInfoBuilder currentIndexInfoBuilder = null; + + public void addIndex(PagesIndex index) + { + if (currentIndexInfoBuilder != null) { + Optional indexInfo = currentIndexInfoBuilder.build(); + indexInfo.ifPresent(indexInfosBuilder::add); + } + currentIndexInfoBuilder = new IndexInfoBuilder(index.getPositionCount(), index.getEstimatedSize().toBytes()); + } + + public void addPartition(WindowPartition partition) + { + checkState(currentIndexInfoBuilder != null, "addIndex must be called before addPartition"); + currentIndexInfoBuilder.addPartition(partition); + } + + public DriverWindowInfo build() + { + if (currentIndexInfoBuilder != null) { + Optional indexInfo = currentIndexInfoBuilder.build(); + indexInfo.ifPresent(indexInfosBuilder::add); + currentIndexInfoBuilder = null; + } + + List indexInfos = indexInfosBuilder.build(); + if (indexInfos.size() == 0) { + return new DriverWindowInfo(0.0, 0.0, 0.0, 0, 0, 0); + } + long totalRowsCount = indexInfos.stream() + .mapToLong(IndexInfo::getTotalRowsCount) + .sum(); + double averageIndexPositions = totalRowsCount / indexInfos.size(); + double squaredDifferencesPositionsOfIndex = indexInfos.stream() + .mapToDouble(index -> Math.pow(index.getTotalRowsCount() - averageIndexPositions, 2)) + .sum(); + double averageIndexSize = indexInfos.stream() + .mapToLong(IndexInfo::getSizeInBytes) + .average() + .getAsDouble(); + double squaredDifferencesSizeOfIndex = indexInfos.stream() + .mapToDouble(index -> Math.pow(index.getSizeInBytes() - averageIndexSize, 2)) + .sum(); + double squaredDifferencesSizeInPartition = indexInfos.stream() + .mapToDouble(IndexInfo::getSumSquaredDifferencesSizeInPartition) + .sum(); + + long totalPartitionsCount = indexInfos.stream() + .mapToLong(IndexInfo::getNumberOfPartitions) + .sum(); + + return new DriverWindowInfo(squaredDifferencesPositionsOfIndex, + squaredDifferencesSizeOfIndex, + squaredDifferencesSizeInPartition, + totalPartitionsCount, + totalRowsCount, + indexInfos.size()); + } + } + + @Immutable + public static class DriverWindowInfo + { + private final double sumSquaredDifferencesPositionsOfIndex; // sum of (indexPositions - averageIndexPositions) ^ 2 for all indexes + private final double sumSquaredDifferencesSizeOfIndex; // sum of (indexSize - averageIndexSize) ^ 2 for all indexes + private final double sumSquaredDifferencesSizeInPartition; // sum of (partitionSize - averagePartitionSize)^2 for each partition + private final long totalPartitionsCount; + private final long totalRowsCount; + private final long numberOfIndexes; + + @JsonCreator + public DriverWindowInfo( + @JsonProperty("sumSquaredDifferencesPositionsOfIndex") double sumSquaredDifferencesPositionsOfIndex, + @JsonProperty("sumSquaredDifferencesSizeOfIndex") double sumSquaredDifferencesSizeOfIndex, + @JsonProperty("sumSquaredDifferencesSizeInPartition") double sumSquaredDifferencesSizeInPartition, + @JsonProperty("totalPartitionsCount") long totalPartitionsCount, + @JsonProperty("totalRowsCount") long totalRowsCount, + @JsonProperty("numberOfIndexes") long numberOfIndexes) + { + this.sumSquaredDifferencesPositionsOfIndex = sumSquaredDifferencesPositionsOfIndex; + this.sumSquaredDifferencesSizeOfIndex = sumSquaredDifferencesSizeOfIndex; + this.sumSquaredDifferencesSizeInPartition = sumSquaredDifferencesSizeInPartition; + this.totalPartitionsCount = totalPartitionsCount; + this.totalRowsCount = totalRowsCount; + this.numberOfIndexes = numberOfIndexes; + } + + @JsonProperty + public double getSumSquaredDifferencesPositionsOfIndex() + { + return sumSquaredDifferencesPositionsOfIndex; + } + + @JsonProperty + public double getSumSquaredDifferencesSizeOfIndex() + { + return sumSquaredDifferencesSizeOfIndex; + } + + @JsonProperty + public double getSumSquaredDifferencesSizeInPartition() + { + return sumSquaredDifferencesSizeInPartition; + } + + @JsonProperty + public long getTotalPartitionsCount() + { + return totalPartitionsCount; + } + + @JsonProperty + public long getTotalRowsCount() + { + return totalRowsCount; + } + + @JsonProperty + public long getNumberOfIndexes() + { + return numberOfIndexes; + } + } + + private static class IndexInfoBuilder + { + private final long rowsNumber; + private final long sizeInBytes; + private final ImmutableList.Builder partitionsSizes = ImmutableList.builder(); + + public IndexInfoBuilder(long rowsNumber, long sizeInBytes) + { + this.rowsNumber = rowsNumber; + this.sizeInBytes = sizeInBytes; + } + + public void addPartition(WindowPartition partition) + { + partitionsSizes.add(partition.getPartitionEnd() - partition.getPartitionStart()); + } + + public Optional build() + { + List partitions = partitionsSizes.build(); + if (partitions.size() == 0) { + return Optional.empty(); + } + double avgSize = partitions.stream().mapToLong(Integer::longValue).average().getAsDouble(); + double squaredDifferences = partitions.stream().mapToDouble(size -> Math.pow(size - avgSize, 2)).sum(); + checkState(partitions.stream().mapToLong(Integer::longValue).sum() == rowsNumber, "Total number of rows in index does not match number of rows in partitions within that index"); + + return Optional.of(new IndexInfo(rowsNumber, sizeInBytes, squaredDifferences, partitions.size())); + } + } + + @Immutable + public static class IndexInfo + { + private final long totalRowsCount; + private final long sizeInBytes; + private final double sumSquaredDifferencesSizeInPartition; // sum of (partitionSize - averagePartitionSize)^2 for each partition + private final long numberOfPartitions; + + public IndexInfo(long totalRowsCount, long sizeInBytes, double sumSquaredDifferencesSizeInPartition, long numberOfPartitions) + { + this.totalRowsCount = totalRowsCount; + this.sizeInBytes = sizeInBytes; + this.sumSquaredDifferencesSizeInPartition = sumSquaredDifferencesSizeInPartition; + this.numberOfPartitions = numberOfPartitions; + } + + public long getTotalRowsCount() + { + return totalRowsCount; + } + + public long getSizeInBytes() + { + return sizeInBytes; + } + + public double getSumSquaredDifferencesSizeInPartition() + { + return sumSquaredDifferencesSizeInPartition; + } + + public long getNumberOfPartitions() + { + return numberOfPartitions; + } + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/operator/WindowOperator.java b/presto-main/src/main/java/com/facebook/presto/operator/WindowOperator.java index b426ffa61d46..8ac15ab510e7 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/WindowOperator.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/WindowOperator.java @@ -29,6 +29,7 @@ import java.util.List; import java.util.Optional; +import java.util.concurrent.atomic.AtomicReference; import java.util.function.BiPredicate; import java.util.stream.Stream; @@ -146,18 +147,18 @@ public void close() public OperatorFactory duplicate() { return new WindowOperatorFactory( - operatorId, - planNodeId, - sourceTypes, - outputChannels, - windowFunctionDefinitions, - partitionChannels, - preGroupedChannels, - sortChannels, - sortOrder, - preSortedChannelPrefix, - expectedPositions, - pagesIndexFactory); + operatorId, + planNodeId, + sourceTypes, + outputChannels, + windowFunctionDefinitions, + partitionChannels, + preGroupedChannels, + sortChannels, + sortOrder, + preSortedChannelPrefix, + expectedPositions, + pagesIndexFactory); } } @@ -187,6 +188,9 @@ private enum State private final PageBuilder pageBuilder; + private final WindowInfo.DriverWindowInfoBuilder windowInfo; + private AtomicReference> driverWindowInfo = new AtomicReference<>(Optional.empty()); + private State state = State.NEEDS_INPUT; private WindowPartition partition; @@ -257,6 +261,14 @@ public WindowOperator( this.orderChannels = ImmutableList.copyOf(concat(unGroupedPartitionChannels, sortChannels)); this.ordering = ImmutableList.copyOf(concat(nCopies(unGroupedPartitionChannels.size(), ASC_NULLS_LAST), sortOrder)); } + + windowInfo = new WindowInfo.DriverWindowInfoBuilder(); + operatorContext.setInfoSupplier(this::getWindowInfo); + } + + private OperatorInfo getWindowInfo() + { + return new WindowInfo(driverWindowInfo.get().map(ImmutableList::of).orElse(ImmutableList.of())); } @Override @@ -279,7 +291,7 @@ public void finish() } if (state == State.NEEDS_INPUT) { // Since was waiting for more input, prepare what we have for output since we will not be getting any more input - sortPagesIndexIfNecessary(); + finishPagesIndex(); } state = State.FINISHING; } @@ -324,7 +336,7 @@ private boolean processPendingInput() // If we have unused input or are finishing, then we have buffered a full group if (pendingInput != null || state == State.FINISHING) { - sortPagesIndexIfNecessary(); + finishPagesIndex(); return true; } else { @@ -421,6 +433,7 @@ else if (state == State.FINISHING) { int partitionEnd = findGroupEnd(pagesIndex, unGroupedPartitionHashStrategy, partitionStart); partition = new WindowPartition(pagesIndex, partitionStart, partitionEnd, outputChannels, windowFunctions, peerGroupHashStrategy); + windowInfo.addPartition(partition); } partition.processNextRow(pageBuilder); @@ -443,6 +456,12 @@ private void sortPagesIndexIfNecessary() } } + private void finishPagesIndex() + { + sortPagesIndexIfNecessary(); + windowInfo.addIndex(pagesIndex); + } + // Assumes input grouped on relevant pagesHashStrategy columns private static int findGroupEnd(Page page, PagesHashStrategy pagesHashStrategy, int startPosition) { @@ -517,4 +536,11 @@ else if (!previousPairsWereEqual) { // the input is sorted, but the algorithm has still failed throw new IllegalArgumentException("failed to find a group ending"); } + + @Override + public void close() + throws Exception + { + driverWindowInfo.set(Optional.of(windowInfo.build())); + } } diff --git a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/AbstractMinMaxAggregationFunction.java b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/AbstractMinMaxAggregationFunction.java index 053785624a85..fc3783348b96 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/AbstractMinMaxAggregationFunction.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/AbstractMinMaxAggregationFunction.java @@ -24,7 +24,6 @@ import com.facebook.presto.operator.aggregation.state.NullableLongState; import com.facebook.presto.operator.aggregation.state.SliceState; import com.facebook.presto.operator.aggregation.state.StateCompiler; -import com.facebook.presto.spi.PrestoException; import com.facebook.presto.spi.block.Block; import com.facebook.presto.spi.block.BlockBuilder; import com.facebook.presto.spi.function.AccumulatorState; @@ -33,7 +32,6 @@ import com.facebook.presto.spi.function.OperatorType; import com.facebook.presto.spi.type.Type; import com.facebook.presto.spi.type.TypeManager; -import com.google.common.base.Throwables; import com.google.common.collect.ImmutableList; import io.airlift.slice.Slice; @@ -46,9 +44,9 @@ import static com.facebook.presto.operator.aggregation.AggregationMetadata.ParameterMetadata.ParameterType.INPUT_CHANNEL; import static com.facebook.presto.operator.aggregation.AggregationMetadata.ParameterMetadata.ParameterType.STATE; import static com.facebook.presto.operator.aggregation.AggregationUtils.generateAggregationName; -import static com.facebook.presto.spi.StandardErrorCode.GENERIC_INTERNAL_ERROR; import static com.facebook.presto.spi.type.BooleanType.BOOLEAN; import static com.facebook.presto.spi.type.TypeSignature.parseTypeSignature; +import static com.facebook.presto.util.Failures.internalError; import static com.facebook.presto.util.Reflection.methodHandle; import static com.google.common.collect.ImmutableList.toImmutableList; import static java.util.Objects.requireNonNull; @@ -185,9 +183,7 @@ public static void input(MethodHandle methodHandle, NullableDoubleState state, d } } catch (Throwable t) { - Throwables.propagateIfInstanceOf(t, Error.class); - Throwables.propagateIfInstanceOf(t, PrestoException.class); - throw new PrestoException(GENERIC_INTERNAL_ERROR, t); + throw internalError(t); } } @@ -204,9 +200,7 @@ public static void input(MethodHandle methodHandle, NullableLongState state, lon } } catch (Throwable t) { - Throwables.propagateIfInstanceOf(t, Error.class); - Throwables.propagateIfInstanceOf(t, PrestoException.class); - throw new PrestoException(GENERIC_INTERNAL_ERROR, t); + throw internalError(t); } } @@ -222,9 +216,7 @@ public static void input(MethodHandle methodHandle, SliceState state, Slice valu } } catch (Throwable t) { - Throwables.propagateIfInstanceOf(t, Error.class); - Throwables.propagateIfInstanceOf(t, PrestoException.class); - throw new PrestoException(GENERIC_INTERNAL_ERROR, t); + throw internalError(t); } } @@ -241,9 +233,7 @@ public static void input(MethodHandle methodHandle, NullableBooleanState state, } } catch (Throwable t) { - Throwables.propagateIfInstanceOf(t, Error.class); - Throwables.propagateIfInstanceOf(t, PrestoException.class); - throw new PrestoException(GENERIC_INTERNAL_ERROR, t); + throw internalError(t); } } @@ -259,9 +249,7 @@ public static void input(MethodHandle methodHandle, BlockState state, Block valu } } catch (Throwable t) { - Throwables.propagateIfInstanceOf(t, Error.class); - Throwables.propagateIfInstanceOf(t, PrestoException.class); - throw new PrestoException(GENERIC_INTERNAL_ERROR, t); + throw internalError(t); } } @@ -278,9 +266,7 @@ public static void combine(MethodHandle methodHandle, NullableLongState state, N } } catch (Throwable t) { - Throwables.propagateIfInstanceOf(t, Error.class); - Throwables.propagateIfInstanceOf(t, PrestoException.class); - throw new PrestoException(GENERIC_INTERNAL_ERROR, t); + throw internalError(t); } } @@ -297,9 +283,7 @@ public static void combine(MethodHandle methodHandle, NullableDoubleState state, } } catch (Throwable t) { - Throwables.propagateIfInstanceOf(t, Error.class); - Throwables.propagateIfInstanceOf(t, PrestoException.class); - throw new PrestoException(GENERIC_INTERNAL_ERROR, t); + throw internalError(t); } } @@ -316,9 +300,7 @@ public static void combine(MethodHandle methodHandle, NullableBooleanState state } } catch (Throwable t) { - Throwables.propagateIfInstanceOf(t, Error.class); - Throwables.propagateIfInstanceOf(t, PrestoException.class); - throw new PrestoException(GENERIC_INTERNAL_ERROR, t); + throw internalError(t); } } @@ -334,9 +316,7 @@ public static void combine(MethodHandle methodHandle, SliceState state, SliceSta } } catch (Throwable t) { - Throwables.propagateIfInstanceOf(t, Error.class); - Throwables.propagateIfInstanceOf(t, PrestoException.class); - throw new PrestoException(GENERIC_INTERNAL_ERROR, t); + throw internalError(t); } } @@ -352,9 +332,7 @@ public static void combine(MethodHandle methodHandle, BlockState state, BlockSta } } catch (Throwable t) { - Throwables.propagateIfInstanceOf(t, Error.class); - Throwables.propagateIfInstanceOf(t, PrestoException.class); - throw new PrestoException(GENERIC_INTERNAL_ERROR, t); + throw internalError(t); } } } diff --git a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/AbstractMinMaxByNAggregationFunction.java b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/AbstractMinMaxByNAggregationFunction.java index 6f29a0721c99..23e0f86ccfcc 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/AbstractMinMaxByNAggregationFunction.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/AbstractMinMaxByNAggregationFunction.java @@ -23,10 +23,10 @@ import com.facebook.presto.spi.PrestoException; import com.facebook.presto.spi.block.Block; import com.facebook.presto.spi.block.BlockBuilder; +import com.facebook.presto.spi.type.ArrayType; import com.facebook.presto.spi.type.StandardTypes; import com.facebook.presto.spi.type.Type; import com.facebook.presto.spi.type.TypeManager; -import com.facebook.presto.type.ArrayType; import com.google.common.collect.ImmutableList; import java.lang.invoke.MethodHandle; diff --git a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/AbstractMinMaxNAggregationFunction.java b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/AbstractMinMaxNAggregationFunction.java index 093e0fccb989..af718c7de681 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/AbstractMinMaxNAggregationFunction.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/AbstractMinMaxNAggregationFunction.java @@ -23,10 +23,10 @@ import com.facebook.presto.spi.PrestoException; import com.facebook.presto.spi.block.Block; import com.facebook.presto.spi.block.BlockBuilder; +import com.facebook.presto.spi.type.ArrayType; import com.facebook.presto.spi.type.StandardTypes; import com.facebook.presto.spi.type.Type; import com.facebook.presto.spi.type.TypeManager; -import com.facebook.presto.type.ArrayType; import com.google.common.collect.ImmutableList; import java.lang.invoke.MethodHandle; diff --git a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/ArrayAggregationFunction.java b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/ArrayAggregationFunction.java index 7f20c738dffc..398e3f4917e5 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/ArrayAggregationFunction.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/ArrayAggregationFunction.java @@ -26,9 +26,9 @@ import com.facebook.presto.spi.function.AccumulatorState; import com.facebook.presto.spi.function.AccumulatorStateFactory; import com.facebook.presto.spi.function.AccumulatorStateSerializer; +import com.facebook.presto.spi.type.ArrayType; import com.facebook.presto.spi.type.Type; import com.facebook.presto.spi.type.TypeManager; -import com.facebook.presto.type.ArrayType; import com.google.common.collect.ImmutableList; import java.lang.invoke.MethodHandle; diff --git a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/Histogram.java b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/Histogram.java index d1244c883547..edfcf1efa2a1 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/Histogram.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/Histogram.java @@ -156,8 +156,7 @@ public static void output(Type type, HistogramState state, BlockBuilder out) out.appendNull(); } else { - Block block = typedHistogram.serialize(); - type.writeObject(out, block); + typedHistogram.serialize(out); } } } diff --git a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/MapAggregationFunction.java b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/MapAggregationFunction.java index 3cb0f142abc9..f5bc8a320082 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/MapAggregationFunction.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/MapAggregationFunction.java @@ -22,11 +22,11 @@ import com.facebook.presto.operator.aggregation.state.KeyValuePairsStateFactory; import com.facebook.presto.spi.block.Block; import com.facebook.presto.spi.block.BlockBuilder; +import com.facebook.presto.spi.type.MapType; import com.facebook.presto.spi.type.StandardTypes; import com.facebook.presto.spi.type.Type; import com.facebook.presto.spi.type.TypeManager; import com.facebook.presto.spi.type.TypeSignatureParameter; -import com.facebook.presto.type.MapType; import com.google.common.collect.ImmutableList; import java.lang.invoke.MethodHandle; diff --git a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/MapUnionAggregation.java b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/MapUnionAggregation.java index dbbee380e99a..d644858cb159 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/MapUnionAggregation.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/MapUnionAggregation.java @@ -22,11 +22,11 @@ import com.facebook.presto.operator.aggregation.state.KeyValuePairsStateFactory; import com.facebook.presto.spi.block.Block; import com.facebook.presto.spi.block.BlockBuilder; +import com.facebook.presto.spi.type.MapType; import com.facebook.presto.spi.type.StandardTypes; import com.facebook.presto.spi.type.Type; import com.facebook.presto.spi.type.TypeManager; import com.facebook.presto.spi.type.TypeSignatureParameter; -import com.facebook.presto.type.MapType; import com.google.common.collect.ImmutableList; import java.lang.invoke.MethodHandle; diff --git a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/MultiKeyValuePairs.java b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/MultiKeyValuePairs.java index bd12c2bb0911..d992282d6a0a 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/MultiKeyValuePairs.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/MultiKeyValuePairs.java @@ -16,11 +16,8 @@ import com.facebook.presto.array.ObjectBigArray; import com.facebook.presto.spi.block.Block; import com.facebook.presto.spi.block.BlockBuilder; -import com.facebook.presto.spi.block.BlockBuilderStatus; -import com.facebook.presto.spi.block.InterleavedBlockBuilder; +import com.facebook.presto.spi.type.ArrayType; import com.facebook.presto.spi.type.Type; -import com.facebook.presto.type.ArrayType; -import com.google.common.collect.ImmutableList; import org.openjdk.jol.info.ClassLayout; import static com.facebook.presto.type.TypeUtils.expectedValueSize; @@ -85,7 +82,7 @@ public void serialize(BlockBuilder out) /** * Serialize as a multimap: map(key, array(value)), each key can be associated with multiple values */ - public Block toMultimapNativeEncoding() + public void toMultimapNativeEncoding(BlockBuilder blockBuilder) { Block keys = keyBlockBuilder.build(); Block values = valueBlockBuilder.build(); @@ -108,13 +105,12 @@ public Block toMultimapNativeEncoding() // Write keys and value arrays into one Block Block distinctKeys = distinctKeyBlockBuilder.build(); Type valueArrayType = new ArrayType(valueType); - BlockBuilder multimapBlockBuilder = new InterleavedBlockBuilder(ImmutableList.of(keyType, valueArrayType), new BlockBuilderStatus(), distinctKeyBlockBuilder.getPositionCount()); + BlockBuilder multimapBlockBuilder = blockBuilder.beginBlockEntry(); for (int i = 0; i < distinctKeys.getPositionCount(); i++) { keyType.appendTo(distinctKeys, i, multimapBlockBuilder); valueArrayType.writeObject(multimapBlockBuilder, valueArrayBlockBuilders.get(i).build()); } - - return multimapBlockBuilder.build(); + blockBuilder.closeEntry(); } public long estimatedInMemorySize() diff --git a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/MultimapAggregationFunction.java b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/MultimapAggregationFunction.java index 263eb2e0e56f..3f9efbf7c3a8 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/MultimapAggregationFunction.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/MultimapAggregationFunction.java @@ -22,11 +22,11 @@ import com.facebook.presto.operator.aggregation.state.MultiKeyValuePairsStateFactory; import com.facebook.presto.spi.block.Block; import com.facebook.presto.spi.block.BlockBuilder; +import com.facebook.presto.spi.type.ArrayType; import com.facebook.presto.spi.type.StandardTypes; import com.facebook.presto.spi.type.Type; import com.facebook.presto.spi.type.TypeManager; import com.facebook.presto.spi.type.TypeSignatureParameter; -import com.facebook.presto.type.ArrayType; import com.google.common.collect.ImmutableList; import java.lang.invoke.MethodHandle; @@ -146,9 +146,7 @@ public static void output(MultiKeyValuePairsState state, BlockBuilder out) out.appendNull(); } else { - Block block = pairs.toMultimapNativeEncoding(); - out.writeObject(block); - out.closeEntry(); + pairs.toMultimapNativeEncoding(out); } } } diff --git a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/TypedHistogram.java b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/TypedHistogram.java index 0646db318dc2..215f0ef45834 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/TypedHistogram.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/TypedHistogram.java @@ -18,11 +18,8 @@ import com.facebook.presto.spi.PrestoException; import com.facebook.presto.spi.block.Block; import com.facebook.presto.spi.block.BlockBuilder; -import com.facebook.presto.spi.block.BlockBuilderStatus; -import com.facebook.presto.spi.block.InterleavedBlockBuilder; import com.facebook.presto.spi.type.Type; import com.facebook.presto.type.TypeUtils; -import com.google.common.collect.ImmutableList; import io.airlift.units.DataSize; import org.openjdk.jol.info.ClassLayout; @@ -93,15 +90,15 @@ private LongBigArray getCounts() return counts; } - public Block serialize() + public void serialize(BlockBuilder out) { Block valuesBlock = values.build(); - BlockBuilder blockBuilder = new InterleavedBlockBuilder(ImmutableList.of(type, BIGINT), new BlockBuilderStatus(), valuesBlock.getPositionCount() * 2); + BlockBuilder blockBuilder = out.beginBlockEntry(); for (int i = 0; i < valuesBlock.getPositionCount(); i++) { type.appendTo(valuesBlock, i, blockBuilder); BIGINT.writeLong(blockBuilder, counts.get(i)); } - return blockBuilder.build(); + out.closeEntry(); } public void addAll(TypedHistogram other) diff --git a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/TypedKeyValueHeap.java b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/TypedKeyValueHeap.java index 5cc73b827b48..1a52fcd6c8e5 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/TypedKeyValueHeap.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/TypedKeyValueHeap.java @@ -15,9 +15,9 @@ import com.facebook.presto.spi.block.Block; import com.facebook.presto.spi.block.BlockBuilder; +import com.facebook.presto.spi.type.ArrayType; +import com.facebook.presto.spi.type.RowType; import com.facebook.presto.spi.type.Type; -import com.facebook.presto.type.ArrayType; -import com.facebook.presto.type.RowType; import com.google.common.collect.ImmutableList; import org.openjdk.jol.info.ClassLayout; diff --git a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/builder/SpillableHashAggregationBuilder.java b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/builder/SpillableHashAggregationBuilder.java index d113299223a3..0f2e25b33294 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/builder/SpillableHashAggregationBuilder.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/builder/SpillableHashAggregationBuilder.java @@ -26,9 +26,11 @@ import com.facebook.presto.sql.gen.JoinCompiler; import com.facebook.presto.sql.planner.plan.AggregationNode; import com.google.common.collect.ImmutableList; +import com.google.common.io.Closer; import com.google.common.util.concurrent.ListenableFuture; import io.airlift.units.DataSize; +import java.io.IOException; import java.util.Iterator; import java.util.List; import java.util.Optional; @@ -189,14 +191,13 @@ public Iterator buildResult() @Override public void close() { - if (merger.isPresent()) { - merger.get().close(); + try (Closer closer = Closer.create()) { + merger.ifPresent(closer::register); + spiller.ifPresent(closer::register); + mergeHashSort.ifPresent(closer::register); } - if (spiller.isPresent()) { - spiller.get().close(); - } - if (mergeHashSort.isPresent()) { - mergeHashSort.get().close(); + catch (IOException e) { + throw new RuntimeException(e); } } diff --git a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/state/ArrayAggregationStateSerializer.java b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/state/ArrayAggregationStateSerializer.java index e5acf9da238f..fe25ee8a4686 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/state/ArrayAggregationStateSerializer.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/state/ArrayAggregationStateSerializer.java @@ -16,8 +16,8 @@ import com.facebook.presto.spi.block.Block; import com.facebook.presto.spi.block.BlockBuilder; import com.facebook.presto.spi.function.AccumulatorStateSerializer; +import com.facebook.presto.spi.type.ArrayType; import com.facebook.presto.spi.type.Type; -import com.facebook.presto.type.ArrayType; public class ArrayAggregationStateSerializer implements AccumulatorStateSerializer diff --git a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/state/HistogramStateSerializer.java b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/state/HistogramStateSerializer.java index 67d69205a911..88553458b5c6 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/state/HistogramStateSerializer.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/state/HistogramStateSerializer.java @@ -46,7 +46,7 @@ public void serialize(HistogramState state, BlockBuilder out) out.appendNull(); } else { - serializedType.writeObject(out, state.get().serialize()); + state.get().serialize(out); } } diff --git a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/state/KeyValuePairStateSerializer.java b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/state/KeyValuePairStateSerializer.java index 3255356d65f0..23f7b02daa27 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/state/KeyValuePairStateSerializer.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/state/KeyValuePairStateSerializer.java @@ -17,8 +17,8 @@ import com.facebook.presto.spi.block.Block; import com.facebook.presto.spi.block.BlockBuilder; import com.facebook.presto.spi.function.AccumulatorStateSerializer; +import com.facebook.presto.spi.type.MapType; import com.facebook.presto.spi.type.Type; -import com.facebook.presto.type.MapType; public class KeyValuePairStateSerializer implements AccumulatorStateSerializer diff --git a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/state/MinMaxNStateSerializer.java b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/state/MinMaxNStateSerializer.java index 1c8068b68171..3115fd6806d2 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/state/MinMaxNStateSerializer.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/state/MinMaxNStateSerializer.java @@ -18,9 +18,9 @@ import com.facebook.presto.spi.block.Block; import com.facebook.presto.spi.block.BlockBuilder; import com.facebook.presto.spi.function.AccumulatorStateSerializer; +import com.facebook.presto.spi.type.ArrayType; +import com.facebook.presto.spi.type.RowType; import com.facebook.presto.spi.type.Type; -import com.facebook.presto.type.ArrayType; -import com.facebook.presto.type.RowType; import com.google.common.collect.ImmutableList; import java.util.Optional; diff --git a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/state/MultiKeyValuePairStateSerializer.java b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/state/MultiKeyValuePairStateSerializer.java index d523333f2e8f..a10d05e22079 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/state/MultiKeyValuePairStateSerializer.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/state/MultiKeyValuePairStateSerializer.java @@ -17,9 +17,9 @@ import com.facebook.presto.spi.block.Block; import com.facebook.presto.spi.block.BlockBuilder; import com.facebook.presto.spi.function.AccumulatorStateSerializer; +import com.facebook.presto.spi.type.ArrayType; +import com.facebook.presto.spi.type.RowType; import com.facebook.presto.spi.type.Type; -import com.facebook.presto.type.ArrayType; -import com.facebook.presto.type.RowType; import com.google.common.collect.ImmutableList; import java.util.Optional; diff --git a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/state/StateCompiler.java b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/state/StateCompiler.java index 2bb86102a03a..8c0efe6db760 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/state/StateCompiler.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/state/StateCompiler.java @@ -35,10 +35,10 @@ import com.facebook.presto.spi.function.AccumulatorStateFactory; import com.facebook.presto.spi.function.AccumulatorStateMetadata; import com.facebook.presto.spi.function.AccumulatorStateSerializer; +import com.facebook.presto.spi.type.RowType; import com.facebook.presto.spi.type.Type; import com.facebook.presto.sql.gen.CallSiteBinder; import com.facebook.presto.sql.gen.SqlTypeBytecodeExpression; -import com.facebook.presto.type.RowType; import com.google.common.base.Throwables; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; diff --git a/presto-main/src/main/java/com/facebook/presto/operator/exchange/LocalExchangeBufferInfo.java b/presto-main/src/main/java/com/facebook/presto/operator/exchange/LocalExchangeBufferInfo.java index e9de92b82cd5..abe455dca99f 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/exchange/LocalExchangeBufferInfo.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/exchange/LocalExchangeBufferInfo.java @@ -13,8 +13,8 @@ */ package com.facebook.presto.operator.exchange; -import com.facebook.presto.operator.Mergeable; import com.facebook.presto.operator.OperatorInfo; +import com.facebook.presto.util.Mergeable; import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; diff --git a/presto-main/src/main/java/com/facebook/presto/operator/index/IndexLookupSourceFactory.java b/presto-main/src/main/java/com/facebook/presto/operator/index/IndexLookupSourceFactory.java index 4a1c6f7fa221..d16c571db123 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/index/IndexLookupSourceFactory.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/index/IndexLookupSourceFactory.java @@ -15,6 +15,7 @@ import com.facebook.presto.operator.LookupSource; import com.facebook.presto.operator.LookupSourceFactory; +import com.facebook.presto.operator.OuterPositionIterator; import com.facebook.presto.operator.PagesIndex; import com.facebook.presto.operator.TaskContext; import com.facebook.presto.spi.type.Type; @@ -102,6 +103,12 @@ public ListenableFuture createLookupSource() return Futures.immediateFuture(new IndexLookupSource(indexLoader)); } + @Override + public OuterPositionIterator getOuterPositionIterator() + { + throw new UnsupportedOperationException(); + } + @Override public void destroy() { diff --git a/presto-main/src/main/java/com/facebook/presto/operator/project/DictionaryAwarePageFilter.java b/presto-main/src/main/java/com/facebook/presto/operator/project/DictionaryAwarePageFilter.java index 2844a942ab51..5bb1d25208c6 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/project/DictionaryAwarePageFilter.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/project/DictionaryAwarePageFilter.java @@ -64,10 +64,12 @@ public SelectedPositions filter(ConnectorSession session, Page page) if (block instanceof RunLengthEncodedBlock) { Block value = ((RunLengthEncodedBlock) block).getValue(); - Optional selectedDictionaryPositions = processDictionary(session, value); - // single value block is always considered effective - verify(selectedDictionaryPositions.isPresent()); - return SelectedPositions.positionsRange(0, selectedDictionaryPositions.get()[0] ? page.getPositionCount() : 0); + Optional selectedPosition = processDictionary(session, value); + // single value block is always considered effective, but the processing could have thrown + // in that case we fallback and process again so the correct error message sent + if (selectedPosition.isPresent()) { + return SelectedPositions.positionsRange(0, selectedPosition.get()[0] ? page.getPositionCount() : 0); + } } if (block instanceof DictionaryBlock) { diff --git a/presto-main/src/main/java/com/facebook/presto/operator/project/DictionaryAwarePageProjection.java b/presto-main/src/main/java/com/facebook/presto/operator/project/DictionaryAwarePageProjection.java index be740bd85d1a..c9d39bcd6bcb 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/project/DictionaryAwarePageProjection.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/project/DictionaryAwarePageProjection.java @@ -75,9 +75,11 @@ public Block project(ConnectorSession session, Page page, SelectedPositions sele if (block instanceof RunLengthEncodedBlock) { Block value = ((RunLengthEncodedBlock) block).getValue(); Optional projectedValue = processDictionary(session, value); - // single value block is always considered effective - verify(projectedValue.isPresent()); - return new RunLengthEncodedBlock(projectedValue.get(), selectedPositions.size()); + // single value block is always considered effective, but the processing could have thrown + // in that case we fallback and process again so the correct error message sent + if (projectedValue.isPresent()) { + return new RunLengthEncodedBlock(projectedValue.get(), selectedPositions.size()); + } } if (block instanceof DictionaryBlock) { diff --git a/presto-main/src/main/java/com/facebook/presto/operator/project/PageProcessor.java b/presto-main/src/main/java/com/facebook/presto/operator/project/PageProcessor.java index fb71cd7fafef..be1d69ae35ba 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/project/PageProcessor.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/project/PageProcessor.java @@ -18,6 +18,7 @@ import com.facebook.presto.spi.block.Block; import com.facebook.presto.spi.block.DictionaryBlock; import com.facebook.presto.spi.block.DictionaryId; +import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.AbstractIterator; import com.google.common.collect.Iterators; @@ -96,6 +97,12 @@ public PageProcessorOutput process(ConnectorSession session, Page page) new PositionsPageProcessorIterator(session, page, SelectedPositions.positionsRange(0, page.getPositionCount()))); } + @VisibleForTesting + public List getProjections() + { + return projections; + } + private class PositionsPageProcessorIterator extends AbstractIterator { diff --git a/presto-main/src/main/java/com/facebook/presto/operator/scalar/AbstractGreatestLeast.java b/presto-main/src/main/java/com/facebook/presto/operator/scalar/AbstractGreatestLeast.java index 008141e46857..b47121814497 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/scalar/AbstractGreatestLeast.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/scalar/AbstractGreatestLeast.java @@ -52,9 +52,11 @@ import static com.facebook.presto.metadata.Signature.internalOperator; import static com.facebook.presto.metadata.Signature.orderableTypeParameter; import static com.facebook.presto.spi.StandardErrorCode.INVALID_FUNCTION_ARGUMENT; +import static com.facebook.presto.spi.StandardErrorCode.NOT_SUPPORTED; import static com.facebook.presto.spi.type.BooleanType.BOOLEAN; import static com.facebook.presto.spi.type.TypeSignature.parseTypeSignature; import static com.facebook.presto.sql.gen.BytecodeUtils.invoke; +import static com.facebook.presto.util.Failures.checkCondition; import static com.facebook.presto.util.Reflection.methodHandle; import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.collect.ImmutableList.toImmutableList; @@ -123,6 +125,7 @@ public static void checkNotNaN(String name, double value) private Class generate(List> javaTypes, Type type, MethodHandle compareMethod) { + checkCondition(javaTypes.size() <= 127, NOT_SUPPORTED, "Too many arguments for function call %s()", getSignature().getName()); String javaTypeName = javaTypes.stream() .map(Class::getSimpleName) .collect(joining()); diff --git a/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayConstructor.java b/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayConstructor.java index 99a2d79e1555..fcd4a73c7a9b 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayConstructor.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayConstructor.java @@ -58,8 +58,10 @@ import static com.facebook.presto.bytecode.expression.BytecodeExpressions.equal; import static com.facebook.presto.bytecode.expression.BytecodeExpressions.newInstance; import static com.facebook.presto.metadata.Signature.typeVariable; +import static com.facebook.presto.spi.StandardErrorCode.NOT_SUPPORTED; import static com.facebook.presto.spi.type.TypeSignature.parseTypeSignature; import static com.facebook.presto.sql.gen.SqlTypeBytecodeExpression.constantType; +import static com.facebook.presto.util.Failures.checkCondition; import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.collect.ImmutableList.toImmutableList; import static java.lang.invoke.MethodHandles.lookup; @@ -130,6 +132,7 @@ public ScalarFunctionImplementation specialize(BoundVariables boundVariables, in private static Class generateArrayConstructor(List> stackTypes, Type elementType) { + checkCondition(stackTypes.size() <= 254, NOT_SUPPORTED, "Too many arguments for array constructor"); List stackTypeNames = stackTypes.stream() .map(Class::getSimpleName) .collect(toImmutableList()); diff --git a/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayContains.java b/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayContains.java index 148552e1db4d..a0d85f074afc 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayContains.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayContains.java @@ -13,7 +13,6 @@ */ package com.facebook.presto.operator.scalar; -import com.facebook.presto.spi.PrestoException; import com.facebook.presto.spi.block.Block; import com.facebook.presto.spi.function.Description; import com.facebook.presto.spi.function.OperatorDependency; @@ -23,13 +22,12 @@ import com.facebook.presto.spi.function.TypeParameter; import com.facebook.presto.spi.type.StandardTypes; import com.facebook.presto.spi.type.Type; -import com.google.common.base.Throwables; import io.airlift.slice.Slice; import java.lang.invoke.MethodHandle; -import static com.facebook.presto.spi.StandardErrorCode.GENERIC_INTERNAL_ERROR; import static com.facebook.presto.spi.function.OperatorType.EQUAL; +import static com.facebook.presto.util.Failures.internalError; @Description("Determines whether given value exists in the array") @ScalarFunction("contains") @@ -64,10 +62,7 @@ public static Boolean contains(@TypeParameter("T") Type elementType, } } catch (Throwable t) { - Throwables.propagateIfInstanceOf(t, Error.class); - Throwables.propagateIfInstanceOf(t, PrestoException.class); - - throw new PrestoException(GENERIC_INTERNAL_ERROR, t); + throw internalError(t); } } if (foundNull) { @@ -96,10 +91,7 @@ public static Boolean contains(@TypeParameter("T") Type elementType, } } catch (Throwable t) { - Throwables.propagateIfInstanceOf(t, Error.class); - Throwables.propagateIfInstanceOf(t, PrestoException.class); - - throw new PrestoException(GENERIC_INTERNAL_ERROR, t); + throw internalError(t); } } if (foundNull) { @@ -128,10 +120,7 @@ public static Boolean contains(@TypeParameter("T") Type elementType, } } catch (Throwable t) { - Throwables.propagateIfInstanceOf(t, Error.class); - Throwables.propagateIfInstanceOf(t, PrestoException.class); - - throw new PrestoException(GENERIC_INTERNAL_ERROR, t); + throw internalError(t); } } if (foundNull) { @@ -160,10 +149,7 @@ public static Boolean contains(@TypeParameter("T") Type elementType, } } catch (Throwable t) { - Throwables.propagateIfInstanceOf(t, Error.class); - Throwables.propagateIfInstanceOf(t, PrestoException.class); - - throw new PrestoException(GENERIC_INTERNAL_ERROR, t); + throw internalError(t); } } if (foundNull) { @@ -192,10 +178,7 @@ public static Boolean contains(@TypeParameter("T") Type elementType, } } catch (Throwable t) { - Throwables.propagateIfInstanceOf(t, Error.class); - Throwables.propagateIfInstanceOf(t, PrestoException.class); - - throw new PrestoException(GENERIC_INTERNAL_ERROR, t); + throw internalError(t); } } if (foundNull) { diff --git a/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayDistinctFromOperator.java b/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayDistinctFromOperator.java index ad3fe14bc2d3..0257615dbbff 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayDistinctFromOperator.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayDistinctFromOperator.java @@ -13,7 +13,6 @@ * limitations under the License. */ -import com.facebook.presto.spi.PrestoException; import com.facebook.presto.spi.block.Block; import com.facebook.presto.spi.function.IsNull; import com.facebook.presto.spi.function.OperatorDependency; @@ -22,13 +21,12 @@ import com.facebook.presto.spi.function.TypeParameter; import com.facebook.presto.spi.type.StandardTypes; import com.facebook.presto.spi.type.Type; -import com.google.common.base.Throwables; import java.lang.invoke.MethodHandle; -import static com.facebook.presto.spi.StandardErrorCode.GENERIC_INTERNAL_ERROR; import static com.facebook.presto.spi.function.OperatorType.IS_DISTINCT_FROM; import static com.facebook.presto.spi.type.TypeUtils.readNativeValue; +import static com.facebook.presto.util.Failures.internalError; import static com.google.common.base.Defaults.defaultValue; @ScalarOperator(IS_DISTINCT_FROM) @@ -76,10 +74,7 @@ public static boolean isDistinctFrom( } } catch (Throwable t) { - Throwables.propagateIfInstanceOf(t, Error.class); - Throwables.propagateIfInstanceOf(t, PrestoException.class); - - throw new PrestoException(GENERIC_INTERNAL_ERROR, t); + throw internalError(t); } } return false; diff --git a/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayEqualOperator.java b/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayEqualOperator.java index 0e6db85df20a..5f1aa88b3aaf 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayEqualOperator.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayEqualOperator.java @@ -13,7 +13,6 @@ * limitations under the License. */ -import com.facebook.presto.spi.PrestoException; import com.facebook.presto.spi.block.Block; import com.facebook.presto.spi.function.OperatorDependency; import com.facebook.presto.spi.function.ScalarOperator; @@ -21,15 +20,14 @@ import com.facebook.presto.spi.function.TypeParameter; import com.facebook.presto.spi.type.StandardTypes; import com.facebook.presto.spi.type.Type; -import com.google.common.base.Throwables; import java.lang.invoke.MethodHandle; -import static com.facebook.presto.spi.StandardErrorCode.GENERIC_INTERNAL_ERROR; import static com.facebook.presto.spi.function.OperatorType.EQUAL; +import static com.facebook.presto.spi.type.ArrayType.ARRAY_NULL_ELEMENT_MSG; import static com.facebook.presto.spi.type.TypeUtils.readNativeValue; -import static com.facebook.presto.type.ArrayType.ARRAY_NULL_ELEMENT_MSG; import static com.facebook.presto.type.TypeUtils.checkElementNotNull; +import static com.facebook.presto.util.Failures.internalError; @ScalarOperator(EQUAL) public final class ArrayEqualOperator @@ -58,10 +56,7 @@ public static boolean equals( } } catch (Throwable t) { - Throwables.propagateIfInstanceOf(t, Error.class); - Throwables.propagateIfInstanceOf(t, PrestoException.class); - - throw new PrestoException(GENERIC_INTERNAL_ERROR, t); + throw internalError(t); } } return true; diff --git a/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayFlattenFunction.java b/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayFlattenFunction.java index b25fc849d209..06ad23192ff6 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayFlattenFunction.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayFlattenFunction.java @@ -32,6 +32,7 @@ import static com.facebook.presto.metadata.Signature.typeVariable; import static com.facebook.presto.spi.type.TypeSignature.parseTypeSignature; import static com.facebook.presto.util.Reflection.methodHandle; +import static java.lang.Math.toIntExact; public class ArrayFlattenFunction extends SqlScalarFunction @@ -84,7 +85,7 @@ public static Block flatten(Type type, Type arrayType, Block array) return type.createBlockBuilder(new BlockBuilderStatus(), 0).build(); } - BlockBuilder builder = type.createBlockBuilder(new BlockBuilderStatus(), array.getPositionCount(), array.getSizeInBytes() / array.getPositionCount()); + BlockBuilder builder = type.createBlockBuilder(new BlockBuilderStatus(), array.getPositionCount(), toIntExact(array.getSizeInBytes() / array.getPositionCount())); for (int i = 0; i < array.getPositionCount(); i++) { if (!array.isNull(i)) { Block subArray = (Block) arrayType.getObject(array, i); diff --git a/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayFunctions.java b/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayFunctions.java index 71465038997b..2ecb5903c320 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayFunctions.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayFunctions.java @@ -18,7 +18,7 @@ import com.facebook.presto.spi.block.BlockBuilderStatus; import com.facebook.presto.spi.function.ScalarFunction; import com.facebook.presto.spi.function.SqlType; -import com.facebook.presto.type.ArrayType; +import com.facebook.presto.spi.type.ArrayType; import static com.facebook.presto.type.UnknownType.UNKNOWN; diff --git a/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayGreaterThanOperator.java b/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayGreaterThanOperator.java index 9fb1ab3bab83..2d590c16d0de 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayGreaterThanOperator.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayGreaterThanOperator.java @@ -13,7 +13,6 @@ * limitations under the License. */ -import com.facebook.presto.spi.PrestoException; import com.facebook.presto.spi.block.Block; import com.facebook.presto.spi.function.OperatorDependency; import com.facebook.presto.spi.function.ScalarOperator; @@ -22,15 +21,14 @@ import com.facebook.presto.spi.function.TypeParameterSpecialization; import com.facebook.presto.spi.type.StandardTypes; import com.facebook.presto.spi.type.Type; -import com.google.common.base.Throwables; import java.lang.invoke.MethodHandle; -import static com.facebook.presto.spi.StandardErrorCode.GENERIC_INTERNAL_ERROR; import static com.facebook.presto.spi.function.OperatorType.GREATER_THAN; +import static com.facebook.presto.spi.type.ArrayType.ARRAY_NULL_ELEMENT_MSG; import static com.facebook.presto.spi.type.TypeUtils.readNativeValue; -import static com.facebook.presto.type.ArrayType.ARRAY_NULL_ELEMENT_MSG; import static com.facebook.presto.type.TypeUtils.checkElementNotNull; +import static com.facebook.presto.util.Failures.internalError; @ScalarOperator(GREATER_THAN) public final class ArrayGreaterThanOperator @@ -61,10 +59,7 @@ public static boolean greaterThan( } } catch (Throwable t) { - Throwables.propagateIfInstanceOf(t, Error.class); - Throwables.propagateIfInstanceOf(t, PrestoException.class); - - throw new PrestoException(GENERIC_INTERNAL_ERROR, t); + throw internalError(t); } index++; } @@ -97,10 +92,7 @@ public static boolean greaterThanLong( } } catch (Throwable t) { - Throwables.propagateIfInstanceOf(t, Error.class); - Throwables.propagateIfInstanceOf(t, PrestoException.class); - - throw new PrestoException(GENERIC_INTERNAL_ERROR, t); + throw internalError(t); } index++; } diff --git a/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayGreaterThanOrEqualOperator.java b/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayGreaterThanOrEqualOperator.java index 0cb9b023e887..d0a066185190 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayGreaterThanOrEqualOperator.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayGreaterThanOrEqualOperator.java @@ -13,7 +13,6 @@ * limitations under the License. */ -import com.facebook.presto.spi.PrestoException; import com.facebook.presto.spi.block.Block; import com.facebook.presto.spi.function.OperatorDependency; import com.facebook.presto.spi.function.ScalarOperator; @@ -21,16 +20,15 @@ import com.facebook.presto.spi.function.TypeParameter; import com.facebook.presto.spi.type.StandardTypes; import com.facebook.presto.spi.type.Type; -import com.google.common.base.Throwables; import java.lang.invoke.MethodHandle; -import static com.facebook.presto.spi.StandardErrorCode.GENERIC_INTERNAL_ERROR; import static com.facebook.presto.spi.function.OperatorType.GREATER_THAN; import static com.facebook.presto.spi.function.OperatorType.GREATER_THAN_OR_EQUAL; +import static com.facebook.presto.spi.type.ArrayType.ARRAY_NULL_ELEMENT_MSG; import static com.facebook.presto.spi.type.TypeUtils.readNativeValue; -import static com.facebook.presto.type.ArrayType.ARRAY_NULL_ELEMENT_MSG; import static com.facebook.presto.type.TypeUtils.checkElementNotNull; +import static com.facebook.presto.util.Failures.internalError; @ScalarOperator(GREATER_THAN_OR_EQUAL) public final class ArrayGreaterThanOrEqualOperator @@ -61,10 +59,7 @@ public static boolean greaterThanOrEqual( } } catch (Throwable t) { - Throwables.propagateIfInstanceOf(t, Error.class); - Throwables.propagateIfInstanceOf(t, PrestoException.class); - - throw new PrestoException(GENERIC_INTERNAL_ERROR, t); + throw internalError(t); } index++; } diff --git a/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayHashCodeOperator.java b/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayHashCodeOperator.java index dd8eef6f0a19..cc26bf36d930 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayHashCodeOperator.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayHashCodeOperator.java @@ -13,8 +13,6 @@ * limitations under the License. */ -import com.facebook.presto.spi.PrestoException; -import com.facebook.presto.spi.StandardErrorCode; import com.facebook.presto.spi.block.Block; import com.facebook.presto.spi.function.OperatorDependency; import com.facebook.presto.spi.function.ScalarOperator; @@ -23,7 +21,6 @@ import com.facebook.presto.spi.function.TypeParameterSpecialization; import com.facebook.presto.spi.type.StandardTypes; import com.facebook.presto.spi.type.Type; -import com.google.common.base.Throwables; import io.airlift.slice.Slice; import java.lang.invoke.MethodHandle; @@ -31,6 +28,7 @@ import static com.facebook.presto.spi.function.OperatorType.HASH_CODE; import static com.facebook.presto.spi.type.TypeUtils.readNativeValue; import static com.facebook.presto.type.TypeUtils.NULL_HASH_CODE; +import static com.facebook.presto.util.Failures.internalError; @ScalarOperator(HASH_CODE) public final class ArrayHashCodeOperator @@ -135,12 +133,4 @@ public static long hashDouble( throw internalError(t); } } - - private static PrestoException internalError(Throwable t) - { - Throwables.propagateIfInstanceOf(t, Error.class); - Throwables.propagateIfInstanceOf(t, PrestoException.class); - - return new PrestoException(StandardErrorCode.GENERIC_INTERNAL_ERROR, t); - } } diff --git a/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayJoin.java b/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayJoin.java index 01bd53a45c1c..aef060e9cca9 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayJoin.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayJoin.java @@ -45,6 +45,7 @@ import static com.facebook.presto.spi.type.TypeSignature.parseTypeSignature; import static com.facebook.presto.spi.type.VarcharType.VARCHAR; import static com.facebook.presto.util.Reflection.methodHandle; +import static java.lang.Math.toIntExact; import static java.lang.String.format; public final class ArrayJoin @@ -202,7 +203,7 @@ public static Slice arrayJoin(MethodHandle castFunction, ConnectorSession sessio { int numElements = arrayBlock.getPositionCount(); - DynamicSliceOutput sliceOutput = new DynamicSliceOutput(arrayBlock.getSizeInBytes() + delimiter.length() * arrayBlock.getPositionCount()); + DynamicSliceOutput sliceOutput = new DynamicSliceOutput(toIntExact(arrayBlock.getSizeInBytes() + delimiter.length() * arrayBlock.getPositionCount())); for (int i = 0; i < numElements; i++) { if (arrayBlock.isNull(i)) { diff --git a/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayLessThanOperator.java b/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayLessThanOperator.java index b966d16121a7..c7e2c11e817b 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayLessThanOperator.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayLessThanOperator.java @@ -13,7 +13,6 @@ * limitations under the License. */ -import com.facebook.presto.spi.PrestoException; import com.facebook.presto.spi.block.Block; import com.facebook.presto.spi.function.OperatorDependency; import com.facebook.presto.spi.function.ScalarOperator; @@ -22,15 +21,14 @@ import com.facebook.presto.spi.function.TypeParameterSpecialization; import com.facebook.presto.spi.type.StandardTypes; import com.facebook.presto.spi.type.Type; -import com.google.common.base.Throwables; import java.lang.invoke.MethodHandle; -import static com.facebook.presto.spi.StandardErrorCode.GENERIC_INTERNAL_ERROR; import static com.facebook.presto.spi.function.OperatorType.LESS_THAN; +import static com.facebook.presto.spi.type.ArrayType.ARRAY_NULL_ELEMENT_MSG; import static com.facebook.presto.spi.type.TypeUtils.readNativeValue; -import static com.facebook.presto.type.ArrayType.ARRAY_NULL_ELEMENT_MSG; import static com.facebook.presto.type.TypeUtils.checkElementNotNull; +import static com.facebook.presto.util.Failures.internalError; @ScalarOperator(LESS_THAN) public final class ArrayLessThanOperator @@ -61,10 +59,7 @@ public static boolean lessThan( } } catch (Throwable t) { - Throwables.propagateIfInstanceOf(t, Error.class); - Throwables.propagateIfInstanceOf(t, PrestoException.class); - - throw new PrestoException(GENERIC_INTERNAL_ERROR, t); + throw internalError(t); } index++; } @@ -97,10 +92,7 @@ public static boolean lessThanLong( } } catch (Throwable t) { - Throwables.propagateIfInstanceOf(t, Error.class); - Throwables.propagateIfInstanceOf(t, PrestoException.class); - - throw new PrestoException(GENERIC_INTERNAL_ERROR, t); + throw internalError(t); } index++; } diff --git a/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayLessThanOrEqualOperator.java b/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayLessThanOrEqualOperator.java index 104f1b3cb80b..cf8f472ac2df 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayLessThanOrEqualOperator.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayLessThanOrEqualOperator.java @@ -13,7 +13,6 @@ * limitations under the License. */ -import com.facebook.presto.spi.PrestoException; import com.facebook.presto.spi.block.Block; import com.facebook.presto.spi.function.OperatorDependency; import com.facebook.presto.spi.function.ScalarOperator; @@ -22,16 +21,15 @@ import com.facebook.presto.spi.function.TypeParameterSpecialization; import com.facebook.presto.spi.type.StandardTypes; import com.facebook.presto.spi.type.Type; -import com.google.common.base.Throwables; import java.lang.invoke.MethodHandle; -import static com.facebook.presto.spi.StandardErrorCode.GENERIC_INTERNAL_ERROR; import static com.facebook.presto.spi.function.OperatorType.LESS_THAN; import static com.facebook.presto.spi.function.OperatorType.LESS_THAN_OR_EQUAL; +import static com.facebook.presto.spi.type.ArrayType.ARRAY_NULL_ELEMENT_MSG; import static com.facebook.presto.spi.type.TypeUtils.readNativeValue; -import static com.facebook.presto.type.ArrayType.ARRAY_NULL_ELEMENT_MSG; import static com.facebook.presto.type.TypeUtils.checkElementNotNull; +import static com.facebook.presto.util.Failures.internalError; @ScalarOperator(LESS_THAN_OR_EQUAL) public final class ArrayLessThanOrEqualOperator @@ -62,10 +60,7 @@ public static boolean lessThanOrEqual( } } catch (Throwable t) { - Throwables.propagateIfInstanceOf(t, Error.class); - Throwables.propagateIfInstanceOf(t, PrestoException.class); - - throw new PrestoException(GENERIC_INTERNAL_ERROR, t); + throw internalError(t); } index++; } @@ -98,10 +93,7 @@ public static boolean lessThanOrEqualLong( } } catch (Throwable t) { - Throwables.propagateIfInstanceOf(t, Error.class); - Throwables.propagateIfInstanceOf(t, PrestoException.class); - - throw new PrestoException(GENERIC_INTERNAL_ERROR, t); + throw internalError(t); } index++; } diff --git a/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayMaxFunction.java b/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayMaxFunction.java index a2983de4c246..14018ad53acf 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayMaxFunction.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayMaxFunction.java @@ -13,7 +13,6 @@ */ package com.facebook.presto.operator.scalar; -import com.facebook.presto.spi.PrestoException; import com.facebook.presto.spi.block.Block; import com.facebook.presto.spi.function.Description; import com.facebook.presto.spi.function.OperatorDependency; @@ -31,9 +30,8 @@ import static com.facebook.presto.operator.scalar.ArrayMinMaxUtils.doubleArrayMinMax; import static com.facebook.presto.operator.scalar.ArrayMinMaxUtils.longArrayMinMax; import static com.facebook.presto.operator.scalar.ArrayMinMaxUtils.sliceArrayMinMax; -import static com.facebook.presto.spi.StandardErrorCode.GENERIC_INTERNAL_ERROR; import static com.facebook.presto.spi.function.OperatorType.GREATER_THAN; -import static com.google.common.base.Throwables.propagateIfInstanceOf; +import static com.facebook.presto.util.Failures.internalError; @ScalarFunction("array_max") @Description("Get maximum value of array") @@ -122,9 +120,7 @@ public static Block blockArrayMax( return selectedValue; } catch (Throwable t) { - propagateIfInstanceOf(t, Error.class); - propagateIfInstanceOf(t, PrestoException.class); - throw new PrestoException(GENERIC_INTERNAL_ERROR, t); + throw internalError(t); } } } diff --git a/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayMinFunction.java b/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayMinFunction.java index 720c0d923c3b..f606d640c3f9 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayMinFunction.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayMinFunction.java @@ -13,7 +13,6 @@ */ package com.facebook.presto.operator.scalar; -import com.facebook.presto.spi.PrestoException; import com.facebook.presto.spi.block.Block; import com.facebook.presto.spi.function.Description; import com.facebook.presto.spi.function.OperatorDependency; @@ -31,9 +30,8 @@ import static com.facebook.presto.operator.scalar.ArrayMinMaxUtils.doubleArrayMinMax; import static com.facebook.presto.operator.scalar.ArrayMinMaxUtils.longArrayMinMax; import static com.facebook.presto.operator.scalar.ArrayMinMaxUtils.sliceArrayMinMax; -import static com.facebook.presto.spi.StandardErrorCode.GENERIC_INTERNAL_ERROR; import static com.facebook.presto.spi.function.OperatorType.LESS_THAN; -import static com.google.common.base.Throwables.propagateIfInstanceOf; +import static com.facebook.presto.util.Failures.internalError; @ScalarFunction("array_min") @Description("Get minimum value of array") @@ -122,9 +120,7 @@ public static Block blockArrayMin( return selectedValue; } catch (Throwable t) { - propagateIfInstanceOf(t, Error.class); - propagateIfInstanceOf(t, PrestoException.class); - throw new PrestoException(GENERIC_INTERNAL_ERROR, t); + throw internalError(t); } } } diff --git a/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayMinMaxUtils.java b/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayMinMaxUtils.java index 32eb9f0c8f91..40a3601bfd60 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayMinMaxUtils.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayMinMaxUtils.java @@ -14,15 +14,13 @@ package com.facebook.presto.operator.scalar; import com.facebook.presto.annotation.UsedByGeneratedCode; -import com.facebook.presto.spi.PrestoException; import com.facebook.presto.spi.block.Block; import com.facebook.presto.spi.type.Type; import io.airlift.slice.Slice; import java.lang.invoke.MethodHandle; -import static com.facebook.presto.spi.StandardErrorCode.GENERIC_INTERNAL_ERROR; -import static com.google.common.base.Throwables.propagateIfInstanceOf; +import static com.facebook.presto.util.Failures.internalError; public final class ArrayMinMaxUtils { @@ -50,9 +48,7 @@ public static Long longArrayMinMax(MethodHandle compareMethodHandle, Type elemen return selectedValue; } catch (Throwable t) { - propagateIfInstanceOf(t, Error.class); - propagateIfInstanceOf(t, PrestoException.class); - throw new PrestoException(GENERIC_INTERNAL_ERROR, t); + throw internalError(t); } } @@ -78,9 +74,7 @@ public static Boolean booleanArrayMinMax(MethodHandle compareMethodHandle, Type return selectedValue; } catch (Throwable t) { - propagateIfInstanceOf(t, Error.class); - propagateIfInstanceOf(t, PrestoException.class); - throw new PrestoException(GENERIC_INTERNAL_ERROR, t); + throw internalError(t); } } @@ -106,9 +100,7 @@ public static Double doubleArrayMinMax(MethodHandle compareMethodHandle, Type el return selectedValue; } catch (Throwable t) { - propagateIfInstanceOf(t, Error.class); - propagateIfInstanceOf(t, PrestoException.class); - throw new PrestoException(GENERIC_INTERNAL_ERROR, t); + throw internalError(t); } } @@ -134,9 +126,7 @@ public static Slice sliceArrayMinMax(MethodHandle compareMethodHandle, Type elem return selectedValue; } catch (Throwable t) { - propagateIfInstanceOf(t, Error.class); - propagateIfInstanceOf(t, PrestoException.class); - throw new PrestoException(GENERIC_INTERNAL_ERROR, t); + throw internalError(t); } } } diff --git a/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayPositionFunction.java b/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayPositionFunction.java index 30698eec89b2..c42632221ed9 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayPositionFunction.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayPositionFunction.java @@ -13,7 +13,6 @@ */ package com.facebook.presto.operator.scalar; -import com.facebook.presto.spi.PrestoException; import com.facebook.presto.spi.block.Block; import com.facebook.presto.spi.function.Description; import com.facebook.presto.spi.function.OperatorDependency; @@ -23,13 +22,12 @@ import com.facebook.presto.spi.function.TypeParameter; import com.facebook.presto.spi.type.StandardTypes; import com.facebook.presto.spi.type.Type; -import com.google.common.base.Throwables; import io.airlift.slice.Slice; import java.lang.invoke.MethodHandle; -import static com.facebook.presto.spi.StandardErrorCode.GENERIC_INTERNAL_ERROR; import static com.facebook.presto.spi.function.OperatorType.EQUAL; +import static com.facebook.presto.util.Failures.internalError; @Description("Returns the position of the first occurrence of the given value in array (or 0 if not found)") @ScalarFunction("array_position") @@ -54,9 +52,7 @@ public static long arrayPosition(@TypeParameter("T") Type type, } } catch (Throwable t) { - Throwables.propagateIfInstanceOf(t, Error.class); - Throwables.propagateIfInstanceOf(t, PrestoException.class); - throw new PrestoException(GENERIC_INTERNAL_ERROR, t); + throw internalError(t); } } } @@ -80,9 +76,7 @@ public static long arrayPosition(@TypeParameter("T") Type type, } } catch (Throwable t) { - Throwables.propagateIfInstanceOf(t, Error.class); - Throwables.propagateIfInstanceOf(t, PrestoException.class); - throw new PrestoException(GENERIC_INTERNAL_ERROR, t); + throw internalError(t); } } } @@ -106,9 +100,7 @@ public static long arrayPosition(@TypeParameter("T") Type type, } } catch (Throwable t) { - Throwables.propagateIfInstanceOf(t, Error.class); - Throwables.propagateIfInstanceOf(t, PrestoException.class); - throw new PrestoException(GENERIC_INTERNAL_ERROR, t); + throw internalError(t); } } } @@ -132,9 +124,7 @@ public static long arrayPosition(@TypeParameter("T") Type type, } } catch (Throwable t) { - Throwables.propagateIfInstanceOf(t, Error.class); - Throwables.propagateIfInstanceOf(t, PrestoException.class); - throw new PrestoException(GENERIC_INTERNAL_ERROR, t); + throw internalError(t); } } } @@ -158,9 +148,7 @@ public static long arrayPosition(@TypeParameter("T") Type type, } } catch (Throwable t) { - Throwables.propagateIfInstanceOf(t, Error.class); - Throwables.propagateIfInstanceOf(t, PrestoException.class); - throw new PrestoException(GENERIC_INTERNAL_ERROR, t); + throw internalError(t); } } } diff --git a/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayRemoveFunction.java b/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayRemoveFunction.java index 78fab02a2b34..d63082bd0ecc 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayRemoveFunction.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayRemoveFunction.java @@ -13,7 +13,6 @@ */ package com.facebook.presto.operator.scalar; -import com.facebook.presto.spi.PrestoException; import com.facebook.presto.spi.block.Block; import com.facebook.presto.spi.block.BlockBuilder; import com.facebook.presto.spi.block.BlockBuilderStatus; @@ -24,15 +23,14 @@ import com.facebook.presto.spi.function.TypeParameter; import com.facebook.presto.spi.type.StandardTypes; import com.facebook.presto.spi.type.Type; -import com.google.common.base.Throwables; import java.lang.invoke.MethodHandle; import java.util.ArrayList; import java.util.List; -import static com.facebook.presto.spi.StandardErrorCode.GENERIC_INTERNAL_ERROR; import static com.facebook.presto.spi.function.OperatorType.EQUAL; import static com.facebook.presto.spi.type.TypeUtils.readNativeValue; +import static com.facebook.presto.util.Failures.internalError; @ScalarFunction("array_remove") @Description("Remove specified values from the given array") @@ -94,10 +92,7 @@ public static Block remove( } } catch (Throwable t) { - Throwables.propagateIfInstanceOf(t, Error.class); - Throwables.propagateIfInstanceOf(t, PrestoException.class); - - throw new PrestoException(GENERIC_INTERNAL_ERROR, t); + throw internalError(t); } } diff --git a/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayTransformFunction.java b/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayTransformFunction.java index 15e41191c1cc..3af984b239ad 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayTransformFunction.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayTransformFunction.java @@ -30,11 +30,11 @@ import com.facebook.presto.spi.PageBuilder; import com.facebook.presto.spi.block.Block; import com.facebook.presto.spi.block.BlockBuilder; +import com.facebook.presto.spi.type.ArrayType; import com.facebook.presto.spi.type.Type; import com.facebook.presto.spi.type.TypeManager; import com.facebook.presto.sql.gen.CallSiteBinder; import com.facebook.presto.sql.gen.lambda.UnaryFunctionInterface; -import com.facebook.presto.type.ArrayType; import com.google.common.collect.ImmutableList; import com.google.common.primitives.Primitives; diff --git a/presto-main/src/main/java/com/facebook/presto/operator/scalar/ConcatFunction.java b/presto-main/src/main/java/com/facebook/presto/operator/scalar/ConcatFunction.java index a3b8c1fe5dc0..fcaf10cf08f2 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/scalar/ConcatFunction.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/scalar/ConcatFunction.java @@ -30,6 +30,7 @@ import com.facebook.presto.metadata.SqlScalarFunction; import com.facebook.presto.spi.PrestoException; import com.facebook.presto.spi.type.TypeManager; +import com.facebook.presto.spi.type.TypeSignature; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import io.airlift.slice.Slice; @@ -52,7 +53,10 @@ import static com.facebook.presto.bytecode.expression.BytecodeExpressions.constantInt; import static com.facebook.presto.bytecode.expression.BytecodeExpressions.invokeStatic; import static com.facebook.presto.spi.StandardErrorCode.INVALID_FUNCTION_ARGUMENT; +import static com.facebook.presto.spi.StandardErrorCode.NOT_SUPPORTED; +import static com.facebook.presto.spi.type.VarbinaryType.VARBINARY; import static com.facebook.presto.spi.type.VarcharType.createUnboundedVarcharType; +import static com.facebook.presto.util.Failures.checkCondition; import static com.facebook.presto.util.Reflection.methodHandle; import static com.google.common.collect.ImmutableList.toImmutableList; import static java.lang.Math.addExact; @@ -60,19 +64,24 @@ public final class ConcatFunction extends SqlScalarFunction { - public static final ConcatFunction CONCAT = new ConcatFunction(); + // TODO design new variadic functions binding mechanism that will allow to produce VARCHAR(x) where x < MAX_LENGTH. + public static final ConcatFunction VARCHAR_CONCAT = new ConcatFunction(createUnboundedVarcharType().getTypeSignature(), "concatenates given strings"); - public ConcatFunction() + public static final ConcatFunction VARBINARY_CONCAT = new ConcatFunction(VARBINARY.getTypeSignature(), "concatenates given varbinary values"); + + private String description; + + private ConcatFunction(TypeSignature type, String description) { - // TODO design new variadic functions binding mechanism that will allow to produce VARCHAR(x) where x < MAX_LENGTH. super(new Signature( "concat", FunctionKind.SCALAR, ImmutableList.of(), ImmutableList.of(), - createUnboundedVarcharType().getTypeSignature(), - ImmutableList.of(createUnboundedVarcharType().getTypeSignature()), + type, + ImmutableList.of(type), true)); + this.description = description; } @Override @@ -90,7 +99,7 @@ public boolean isDeterministic() @Override public String getDescription() { - return "concatenates given strings"; + return description; } @Override @@ -100,18 +109,19 @@ public ScalarFunctionImplementation specialize(BoundVariables boundVariables, in throw new PrestoException(INVALID_FUNCTION_ARGUMENT, "There must be two or more concatenation arguments"); } - Class clazz = generateConcat(arity); + Class clazz = generateConcat(getSignature().getReturnType(), arity); MethodHandle methodHandle = methodHandle(clazz, "concat", Collections.nCopies(arity, Slice.class).toArray(new Class[arity])); List nullableParameters = ImmutableList.copyOf(Collections.nCopies(arity, false)); return new ScalarFunctionImplementation(false, nullableParameters, methodHandle, isDeterministic()); } - private static Class generateConcat(int arity) + private static Class generateConcat(TypeSignature type, int arity) { + checkCondition(arity <= 254, NOT_SUPPORTED, "Too many arguments for string concatenation"); ClassDefinition definition = new ClassDefinition( a(PUBLIC, FINAL), - CompilerUtils.makeClassName("Concat" + arity + "ScalarFunction"), + CompilerUtils.makeClassName(type.getBase() + "_concat" + arity + "ScalarFunction"), type(Object.class)); // Generate constructor diff --git a/presto-main/src/main/java/com/facebook/presto/operator/scalar/EmptyMapConstructor.java b/presto-main/src/main/java/com/facebook/presto/operator/scalar/EmptyMapConstructor.java index 54c4ddc3bc8c..226dbb909996 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/scalar/EmptyMapConstructor.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/scalar/EmptyMapConstructor.java @@ -14,26 +14,32 @@ package com.facebook.presto.operator.scalar; import com.facebook.presto.spi.block.Block; +import com.facebook.presto.spi.block.BlockBuilder; import com.facebook.presto.spi.block.BlockBuilderStatus; -import com.facebook.presto.spi.block.InterleavedBlockBuilder; import com.facebook.presto.spi.function.Description; import com.facebook.presto.spi.function.ScalarFunction; import com.facebook.presto.spi.function.SqlType; -import com.google.common.collect.ImmutableList; - -import static com.facebook.presto.type.UnknownType.UNKNOWN; +import com.facebook.presto.spi.function.TypeParameter; +import com.facebook.presto.spi.type.MapType; +import com.facebook.presto.spi.type.Type; public final class EmptyMapConstructor { - private static final Block EMPTY_MAP = new InterleavedBlockBuilder(ImmutableList.of(UNKNOWN, UNKNOWN), new BlockBuilderStatus(), 0).build(); + private final Block emptyMap; - private EmptyMapConstructor() {} + public EmptyMapConstructor(@TypeParameter("map(unknown,unknown)") Type mapType) + { + BlockBuilder mapBlockBuilder = mapType.createBlockBuilder(new BlockBuilderStatus(), 1); + mapBlockBuilder.beginBlockEntry(); + mapBlockBuilder.closeEntry(); + emptyMap = ((MapType) mapType).getObject(mapBlockBuilder.build(), 0); + } @Description("Creates an empty map") @ScalarFunction @SqlType("map(unknown,unknown)") - public static Block map() + public Block map() { - return EMPTY_MAP; + return emptyMap; } } diff --git a/presto-main/src/main/java/com/facebook/presto/operator/scalar/JsonToArrayCast.java b/presto-main/src/main/java/com/facebook/presto/operator/scalar/JsonToArrayCast.java index ebd161623c95..97a133ea68e5 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/scalar/JsonToArrayCast.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/scalar/JsonToArrayCast.java @@ -23,11 +23,11 @@ import com.facebook.presto.spi.block.BlockBuilder; import com.facebook.presto.spi.block.BlockBuilderStatus; import com.facebook.presto.spi.function.OperatorType; +import com.facebook.presto.spi.type.ArrayType; import com.facebook.presto.spi.type.StandardTypes; import com.facebook.presto.spi.type.Type; import com.facebook.presto.spi.type.TypeManager; import com.facebook.presto.spi.type.TypeSignatureParameter; -import com.facebook.presto.type.ArrayType; import com.google.common.collect.ImmutableList; import io.airlift.slice.Slice; diff --git a/presto-main/src/main/java/com/facebook/presto/operator/scalar/JsonToMapCast.java b/presto-main/src/main/java/com/facebook/presto/operator/scalar/JsonToMapCast.java index 89ab4680f25a..aee7ff4e835b 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/scalar/JsonToMapCast.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/scalar/JsonToMapCast.java @@ -22,13 +22,12 @@ import com.facebook.presto.spi.block.Block; import com.facebook.presto.spi.block.BlockBuilder; import com.facebook.presto.spi.block.BlockBuilderStatus; -import com.facebook.presto.spi.block.InterleavedBlockBuilder; import com.facebook.presto.spi.function.OperatorType; +import com.facebook.presto.spi.type.MapType; import com.facebook.presto.spi.type.StandardTypes; import com.facebook.presto.spi.type.Type; import com.facebook.presto.spi.type.TypeManager; import com.facebook.presto.spi.type.TypeSignatureParameter; -import com.facebook.presto.type.MapType; import com.google.common.collect.ImmutableList; import io.airlift.slice.Slice; @@ -50,7 +49,7 @@ public class JsonToMapCast extends SqlOperator { public static final JsonToMapCast JSON_TO_MAP = new JsonToMapCast(); - private static final MethodHandle METHOD_HANDLE = methodHandle(JsonToMapCast.class, "toMap", Type.class, ConnectorSession.class, Slice.class); + private static final MethodHandle METHOD_HANDLE = methodHandle(JsonToMapCast.class, "toMap", MapType.class, ConnectorSession.class, Slice.class); private JsonToMapCast() { @@ -67,28 +66,30 @@ public ScalarFunctionImplementation specialize(BoundVariables boundVariables, in checkArgument(arity == 1, "Expected arity to be 1"); Type keyType = boundVariables.getTypeVariable("K"); Type valueType = boundVariables.getTypeVariable("V"); - Type mapType = typeManager.getParameterizedType(StandardTypes.MAP, ImmutableList.of(TypeSignatureParameter.of(keyType.getTypeSignature()), TypeSignatureParameter.of(valueType.getTypeSignature()))); + MapType mapType = (MapType) typeManager.getParameterizedType(StandardTypes.MAP, ImmutableList.of(TypeSignatureParameter.of(keyType.getTypeSignature()), TypeSignatureParameter.of(valueType.getTypeSignature()))); checkCondition(canCastFromJson(mapType), INVALID_CAST_ARGUMENT, "Cannot cast JSON to %s", mapType); MethodHandle methodHandle = METHOD_HANDLE.bindTo(mapType); return new ScalarFunctionImplementation(true, ImmutableList.of(false), methodHandle, isDeterministic()); } @UsedByGeneratedCode - public static Block toMap(Type mapType, ConnectorSession connectorSession, Slice json) + public static Block toMap(MapType mapType, ConnectorSession connectorSession, Slice json) { try { Map map = (Map) stackRepresentationToObject(connectorSession, json, mapType); if (map == null) { return null; } - Type keyType = ((MapType) mapType).getKeyType(); - Type valueType = ((MapType) mapType).getValueType(); - BlockBuilder blockBuilder = new InterleavedBlockBuilder(ImmutableList.of(keyType, valueType), new BlockBuilderStatus(), map.size() * 2); + BlockBuilder mapBlockBuilder = mapType.createBlockBuilder(new BlockBuilderStatus(), 1); + BlockBuilder blockBuilder = mapBlockBuilder.beginBlockEntry(); + Type keyType = mapType.getKeyType(); + Type valueType = mapType.getValueType(); for (Map.Entry entry : map.entrySet()) { appendToBlockBuilder(keyType, entry.getKey(), blockBuilder); appendToBlockBuilder(valueType, entry.getValue(), blockBuilder); } - return blockBuilder.build(); + mapBlockBuilder.closeEntry(); + return mapType.getObject(mapBlockBuilder, mapBlockBuilder.getPositionCount() - 1); } catch (RuntimeException e) { throw new PrestoException(INVALID_CAST_ARGUMENT, "Value cannot be cast to " + mapType, e); diff --git a/presto-main/src/main/java/com/facebook/presto/operator/scalar/ListLiteralCast.java b/presto-main/src/main/java/com/facebook/presto/operator/scalar/ListLiteralCast.java index 9365fd76b6be..b34b86de18c2 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/scalar/ListLiteralCast.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/scalar/ListLiteralCast.java @@ -17,7 +17,7 @@ import com.facebook.presto.spi.function.OperatorType; import com.facebook.presto.spi.function.ScalarOperator; import com.facebook.presto.spi.function.SqlType; -import com.facebook.presto.type.ArrayType; +import com.facebook.presto.spi.type.ArrayType; import com.facebook.presto.type.ListLiteralType; import com.google.common.collect.ImmutableList; diff --git a/presto-main/src/main/java/com/facebook/presto/operator/scalar/MapConcatFunction.java b/presto-main/src/main/java/com/facebook/presto/operator/scalar/MapConcatFunction.java index cd7fdc45f9c3..dca56aa4c424 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/scalar/MapConcatFunction.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/scalar/MapConcatFunction.java @@ -24,9 +24,11 @@ import com.facebook.presto.spi.PrestoException; import com.facebook.presto.spi.block.Block; import com.facebook.presto.spi.block.BlockBuilder; -import com.facebook.presto.spi.block.InterleavedBlock; +import com.facebook.presto.spi.type.MapType; +import com.facebook.presto.spi.type.StandardTypes; import com.facebook.presto.spi.type.Type; import com.facebook.presto.spi.type.TypeManager; +import com.facebook.presto.spi.type.TypeSignatureParameter; import com.facebook.presto.sql.gen.VarArgsToArrayAdapterGenerator.MethodHandleAndConstructor; import com.google.common.collect.ImmutableList; @@ -49,8 +51,8 @@ public final class MapConcatFunction private static final String FUNCTION_NAME = "map_concat"; private static final String DESCRIPTION = "Concatenates given maps"; - private static final MethodHandle USER_STATE_FACTORY = methodHandle(MapConcatFunction.class, "createMapState", Type.class, Type.class); - private static final MethodHandle METHOD_HANDLE = methodHandle(MapConcatFunction.class, "mapConcat", Type.class, Type.class, Object.class, Block[].class); + private static final MethodHandle USER_STATE_FACTORY = methodHandle(MapConcatFunction.class, "createMapState", MapType.class); + private static final MethodHandle METHOD_HANDLE = methodHandle(MapConcatFunction.class, "mapConcat", MapType.class, Object.class, Block[].class); private MapConcatFunction() { @@ -90,13 +92,16 @@ public ScalarFunctionImplementation specialize(BoundVariables boundVariables, in Type keyType = boundVariables.getTypeVariable("K"); Type valueType = boundVariables.getTypeVariable("V"); + MapType mapType = (MapType) typeManager.getParameterizedType(StandardTypes.MAP, ImmutableList.of( + TypeSignatureParameter.of(keyType.getTypeSignature()), + TypeSignatureParameter.of(valueType.getTypeSignature()))); MethodHandleAndConstructor methodHandleAndConstructor = generateVarArgsToArrayAdapter( Block.class, Block.class, arity, - METHOD_HANDLE.bindTo(keyType).bindTo(valueType), - USER_STATE_FACTORY.bindTo(keyType).bindTo(valueType)); + METHOD_HANDLE.bindTo(mapType), + USER_STATE_FACTORY.bindTo(mapType)); return new ScalarFunctionImplementation( false, @@ -109,13 +114,13 @@ public ScalarFunctionImplementation specialize(BoundVariables boundVariables, in } @UsedByGeneratedCode - public static Object createMapState(Type keyType, Type valueType) + public static Object createMapState(MapType mapType) { - return new PageBuilder(ImmutableList.of(keyType, valueType)); + return new PageBuilder(ImmutableList.of(mapType)); } @UsedByGeneratedCode - public static Block mapConcat(Type keyType, Type valueType, Object state, Block[] maps) + public static Block mapConcat(MapType mapType, Object state, Block[] maps) { int entries = 0; int lastMapIndex = maps.length - 1; @@ -137,17 +142,19 @@ public static Block mapConcat(Type keyType, Type valueType, Object state, Block[ } // TODO: we should move TypedSet into user state as well + Type keyType = mapType.getKeyType(); + Type valueType = mapType.getValueType(); TypedSet typedSet = new TypedSet(keyType, entries / 2); - BlockBuilder keyBlockBuilder = pageBuilder.getBlockBuilder(0); - BlockBuilder valueBlockBuilder = pageBuilder.getBlockBuilder(1); + BlockBuilder mapBlockBuilder = pageBuilder.getBlockBuilder(0); + BlockBuilder blockBuilder = mapBlockBuilder.beginBlockEntry(); // the last map Block map = maps[lastMapIndex]; int total = 0; for (int i = 0; i < map.getPositionCount(); i += 2) { typedSet.add(map, i); - keyType.appendTo(map, i, keyBlockBuilder); - valueType.appendTo(map, i + 1, valueBlockBuilder); + keyType.appendTo(map, i, blockBuilder); + valueType.appendTo(map, i + 1, blockBuilder); total++; } // the map between the last and the first @@ -156,8 +163,8 @@ public static Block mapConcat(Type keyType, Type valueType, Object state, Block[ for (int i = 0; i < map.getPositionCount(); i += 2) { if (!typedSet.contains(map, i)) { typedSet.add(map, i); - keyType.appendTo(map, i, keyBlockBuilder); - valueType.appendTo(map, i + 1, valueBlockBuilder); + keyType.appendTo(map, i, blockBuilder); + valueType.appendTo(map, i + 1, blockBuilder); total++; } } @@ -166,16 +173,14 @@ public static Block mapConcat(Type keyType, Type valueType, Object state, Block[ map = maps[firstMapIndex]; for (int i = 0; i < map.getPositionCount(); i += 2) { if (!typedSet.contains(map, i)) { - keyType.appendTo(map, i, keyBlockBuilder); - valueType.appendTo(map, i + 1, valueBlockBuilder); + keyType.appendTo(map, i, blockBuilder); + valueType.appendTo(map, i + 1, blockBuilder); total++; } } - pageBuilder.declarePositions(total); - Block[] blocks = new Block[2]; - blocks[0] = keyBlockBuilder.getRegion(keyBlockBuilder.getPositionCount() - total, total); - blocks[1] = valueBlockBuilder.getRegion(valueBlockBuilder.getPositionCount() - total, total); - return new InterleavedBlock(blocks); + mapBlockBuilder.closeEntry(); + pageBuilder.declarePosition(); + return mapType.getObject(mapBlockBuilder, mapBlockBuilder.getPositionCount() - 1); } } diff --git a/presto-main/src/main/java/com/facebook/presto/operator/scalar/MapConstructor.java b/presto-main/src/main/java/com/facebook/presto/operator/scalar/MapConstructor.java index e73f0a2b576c..ac12e9690e88 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/scalar/MapConstructor.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/scalar/MapConstructor.java @@ -24,11 +24,11 @@ import com.facebook.presto.spi.block.Block; import com.facebook.presto.spi.block.BlockBuilder; import com.facebook.presto.spi.function.OperatorType; +import com.facebook.presto.spi.type.MapType; import com.facebook.presto.spi.type.Type; import com.facebook.presto.spi.type.TypeManager; import com.facebook.presto.spi.type.TypeSignature; import com.facebook.presto.spi.type.TypeSignatureParameter; -import com.facebook.presto.type.MapType; import com.google.common.collect.ImmutableList; import java.lang.invoke.MethodHandle; @@ -104,6 +104,7 @@ public ScalarFunctionImplementation specialize(BoundVariables boundVariables, in @UsedByGeneratedCode public static Block createMap(MapType mapType, MethodHandle keyEqual, MethodHandle keyHashCode, State state, Block keyBlock, Block valueBlock) { + checkCondition(keyBlock.getPositionCount() == valueBlock.getPositionCount(), INVALID_FUNCTION_ARGUMENT, "Key and value arrays must be the same length"); PageBuilder pageBuilder = state.getPageBuilder(); if (pageBuilder.isFull()) { pageBuilder.reset(); @@ -111,9 +112,11 @@ public static Block createMap(MapType mapType, MethodHandle keyEqual, MethodHand BlockBuilder mapBlockBuilder = pageBuilder.getBlockBuilder(0); BlockBuilder blockBuilder = mapBlockBuilder.beginBlockEntry(); - checkCondition(keyBlock.getPositionCount() == valueBlock.getPositionCount(), INVALID_FUNCTION_ARGUMENT, "Key and value arrays must be the same length"); for (int i = 0; i < keyBlock.getPositionCount(); i++) { if (keyBlock.isNull(i)) { + // close block builder before throwing as we may be in a TRY() call + // so that subsequent calls do not find it in an inconsistent state + mapBlockBuilder.closeEntry(); throw new PrestoException(INVALID_FUNCTION_ARGUMENT, "map key cannot be null"); } mapType.getKeyType().appendTo(keyBlock, i, blockBuilder); diff --git a/presto-main/src/main/java/com/facebook/presto/operator/scalar/MapElementAtFunction.java b/presto-main/src/main/java/com/facebook/presto/operator/scalar/MapElementAtFunction.java index 77f8003d3865..0723e04822dd 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/scalar/MapElementAtFunction.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/scalar/MapElementAtFunction.java @@ -19,13 +19,11 @@ import com.facebook.presto.metadata.FunctionRegistry; import com.facebook.presto.metadata.Signature; import com.facebook.presto.metadata.SqlScalarFunction; -import com.facebook.presto.spi.PrestoException; import com.facebook.presto.spi.block.Block; import com.facebook.presto.spi.function.OperatorType; import com.facebook.presto.spi.type.BooleanType; import com.facebook.presto.spi.type.Type; import com.facebook.presto.spi.type.TypeManager; -import com.google.common.base.Throwables; import com.google.common.collect.ImmutableList; import com.google.common.primitives.Primitives; import io.airlift.slice.Slice; @@ -34,9 +32,9 @@ import static com.facebook.presto.metadata.Signature.internalOperator; import static com.facebook.presto.metadata.Signature.typeVariable; -import static com.facebook.presto.spi.StandardErrorCode.GENERIC_INTERNAL_ERROR; import static com.facebook.presto.spi.type.TypeSignature.parseTypeSignature; import static com.facebook.presto.spi.type.TypeUtils.readNativeValue; +import static com.facebook.presto.util.Failures.internalError; import static com.facebook.presto.util.Reflection.methodHandle; public class MapElementAtFunction @@ -127,9 +125,7 @@ public static Object elementAt(MethodHandle keyEqualsMethod, Type keyType, Type } } catch (Throwable t) { - Throwables.propagateIfInstanceOf(t, Error.class); - Throwables.propagateIfInstanceOf(t, PrestoException.class); - throw new PrestoException(GENERIC_INTERNAL_ERROR, t); + throw internalError(t); } } return null; @@ -145,9 +141,7 @@ public static Object elementAt(MethodHandle keyEqualsMethod, Type keyType, Type } } catch (Throwable t) { - Throwables.propagateIfInstanceOf(t, Error.class); - Throwables.propagateIfInstanceOf(t, PrestoException.class); - throw new PrestoException(GENERIC_INTERNAL_ERROR, t); + throw internalError(t); } } return null; @@ -163,9 +157,7 @@ public static Object elementAt(MethodHandle keyEqualsMethod, Type keyType, Type } } catch (Throwable t) { - Throwables.propagateIfInstanceOf(t, Error.class); - Throwables.propagateIfInstanceOf(t, PrestoException.class); - throw new PrestoException(GENERIC_INTERNAL_ERROR, t); + throw internalError(t); } } return null; @@ -181,9 +173,7 @@ public static Object elementAt(MethodHandle keyEqualsMethod, Type keyType, Type } } catch (Throwable t) { - Throwables.propagateIfInstanceOf(t, Error.class); - Throwables.propagateIfInstanceOf(t, PrestoException.class); - throw new PrestoException(GENERIC_INTERNAL_ERROR, t); + throw internalError(t); } } return null; @@ -199,9 +189,7 @@ public static Object elementAt(MethodHandle keyEqualsMethod, Type keyType, Type } } catch (Throwable t) { - Throwables.propagateIfInstanceOf(t, Error.class); - Throwables.propagateIfInstanceOf(t, PrestoException.class); - throw new PrestoException(GENERIC_INTERNAL_ERROR, t); + throw internalError(t); } } return null; diff --git a/presto-main/src/main/java/com/facebook/presto/operator/scalar/MapGenericEquality.java b/presto-main/src/main/java/com/facebook/presto/operator/scalar/MapGenericEquality.java index d64c97c80caa..6aed7d1b222a 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/scalar/MapGenericEquality.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/scalar/MapGenericEquality.java @@ -13,17 +13,15 @@ * limitations under the License. */ -import com.facebook.presto.spi.PrestoException; import com.facebook.presto.spi.block.Block; import com.facebook.presto.spi.type.Type; -import com.google.common.base.Throwables; import java.lang.invoke.MethodHandle; import java.util.LinkedHashMap; import java.util.Map; -import static com.facebook.presto.spi.StandardErrorCode.GENERIC_INTERNAL_ERROR; import static com.facebook.presto.spi.type.TypeUtils.readNativeValue; +import static com.facebook.presto.util.Failures.internalError; public final class MapGenericEquality { @@ -73,10 +71,7 @@ else if (!result) { } } catch (Throwable t) { - Throwables.propagateIfInstanceOf(t, Error.class); - Throwables.propagateIfInstanceOf(t, PrestoException.class); - - throw new PrestoException(GENERIC_INTERNAL_ERROR, t); + throw internalError(t); } } return true; @@ -102,10 +97,7 @@ public int hashCode() return Long.hashCode((long) hashCode.invoke(key)); } catch (Throwable t) { - Throwables.propagateIfInstanceOf(t, Error.class); - Throwables.propagateIfInstanceOf(t, PrestoException.class); - - throw new PrestoException(GENERIC_INTERNAL_ERROR, t); + throw internalError(t); } } @@ -120,10 +112,7 @@ public boolean equals(Object obj) return (boolean) equals.invoke(key, other.key); } catch (Throwable t) { - Throwables.propagateIfInstanceOf(t, Error.class); - Throwables.propagateIfInstanceOf(t, PrestoException.class); - - throw new PrestoException(GENERIC_INTERNAL_ERROR, t); + throw internalError(t); } } } diff --git a/presto-main/src/main/java/com/facebook/presto/operator/scalar/MapSubscriptOperator.java b/presto-main/src/main/java/com/facebook/presto/operator/scalar/MapSubscriptOperator.java index 061738cfc721..788b2eab145c 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/scalar/MapSubscriptOperator.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/scalar/MapSubscriptOperator.java @@ -27,7 +27,6 @@ import com.facebook.presto.spi.type.TypeManager; import com.facebook.presto.spi.type.VarcharType; import com.facebook.presto.sql.FunctionInvoker; -import com.google.common.base.Throwables; import com.google.common.collect.ImmutableList; import com.google.common.primitives.Primitives; import io.airlift.slice.Slice; @@ -37,27 +36,28 @@ import static com.facebook.presto.metadata.Signature.internalOperator; import static com.facebook.presto.metadata.Signature.typeVariable; -import static com.facebook.presto.spi.StandardErrorCode.GENERIC_INTERNAL_ERROR; import static com.facebook.presto.spi.StandardErrorCode.INVALID_FUNCTION_ARGUMENT; import static com.facebook.presto.spi.function.OperatorType.SUBSCRIPT; import static com.facebook.presto.spi.type.TypeSignature.parseTypeSignature; import static com.facebook.presto.spi.type.TypeUtils.readNativeValue; import static com.facebook.presto.sql.relational.Signatures.castSignature; +import static com.facebook.presto.util.Failures.internalError; import static com.facebook.presto.util.Reflection.methodHandle; import static java.lang.String.format; public class MapSubscriptOperator extends SqlOperator { - private static final MethodHandle METHOD_HANDLE_BOOLEAN = methodHandle(MapSubscriptOperator.class, "subscript", boolean.class, FunctionInvoker.class, MethodHandle.class, Type.class, Type.class, ConnectorSession.class, Block.class, boolean.class); - private static final MethodHandle METHOD_HANDLE_LONG = methodHandle(MapSubscriptOperator.class, "subscript", boolean.class, FunctionInvoker.class, MethodHandle.class, Type.class, Type.class, ConnectorSession.class, Block.class, long.class); - private static final MethodHandle METHOD_HANDLE_DOUBLE = methodHandle(MapSubscriptOperator.class, "subscript", boolean.class, FunctionInvoker.class, MethodHandle.class, Type.class, Type.class, ConnectorSession.class, Block.class, double.class); - private static final MethodHandle METHOD_HANDLE_SLICE = methodHandle(MapSubscriptOperator.class, "subscript", boolean.class, FunctionInvoker.class, MethodHandle.class, Type.class, Type.class, ConnectorSession.class, Block.class, Slice.class); - private static final MethodHandle METHOD_HANDLE_OBJECT = methodHandle(MapSubscriptOperator.class, "subscript", boolean.class, FunctionInvoker.class, MethodHandle.class, Type.class, Type.class, ConnectorSession.class, Block.class, Object.class); + private static final MethodHandle METHOD_HANDLE_BOOLEAN = methodHandle(MapSubscriptOperator.class, "subscript", boolean.class, boolean.class, FunctionInvoker.class, MethodHandle.class, Type.class, Type.class, ConnectorSession.class, Block.class, boolean.class); + private static final MethodHandle METHOD_HANDLE_LONG = methodHandle(MapSubscriptOperator.class, "subscript", boolean.class, boolean.class, FunctionInvoker.class, MethodHandle.class, Type.class, Type.class, ConnectorSession.class, Block.class, long.class); + private static final MethodHandle METHOD_HANDLE_DOUBLE = methodHandle(MapSubscriptOperator.class, "subscript", boolean.class, boolean.class, FunctionInvoker.class, MethodHandle.class, Type.class, Type.class, ConnectorSession.class, Block.class, double.class); + private static final MethodHandle METHOD_HANDLE_SLICE = methodHandle(MapSubscriptOperator.class, "subscript", boolean.class, boolean.class, FunctionInvoker.class, MethodHandle.class, Type.class, Type.class, ConnectorSession.class, Block.class, Slice.class); + private static final MethodHandle METHOD_HANDLE_OBJECT = methodHandle(MapSubscriptOperator.class, "subscript", boolean.class, boolean.class, FunctionInvoker.class, MethodHandle.class, Type.class, Type.class, ConnectorSession.class, Block.class, Object.class); private final boolean legacyMissingKey; + private final boolean useNewMapBlock; - public MapSubscriptOperator(boolean legacyMissingKey) + public MapSubscriptOperator(boolean legacyMissingKey, boolean useNewMapBlock) { super(SUBSCRIPT, ImmutableList.of(typeVariable("K"), typeVariable("V")), @@ -65,6 +65,7 @@ public MapSubscriptOperator(boolean legacyMissingKey) parseTypeSignature("V"), ImmutableList.of(parseTypeSignature("map(K,V)"), parseTypeSignature("K"))); this.legacyMissingKey = legacyMissingKey; + this.useNewMapBlock = useNewMapBlock; } @Override @@ -91,7 +92,7 @@ else if (keyType.getJavaType() == Slice.class) { else { methodHandle = METHOD_HANDLE_OBJECT; } - methodHandle = MethodHandles.insertArguments(methodHandle, 0, legacyMissingKey); + methodHandle = MethodHandles.insertArguments(methodHandle, 0, legacyMissingKey, useNewMapBlock); FunctionInvoker functionInvoker = new FunctionInvoker(functionRegistry); methodHandle = methodHandle.bindTo(functionInvoker).bindTo(keyEqualsMethod).bindTo(keyType).bindTo(valueType); @@ -107,9 +108,9 @@ else if (keyType.getJavaType() == Slice.class) { } @UsedByGeneratedCode - public static Object subscript(boolean legacyMissingKey, FunctionInvoker functionInvoker, MethodHandle keyEqualsMethod, Type keyType, Type valueType, ConnectorSession session, Block map, boolean key) + public static Object subscript(boolean legacyMissingKey, boolean useNewMapBlock, FunctionInvoker functionInvoker, MethodHandle keyEqualsMethod, Type keyType, Type valueType, ConnectorSession session, Block map, boolean key) { - if (map instanceof SingleMapBlock) { + if (map instanceof SingleMapBlock && useNewMapBlock) { SingleMapBlock mapBlock = (SingleMapBlock) map; int valuePosition = mapBlock.seekKeyExact(key); if (valuePosition == -1) { @@ -128,9 +129,7 @@ public static Object subscript(boolean legacyMissingKey, FunctionInvoker functio } } catch (Throwable t) { - Throwables.propagateIfInstanceOf(t, Error.class); - Throwables.propagateIfInstanceOf(t, PrestoException.class); - throw new PrestoException(GENERIC_INTERNAL_ERROR, t); + throw internalError(t); } } if (legacyMissingKey) { @@ -140,9 +139,9 @@ public static Object subscript(boolean legacyMissingKey, FunctionInvoker functio } @UsedByGeneratedCode - public static Object subscript(boolean legacyMissingKey, FunctionInvoker functionInvoker, MethodHandle keyEqualsMethod, Type keyType, Type valueType, ConnectorSession session, Block map, long key) + public static Object subscript(boolean legacyMissingKey, boolean useNewMapBlock, FunctionInvoker functionInvoker, MethodHandle keyEqualsMethod, Type keyType, Type valueType, ConnectorSession session, Block map, long key) { - if (map instanceof SingleMapBlock) { + if (map instanceof SingleMapBlock && useNewMapBlock) { SingleMapBlock mapBlock = (SingleMapBlock) map; int valuePosition = mapBlock.seekKeyExact(key); if (valuePosition == -1) { @@ -161,9 +160,7 @@ public static Object subscript(boolean legacyMissingKey, FunctionInvoker functio } } catch (Throwable t) { - Throwables.propagateIfInstanceOf(t, Error.class); - Throwables.propagateIfInstanceOf(t, PrestoException.class); - throw new PrestoException(GENERIC_INTERNAL_ERROR, t); + throw internalError(t); } } if (legacyMissingKey) { @@ -173,9 +170,9 @@ public static Object subscript(boolean legacyMissingKey, FunctionInvoker functio } @UsedByGeneratedCode - public static Object subscript(boolean legacyMissingKey, FunctionInvoker functionInvoker, MethodHandle keyEqualsMethod, Type keyType, Type valueType, ConnectorSession session, Block map, double key) + public static Object subscript(boolean legacyMissingKey, boolean useNewMapBlock, FunctionInvoker functionInvoker, MethodHandle keyEqualsMethod, Type keyType, Type valueType, ConnectorSession session, Block map, double key) { - if (map instanceof SingleMapBlock) { + if (map instanceof SingleMapBlock && useNewMapBlock) { SingleMapBlock mapBlock = (SingleMapBlock) map; int valuePosition = mapBlock.seekKeyExact(key); if (valuePosition == -1) { @@ -194,9 +191,7 @@ public static Object subscript(boolean legacyMissingKey, FunctionInvoker functio } } catch (Throwable t) { - Throwables.propagateIfInstanceOf(t, Error.class); - Throwables.propagateIfInstanceOf(t, PrestoException.class); - throw new PrestoException(GENERIC_INTERNAL_ERROR, t); + throw internalError(t); } } if (legacyMissingKey) { @@ -206,9 +201,9 @@ public static Object subscript(boolean legacyMissingKey, FunctionInvoker functio } @UsedByGeneratedCode - public static Object subscript(boolean legacyMissingKey, FunctionInvoker functionInvoker, MethodHandle keyEqualsMethod, Type keyType, Type valueType, ConnectorSession session, Block map, Slice key) + public static Object subscript(boolean legacyMissingKey, boolean useNewMapBlock, FunctionInvoker functionInvoker, MethodHandle keyEqualsMethod, Type keyType, Type valueType, ConnectorSession session, Block map, Slice key) { - if (map instanceof SingleMapBlock) { + if (map instanceof SingleMapBlock && useNewMapBlock) { SingleMapBlock mapBlock = (SingleMapBlock) map; int valuePosition = mapBlock.seekKeyExact(key); if (valuePosition == -1) { @@ -227,9 +222,7 @@ public static Object subscript(boolean legacyMissingKey, FunctionInvoker functio } } catch (Throwable t) { - Throwables.propagateIfInstanceOf(t, Error.class); - Throwables.propagateIfInstanceOf(t, PrestoException.class); - throw new PrestoException(GENERIC_INTERNAL_ERROR, t); + throw internalError(t); } } if (legacyMissingKey) { @@ -239,9 +232,9 @@ public static Object subscript(boolean legacyMissingKey, FunctionInvoker functio } @UsedByGeneratedCode - public static Object subscript(boolean legacyMissingKey, FunctionInvoker functionInvoker, MethodHandle keyEqualsMethod, Type keyType, Type valueType, ConnectorSession session, Block map, Object key) + public static Object subscript(boolean legacyMissingKey, boolean useNewMapBlock, FunctionInvoker functionInvoker, MethodHandle keyEqualsMethod, Type keyType, Type valueType, ConnectorSession session, Block map, Object key) { - if (map instanceof SingleMapBlock) { + if (map instanceof SingleMapBlock && useNewMapBlock) { SingleMapBlock mapBlock = (SingleMapBlock) map; int valuePosition = mapBlock.seekKeyExact((Block) key); if (valuePosition == -1) { @@ -260,9 +253,7 @@ public static Object subscript(boolean legacyMissingKey, FunctionInvoker functio } } catch (Throwable t) { - Throwables.propagateIfInstanceOf(t, Error.class); - Throwables.propagateIfInstanceOf(t, PrestoException.class); - throw new PrestoException(GENERIC_INTERNAL_ERROR, t); + throw internalError(t); } } if (legacyMissingKey) { diff --git a/presto-main/src/main/java/com/facebook/presto/operator/scalar/MapToMapCast.java b/presto-main/src/main/java/com/facebook/presto/operator/scalar/MapToMapCast.java index 3aa6494dca60..0ae7d0cd8135 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/scalar/MapToMapCast.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/scalar/MapToMapCast.java @@ -20,15 +20,12 @@ import com.facebook.presto.spi.block.Block; import com.facebook.presto.spi.block.BlockBuilder; import com.facebook.presto.spi.block.BlockBuilderStatus; -import com.facebook.presto.spi.block.InterleavedBlockBuilder; import com.facebook.presto.spi.function.OperatorDependency; import com.facebook.presto.spi.function.ScalarOperator; import com.facebook.presto.spi.function.SqlType; import com.facebook.presto.spi.function.TypeParameter; import com.facebook.presto.spi.type.StandardTypes; import com.facebook.presto.spi.type.Type; -import com.google.common.base.Throwables; -import com.google.common.collect.ImmutableList; import java.lang.invoke.MethodHandle; import java.lang.invoke.MethodHandles; @@ -37,6 +34,7 @@ import static com.facebook.presto.spi.function.OperatorType.EQUAL; import static com.facebook.presto.spi.type.TypeUtils.readNativeValue; import static com.facebook.presto.spi.type.TypeUtils.writeNativeValue; +import static com.facebook.presto.util.Failures.internalError; @ScalarOperator(CAST) public final class MapToMapCast @@ -56,6 +54,7 @@ public static Block toMap( @TypeParameter("FV") Type fromValueType, @TypeParameter("TK") Type toKeyType, @TypeParameter("TV") Type toValueType, + @TypeParameter("map(TK,TV)") Type toMapType, ConnectorSession session, @SqlType("map(FK,FV)") Block fromMap) { @@ -88,13 +87,13 @@ public static Block toMap( writeNativeValue(toKeyType, keyBlockBuilder, toKey); } catch (Throwable t) { - Throwables.propagateIfInstanceOf(t, Error.class); - Throwables.propagateIfInstanceOf(t, PrestoException.class); - throw new PrestoException(StandardErrorCode.GENERIC_INTERNAL_ERROR, t); + throw internalError(t); } } Block keyBlock = keyBlockBuilder.build(); - BlockBuilder blockBuilder = new InterleavedBlockBuilder(ImmutableList.of(toKeyType, toValueType), new BlockBuilderStatus(), fromMap.getPositionCount()); + + BlockBuilder mapBlockBuilder = toMapType.createBlockBuilder(new BlockBuilderStatus(), 1); + BlockBuilder blockBuilder = mapBlockBuilder.beginBlockEntry(); for (int i = 0; i < fromMap.getPositionCount(); i += 2) { if (!typedSet.contains(keyBlock, i / 2)) { typedSet.add(keyBlock, i / 2); @@ -110,9 +109,7 @@ public static Block toMap( writeNativeValue(toValueType, blockBuilder, toValue); } catch (Throwable t) { - Throwables.propagateIfInstanceOf(t, Error.class); - Throwables.propagateIfInstanceOf(t, PrestoException.class); - throw new PrestoException(StandardErrorCode.GENERIC_INTERNAL_ERROR, t); + throw internalError(t); } } else { @@ -120,6 +117,8 @@ public static Block toMap( throw new PrestoException(StandardErrorCode.INVALID_CAST_ARGUMENT, "duplicate keys"); } } - return blockBuilder.build(); + + mapBlockBuilder.closeEntry(); + return (Block) toMapType.getObject(mapBlockBuilder, mapBlockBuilder.getPositionCount() - 1); } } diff --git a/presto-main/src/main/java/com/facebook/presto/operator/scalar/MapTransformKeyFunction.java b/presto-main/src/main/java/com/facebook/presto/operator/scalar/MapTransformKeyFunction.java index 86a1cc37c22c..9e36d1e075a6 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/scalar/MapTransformKeyFunction.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/scalar/MapTransformKeyFunction.java @@ -13,6 +13,7 @@ */ package com.facebook.presto.operator.scalar; +import com.facebook.presto.annotation.UsedByGeneratedCode; import com.facebook.presto.bytecode.BytecodeBlock; import com.facebook.presto.bytecode.BytecodeNode; import com.facebook.presto.bytecode.ClassDefinition; @@ -30,11 +31,11 @@ import com.facebook.presto.operator.aggregation.TypedSet; import com.facebook.presto.spi.ConnectorSession; import com.facebook.presto.spi.ErrorCodeSupplier; +import com.facebook.presto.spi.PageBuilder; import com.facebook.presto.spi.PrestoException; import com.facebook.presto.spi.block.Block; import com.facebook.presto.spi.block.BlockBuilder; -import com.facebook.presto.spi.block.BlockBuilderStatus; -import com.facebook.presto.spi.block.InterleavedBlockBuilder; +import com.facebook.presto.spi.type.MapType; import com.facebook.presto.spi.type.StandardTypes; import com.facebook.presto.spi.type.Type; import com.facebook.presto.spi.type.TypeManager; @@ -46,7 +47,6 @@ import com.google.common.primitives.Primitives; import java.lang.invoke.MethodHandle; -import java.util.List; import java.util.Optional; import static com.facebook.presto.bytecode.Access.FINAL; @@ -69,6 +69,7 @@ import static com.facebook.presto.bytecode.expression.BytecodeExpressions.lessThan; import static com.facebook.presto.bytecode.expression.BytecodeExpressions.newArray; import static com.facebook.presto.bytecode.expression.BytecodeExpressions.newInstance; +import static com.facebook.presto.bytecode.expression.BytecodeExpressions.subtract; import static com.facebook.presto.bytecode.instruction.VariableInstruction.incrementVariable; import static com.facebook.presto.metadata.Signature.typeVariable; import static com.facebook.presto.spi.StandardErrorCode.INVALID_FUNCTION_ARGUMENT; @@ -81,6 +82,7 @@ public final class MapTransformKeyFunction extends SqlScalarFunction { public static final MapTransformKeyFunction MAP_TRANSFORM_KEY_FUNCTION = new MapTransformKeyFunction(); + private static final MethodHandle STATE_FACTORY = methodHandle(MapTransformKeyFunction.class, "createState", MapType.class); private MapTransformKeyFunction() { @@ -118,7 +120,7 @@ public ScalarFunctionImplementation specialize(BoundVariables boundVariables, in Type keyType = boundVariables.getTypeVariable("K1"); Type transformedKeyType = boundVariables.getTypeVariable("K2"); Type valueType = boundVariables.getTypeVariable("V"); - Type resultMapType = typeManager.getParameterizedType(StandardTypes.MAP, ImmutableList.of( + MapType resultMapType = (MapType) typeManager.getParameterizedType(StandardTypes.MAP, ImmutableList.of( TypeSignatureParameter.of(transformedKeyType.getTypeSignature()), TypeSignatureParameter.of(valueType.getTypeSignature()))); return new ScalarFunctionImplementation( @@ -127,9 +129,16 @@ public ScalarFunctionImplementation specialize(BoundVariables boundVariables, in ImmutableList.of(false, false), ImmutableList.of(Optional.empty(), Optional.of(BinaryFunctionInterface.class)), generateTransformKey(keyType, transformedKeyType, valueType, resultMapType), + Optional.of(STATE_FACTORY.bindTo(resultMapType)), isDeterministic()); } + @UsedByGeneratedCode + public static Object createState(MapType mapType) + { + return new PageBuilder(ImmutableList.of(mapType)); + } + private static MethodHandle generateTransformKey(Type keyType, Type transformedKeyType, Type valueType, Type resultMapType) { CallSiteBinder binder = new CallSiteBinder(); @@ -143,6 +152,7 @@ private static MethodHandle generateTransformKey(Type keyType, Type transformedK type(Object.class)); definition.declareDefaultConstructor(a(PRIVATE)); + Parameter state = arg("state", Object.class); Parameter session = arg("session", ConnectorSession.class); Parameter block = arg("block", Block.class); Parameter function = arg("function", BinaryFunctionInterface.class); @@ -150,12 +160,14 @@ private static MethodHandle generateTransformKey(Type keyType, Type transformedK a(PUBLIC, STATIC), "transform", type(Block.class), - ImmutableList.of(session, block, function)); + ImmutableList.of(state, session, block, function)); BytecodeBlock body = method.getBody(); Scope scope = method.getScope(); Variable positionCount = scope.declareVariable(int.class, "positionCount"); Variable position = scope.declareVariable(int.class, "position"); + Variable pageBuilder = scope.declareVariable(PageBuilder.class, "pageBuilder"); + Variable mapBlockBuilder = scope.declareVariable(BlockBuilder.class, "mapBlockBuilder"); Variable blockBuilder = scope.declareVariable(BlockBuilder.class, "blockBuilder"); Variable typedSet = scope.declareVariable(TypedSet.class, "typeSet"); Variable keyElement = scope.declareVariable(keyJavaType, "keyElement"); @@ -165,12 +177,13 @@ private static MethodHandle generateTransformKey(Type keyType, Type transformedK // invoke block.getPositionCount() body.append(positionCount.set(block.invoke("getPositionCount", int.class))); - // create the interleaved block builder - body.append(blockBuilder.set(newInstance( - InterleavedBlockBuilder.class, - constantType(binder, resultMapType).invoke("getTypeParameters", List.class), - newInstance(BlockBuilderStatus.class), - positionCount))); + // prepare the single map block builder + body.append(pageBuilder.set(state.cast(PageBuilder.class))); + body.append(new IfStatement() + .condition(pageBuilder.invoke("isFull", boolean.class)) + .ifTrue(pageBuilder.invoke("reset", void.class))); + body.append(mapBlockBuilder.set(pageBuilder.invoke("getBlockBuilder", BlockBuilder.class, constantInt(0)))); + body.append(blockBuilder.set(mapBlockBuilder.invoke("beginBlockEntry", BlockBuilder.class))); // create typed set body.append(typedSet.set(newInstance( @@ -194,7 +207,10 @@ private static MethodHandle generateTransformKey(Type keyType, Type transformedK else { // make sure invokeExact will not take uninitialized keys during compile time // but if we reach this point during runtime, it is an exception + // also close the block builder before throwing as we may be in a TRY() call + // so that subsequent calls do not find it in an inconsistent state loadKeyElement = new BytecodeBlock() + .append(mapBlockBuilder.invoke("closeEntry", BlockBuilder.class).pop()) .append(keyElement.set(constantNull(keyJavaType))) .append(throwNullKeyException); } @@ -227,6 +243,7 @@ private static MethodHandle generateTransformKey(Type keyType, Type transformedK // make sure getObjectValue takes a known key type throwDuplicatedKeyException = new BytecodeBlock() + .append(mapBlockBuilder.invoke("closeEntry", BlockBuilder.class).pop()) .append(newInstance( PrestoException.class, getStatic(INVALID_FUNCTION_ARGUMENT.getDeclaringClass(), "INVALID_FUNCTION_ARGUMENT").cast(ErrorCodeSupplier.class), @@ -258,9 +275,19 @@ private static MethodHandle generateTransformKey(Type keyType, Type transformedK .ifTrue(throwDuplicatedKeyException) .ifFalse(typedSet.invoke("add", void.class, blockBuilder.cast(Block.class), position))))); - body.append(blockBuilder.invoke("build", Block.class).ret()); + body.append(mapBlockBuilder + .invoke("closeEntry", BlockBuilder.class) + .pop()); + body.append(pageBuilder.invoke("declarePosition", void.class)); + body.append(constantType(binder, resultMapType) + .invoke( + "getObject", + Object.class, + mapBlockBuilder.cast(Block.class), + subtract(mapBlockBuilder.invoke("getPositionCount", int.class), constantInt(1))) + .ret()); Class generatedClass = defineClass(definition, Object.class, binder.getBindings(), MapTransformKeyFunction.class.getClassLoader()); - return methodHandle(generatedClass, "transform", ConnectorSession.class, Block.class, BinaryFunctionInterface.class); + return methodHandle(generatedClass, "transform", Object.class, ConnectorSession.class, Block.class, BinaryFunctionInterface.class); } } diff --git a/presto-main/src/main/java/com/facebook/presto/operator/scalar/MapTransformValueFunction.java b/presto-main/src/main/java/com/facebook/presto/operator/scalar/MapTransformValueFunction.java index cb8330d546c9..50e7c82d1d7d 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/scalar/MapTransformValueFunction.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/scalar/MapTransformValueFunction.java @@ -13,6 +13,7 @@ */ package com.facebook.presto.operator.scalar; +import com.facebook.presto.annotation.UsedByGeneratedCode; import com.facebook.presto.bytecode.BytecodeBlock; import com.facebook.presto.bytecode.BytecodeNode; import com.facebook.presto.bytecode.ClassDefinition; @@ -28,11 +29,11 @@ import com.facebook.presto.metadata.Signature; import com.facebook.presto.metadata.SqlScalarFunction; import com.facebook.presto.spi.ErrorCodeSupplier; +import com.facebook.presto.spi.PageBuilder; import com.facebook.presto.spi.PrestoException; import com.facebook.presto.spi.block.Block; import com.facebook.presto.spi.block.BlockBuilder; -import com.facebook.presto.spi.block.BlockBuilderStatus; -import com.facebook.presto.spi.block.InterleavedBlockBuilder; +import com.facebook.presto.spi.type.MapType; import com.facebook.presto.spi.type.StandardTypes; import com.facebook.presto.spi.type.Type; import com.facebook.presto.spi.type.TypeManager; @@ -44,7 +45,6 @@ import com.google.common.primitives.Primitives; import java.lang.invoke.MethodHandle; -import java.util.List; import java.util.Optional; import static com.facebook.presto.bytecode.Access.FINAL; @@ -64,6 +64,7 @@ import static com.facebook.presto.bytecode.expression.BytecodeExpressions.getStatic; import static com.facebook.presto.bytecode.expression.BytecodeExpressions.lessThan; import static com.facebook.presto.bytecode.expression.BytecodeExpressions.newInstance; +import static com.facebook.presto.bytecode.expression.BytecodeExpressions.subtract; import static com.facebook.presto.bytecode.instruction.VariableInstruction.incrementVariable; import static com.facebook.presto.metadata.Signature.typeVariable; import static com.facebook.presto.spi.StandardErrorCode.INVALID_FUNCTION_ARGUMENT; @@ -76,6 +77,7 @@ public final class MapTransformValueFunction extends SqlScalarFunction { public static final MapTransformValueFunction MAP_TRANSFORM_VALUE_FUNCTION = new MapTransformValueFunction(); + private static final MethodHandle STATE_FACTORY = methodHandle(MapTransformKeyFunction.class, "createState", MapType.class); private MapTransformValueFunction() { @@ -122,9 +124,16 @@ public ScalarFunctionImplementation specialize(BoundVariables boundVariables, in ImmutableList.of(false, false), ImmutableList.of(Optional.empty(), Optional.of(BinaryFunctionInterface.class)), generateTransform(keyType, valueType, transformedValueType, resultMapType), + Optional.of(STATE_FACTORY.bindTo(resultMapType)), isDeterministic()); } + @UsedByGeneratedCode + public static Object createState(MapType mapType) + { + return new PageBuilder(ImmutableList.of(mapType)); + } + private static MethodHandle generateTransform(Type keyType, Type valueType, Type transformedValueType, Type resultMapType) { CallSiteBinder binder = new CallSiteBinder(); @@ -139,18 +148,21 @@ private static MethodHandle generateTransform(Type keyType, Type valueType, Type definition.declareDefaultConstructor(a(PRIVATE)); // define transform method + Parameter state = arg("state", Object.class); Parameter block = arg("block", Block.class); Parameter function = arg("function", BinaryFunctionInterface.class); MethodDefinition method = definition.declareMethod( a(PUBLIC, STATIC), "transform", type(Block.class), - ImmutableList.of(block, function)); + ImmutableList.of(state, block, function)); BytecodeBlock body = method.getBody(); Scope scope = method.getScope(); Variable positionCount = scope.declareVariable(int.class, "positionCount"); Variable position = scope.declareVariable(int.class, "position"); + Variable pageBuilder = scope.declareVariable(PageBuilder.class, "pageBuilder"); + Variable mapBlockBuilder = scope.declareVariable(BlockBuilder.class, "mapBlockBuilder"); Variable blockBuilder = scope.declareVariable(BlockBuilder.class, "blockBuilder"); Variable keyElement = scope.declareVariable(keyJavaType, "keyElement"); Variable valueElement = scope.declareVariable(valueJavaType, "valueElement"); @@ -159,12 +171,13 @@ private static MethodHandle generateTransform(Type keyType, Type valueType, Type // invoke block.getPositionCount() body.append(positionCount.set(block.invoke("getPositionCount", int.class))); - // create the interleaved block builder - body.append(blockBuilder.set(newInstance( - InterleavedBlockBuilder.class, - constantType(binder, resultMapType).invoke("getTypeParameters", List.class), - newInstance(BlockBuilderStatus.class), - positionCount))); + // prepare the single map block builder + body.append(pageBuilder.set(state.cast(PageBuilder.class))); + body.append(new IfStatement() + .condition(pageBuilder.invoke("isFull", boolean.class)) + .ifTrue(pageBuilder.invoke("reset", void.class))); + body.append(mapBlockBuilder.set(pageBuilder.invoke("getBlockBuilder", BlockBuilder.class, constantInt(0)))); + body.append(blockBuilder.set(mapBlockBuilder.invoke("beginBlockEntry", BlockBuilder.class))); // throw null key exception block BytecodeNode throwNullKeyException = new BytecodeBlock() @@ -182,7 +195,10 @@ private static MethodHandle generateTransform(Type keyType, Type valueType, Type else { // make sure invokeExact will not take uninitialized keys during compile time // but if we reach this point during runtime, it is an exception + // also close the block builder before throwing as we may be in a TRY() call + // so that subsequent calls do not find it in an inconsistent state loadKeyElement = new BytecodeBlock() + .append(mapBlockBuilder.invoke("closeEntry", BlockBuilder.class).pop()) .append(keyElement.set(constantNull(keyJavaType))) .append(throwNullKeyException); } @@ -221,9 +237,19 @@ private static MethodHandle generateTransform(Type keyType, Type valueType, Type .append(keySqlType.invoke("appendTo", void.class, block, position, blockBuilder)) .append(writeTransformedValueElement))); - body.append(blockBuilder.invoke("build", Block.class).ret()); + body.append(mapBlockBuilder + .invoke("closeEntry", BlockBuilder.class) + .pop()); + body.append(pageBuilder.invoke("declarePosition", void.class)); + body.append(constantType(binder, resultMapType) + .invoke( + "getObject", + Object.class, + mapBlockBuilder.cast(Block.class), + subtract(mapBlockBuilder.invoke("getPositionCount", int.class), constantInt(1))) + .ret()); Class generatedClass = defineClass(definition, Object.class, binder.getBindings(), MapTransformValueFunction.class.getClassLoader()); - return methodHandle(generatedClass, "transform", Block.class, BinaryFunctionInterface.class); + return methodHandle(generatedClass, "transform", Object.class, Block.class, BinaryFunctionInterface.class); } } diff --git a/presto-main/src/main/java/com/facebook/presto/operator/scalar/RowComparisonOperator.java b/presto-main/src/main/java/com/facebook/presto/operator/scalar/RowComparisonOperator.java index 44dac4602ed8..0b3a6eaaa976 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/scalar/RowComparisonOperator.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/scalar/RowComparisonOperator.java @@ -16,23 +16,21 @@ import com.facebook.presto.metadata.FunctionRegistry; import com.facebook.presto.metadata.Signature; import com.facebook.presto.metadata.SqlOperator; -import com.facebook.presto.spi.PrestoException; import com.facebook.presto.spi.block.Block; import com.facebook.presto.spi.function.OperatorType; +import com.facebook.presto.spi.type.RowType; import com.facebook.presto.spi.type.StandardTypes; import com.facebook.presto.spi.type.Type; -import com.facebook.presto.type.RowType; -import com.google.common.base.Throwables; import com.google.common.collect.ImmutableList; import java.lang.invoke.MethodHandle; import java.util.List; import static com.facebook.presto.metadata.Signature.orderableWithVariadicBound; -import static com.facebook.presto.spi.StandardErrorCode.GENERIC_INTERNAL_ERROR; import static com.facebook.presto.spi.type.TypeSignature.parseTypeSignature; import static com.facebook.presto.spi.type.TypeUtils.readNativeValue; import static com.facebook.presto.type.TypeUtils.checkElementNotNull; +import static com.facebook.presto.util.Failures.internalError; public abstract class RowComparisonOperator extends SqlOperator @@ -77,10 +75,7 @@ protected static int compare( } } catch (Throwable t) { - Throwables.propagateIfInstanceOf(t, Error.class); - Throwables.propagateIfInstanceOf(t, PrestoException.class); - - throw new PrestoException(GENERIC_INTERNAL_ERROR, t); + throw internalError(t); } } return 0; diff --git a/presto-main/src/main/java/com/facebook/presto/operator/scalar/RowDistinctFromOperator.java b/presto-main/src/main/java/com/facebook/presto/operator/scalar/RowDistinctFromOperator.java index 28da4342f0de..2f3c3bfb015d 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/scalar/RowDistinctFromOperator.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/scalar/RowDistinctFromOperator.java @@ -17,22 +17,20 @@ import com.facebook.presto.metadata.FunctionRegistry; import com.facebook.presto.metadata.Signature; import com.facebook.presto.metadata.SqlOperator; -import com.facebook.presto.spi.PrestoException; import com.facebook.presto.spi.block.Block; import com.facebook.presto.spi.type.StandardTypes; import com.facebook.presto.spi.type.Type; import com.facebook.presto.spi.type.TypeManager; -import com.google.common.base.Throwables; import com.google.common.collect.ImmutableList; import java.lang.invoke.MethodHandle; import java.util.List; import static com.facebook.presto.metadata.Signature.comparableWithVariadicBound; -import static com.facebook.presto.spi.StandardErrorCode.GENERIC_INTERNAL_ERROR; import static com.facebook.presto.spi.function.OperatorType.IS_DISTINCT_FROM; import static com.facebook.presto.spi.type.TypeSignature.parseTypeSignature; import static com.facebook.presto.spi.type.TypeUtils.readNativeValue; +import static com.facebook.presto.util.Failures.internalError; import static com.facebook.presto.util.Reflection.methodHandle; import static com.google.common.base.Defaults.defaultValue; @@ -100,10 +98,7 @@ public static boolean isDistinctFrom(Type rowType, List argumentMe } } catch (Throwable t) { - Throwables.propagateIfInstanceOf(t, Error.class); - Throwables.propagateIfInstanceOf(t, PrestoException.class); - - throw new PrestoException(GENERIC_INTERNAL_ERROR, t); + throw internalError(t); } } return false; diff --git a/presto-main/src/main/java/com/facebook/presto/operator/scalar/RowGreaterThanOperator.java b/presto-main/src/main/java/com/facebook/presto/operator/scalar/RowGreaterThanOperator.java index b2b3dc3c67f3..c13be1736708 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/scalar/RowGreaterThanOperator.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/scalar/RowGreaterThanOperator.java @@ -16,9 +16,9 @@ import com.facebook.presto.metadata.BoundVariables; import com.facebook.presto.metadata.FunctionRegistry; import com.facebook.presto.spi.block.Block; +import com.facebook.presto.spi.type.RowType; import com.facebook.presto.spi.type.Type; import com.facebook.presto.spi.type.TypeManager; -import com.facebook.presto.type.RowType; import com.google.common.collect.ImmutableList; import java.lang.invoke.MethodHandle; diff --git a/presto-main/src/main/java/com/facebook/presto/operator/scalar/RowGreaterThanOrEqualOperator.java b/presto-main/src/main/java/com/facebook/presto/operator/scalar/RowGreaterThanOrEqualOperator.java index b7fc55e79b6a..8aadf24d884c 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/scalar/RowGreaterThanOrEqualOperator.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/scalar/RowGreaterThanOrEqualOperator.java @@ -16,9 +16,9 @@ import com.facebook.presto.metadata.BoundVariables; import com.facebook.presto.metadata.FunctionRegistry; import com.facebook.presto.spi.block.Block; +import com.facebook.presto.spi.type.RowType; import com.facebook.presto.spi.type.Type; import com.facebook.presto.spi.type.TypeManager; -import com.facebook.presto.type.RowType; import com.google.common.collect.ImmutableList; import java.lang.invoke.MethodHandle; diff --git a/presto-main/src/main/java/com/facebook/presto/operator/scalar/RowLessThanOperator.java b/presto-main/src/main/java/com/facebook/presto/operator/scalar/RowLessThanOperator.java index 9fa72bdf2d8b..013df743bf74 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/scalar/RowLessThanOperator.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/scalar/RowLessThanOperator.java @@ -16,9 +16,9 @@ import com.facebook.presto.metadata.BoundVariables; import com.facebook.presto.metadata.FunctionRegistry; import com.facebook.presto.spi.block.Block; +import com.facebook.presto.spi.type.RowType; import com.facebook.presto.spi.type.Type; import com.facebook.presto.spi.type.TypeManager; -import com.facebook.presto.type.RowType; import com.google.common.collect.ImmutableList; import java.lang.invoke.MethodHandle; diff --git a/presto-main/src/main/java/com/facebook/presto/operator/scalar/RowLessThanOrEqualOperator.java b/presto-main/src/main/java/com/facebook/presto/operator/scalar/RowLessThanOrEqualOperator.java index 82c7519499b7..4f71aaf7c70e 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/scalar/RowLessThanOrEqualOperator.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/scalar/RowLessThanOrEqualOperator.java @@ -16,9 +16,9 @@ import com.facebook.presto.metadata.BoundVariables; import com.facebook.presto.metadata.FunctionRegistry; import com.facebook.presto.spi.block.Block; +import com.facebook.presto.spi.type.RowType; import com.facebook.presto.spi.type.Type; import com.facebook.presto.spi.type.TypeManager; -import com.facebook.presto.type.RowType; import com.google.common.collect.ImmutableList; import java.lang.invoke.MethodHandle; diff --git a/presto-main/src/main/java/com/facebook/presto/operator/scalar/ZipFunction.java b/presto-main/src/main/java/com/facebook/presto/operator/scalar/ZipFunction.java index 5138cd93df45..36ea2d1f2c9c 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/scalar/ZipFunction.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/scalar/ZipFunction.java @@ -22,10 +22,10 @@ import com.facebook.presto.spi.block.Block; import com.facebook.presto.spi.block.BlockBuilder; import com.facebook.presto.spi.block.BlockBuilderStatus; +import com.facebook.presto.spi.type.RowType; import com.facebook.presto.spi.type.Type; import com.facebook.presto.spi.type.TypeManager; import com.facebook.presto.spi.type.TypeSignature; -import com.facebook.presto.type.RowType; import com.google.common.collect.ImmutableList; import java.lang.invoke.MethodHandle; diff --git a/presto-main/src/main/java/com/facebook/presto/operator/scalar/annotations/ScalarImplementation.java b/presto-main/src/main/java/com/facebook/presto/operator/scalar/annotations/ScalarImplementation.java index 8141b0900591..a606b81b4dd6 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/scalar/annotations/ScalarImplementation.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/scalar/annotations/ScalarImplementation.java @@ -18,6 +18,7 @@ import com.facebook.presto.metadata.LongVariableConstraint; import com.facebook.presto.metadata.Signature; import com.facebook.presto.metadata.TypeVariableConstraint; +import com.facebook.presto.operator.scalar.ScalarFunctionImplementation; import com.facebook.presto.spi.ConnectorSession; import com.facebook.presto.spi.function.FunctionDependency; import com.facebook.presto.spi.function.IsNull; @@ -318,7 +319,14 @@ public Signature getSignature() public MethodHandle resolve(BoundVariables boundVariables, TypeManager typeManager, FunctionRegistry functionRegistry) { Signature signature = applyBoundVariables(this.signature, boundVariables, this.signature.getArgumentTypes().size()); - return functionRegistry.getScalarFunctionImplementation(signature).getMethodHandle(); + ScalarFunctionImplementation scalarFunctionImplementation = functionRegistry.getScalarFunctionImplementation(signature); + if (scalarFunctionImplementation.getInstanceFactory().isPresent()) { + // TODO: This feature is useful for a few casts, e.g. MapToMapCast, JsonToMapCast + // Implementing this requires a revamp because we must be able to defer binding of MethodHandles, + // and be able to express such need in a recursive way in ScalarFunctionImplementation. + throw new UnsupportedOperationException("OperatorDependency/FunctionDependency cannot refer to methods with instance factory"); + } + return scalarFunctionImplementation.getMethodHandle(); } } @@ -546,7 +554,9 @@ private Optional getConstructor(Method method, Map> getDeclaredSpecializedTypeParameters(Method method) @@ -570,6 +580,8 @@ private MethodHandle getMethodHandle(Method method) { MethodHandle methodHandle = methodHandle(FUNCTION_IMPLEMENTATION_ERROR, method); if (!isStatic(method.getModifiers())) { + // Change type of "this" argument to Object to make sure callers won't have classloader issues + methodHandle = methodHandle.asType(methodHandle.type().changeParameterType(0, Object.class)); // Re-arrange the parameters, so that the "this" parameter is after the meta parameters int[] permutedIndices = new int[methodHandle.type().parameterCount()]; permutedIndices[0] = dependencies.size(); diff --git a/presto-main/src/main/java/com/facebook/presto/operator/window/WindowPartition.java b/presto-main/src/main/java/com/facebook/presto/operator/window/WindowPartition.java index c882a5fa9c0a..78d08058e7e6 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/window/WindowPartition.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/window/WindowPartition.java @@ -71,6 +71,11 @@ public WindowPartition(PagesIndex pagesIndex, updatePeerGroup(); } + public int getPartitionStart() + { + return partitionStart; + } + public int getPartitionEnd() { return partitionEnd; diff --git a/presto-main/src/main/java/com/facebook/presto/security/AccessControl.java b/presto-main/src/main/java/com/facebook/presto/security/AccessControl.java index 42768dae4413..9c8c0d407182 100644 --- a/presto-main/src/main/java/com/facebook/presto/security/AccessControl.java +++ b/presto-main/src/main/java/com/facebook/presto/security/AccessControl.java @@ -115,6 +115,12 @@ public interface AccessControl */ void checkCanAddColumns(TransactionId transactionId, Identity identity, QualifiedObjectName tableName); + /** + * Check if identity is allowed to drop columns from the specified table. + * @throws com.facebook.presto.spi.security.AccessDeniedException if not allowed + */ + void checkCanDropColumn(TransactionId transactionId, Identity identity, QualifiedObjectName tableName); + /** * Check if identity is allowed to rename a column in the specified table. * @throws com.facebook.presto.spi.security.AccessDeniedException if not allowed diff --git a/presto-main/src/main/java/com/facebook/presto/security/AccessControlManager.java b/presto-main/src/main/java/com/facebook/presto/security/AccessControlManager.java index e45681386c6c..f89a20af52de 100644 --- a/presto-main/src/main/java/com/facebook/presto/security/AccessControlManager.java +++ b/presto-main/src/main/java/com/facebook/presto/security/AccessControlManager.java @@ -352,6 +352,22 @@ public void checkCanAddColumns(TransactionId transactionId, Identity identity, Q } } + @Override + public void checkCanDropColumn(TransactionId transactionId, Identity identity, QualifiedObjectName tableName) + { + requireNonNull(identity, "identity is null"); + requireNonNull(tableName, "tableName is null"); + + authenticationCheck(() -> checkCanAccessCatalog(identity, tableName.getCatalogName())); + + authorizationCheck(() -> systemAccessControl.get().checkCanDropColumn(identity, tableName.asCatalogSchemaTableName())); + + CatalogAccessControlEntry entry = getConnectorAccessControl(transactionId, tableName.getCatalogName()); + if (entry != null) { + authorizationCheck(() -> entry.getAccessControl().checkCanDropColumn(entry.getTransactionHandle(transactionId), identity, tableName.asSchemaTableName())); + } + } + @Override public void checkCanRenameColumn(TransactionId transactionId, Identity identity, QualifiedObjectName tableName) { diff --git a/presto-main/src/main/java/com/facebook/presto/security/AllowAllAccessControl.java b/presto-main/src/main/java/com/facebook/presto/security/AllowAllAccessControl.java index 356ff83c75b8..5163236409d9 100644 --- a/presto-main/src/main/java/com/facebook/presto/security/AllowAllAccessControl.java +++ b/presto-main/src/main/java/com/facebook/presto/security/AllowAllAccessControl.java @@ -99,6 +99,11 @@ public void checkCanAddColumns(TransactionId transactionId, Identity identity, Q { } + @Override + public void checkCanDropColumn(TransactionId transactionId, Identity identity, QualifiedObjectName tableName) + { + } + @Override public void checkCanRenameColumn(TransactionId transactionId, Identity identity, QualifiedObjectName tableName) { diff --git a/presto-main/src/main/java/com/facebook/presto/security/AllowAllSystemAccessControl.java b/presto-main/src/main/java/com/facebook/presto/security/AllowAllSystemAccessControl.java index 442718d4e337..233d25def6b5 100644 --- a/presto-main/src/main/java/com/facebook/presto/security/AllowAllSystemAccessControl.java +++ b/presto-main/src/main/java/com/facebook/presto/security/AllowAllSystemAccessControl.java @@ -131,6 +131,11 @@ public void checkCanAddColumn(Identity identity, CatalogSchemaTableName table) { } + @Override + public void checkCanDropColumn(Identity identity, CatalogSchemaTableName table) + { + } + @Override public void checkCanRenameColumn(Identity identity, CatalogSchemaTableName table) { diff --git a/presto-main/src/main/java/com/facebook/presto/security/DenyAllAccessControl.java b/presto-main/src/main/java/com/facebook/presto/security/DenyAllAccessControl.java index 1195c1d94d2a..82609d4ae495 100644 --- a/presto-main/src/main/java/com/facebook/presto/security/DenyAllAccessControl.java +++ b/presto-main/src/main/java/com/facebook/presto/security/DenyAllAccessControl.java @@ -31,6 +31,7 @@ import static com.facebook.presto.spi.security.AccessDeniedException.denyCreateView; import static com.facebook.presto.spi.security.AccessDeniedException.denyCreateViewWithSelect; import static com.facebook.presto.spi.security.AccessDeniedException.denyDeleteTable; +import static com.facebook.presto.spi.security.AccessDeniedException.denyDropColumn; import static com.facebook.presto.spi.security.AccessDeniedException.denyDropSchema; import static com.facebook.presto.spi.security.AccessDeniedException.denyDropTable; import static com.facebook.presto.spi.security.AccessDeniedException.denyDropView; @@ -141,6 +142,12 @@ public void checkCanRenameColumn(TransactionId transactionId, Identity identity, denyRenameColumn(tableName.toString()); } + @Override + public void checkCanDropColumn(TransactionId transactionId, Identity identity, QualifiedObjectName tableName) + { + denyDropColumn(tableName.toString()); + } + @Override public void checkCanSelectFromTable(TransactionId transactionId, Identity identity, QualifiedObjectName tableName) { diff --git a/presto-main/src/main/java/com/facebook/presto/security/FileBasedSystemAccessControl.java b/presto-main/src/main/java/com/facebook/presto/security/FileBasedSystemAccessControl.java index 1490f6ec74cc..0faacfd0201b 100644 --- a/presto-main/src/main/java/com/facebook/presto/security/FileBasedSystemAccessControl.java +++ b/presto-main/src/main/java/com/facebook/presto/security/FileBasedSystemAccessControl.java @@ -206,6 +206,11 @@ public void checkCanAddColumn(Identity identity, CatalogSchemaTableName table) { } + @Override + public void checkCanDropColumn(Identity identity, CatalogSchemaTableName table) + { + } + @Override public void checkCanRenameColumn(Identity identity, CatalogSchemaTableName table) { diff --git a/presto-main/src/main/java/com/facebook/presto/server/CoordinatorModule.java b/presto-main/src/main/java/com/facebook/presto/server/CoordinatorModule.java index 64626c28f5f4..70b98a66b9d0 100644 --- a/presto-main/src/main/java/com/facebook/presto/server/CoordinatorModule.java +++ b/presto-main/src/main/java/com/facebook/presto/server/CoordinatorModule.java @@ -22,6 +22,7 @@ import com.facebook.presto.execution.CreateViewTask; import com.facebook.presto.execution.DataDefinitionTask; import com.facebook.presto.execution.DeallocateTask; +import com.facebook.presto.execution.DropColumnTask; import com.facebook.presto.execution.DropSchemaTask; import com.facebook.presto.execution.DropTableTask; import com.facebook.presto.execution.DropViewTask; @@ -73,6 +74,7 @@ import com.facebook.presto.sql.tree.Delete; import com.facebook.presto.sql.tree.DescribeInput; import com.facebook.presto.sql.tree.DescribeOutput; +import com.facebook.presto.sql.tree.DropColumn; import com.facebook.presto.sql.tree.DropSchema; import com.facebook.presto.sql.tree.DropTable; import com.facebook.presto.sql.tree.DropView; @@ -154,6 +156,7 @@ protected void setup(Binder binder) jaxrsBinder(binder).bind(QueryResource.class); jaxrsBinder(binder).bind(StageResource.class); jaxrsBinder(binder).bind(QueryStateInfoResource.class); + jaxrsBinder(binder).bind(ResourceGroupStateInfoResource.class); binder.bind(QueryIdGenerator.class).in(Scopes.SINGLETON); binder.bind(QueryManager.class).to(SqlQueryManager.class).in(Scopes.SINGLETON); binder.bind(InternalResourceGroupManager.class).in(Scopes.SINGLETON); @@ -239,6 +242,7 @@ protected void setup(Binder binder) bindDataDefinitionTask(binder, executionBinder, CreateTable.class, CreateTableTask.class); bindDataDefinitionTask(binder, executionBinder, RenameTable.class, RenameTableTask.class); bindDataDefinitionTask(binder, executionBinder, RenameColumn.class, RenameColumnTask.class); + bindDataDefinitionTask(binder, executionBinder, DropColumn.class, DropColumnTask.class); bindDataDefinitionTask(binder, executionBinder, DropTable.class, DropTableTask.class); bindDataDefinitionTask(binder, executionBinder, CreateView.class, CreateViewTask.class); bindDataDefinitionTask(binder, executionBinder, DropView.class, DropViewTask.class); diff --git a/presto-main/src/main/java/com/facebook/presto/server/InternalCommunicationConfig.java b/presto-main/src/main/java/com/facebook/presto/server/InternalCommunicationConfig.java new file mode 100644 index 000000000000..c471f2cd6445 --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/server/InternalCommunicationConfig.java @@ -0,0 +1,59 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.server; + +import io.airlift.configuration.Config; + +public class InternalCommunicationConfig +{ + private boolean httpsRequired; + private String keyStorePath; + private String keyStorePassword; + + public boolean isHttpsRequired() + { + return httpsRequired; + } + + @Config("internal-communication.https.required") + public InternalCommunicationConfig setHttpsRequired(boolean httpsRequired) + { + this.httpsRequired = httpsRequired; + return this; + } + + public String getKeyStorePath() + { + return keyStorePath; + } + + @Config("internal-communication.https.keystore.path") + public InternalCommunicationConfig setKeyStorePath(String keyStorePath) + { + this.keyStorePath = keyStorePath; + return this; + } + + public String getKeyStorePassword() + { + return keyStorePassword; + } + + @Config("internal-communication.https.keystore.key") + public InternalCommunicationConfig setKeyStorePassword(String keyStorePassword) + { + this.keyStorePassword = keyStorePassword; + return this; + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/server/ResourceGroupStateInfo.java b/presto-main/src/main/java/com/facebook/presto/server/ResourceGroupStateInfo.java new file mode 100644 index 000000000000..4c5f842fc712 --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/server/ResourceGroupStateInfo.java @@ -0,0 +1,143 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.server; + +import com.facebook.presto.spi.resourceGroups.ResourceGroupId; +import com.facebook.presto.spi.resourceGroups.ResourceGroupInfo; +import com.facebook.presto.spi.resourceGroups.ResourceGroupState; +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.google.common.collect.ImmutableList; +import io.airlift.units.DataSize; +import io.airlift.units.Duration; + +import java.util.List; + +import static java.util.Objects.requireNonNull; + +public class ResourceGroupStateInfo +{ + private final ResourceGroupId id; + private final ResourceGroupState state; + + private final DataSize softMemoryLimit; + private final DataSize memoryUsage; + + private final List subGroups; + + private final int maxRunningQueries; + private final int maxQueuedQueries; + private final Duration runningTimeLimit; + private final Duration queuedTimeLimit; + private final List runningQueries; + private final int numQueuedQueries; + + @JsonCreator + public ResourceGroupStateInfo( + @JsonProperty("id") ResourceGroupId id, + @JsonProperty("state") ResourceGroupState state, + @JsonProperty("softMemoryLimit") DataSize softMemoryLimit, + @JsonProperty("memoryUsage") DataSize memoryUsage, + @JsonProperty("maxRunningQueries") int maxRunningQueries, + @JsonProperty("maxQueuedQueries") int maxQueuedQueries, + @JsonProperty("runningTimeLimit") Duration runningTimeLimit, + @JsonProperty("queuedTimeLimit") Duration queuedTimeLimit, + @JsonProperty("runningQueries") List runningQueries, + @JsonProperty("numQueuedQueries") int numQueuedQueries, + @JsonProperty("subGroups") List subGroups) + { + this.id = requireNonNull(id, "id is null"); + this.state = requireNonNull(state, "state is null"); + + this.softMemoryLimit = requireNonNull(softMemoryLimit, "softMemoryLimit is null"); + this.memoryUsage = requireNonNull(memoryUsage, "memoryUsage is null"); + + this.maxRunningQueries = maxRunningQueries; + this.maxQueuedQueries = maxQueuedQueries; + + this.runningTimeLimit = requireNonNull(runningTimeLimit, "runningTimeLimit is null"); + this.queuedTimeLimit = requireNonNull(queuedTimeLimit, "queuedTimeLimit is null"); + + this.runningQueries = ImmutableList.copyOf(requireNonNull(runningQueries, "runningQueries is null")); + this.numQueuedQueries = numQueuedQueries; + + this.subGroups = ImmutableList.copyOf(requireNonNull(subGroups, "subGroups is null")); + } + + @JsonProperty + public ResourceGroupId getId() + { + return id; + } + + @JsonProperty + public ResourceGroupState getState() + { + return state; + } + + @JsonProperty + public DataSize getSoftMemoryLimit() + { + return softMemoryLimit; + } + + @JsonProperty + public DataSize getMemoryUsage() + { + return memoryUsage; + } + + @JsonProperty + public int getMaxRunningQueries() + { + return maxRunningQueries; + } + + @JsonProperty + public int getMaxQueuedQueries() + { + return maxQueuedQueries; + } + + @JsonProperty + public Duration getQueuedTimeLimit() + { + return queuedTimeLimit; + } + + @JsonProperty + public Duration getRunningTimeLimit() + { + return runningTimeLimit; + } + + @JsonProperty + public List getRunningQueries() + { + return runningQueries; + } + + @JsonProperty + public int getNumQueuedQueries() + { + return numQueuedQueries; + } + + @JsonProperty + public List getSubGroups() + { + return subGroups; + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/server/ResourceGroupStateInfoResource.java b/presto-main/src/main/java/com/facebook/presto/server/ResourceGroupStateInfoResource.java new file mode 100644 index 000000000000..1b649d367deb --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/server/ResourceGroupStateInfoResource.java @@ -0,0 +1,80 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.server; + +import com.facebook.presto.execution.resourceGroups.ResourceGroupManager; +import com.facebook.presto.spi.resourceGroups.ResourceGroupId; + +import javax.inject.Inject; +import javax.ws.rs.Encoded; +import javax.ws.rs.GET; +import javax.ws.rs.Path; +import javax.ws.rs.PathParam; +import javax.ws.rs.Produces; +import javax.ws.rs.WebApplicationException; +import javax.ws.rs.core.MediaType; + +import java.io.UnsupportedEncodingException; +import java.net.URLDecoder; +import java.util.Arrays; +import java.util.NoSuchElementException; + +import static com.google.common.base.Strings.isNullOrEmpty; +import static com.google.common.collect.ImmutableList.toImmutableList; +import static java.util.Objects.requireNonNull; +import static javax.ws.rs.core.Response.Status.BAD_REQUEST; +import static javax.ws.rs.core.Response.Status.NOT_FOUND; + +@Path("/v1/resourceGroupState") +public class ResourceGroupStateInfoResource +{ + private final ResourceGroupManager resourceGroupManager; + + @Inject + public ResourceGroupStateInfoResource(ResourceGroupManager resourceGroupManager) + { + this.resourceGroupManager = requireNonNull(resourceGroupManager, "resourceGroupManager is null"); + } + + @GET + @Produces(MediaType.APPLICATION_JSON) + @Encoded + @Path("{resourceGroupId: .+}") + public ResourceGroupStateInfo getQueryStateInfos(@PathParam("resourceGroupId") String resourceGroupIdString) + { + if (!isNullOrEmpty(resourceGroupIdString)) { + try { + return resourceGroupManager.getResourceGroupStateInfo( + new ResourceGroupId( + Arrays.stream(resourceGroupIdString.split("/")) + .map(ResourceGroupStateInfoResource::urlDecode) + .collect(toImmutableList()))); + } + catch (NoSuchElementException e) { + throw new WebApplicationException(NOT_FOUND); + } + } + throw new WebApplicationException(NOT_FOUND); + } + + private static String urlDecode(String value) + { + try { + return URLDecoder.decode(value, "UTF-8"); + } + catch (UnsupportedEncodingException e) { + throw new WebApplicationException(BAD_REQUEST); + } + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/server/ServerMainModule.java b/presto-main/src/main/java/com/facebook/presto/server/ServerMainModule.java index c6a54bf228a1..bd88e4c5020b 100644 --- a/presto-main/src/main/java/com/facebook/presto/server/ServerMainModule.java +++ b/presto-main/src/main/java/com/facebook/presto/server/ServerMainModule.java @@ -131,6 +131,7 @@ import io.airlift.concurrent.BoundedExecutor; import io.airlift.configuration.AbstractConfigurationAwareModule; import io.airlift.discovery.client.ServiceDescriptor; +import io.airlift.http.client.HttpClientConfig; import io.airlift.slice.Slice; import io.airlift.stats.PauseMeter; import io.airlift.units.DataSize; @@ -198,6 +199,12 @@ protected void setup(Binder binder) })); } + InternalCommunicationConfig internalCommunicationConfig = buildConfigObject(InternalCommunicationConfig.class); + configBinder(binder).bindConfigGlobalDefaults(HttpClientConfig.class, config -> { + config.setKeyStorePath(internalCommunicationConfig.getKeyStorePath()); + config.setKeyStorePassword(internalCommunicationConfig.getKeyStorePassword()); + }); + configBinder(binder).bindConfig(FeaturesConfig.class); binder.bind(SqlParser.class).in(Scopes.SINGLETON); diff --git a/presto-main/src/main/java/com/facebook/presto/server/StatementResource.java b/presto-main/src/main/java/com/facebook/presto/server/StatementResource.java index 7f36b1e8446f..684a77b642cf 100644 --- a/presto-main/src/main/java/com/facebook/presto/server/StatementResource.java +++ b/presto-main/src/main/java/com/facebook/presto/server/StatementResource.java @@ -534,11 +534,12 @@ private synchronized Iterable> getData(Duration maxWait) maxWait = new Duration(0, MILLISECONDS); } - if (bytes == 0) { + List rowIterables = pages.build(); + if (rowIterables.isEmpty()) { return null; } - return Iterables.concat(pages.build()); + return Iterables.concat(rowIterables); } private static boolean isQueryStarted(QueryInfo queryInfo) @@ -624,7 +625,7 @@ private static StatementStats toStatementStats(QueryInfo queryInfo) .setNodes(globalUniqueNodes(outputStage).size()) .setTotalSplits(queryStats.getTotalDrivers()) .setQueuedSplits(queryStats.getQueuedDrivers()) - .setRunningSplits(queryStats.getRunningDrivers()) + .setRunningSplits(queryStats.getRunningDrivers() + queryStats.getBlockedDrivers()) .setCompletedSplits(queryStats.getCompletedDrivers()) .setUserTimeMillis(queryStats.getTotalUserTime().toMillis()) .setCpuTimeMillis(queryStats.getTotalCpuTime().toMillis()) @@ -662,7 +663,7 @@ private static StageStats toStageStats(StageInfo stageInfo) .setNodes(uniqueNodes.size()) .setTotalSplits(stageStats.getTotalDrivers()) .setQueuedSplits(stageStats.getQueuedDrivers()) - .setRunningSplits(stageStats.getRunningDrivers()) + .setRunningSplits(stageStats.getRunningDrivers() + stageStats.getBlockedDrivers()) .setCompletedSplits(stageStats.getCompletedDrivers()) .setUserTimeMillis(stageStats.getTotalUserTime().toMillis()) .setCpuTimeMillis(stageStats.getTotalCpuTime().toMillis()) diff --git a/presto-main/src/main/java/com/facebook/presto/server/remotetask/HttpLocationFactory.java b/presto-main/src/main/java/com/facebook/presto/server/remotetask/HttpLocationFactory.java index 5f5b56983db7..878fb511508a 100644 --- a/presto-main/src/main/java/com/facebook/presto/server/remotetask/HttpLocationFactory.java +++ b/presto-main/src/main/java/com/facebook/presto/server/remotetask/HttpLocationFactory.java @@ -17,6 +17,7 @@ import com.facebook.presto.execution.StageId; import com.facebook.presto.execution.TaskId; import com.facebook.presto.metadata.InternalNodeManager; +import com.facebook.presto.server.InternalCommunicationConfig; import com.facebook.presto.spi.Node; import com.facebook.presto.spi.QueryId; import io.airlift.http.server.HttpServerInfo; @@ -35,15 +36,15 @@ public class HttpLocationFactory private final URI baseUri; @Inject - public HttpLocationFactory(InternalNodeManager nodeManager, HttpServerInfo httpServerInfo) + public HttpLocationFactory(InternalNodeManager nodeManager, HttpServerInfo httpServerInfo, InternalCommunicationConfig config) { - this(nodeManager, httpServerInfo.getHttpUri()); + this(nodeManager, config.isHttpsRequired() ? httpServerInfo.getHttpsUri() : httpServerInfo.getHttpUri()); } public HttpLocationFactory(InternalNodeManager nodeManager, URI baseUri) { - this.nodeManager = nodeManager; - this.baseUri = baseUri; + this.nodeManager = requireNonNull(nodeManager, "nodeManager is null"); + this.baseUri = requireNonNull(baseUri, "baseUri is null"); } @Override diff --git a/presto-main/src/main/java/com/facebook/presto/server/security/LdapFilter.java b/presto-main/src/main/java/com/facebook/presto/server/security/LdapFilter.java index 343190330dd4..2cb3ec4b90f6 100644 --- a/presto-main/src/main/java/com/facebook/presto/server/security/LdapFilter.java +++ b/presto-main/src/main/java/com/facebook/presto/server/security/LdapFilter.java @@ -55,7 +55,7 @@ import static com.google.common.base.CharMatcher.JAVA_ISO_CONTROL; import static com.google.common.base.MoreObjects.toStringHelper; import static com.google.common.base.Preconditions.checkState; -import static com.google.common.base.Throwables.propagateIfInstanceOf; +import static com.google.common.base.Throwables.throwIfInstanceOf; import static com.google.common.io.ByteStreams.copy; import static com.google.common.io.ByteStreams.nullOutputStream; import static com.google.common.net.HttpHeaders.AUTHORIZATION; @@ -186,7 +186,9 @@ private Principal getPrincipal(Credentials credentials) } catch (ExecutionException e) { Throwable cause = e.getCause(); - propagateIfInstanceOf(cause, AuthenticationException.class); + if (cause != null) { + throwIfInstanceOf(cause, AuthenticationException.class); + } throw Throwables.propagate(cause); } } diff --git a/presto-main/src/main/java/com/facebook/presto/spiller/LocalSpillContext.java b/presto-main/src/main/java/com/facebook/presto/spiller/LocalSpillContext.java index 2c02205a0c36..86029c20636f 100644 --- a/presto-main/src/main/java/com/facebook/presto/spiller/LocalSpillContext.java +++ b/presto-main/src/main/java/com/facebook/presto/spiller/LocalSpillContext.java @@ -15,11 +15,17 @@ import com.facebook.presto.operator.SpillContext; -public class LocalSpillContext - implements SpillContext +import javax.annotation.concurrent.ThreadSafe; + +import static com.google.common.base.Preconditions.checkState; + +@ThreadSafe +public final class LocalSpillContext + implements SpillContext { private final SpillContext parentSpillContext; private long spilledBytes; + private boolean closed; public LocalSpillContext(SpillContext parentSpillContext) { @@ -27,15 +33,21 @@ public LocalSpillContext(SpillContext parentSpillContext) } @Override - public void updateBytes(long bytes) + public synchronized void updateBytes(long bytes) { + checkState(!closed, "Already closed"); parentSpillContext.updateBytes(bytes); spilledBytes += bytes; } @Override - public void close() + public synchronized void close() { + if (closed) { + return; + } + + closed = true; parentSpillContext.updateBytes(-spilledBytes); } } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/ExpressionUtils.java b/presto-main/src/main/java/com/facebook/presto/sql/ExpressionUtils.java index b5856d0d65d1..4ecc694ec6a0 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/ExpressionUtils.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/ExpressionUtils.java @@ -13,9 +13,9 @@ */ package com.facebook.presto.sql; -import com.facebook.presto.sql.planner.DependencyExtractor; import com.facebook.presto.sql.planner.DeterminismEvaluator; import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.sql.planner.SymbolsExtractor; import com.facebook.presto.sql.tree.ComparisonExpression; import com.facebook.presto.sql.tree.Expression; import com.facebook.presto.sql.tree.ExpressionRewriter; @@ -250,7 +250,7 @@ public static Function expressionOrNullSymbols(final Pre resultDisjunct.add(expression); for (Predicate nullSymbolScope : nullSymbolScopes) { - List symbols = DependencyExtractor.extractUnique(expression).stream() + List symbols = SymbolsExtractor.extractUnique(expression).stream() .filter(nullSymbolScope) .collect(toImmutableList()); diff --git a/presto-main/src/main/java/com/facebook/presto/sql/analyzer/ExpressionAnalyzer.java b/presto-main/src/main/java/com/facebook/presto/sql/analyzer/ExpressionAnalyzer.java index b81ddc48b7b1..ba3cb4244e4f 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/analyzer/ExpressionAnalyzer.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/analyzer/ExpressionAnalyzer.java @@ -26,6 +26,7 @@ import com.facebook.presto.spi.type.CharType; import com.facebook.presto.spi.type.DecimalParseResult; import com.facebook.presto.spi.type.Decimals; +import com.facebook.presto.spi.type.RowType; import com.facebook.presto.spi.type.Type; import com.facebook.presto.spi.type.TypeManager; import com.facebook.presto.spi.type.TypeSignatureParameter; @@ -90,7 +91,6 @@ import com.facebook.presto.sql.tree.WhenClause; import com.facebook.presto.sql.tree.WindowFrame; import com.facebook.presto.type.FunctionType; -import com.facebook.presto.type.RowType; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import io.airlift.slice.SliceUtf8; @@ -117,6 +117,7 @@ import static com.facebook.presto.spi.type.DoubleType.DOUBLE; import static com.facebook.presto.spi.type.IntegerType.INTEGER; import static com.facebook.presto.spi.type.RealType.REAL; +import static com.facebook.presto.spi.type.RowType.RowField; import static com.facebook.presto.spi.type.SmallintType.SMALLINT; import static com.facebook.presto.spi.type.TimeType.TIME; import static com.facebook.presto.spi.type.TimeWithTimeZoneType.TIME_WITH_TIME_ZONE; @@ -143,7 +144,6 @@ import static com.facebook.presto.type.IntervalDayTimeType.INTERVAL_DAY_TIME; import static com.facebook.presto.type.IntervalYearMonthType.INTERVAL_YEAR_MONTH; import static com.facebook.presto.type.JsonType.JSON; -import static com.facebook.presto.type.RowType.RowField; import static com.facebook.presto.type.UnknownType.UNKNOWN; import static com.facebook.presto.util.DateTimeUtils.parseTimestampLiteral; import static com.facebook.presto.util.DateTimeUtils.timeHasTimeZone; diff --git a/presto-main/src/main/java/com/facebook/presto/sql/analyzer/ExpressionTreeUtils.java b/presto-main/src/main/java/com/facebook/presto/sql/analyzer/ExpressionTreeUtils.java index d82eabdb1390..2da9a95ac316 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/analyzer/ExpressionTreeUtils.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/analyzer/ExpressionTreeUtils.java @@ -14,7 +14,6 @@ package com.facebook.presto.sql.analyzer; import com.facebook.presto.metadata.FunctionRegistry; -import com.facebook.presto.sql.planner.optimizations.Predicates; import com.facebook.presto.sql.tree.DefaultExpressionTraversalVisitor; import com.facebook.presto.sql.tree.Expression; import com.facebook.presto.sql.tree.FunctionCall; @@ -24,6 +23,7 @@ import java.util.List; import java.util.function.Predicate; +import static com.google.common.base.Predicates.alwaysTrue; import static com.google.common.collect.ImmutableList.toImmutableList; import static java.util.Objects.requireNonNull; @@ -45,7 +45,7 @@ static List extractExpressions( Iterable nodes, Class clazz) { - return extractExpressions(nodes, clazz, Predicates.alwaysTrue()); + return extractExpressions(nodes, clazz, alwaysTrue()); } private static Predicate isAggregationPredicate(FunctionRegistry functionRegistry) diff --git a/presto-main/src/main/java/com/facebook/presto/sql/analyzer/FeaturesConfig.java b/presto-main/src/main/java/com/facebook/presto/sql/analyzer/FeaturesConfig.java index be18f6d7185e..8d8fb01bf6ee 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/analyzer/FeaturesConfig.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/analyzer/FeaturesConfig.java @@ -54,6 +54,7 @@ public class FeaturesConfig private boolean legacyArrayAgg; private boolean legacyOrderBy; private boolean legacyMapSubscript; + private boolean newMapBlock = true; private boolean optimizeMixedDistinctAggregations; private boolean dictionaryAggregation; @@ -137,6 +138,18 @@ public boolean isLegacyMapSubscript() return legacyMapSubscript; } + @Config("deprecated.new-map-block") + public FeaturesConfig setNewMapBlock(boolean value) + { + this.newMapBlock = value; + return this; + } + + public boolean isNewMapBlock() + { + return newMapBlock; + } + @Config("distributed-joins-enabled") public FeaturesConfig setDistributedJoinsEnabled(boolean distributedJoinsEnabled) { diff --git a/presto-main/src/main/java/com/facebook/presto/sql/analyzer/StatementAnalyzer.java b/presto-main/src/main/java/com/facebook/presto/sql/analyzer/StatementAnalyzer.java index 409d22fddb43..0031238f984a 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/analyzer/StatementAnalyzer.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/analyzer/StatementAnalyzer.java @@ -29,13 +29,16 @@ import com.facebook.presto.spi.ColumnMetadata; import com.facebook.presto.spi.PrestoException; import com.facebook.presto.spi.security.Identity; +import com.facebook.presto.spi.type.ArrayType; +import com.facebook.presto.spi.type.MapType; +import com.facebook.presto.spi.type.RowType; import com.facebook.presto.spi.type.Type; import com.facebook.presto.spi.type.TypeSignature; import com.facebook.presto.sql.ExpressionUtils; import com.facebook.presto.sql.parser.ParsingException; import com.facebook.presto.sql.parser.SqlParser; -import com.facebook.presto.sql.planner.DependencyExtractor; import com.facebook.presto.sql.planner.ExpressionInterpreter; +import com.facebook.presto.sql.planner.SymbolsExtractor; import com.facebook.presto.sql.tree.AddColumn; import com.facebook.presto.sql.tree.AliasedRelation; import com.facebook.presto.sql.tree.AllColumns; @@ -50,6 +53,7 @@ import com.facebook.presto.sql.tree.DefaultTraversalVisitor; import com.facebook.presto.sql.tree.Delete; import com.facebook.presto.sql.tree.DereferenceExpression; +import com.facebook.presto.sql.tree.DropColumn; import com.facebook.presto.sql.tree.DropSchema; import com.facebook.presto.sql.tree.DropTable; import com.facebook.presto.sql.tree.DropView; @@ -73,6 +77,7 @@ import com.facebook.presto.sql.tree.JoinCriteria; import com.facebook.presto.sql.tree.JoinOn; import com.facebook.presto.sql.tree.JoinUsing; +import com.facebook.presto.sql.tree.Lateral; import com.facebook.presto.sql.tree.LongLiteral; import com.facebook.presto.sql.tree.NaturalJoin; import com.facebook.presto.sql.tree.Node; @@ -109,9 +114,6 @@ import com.facebook.presto.sql.tree.With; import com.facebook.presto.sql.tree.WithQuery; import com.facebook.presto.sql.util.AstUtils; -import com.facebook.presto.type.ArrayType; -import com.facebook.presto.type.MapType; -import com.facebook.presto.type.RowType; import com.google.common.base.Joiner; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; @@ -194,6 +196,7 @@ import static com.google.common.collect.Iterables.transform; import static java.lang.Math.toIntExact; import static java.util.Collections.emptyList; +import static java.util.Locale.ENGLISH; import static java.util.Objects.requireNonNull; class StatementAnalyzer @@ -506,6 +509,12 @@ protected Scope visitRenameColumn(RenameColumn node, Optional scope) return createAndAssignScope(node, scope); } + @Override + protected Scope visitDropColumn(DropColumn node, Optional scope) + { + return createAndAssignScope(node, scope); + } + @Override protected Scope visitDropView(DropView node, Optional scope) { @@ -645,6 +654,14 @@ else if (expressionType instanceof MapType) { return createAndAssignScope(node, scope, outputFields.build()); } + @Override + protected Scope visitLateral(Lateral node, Optional scope) + { + StatementAnalyzer analyzer = new StatementAnalyzer(analysis, metadata, sqlParser, accessControl, session); + Scope queryScope = analyzer.analyze(node.getQuery(), scope); + return createAndAssignScope(node, scope, queryScope.getRelationType()); + } + @Override protected Scope visitTable(Table table, Optional scope) { @@ -803,7 +820,7 @@ protected Scope visitAliasedRelation(AliasedRelation relation, Optional s @Override protected Scope visitSampledRelation(SampledRelation relation, Optional scope) { - if (!DependencyExtractor.extractNames(relation.getSamplePercentage(), analysis.getColumnReferences()).isEmpty()) { + if (!SymbolsExtractor.extractNames(relation.getSamplePercentage(), analysis.getColumnReferences()).isEmpty()) { throw new SemanticException(NON_NUMERIC_SAMPLE_PERCENTAGE, relation.getSamplePercentage(), "Sample percentage cannot contain column references"); } @@ -891,7 +908,7 @@ protected Scope visitQuerySpecification(QuerySpecification node, Optional // Original ORDER BY scope "sees" FROM query fields. However, during planning // and when aggregation is present, ORDER BY expressions should only be resolvable against // output scope, group by expressions and aggregation expressions. - computeAndAssignOrderByScopeWithAggregation(node.getOrderBy().get(), outputScope, aggregations, groupByExpressions, analysis.getGroupingOperations(node)); + computeAndAssignOrderByScopeWithAggregation(node.getOrderBy().get(), sourceScope, outputScope, aggregations, groupByExpressions, analysis.getGroupingOperations(node)); } return outputScope; @@ -952,7 +969,7 @@ protected Scope visitSetOperation(SetOperation node, Optional scope) int outputFieldSize = outputFieldTypes.length; RelationType relationType = relationScope.getRelationType(); int descFieldSize = relationType.getVisibleFields().size(); - String setOperationName = node.getClass().getSimpleName(); + String setOperationName = node.getClass().getSimpleName().toUpperCase(ENGLISH); if (outputFieldSize != descFieldSize) { throw new SemanticException(MISMATCHED_SET_COLUMN_TYPES, node, @@ -966,7 +983,7 @@ protected Scope visitSetOperation(SetOperation node, Optional scope) throw new SemanticException(TYPE_MISMATCH, node, "column %d in %s query has incompatible types: %s, %s", - i, outputFieldTypes[i].getDisplayName(), setOperationName, descFieldType.getDisplayName()); + i, setOperationName, outputFieldTypes[i].getDisplayName(), descFieldType.getDisplayName()); } outputFieldTypes[i] = commonSuperType.get(); } @@ -1030,7 +1047,7 @@ protected Scope visitJoin(Join node, Optional scope) } Scope left = process(node.getLeft(), scope); - Scope right = process(node.getRight(), isUnnestRelation(node.getRight()) ? Optional.of(left) : scope); + Scope right = process(node.getRight(), isLateralRelation(node.getRight()) ? Optional.of(left) : scope); Scope output = createAndAssignScope(node, scope, left.getRelationType().joinWith(right.getRelationType())); @@ -1086,12 +1103,12 @@ else if (criteria instanceof JoinOn) { return output; } - private boolean isUnnestRelation(Relation node) + private boolean isLateralRelation(Relation node) { if (node instanceof AliasedRelation) { - return isUnnestRelation(((AliasedRelation) node).getRelation()); + return isLateralRelation(((AliasedRelation) node).getRelation()); } - return node instanceof Unnest; + return node instanceof Unnest || node instanceof Lateral; } private void addCoercionForJoinCriteria(Join node, Expression leftExpression, Expression rightExpression) @@ -1572,7 +1589,7 @@ private Scope computeAndAssignOrderByScope(OrderBy node, Scope sourceScope, Scop return orderByScope; } - private Scope computeAndAssignOrderByScopeWithAggregation(OrderBy node, Scope outputScope, List aggregations, List> groupByExpressions, List groupingOperations) + private Scope computeAndAssignOrderByScopeWithAggregation(OrderBy node, Scope sourceScope, Scope outputScope, List aggregations, List> groupByExpressions, List groupingOperations) { // This scope is only used for planning. When aggregation is present then // only output fields, groups and aggregation expressions should be visible from ORDER BY expression @@ -1583,20 +1600,27 @@ private Scope computeAndAssignOrderByScopeWithAggregation(OrderBy node, Scope ou orderByAggregationExpressionsBuilder.addAll(aggregations); orderByAggregationExpressionsBuilder.addAll(groupingOperations); - // Don't add aggregate expression that contains references to output column because the names would clash in TranslationMap during planning. + // Don't add aggregate complex expressions that contains references to output column because the names would clash in TranslationMap during planning. List orderByExpressionsReferencingOutputScope = AstUtils.preOrder(node) .filter(Expression.class::isInstance) .map(Expression.class::cast) .filter(expression -> hasReferencesToScope(expression, analysis, outputScope)) .collect(toImmutableList()); List orderByAggregationExpressions = orderByAggregationExpressionsBuilder.build().stream() - .filter(expression -> !orderByExpressionsReferencingOutputScope.contains(expression)) + .filter(expression -> !orderByExpressionsReferencingOutputScope.contains(expression) || analysis.getColumnReferences().contains(NodeRef.of(expression))) .collect(toImmutableList()); // generate placeholder fields + Set seen = new HashSet<>(); List orderByAggregationSourceFields = orderByAggregationExpressions.stream() - .map(analysis::getType) - .map(type -> Field.newUnqualified(Optional.empty(), type)) + .map(expression -> { + // generate qualified placeholder field for GROUP BY expressions that are column references + Optional sourceField = sourceScope.tryResolveField(expression) + .filter(resolvedField -> seen.add(resolvedField.getField())) + .map(ResolvedField::getField); + return sourceField + .orElse(Field.newUnqualified(Optional.empty(), analysis.getType(expression))); + }) .collect(toImmutableList()); Scope orderByAggregationScope = Scope.builder() diff --git a/presto-main/src/main/java/com/facebook/presto/sql/gen/JoinCompiler.java b/presto-main/src/main/java/com/facebook/presto/sql/gen/JoinCompiler.java index 0e57f454bc2b..699111aa2bb6 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/gen/JoinCompiler.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/gen/JoinCompiler.java @@ -296,8 +296,7 @@ private static void generateConstructor(ClassDefinition classDefinition, .append( channel.invoke("get", Object.class, blockIndex) .cast(type(Block.class)) - .invoke("getRetainedSizeInBytes", int.class) - .cast(long.class)) + .invoke("getRetainedSizeInBytes", long.class)) .longAdd() .putField(sizeField); } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/gen/LambdaBytecodeGenerator.java b/presto-main/src/main/java/com/facebook/presto/sql/gen/LambdaBytecodeGenerator.java index ca172e4abdb4..98ae92e55578 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/gen/LambdaBytecodeGenerator.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/gen/LambdaBytecodeGenerator.java @@ -64,6 +64,7 @@ import static com.facebook.presto.bytecode.expression.BytecodeExpressions.newArray; import static com.facebook.presto.bytecode.expression.BytecodeExpressions.setStatic; import static com.facebook.presto.spi.StandardErrorCode.COMPILER_ERROR; +import static com.facebook.presto.spi.StandardErrorCode.NOT_SUPPORTED; import static com.facebook.presto.sql.gen.BytecodeUtils.boxPrimitiveIfNecessary; import static com.facebook.presto.sql.gen.BytecodeUtils.unboxPrimitiveIfNecessary; import static com.facebook.presto.sql.gen.LambdaCapture.LAMBDA_CAPTURE_METHOD; @@ -125,6 +126,7 @@ private static CompiledLambda defineLambdaMethodAndField( List inputParameters, LambdaDefinitionExpression lambda) { + checkCondition(inputParameters.size() <= 254, NOT_SUPPORTED, "Too many arguments for lambda expression"); Class returnType = Primitives.wrap(lambda.getBody().getType().getJavaType()); MethodDefinition method = classDefinition.declareMethod(a(PUBLIC), fieldAndMethodName, type(returnType), inputParameters); diff --git a/presto-main/src/main/java/com/facebook/presto/sql/gen/TryCodeGenerator.java b/presto-main/src/main/java/com/facebook/presto/sql/gen/TryCodeGenerator.java index d7e86add779b..68cdde929eb6 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/gen/TryCodeGenerator.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/gen/TryCodeGenerator.java @@ -42,10 +42,12 @@ import static com.facebook.presto.spi.StandardErrorCode.DIVISION_BY_ZERO; import static com.facebook.presto.spi.StandardErrorCode.INVALID_CAST_ARGUMENT; import static com.facebook.presto.spi.StandardErrorCode.INVALID_FUNCTION_ARGUMENT; +import static com.facebook.presto.spi.StandardErrorCode.NOT_SUPPORTED; import static com.facebook.presto.spi.StandardErrorCode.NUMERIC_VALUE_OUT_OF_RANGE; import static com.facebook.presto.sql.gen.BytecodeUtils.boxPrimitiveIfNecessary; import static com.facebook.presto.sql.gen.BytecodeUtils.invoke; import static com.facebook.presto.sql.gen.BytecodeUtils.unboxPrimitiveIfNecessary; +import static com.facebook.presto.util.Failures.checkCondition; import static com.facebook.presto.util.Reflection.methodHandle; import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkState; @@ -95,6 +97,7 @@ public static MethodDefinition defineTryMethod( RowExpression innerRowExpression, CallSiteBinder callSiteBinder) { + checkCondition(inputParameters.size() <= 254, NOT_SUPPORTED, "Too many arguments for method"); MethodDefinition method = classDefinition.declareMethod(a(PUBLIC), methodName, type(returnType), inputParameters); Scope calleeMethodScope = method.getScope(); diff --git a/presto-main/src/main/java/com/facebook/presto/sql/gen/VarArgsToArrayAdapterGenerator.java b/presto-main/src/main/java/com/facebook/presto/sql/gen/VarArgsToArrayAdapterGenerator.java index 13ed4e55ea69..a2416105e6b2 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/gen/VarArgsToArrayAdapterGenerator.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/gen/VarArgsToArrayAdapterGenerator.java @@ -37,7 +37,9 @@ import static com.facebook.presto.bytecode.ParameterizedType.type; import static com.facebook.presto.bytecode.expression.BytecodeExpressions.newArray; import static com.facebook.presto.bytecode.expression.BytecodeExpressions.newInstance; +import static com.facebook.presto.spi.StandardErrorCode.NOT_SUPPORTED; import static com.facebook.presto.sql.gen.BytecodeUtils.loadConstant; +import static com.facebook.presto.util.Failures.checkCondition; import static com.google.common.base.Preconditions.checkArgument; import static java.util.Collections.nCopies; import static java.util.Objects.requireNonNull; @@ -102,6 +104,8 @@ public static MethodHandleAndConstructor generateVarArgsToArrayAdapter( requireNonNull(function, "function is null"); requireNonNull(userStateFactory, "userStateFactory is null"); + checkCondition(argsLength <= 253, NOT_SUPPORTED, "Too many arguments for vararg function"); + MethodType methodType = function.type(); Class javaArrayType = toArrayClass(javaType); checkArgument(methodType.returnType() == returnType, "returnType does not match"); diff --git a/presto-main/src/main/java/com/facebook/presto/sql/gen/VarArgsToMapAdapterGenerator.java b/presto-main/src/main/java/com/facebook/presto/sql/gen/VarArgsToMapAdapterGenerator.java index 52d77776ce3f..9d8f4dcd3ff4 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/gen/VarArgsToMapAdapterGenerator.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/gen/VarArgsToMapAdapterGenerator.java @@ -40,7 +40,9 @@ import static com.facebook.presto.bytecode.expression.BytecodeExpressions.constantInt; import static com.facebook.presto.bytecode.expression.BytecodeExpressions.constantString; import static com.facebook.presto.bytecode.expression.BytecodeExpressions.invokeStatic; +import static com.facebook.presto.spi.StandardErrorCode.NOT_SUPPORTED; import static com.facebook.presto.sql.gen.BytecodeUtils.loadConstant; +import static com.facebook.presto.util.Failures.checkCondition; public class VarArgsToMapAdapterGenerator { @@ -59,6 +61,7 @@ private VarArgsToMapAdapterGenerator() */ public static MethodHandle generateVarArgsToMapAdapter(Class returnType, List> javaTypes, List names, Function, Object> function) { + checkCondition(javaTypes.size() <= 254, NOT_SUPPORTED, "Too many arguments for vararg function"); CallSiteBinder callSiteBinder = new CallSiteBinder(); ClassDefinition classDefinition = new ClassDefinition(a(PUBLIC, FINAL), makeClassName("VarArgsToMapAdapter"), type(Object.class)); diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/EffectivePredicateExtractor.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/EffectivePredicateExtractor.java index 285d3e58863e..6d6d8345962b 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/EffectivePredicateExtractor.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/EffectivePredicateExtractor.java @@ -263,7 +263,7 @@ private static Iterable pullNullableConjunctsThroughOuterJoin(List pullExpressionThroughSymbols(expression, outputSymbols)) - .map(expression -> DependencyExtractor.extractAll(expression).isEmpty() ? TRUE_LITERAL : expression) + .map(expression -> SymbolsExtractor.extractAll(expression).isEmpty() ? TRUE_LITERAL : expression) .map(expressionOrNullSymbols(nullSymbolScopes)) .collect(toImmutableList()); } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/EqualityInference.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/EqualityInference.java index 83b3c49674fd..f71a6346eb8b 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/EqualityInference.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/EqualityInference.java @@ -65,7 +65,7 @@ public int compare(Expression expression1, Expression expression2) // 3) Sort the expressions alphabetically - creates a stable consistent ordering (extremely useful for unit testing) // TODO: be more precise in determining the cost of an expression return ComparisonChain.start() - .compare(DependencyExtractor.extractAll(expression1).size(), DependencyExtractor.extractAll(expression2).size()) + .compare(SymbolsExtractor.extractAll(expression1).size(), SymbolsExtractor.extractAll(expression2).size()) .compare(SubExpressionExtractor.extract(expression1).size(), SubExpressionExtractor.extract(expression2).size()) .compare(expression1.toString(), expression2.toString()) .result(); @@ -244,7 +244,7 @@ Expression getScopedCanonical(Expression expression, Predicate symbolSco private static Predicate symbolToExpressionPredicate(final Predicate symbolScope) { - return expression -> Iterables.all(DependencyExtractor.extractUnique(expression), symbolScope); + return expression -> Iterables.all(SymbolsExtractor.extractUnique(expression), symbolScope); } /** diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/ExpressionInterpreter.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/ExpressionInterpreter.java index 9bb81a58f225..09f859a22518 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/ExpressionInterpreter.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/ExpressionInterpreter.java @@ -28,6 +28,9 @@ import com.facebook.presto.spi.block.BlockBuilderStatus; import com.facebook.presto.spi.block.InterleavedBlockBuilder; import com.facebook.presto.spi.function.OperatorType; +import com.facebook.presto.spi.type.ArrayType; +import com.facebook.presto.spi.type.RowType; +import com.facebook.presto.spi.type.RowType.RowField; import com.facebook.presto.spi.type.StandardTypes; import com.facebook.presto.spi.type.Type; import com.facebook.presto.spi.type.TypeManager; @@ -83,11 +86,8 @@ import com.facebook.presto.sql.tree.SymbolReference; import com.facebook.presto.sql.tree.TryExpression; import com.facebook.presto.sql.tree.WhenClause; -import com.facebook.presto.type.ArrayType; import com.facebook.presto.type.FunctionType; import com.facebook.presto.type.LikeFunctions; -import com.facebook.presto.type.RowType; -import com.facebook.presto.type.RowType.RowField; import com.facebook.presto.util.Failures; import com.facebook.presto.util.FastutilSetHelper; import com.google.common.annotations.VisibleForTesting; @@ -128,6 +128,7 @@ import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkState; import static com.google.common.base.Predicates.instanceOf; +import static com.google.common.base.Throwables.throwIfInstanceOf; import static com.google.common.base.Verify.verify; import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.collect.Iterables.any; @@ -760,8 +761,8 @@ protected Object visitArithmeticUnary(ArithmeticUnaryExpression node, Object con return handle.invokeWithArguments(value); } catch (Throwable throwable) { - Throwables.propagateIfInstanceOf(throwable, RuntimeException.class); - Throwables.propagateIfInstanceOf(throwable, Error.class); + throwIfInstanceOf(throwable, RuntimeException.class); + throwIfInstanceOf(throwable, Error.class); throw new RuntimeException(throwable.getMessage(), throwable); } } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/LiteralInterpreter.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/LiteralInterpreter.java index 8f8b61eb2c2f..a17813f551b4 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/LiteralInterpreter.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/LiteralInterpreter.java @@ -75,6 +75,7 @@ import static com.google.common.base.Preconditions.checkArgument; import static io.airlift.slice.Slices.utf8Slice; import static java.lang.Float.intBitsToFloat; +import static java.lang.Math.toIntExact; import static java.util.Objects.requireNonNull; public final class LiteralInterpreter @@ -193,7 +194,7 @@ else if (value.equals(Float.POSITIVE_INFINITY)) { } if (object instanceof Block) { - SliceOutput output = new DynamicSliceOutput(((Block) object).getSizeInBytes()); + SliceOutput output = new DynamicSliceOutput(toIntExact(((Block) object).getSizeInBytes())); BlockSerdeUtil.writeBlock(output, (Block) object); object = output.slice(); // This if condition will evaluate to true: object instanceof Slice && !type.equals(VARCHAR) diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/LocalExecutionPlanner.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/LocalExecutionPlanner.java index 5aaebc24afa1..1c94317f6899 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/LocalExecutionPlanner.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/LocalExecutionPlanner.java @@ -608,7 +608,7 @@ public PhysicalOperation visitExplainAnalyze(ExplainAnalyzeNode node, LocalExecu checkState(queryPerformanceFetcher.isPresent(), "ExplainAnalyze can only run on coordinator"); PhysicalOperation source = node.getSource().accept(this, context); - OperatorFactory operatorFactory = new ExplainAnalyzeOperatorFactory(context.getNextOperatorId(), node.getId(), queryPerformanceFetcher.get(), metadata, costCalculator); + OperatorFactory operatorFactory = new ExplainAnalyzeOperatorFactory(context.getNextOperatorId(), node.getId(), queryPerformanceFetcher.get(), metadata, costCalculator, node.isVerbose()); return new PhysicalOperation(operatorFactory, makeLayout(node), source); } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/LogicalPlanner.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/LogicalPlanner.java index 3f7ed4be01a5..c62c93e331c3 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/LogicalPlanner.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/LogicalPlanner.java @@ -125,6 +125,8 @@ public Plan plan(Analysis analysis, Stage stage) { PlanNode root = planStatement(analysis, analysis.getStatement()); + PlanSanityChecker.validateIntermediatePlan(root, session, metadata, sqlParser, symbolAllocator.getTypes()); + if (stage.ordinal() >= Stage.OPTIMIZED.ordinal()) { for (PlanOptimizer optimizer : planOptimizers) { root = optimizer.optimize(root, session, symbolAllocator.getTypes(), symbolAllocator, idAllocator); @@ -134,7 +136,7 @@ public Plan plan(Analysis analysis, Stage stage) if (stage.ordinal() >= Stage.OPTIMIZED_AND_VALIDATED.ordinal()) { // make sure we produce a valid plan after optimizations run. This is mainly to catch programming errors - PlanSanityChecker.validate(root, session, metadata, sqlParser, symbolAllocator.getTypes()); + PlanSanityChecker.validateFinalPlan(root, session, metadata, sqlParser, symbolAllocator.getTypes()); } Map planNodeCosts = costCalculator.calculateCostForPlan(session, symbolAllocator.getTypes(), root); @@ -185,7 +187,7 @@ private RelationPlan createExplainAnalyzePlan(Analysis analysis, Explain stateme PlanNode root = underlyingPlan.getRoot(); Scope scope = analysis.getScope(statement); Symbol outputSymbol = symbolAllocator.newSymbol(scope.getRelationType().getFieldByIndex(0)); - root = new ExplainAnalyzeNode(idAllocator.getNextId(), root, outputSymbol); + root = new ExplainAnalyzeNode(idAllocator.getNextId(), root, outputSymbol, statement.isVerbose()); return new RelationPlan(root, scope, ImmutableList.of(outputSymbol)); } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/PlanOptimizers.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/PlanOptimizers.java index a99da8b05497..ba97582c497d 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/PlanOptimizers.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/PlanOptimizers.java @@ -19,6 +19,10 @@ import com.facebook.presto.sql.planner.iterative.IterativeOptimizer; import com.facebook.presto.sql.planner.iterative.Rule; import com.facebook.presto.sql.planner.iterative.rule.AddIntermediateAggregations; +import com.facebook.presto.sql.planner.iterative.rule.CanonicalizeFilterExpressions; +import com.facebook.presto.sql.planner.iterative.rule.CanonicalizeJoinExpressions; +import com.facebook.presto.sql.planner.iterative.rule.CanonicalizeProjectExpressions; +import com.facebook.presto.sql.planner.iterative.rule.CanonicalizeTableScanExpressions; import com.facebook.presto.sql.planner.iterative.rule.CreatePartialTopN; import com.facebook.presto.sql.planner.iterative.rule.EliminateCrossJoins; import com.facebook.presto.sql.planner.iterative.rule.EvaluateZeroLimit; @@ -32,6 +36,15 @@ import com.facebook.presto.sql.planner.iterative.rule.MergeLimitWithSort; import com.facebook.presto.sql.planner.iterative.rule.MergeLimitWithTopN; import com.facebook.presto.sql.planner.iterative.rule.MergeLimits; +import com.facebook.presto.sql.planner.iterative.rule.PruneCountAggregationOverScalar; +import com.facebook.presto.sql.planner.iterative.rule.PruneCrossJoinColumns; +import com.facebook.presto.sql.planner.iterative.rule.PruneIndexSourceColumns; +import com.facebook.presto.sql.planner.iterative.rule.PruneJoinChildrenColumns; +import com.facebook.presto.sql.planner.iterative.rule.PruneJoinColumns; +import com.facebook.presto.sql.planner.iterative.rule.PruneMarkDistinctColumns; +import com.facebook.presto.sql.planner.iterative.rule.PruneOutputColumns; +import com.facebook.presto.sql.planner.iterative.rule.PruneSemiJoinColumns; +import com.facebook.presto.sql.planner.iterative.rule.PruneSemiJoinFilteringSourceColumns; import com.facebook.presto.sql.planner.iterative.rule.PruneTableScanColumns; import com.facebook.presto.sql.planner.iterative.rule.PruneValuesColumns; import com.facebook.presto.sql.planner.iterative.rule.PushAggregationThroughOuterJoin; @@ -40,10 +53,13 @@ import com.facebook.presto.sql.planner.iterative.rule.PushLimitThroughSemiJoin; import com.facebook.presto.sql.planner.iterative.rule.PushProjectionThroughExchange; import com.facebook.presto.sql.planner.iterative.rule.PushProjectionThroughUnion; +import com.facebook.presto.sql.planner.iterative.rule.PushTableWriteThroughUnion; import com.facebook.presto.sql.planner.iterative.rule.PushTopNThroughUnion; import com.facebook.presto.sql.planner.iterative.rule.RemoveEmptyDelete; import com.facebook.presto.sql.planner.iterative.rule.RemoveFullSample; import com.facebook.presto.sql.planner.iterative.rule.RemoveRedundantIdentityProjections; +import com.facebook.presto.sql.planner.iterative.rule.RemoveTrivialFilters; +import com.facebook.presto.sql.planner.iterative.rule.RemoveUnreferencedScalarApplyNodes; import com.facebook.presto.sql.planner.iterative.rule.SimplifyCountOverConstant; import com.facebook.presto.sql.planner.iterative.rule.SingleMarkDistinctToGroupBy; import com.facebook.presto.sql.planner.iterative.rule.SwapAdjacentWindowsBySpecifications; @@ -68,11 +84,12 @@ import com.facebook.presto.sql.planner.optimizations.PredicatePushDown; import com.facebook.presto.sql.planner.optimizations.ProjectionPushDown; import com.facebook.presto.sql.planner.optimizations.PruneUnreferencedOutputs; -import com.facebook.presto.sql.planner.optimizations.PushTableWriteThroughUnion; import com.facebook.presto.sql.planner.optimizations.RemoveUnreferencedScalarLateralNodes; import com.facebook.presto.sql.planner.optimizations.SetFlatteningOptimizer; import com.facebook.presto.sql.planner.optimizations.SimplifyExpressions; +import com.facebook.presto.sql.planner.optimizations.TransformCorrelatedNoAggregationSubqueryToJoin; import com.facebook.presto.sql.planner.optimizations.TransformCorrelatedScalarAggregationToJoin; +import com.facebook.presto.sql.planner.optimizations.TransformCorrelatedSingleRowSubqueryToProject; import com.facebook.presto.sql.planner.optimizations.TransformQuantifiedComparisonApplyToLateralJoin; import com.facebook.presto.sql.planner.optimizations.TransformUncorrelatedInPredicateSubqueryToSemiJoin; import com.facebook.presto.sql.planner.optimizations.TransformUncorrelatedLateralToJoin; @@ -121,6 +138,19 @@ public PlanOptimizers(Metadata metadata, SqlParser sqlParser, FeaturesConfig fea Set predicatePushDownRules = ImmutableSet.of( new MergeFilters()); + // TODO: Once we've migrated handling all the plan node types, replace uses of PruneUnreferencedOutputs with an IterativeOptimizer containing these rules. + Set columnPruningRules = ImmutableSet.of( + new PruneCrossJoinColumns(), + new PruneIndexSourceColumns(), + new PruneJoinChildrenColumns(), + new PruneJoinColumns(), + new PruneMarkDistinctColumns(), + new PruneOutputColumns(), + new PruneSemiJoinColumns(), + new PruneSemiJoinFilteringSourceColumns(), + new PruneValuesColumns(), + new PruneTableScanColumns()); + IterativeOptimizer inlineProjections = new IterativeOptimizer( stats, ImmutableSet.of( @@ -136,11 +166,21 @@ public PlanOptimizers(Metadata metadata, SqlParser sqlParser, FeaturesConfig fea builder.add( new DesugaringOptimizer(metadata, sqlParser), // Clean up all the sugar in expressions, e.g. AtTimeZone, must be run before all the other optimizers - new CanonicalizeExpressions(), + new IterativeOptimizer( + stats, + ImmutableList.of(new CanonicalizeExpressions()), + ImmutableSet.of( + new CanonicalizeJoinExpressions(), + new CanonicalizeProjectExpressions(), + new CanonicalizeFilterExpressions(), + new CanonicalizeTableScanExpressions() + ) + ), new IterativeOptimizer( stats, ImmutableSet.builder() .addAll(predicatePushDownRules) + .addAll(columnPruningRules) .addAll(ImmutableSet.of( new RemoveRedundantIdentityProjections(), new RemoveFullSample(), @@ -152,16 +192,13 @@ public PlanOptimizers(Metadata metadata, SqlParser sqlParser, FeaturesConfig fea new MergeLimitWithTopN(), new PushLimitThroughMarkDistinct(), new PushLimitThroughSemiJoin(), + new RemoveTrivialFilters(), + new ImplementFilteredAggregations(), + new ImplementBernoulliSampleAsFilter(), new MergeLimitWithDistinct(), - new PruneValuesColumns(), - new PruneTableScanColumns())) + new PruneCountAggregationOverScalar())) .build() ), - new IterativeOptimizer( - stats, - ImmutableSet.of( - new ImplementFilteredAggregations(), - new ImplementBernoulliSampleAsFilter())), new SimplifyExpressions(metadata, sqlParser), new UnaliasSymbolReferences(), new IterativeOptimizer( @@ -177,19 +214,26 @@ public PlanOptimizers(Metadata metadata, SqlParser sqlParser, FeaturesConfig fea stats, ImmutableSet.of(new TransformExistsApplyToLateralNode(metadata.getFunctionRegistry()))), new TransformQuantifiedComparisonApplyToLateralJoin(metadata), - new RemoveUnreferencedScalarLateralNodes(), - new TransformUncorrelatedInPredicateSubqueryToSemiJoin(), - new TransformUncorrelatedLateralToJoin(), - new IterativeOptimizer( - stats, - ImmutableList.of(new TransformCorrelatedScalarAggregationToJoin(metadata.getFunctionRegistry())), - ImmutableSet.of(new com.facebook.presto.sql.planner.iterative.rule.TransformCorrelatedScalarAggregationToJoin(metadata.getFunctionRegistry()))), + new IterativeOptimizer(stats, + ImmutableList.of( + new RemoveUnreferencedScalarLateralNodes(), + new TransformUncorrelatedLateralToJoin(), + new TransformUncorrelatedInPredicateSubqueryToSemiJoin(), + new TransformCorrelatedScalarAggregationToJoin(metadata.getFunctionRegistry())), + ImmutableSet.of( + new com.facebook.presto.sql.planner.iterative.rule.RemoveUnreferencedScalarLateralNodes(), + new com.facebook.presto.sql.planner.iterative.rule.TransformUncorrelatedLateralToJoin(), + new com.facebook.presto.sql.planner.iterative.rule.TransformUncorrelatedInPredicateSubqueryToSemiJoin(), + new com.facebook.presto.sql.planner.iterative.rule.TransformCorrelatedScalarAggregationToJoin(metadata.getFunctionRegistry()))), new IterativeOptimizer( stats, ImmutableSet.of( + new RemoveUnreferencedScalarApplyNodes(), new TransformCorrelatedInPredicateToJoin(), // must be run after PruneUnreferencedOutputs new ImplementFilteredAggregations()) ), + new TransformCorrelatedNoAggregationSubqueryToJoin(), + new TransformCorrelatedSingleRowSubqueryToProject(), new PredicatePushDown(metadata, sqlParser), new PruneUnreferencedOutputs(), new IterativeOptimizer( @@ -247,7 +291,12 @@ public PlanOptimizers(Metadata metadata, SqlParser sqlParser, FeaturesConfig fea if (!forceSingleNode) { builder.add(new DetermineJoinDistributionType()); // Must run before AddExchanges - builder.add(new PushTableWriteThroughUnion()); // Must run before AddExchanges + builder.add( + new IterativeOptimizer( + stats, + ImmutableList.of(new com.facebook.presto.sql.planner.optimizations.PushTableWriteThroughUnion()), // Must run before AddExchanges + ImmutableSet.of(new PushTableWriteThroughUnion()) + )); builder.add(new AddExchanges(metadata, sqlParser)); } @@ -279,7 +328,6 @@ public PlanOptimizers(Metadata metadata, SqlParser sqlParser, FeaturesConfig fea ImmutableSet.of( new AddIntermediateAggregations(), new RemoveRedundantIdentityProjections()))); - // DO NOT add optimizers that change the plan shape (computations) after this point // Precomputed hashes - this assumes that partitioning will not change diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/QueryPlanner.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/QueryPlanner.java index 6baf1e9888f9..b39ac8a1afc5 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/QueryPlanner.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/QueryPlanner.java @@ -70,7 +70,6 @@ import java.util.ArrayList; import java.util.Collection; import java.util.HashMap; -import java.util.HashSet; import java.util.Iterator; import java.util.LinkedHashMap; import java.util.List; @@ -78,7 +77,6 @@ import java.util.Optional; import java.util.Set; import java.util.stream.Collectors; -import java.util.stream.StreamSupport; import static com.facebook.presto.spi.type.BigintType.BIGINT; import static com.facebook.presto.spi.type.BooleanType.BOOLEAN; @@ -89,6 +87,7 @@ import static com.google.common.collect.ImmutableMap.toImmutableMap; import static com.google.common.collect.ImmutableSet.toImmutableSet; import static com.google.common.collect.Iterables.getOnlyElement; +import static com.google.common.collect.Streams.stream; import static java.util.Objects.requireNonNull; class QueryPlanner @@ -171,7 +170,10 @@ public RelationPlan plan(QuerySpecification node) List orderByAggregates = analysis.getOrderByAggregates(node.getOrderBy().get()); builder = project(builder, Iterables.concat(outputs, orderByAggregates)); outputs = toSymbolReferences(computeOutputs(builder, outputs)); - builder = planBuilderFor(builder, analysis.getScope(node.getOrderBy().get()), orderByAggregates); + List complexOrderByAggregatesToRemap = orderByAggregates.stream() + .filter(expression -> !analysis.getColumnReferences().contains(NodeRef.of(expression))) + .collect(toImmutableList()); + builder = planBuilderFor(builder, analysis.getScope(node.getOrderBy().get()), complexOrderByAggregatesToRemap); } builder = window(builder, node.getOrderBy().get()); @@ -407,8 +409,9 @@ private PlanBuilder explicitCoercionSymbols(PlanBuilder subPlan, Iterable toSymbolReferences(List symbols) private static Map symbolsForExpressions(PlanBuilder builder, Iterable expressions) { - Set added = new HashSet<>(); - return StreamSupport.stream(expressions.spliterator(), false) - .filter(added::add) + return stream(expressions) + .distinct() .collect(toImmutableMap(expression -> expression, builder::translate)); } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/RelationPlanner.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/RelationPlanner.java index 7129317e68bc..6ce3c4ab570f 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/RelationPlanner.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/RelationPlanner.java @@ -18,6 +18,8 @@ import com.facebook.presto.metadata.TableHandle; import com.facebook.presto.spi.ColumnHandle; import com.facebook.presto.spi.predicate.TupleDomain; +import com.facebook.presto.spi.type.ArrayType; +import com.facebook.presto.spi.type.MapType; import com.facebook.presto.spi.type.Type; import com.facebook.presto.sql.ExpressionUtils; import com.facebook.presto.sql.analyzer.Analysis; @@ -51,6 +53,7 @@ import com.facebook.presto.sql.tree.Join; import com.facebook.presto.sql.tree.JoinUsing; import com.facebook.presto.sql.tree.LambdaArgumentDeclaration; +import com.facebook.presto.sql.tree.Lateral; import com.facebook.presto.sql.tree.NodeRef; import com.facebook.presto.sql.tree.QualifiedName; import com.facebook.presto.sql.tree.Query; @@ -65,8 +68,6 @@ import com.facebook.presto.sql.tree.Union; import com.facebook.presto.sql.tree.Unnest; import com.facebook.presto.sql.tree.Values; -import com.facebook.presto.type.ArrayType; -import com.facebook.presto.type.MapType; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableListMultimap; import com.google.common.collect.ImmutableMap; @@ -184,19 +185,20 @@ protected RelationPlan visitJoin(Join node, Void context) // TODO: translate the RIGHT join into a mirrored LEFT join when we refactor (@martint) RelationPlan leftPlan = process(node.getLeft(), context); - // Convert CROSS JOIN UNNEST to an UnnestNode - if (node.getRight() instanceof Unnest || (node.getRight() instanceof AliasedRelation && ((AliasedRelation) node.getRight()).getRelation() instanceof Unnest)) { - Unnest unnest; - if (node.getRight() instanceof AliasedRelation) { - unnest = (Unnest) ((AliasedRelation) node.getRight()).getRelation(); - } - else { - unnest = (Unnest) node.getRight(); + Optional unnest = getUnnest(node.getRight()); + if (unnest.isPresent()) { + if (node.getType() != Join.Type.CROSS && node.getType() != Join.Type.IMPLICIT) { + throw notSupportedException(unnest.get(), "UNNEST on other than the right side of CROSS JOIN"); } + return planCrossJoinUnnest(leftPlan, node, unnest.get()); + } + + Optional lateral = getLateral(node.getRight()); + if (lateral.isPresent()) { if (node.getType() != Join.Type.CROSS && node.getType() != Join.Type.IMPLICIT) { - throw notSupportedException(unnest, "UNNEST on other than the right side of CROSS JOIN"); + throw notSupportedException(lateral.get(), "LATERAL on other than the right side of CROSS JOIN"); } - return planCrossJoinUnnest(leftPlan, node, unnest); + return planLateralJoin(node, leftPlan, lateral.get()); } RelationPlan rightPlan = process(node.getRight(), context); @@ -232,7 +234,7 @@ protected RelationPlan visitJoin(Join node, Void context) continue; } - Set dependencies = DependencyExtractor.extractNames(conjunct, analysis.getColumnReferences()); + Set dependencies = SymbolsExtractor.extractNames(conjunct, analysis.getColumnReferences()); boolean isJoinUsing = node.getCriteria().filter(JoinUsing.class::isInstance).isPresent(); if (!isJoinUsing && (dependencies.stream().allMatch(left::canResolve) || dependencies.stream().allMatch(right::canResolve))) { // If the conjunct can be evaluated entirely with the inputs on either side of the join, add @@ -246,8 +248,8 @@ else if (conjunct instanceof ComparisonExpression) { Expression firstExpression = ((ComparisonExpression) conjunct).getLeft(); Expression secondExpression = ((ComparisonExpression) conjunct).getRight(); ComparisonExpressionType comparisonType = ((ComparisonExpression) conjunct).getType(); - Set firstDependencies = DependencyExtractor.extractNames(firstExpression, analysis.getColumnReferences()); - Set secondDependencies = DependencyExtractor.extractNames(secondExpression, analysis.getColumnReferences()); + Set firstDependencies = SymbolsExtractor.extractNames(firstExpression, analysis.getColumnReferences()); + Set secondDependencies = SymbolsExtractor.extractNames(secondExpression, analysis.getColumnReferences()); if (firstDependencies.stream().allMatch(left::canResolve) && secondDependencies.stream().allMatch(right::canResolve)) { leftComparisonExpressions.add(firstExpression); @@ -364,6 +366,43 @@ else if (firstDependencies.stream().allMatch(right::canResolve) && secondDepende return new RelationPlan(root, analysis.getScope(node), outputSymbols); } + private Optional getUnnest(Relation relation) + { + if (relation instanceof AliasedRelation) { + return getUnnest(((AliasedRelation) relation).getRelation()); + } + if (relation instanceof Unnest) { + return Optional.of((Unnest) relation); + } + return Optional.empty(); + } + + private Optional getLateral(Relation relation) + { + if (relation instanceof AliasedRelation) { + return getLateral(((AliasedRelation) relation).getRelation()); + } + if (relation instanceof Lateral) { + return Optional.of((Lateral) relation); + } + return Optional.empty(); + } + + private RelationPlan planLateralJoin(Join join, RelationPlan leftPlan, Lateral lateral) + { + RelationPlan rightPlan = process(lateral.getQuery(), null); + PlanBuilder leftPlanBuilder = initializePlanBuilder(leftPlan); + PlanBuilder rightPlanBuilder = initializePlanBuilder(rightPlan); + + PlanBuilder planBuilder = subqueryPlanner.appendLateralJoin(leftPlanBuilder, rightPlanBuilder, lateral.getQuery(), true); + + List outputSymbols = ImmutableList.builder() + .addAll(leftPlan.getRoot().getOutputSymbols()) + .addAll(rightPlan.getRoot().getOutputSymbols()) + .build(); + return new RelationPlan(planBuilder.getRoot(), analysis.getScope(join), outputSymbols); + } + private static boolean isEqualComparisonExpression(Expression conjunct) { return conjunct instanceof ComparisonExpression && ((ComparisonExpression) conjunct).getType() == ComparisonExpressionType.EQUAL; diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/SubqueryPlanner.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/SubqueryPlanner.java index bed6ef3bf76f..5f9efd4cda0a 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/SubqueryPlanner.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/SubqueryPlanner.java @@ -16,7 +16,6 @@ import com.facebook.presto.Session; import com.facebook.presto.metadata.Metadata; import com.facebook.presto.sql.analyzer.Analysis; -import com.facebook.presto.sql.planner.optimizations.Predicates; import com.facebook.presto.sql.planner.plan.AggregationNode; import com.facebook.presto.sql.planner.plan.ApplyNode; import com.facebook.presto.sql.planner.plan.Assignments; @@ -41,8 +40,10 @@ import com.facebook.presto.sql.tree.NotExpression; import com.facebook.presto.sql.tree.QuantifiedComparisonExpression; import com.facebook.presto.sql.tree.QuantifiedComparisonExpression.Quantifier; +import com.facebook.presto.sql.tree.Query; import com.facebook.presto.sql.tree.SubqueryExpression; import com.facebook.presto.sql.tree.SymbolReference; +import com.facebook.presto.util.MorePredicates; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; @@ -230,12 +231,16 @@ private PlanBuilder appendScalarSubqueryApplyNode(PlanBuilder subPlan, SubqueryE subPlan.getTranslations().put(coercion, coercionSymbol); } + return appendLateralJoin(subPlan, subqueryPlan, scalarSubquery.getQuery(), correlationAllowed); + } + + public PlanBuilder appendLateralJoin(PlanBuilder subPlan, PlanBuilder subqueryPlan, Query query, boolean correlationAllowed) + { PlanNode subqueryNode = subqueryPlan.getRoot(); Map correlation = extractCorrelation(subPlan, subqueryNode); if (!correlationAllowed && !correlation.isEmpty()) { - throw notSupportedException(scalarSubquery.getQuery(), "Correlated subquery in given context"); + throw notSupportedException(query, "Correlated subquery in given context"); } - subPlan = subPlan.appendProjections(correlation.keySet(), symbolAllocator, idAllocator); subqueryNode = replaceExpressionsWithSymbols(subqueryNode, correlation); return new PlanBuilder( @@ -244,7 +249,7 @@ private PlanBuilder appendScalarSubqueryApplyNode(PlanBuilder subPlan, SubqueryE idAllocator.getNextId(), subPlan.getRoot(), subqueryNode, - ImmutableList.copyOf(DependencyExtractor.extractUnique(correlation.values())), + ImmutableList.copyOf(SymbolsExtractor.extractUnique(correlation.values())), LateralJoinNode.Type.INNER), analysis.getParameters()); } @@ -394,7 +399,7 @@ private PlanBuilder planQuantifiedApplyNode(PlanBuilder subPlan, QuantifiedCompa private static boolean isAggregationWithEmptyGroupBy(PlanNode planNode) { return searchFrom(planNode) - .skipOnlyWhen(Predicates.isInstanceOfAny(ProjectNode.class)) + .recurseOnlyWhen(MorePredicates.isInstanceOfAny(ProjectNode.class)) .where(AggregationNode.class::isInstance) .findFirst() .map(AggregationNode.class::cast) @@ -442,7 +447,7 @@ private PlanBuilder appendApplyNode( root, subqueryNode, subqueryAssignments, - ImmutableList.copyOf(DependencyExtractor.extractUnique(correlation.values()))), + ImmutableList.copyOf(SymbolsExtractor.extractUnique(correlation.values()))), analysis.getParameters()); } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/SymbolAllocator.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/SymbolAllocator.java index 6f4a6554ec76..6d25ebb48558 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/SymbolAllocator.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/SymbolAllocator.java @@ -26,6 +26,7 @@ import java.util.HashMap; import java.util.Map; +import static com.google.common.base.Preconditions.checkArgument; import static java.util.Locale.ENGLISH; import static java.util.Objects.requireNonNull; @@ -44,6 +45,12 @@ public SymbolAllocator(Map initial) symbols = new HashMap<>(initial); } + public Symbol newSymbol(Symbol symbolHint) + { + checkArgument(symbols.containsKey(symbolHint), "symbolHint not in symbols map"); + return newSymbol(symbolHint.getName(), symbols.get(symbolHint)); + } + public Symbol newSymbol(String nameHint, Type type) { return newSymbol(nameHint, type, null); diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/DependencyExtractor.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/SymbolsExtractor.java similarity index 92% rename from presto-main/src/main/java/com/facebook/presto/sql/planner/DependencyExtractor.java rename to presto-main/src/main/java/com/facebook/presto/sql/planner/SymbolsExtractor.java index f75722671f63..114b30881d1f 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/DependencyExtractor.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/SymbolsExtractor.java @@ -33,9 +33,9 @@ import static com.facebook.presto.sql.planner.ExpressionExtractor.extractExpressionsNonRecursive; import static java.util.Objects.requireNonNull; -public final class DependencyExtractor +public final class SymbolsExtractor { - private DependencyExtractor() {} + private SymbolsExtractor() {} public static Set extractUnique(PlanNode node) { @@ -45,9 +45,6 @@ public static Set extractUnique(PlanNode node) return uniqueSymbols.build(); } - // TODO: this is a temporary hack. We need to clarify the semantics of extractUniqueNonRecursive and extractUnique. - // The notion of extracting dependencies recursively is odd, since the dependencies to a plan node are the inputs - // to its expressions. We need to figure out if the distinction between the two functions is needed at all. public static Set extractUniqueNonRecursive(PlanNode node) { ImmutableSet.Builder uniqueSymbols = ImmutableSet.builder(); diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/IterativeOptimizer.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/IterativeOptimizer.java index 612836aee46d..7ca7963ac34e 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/IterativeOptimizer.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/IterativeOptimizer.java @@ -15,6 +15,7 @@ import com.facebook.presto.Session; import com.facebook.presto.SystemSessionProperties; +import com.facebook.presto.matching.MatchingEngine; import com.facebook.presto.spi.PrestoException; import com.facebook.presto.spi.type.Type; import com.facebook.presto.sql.planner.PlanNodeIdAllocator; @@ -41,7 +42,7 @@ public class IterativeOptimizer implements PlanOptimizer { private final List legacyRules; - private final RuleStore ruleStore; + private final MatchingEngine ruleStore; private final StatsRecorder stats; public IterativeOptimizer(StatsRecorder stats, Set rules) @@ -52,7 +53,7 @@ public IterativeOptimizer(StatsRecorder stats, Set rules) public IterativeOptimizer(StatsRecorder stats, List legacyRules, Set newRules) { this.legacyRules = ImmutableList.copyOf(legacyRules); - this.ruleStore = RuleStore.builder() + this.ruleStore = MatchingEngine.builder() .register(newRules) .build(); @@ -127,7 +128,7 @@ private boolean exploreNode(int group, Context context) long duration; try { long start = System.nanoTime(); - transformed = rule.apply(node, context.getLookup(), context.getIdAllocator(), context.getSymbolAllocator(), context.getSession()); + transformed = rule.apply(node, context); duration = System.nanoTime() - start; } catch (RuntimeException e) { @@ -169,7 +170,7 @@ private boolean exploreChildren(int group, Context context) return progress; } - private static class Context + private static class Context implements Rule.Context { private final Memo memo; private final Lookup lookup; @@ -204,16 +205,19 @@ public Memo getMemo() return memo; } + @Override public Lookup getLookup() { return lookup; } + @Override public PlanNodeIdAllocator getIdAllocator() { return idAllocator; } + @Override public SymbolAllocator getSymbolAllocator() { return symbolAllocator; @@ -229,6 +233,7 @@ public long getTimeoutInMilliseconds() return timeoutInMilliseconds; } + @Override public Session getSession() { return session; diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/Rule.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/Rule.java index 4d2776bf068b..31d6d189bd0d 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/Rule.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/Rule.java @@ -14,23 +14,34 @@ package com.facebook.presto.sql.planner.iterative; import com.facebook.presto.Session; +import com.facebook.presto.matching.Matchable; +import com.facebook.presto.matching.Pattern; import com.facebook.presto.sql.planner.PlanNodeIdAllocator; import com.facebook.presto.sql.planner.SymbolAllocator; import com.facebook.presto.sql.planner.plan.PlanNode; import java.util.Optional; -public interface Rule +public interface Rule extends Matchable { /** * Returns a pattern to which plan nodes this rule applies. - * Notice that rule may be still invoked for plan nodes which given pattern does not apply, - * then rule should return Optional.empty() in such case */ default Pattern getPattern() { return Pattern.any(); } - Optional apply(PlanNode node, Lookup lookup, PlanNodeIdAllocator idAllocator, SymbolAllocator symbolAllocator, Session session); + Optional apply(PlanNode node, Context context); + + interface Context + { + Lookup getLookup(); + + PlanNodeIdAllocator getIdAllocator(); + + SymbolAllocator getSymbolAllocator(); + + Session getSession(); + } } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/RuleStore.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/RuleStore.java deleted file mode 100644 index f190a5656092..000000000000 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/RuleStore.java +++ /dev/null @@ -1,94 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.facebook.presto.sql.planner.iterative; - -import com.facebook.presto.sql.planner.plan.PlanNode; -import com.google.common.collect.AbstractIterator; -import com.google.common.collect.ImmutableListMultimap; -import com.google.common.collect.ListMultimap; -import com.google.common.collect.Streams; - -import java.util.Iterator; -import java.util.Set; -import java.util.stream.Stream; - -public class RuleStore -{ - private final ListMultimap, Rule> rulesByClass; - - private RuleStore(ListMultimap, Rule> rulesByClass) - { - this.rulesByClass = ImmutableListMultimap.copyOf(rulesByClass); - } - - public Stream getCandidates(PlanNode planNode) - { - return Streams.stream(ancestors(planNode.getClass())) - .flatMap(clazz -> rulesByClass.get(clazz).stream()); - } - - private static Iterator> ancestors(Class planNodeClass) - { - return new AbstractIterator>() { - private Class current = planNodeClass; - - @Override - protected Class computeNext() - { - if (!PlanNode.class.isAssignableFrom(current)) { - return endOfData(); - } - - Class result = (Class) current; - current = current.getSuperclass(); - - return result; - } - }; - } - - public static Builder builder() - { - return new Builder(); - } - - public static class Builder - { - private final ImmutableListMultimap.Builder, Rule> rulesByClass = ImmutableListMultimap.builder(); - - public Builder register(Set newRules) - { - newRules.forEach(this::register); - return this; - } - - public Builder register(Rule newRule) - { - Pattern pattern = newRule.getPattern(); - if (pattern instanceof Pattern.MatchNodeClass) { - rulesByClass.put(((Pattern.MatchNodeClass) pattern).getNodeClass(), newRule); - } - else { - throw new IllegalArgumentException("Unexpected Pattern: " + pattern); - } - return this; - } - - public RuleStore build() - { - return new RuleStore(rulesByClass.build()); - } - } -} diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/AddIntermediateAggregations.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/AddIntermediateAggregations.java index 3a3d9c7e4455..9d180fc4a9f1 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/AddIntermediateAggregations.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/AddIntermediateAggregations.java @@ -15,14 +15,13 @@ import com.facebook.presto.Session; import com.facebook.presto.SystemSessionProperties; -import com.facebook.presto.sql.planner.DependencyExtractor; +import com.facebook.presto.matching.Pattern; import com.facebook.presto.sql.planner.Partitioning; import com.facebook.presto.sql.planner.PartitioningScheme; import com.facebook.presto.sql.planner.PlanNodeIdAllocator; import com.facebook.presto.sql.planner.Symbol; -import com.facebook.presto.sql.planner.SymbolAllocator; +import com.facebook.presto.sql.planner.SymbolsExtractor; import com.facebook.presto.sql.planner.iterative.Lookup; -import com.facebook.presto.sql.planner.iterative.Pattern; import com.facebook.presto.sql.planner.iterative.Rule; import com.facebook.presto.sql.planner.plan.AggregationNode; import com.facebook.presto.sql.planner.plan.ExchangeNode; @@ -67,7 +66,7 @@ public class AddIntermediateAggregations implements Rule { - private static final Pattern PATTERN = Pattern.node(AggregationNode.class); + private static final Pattern PATTERN = Pattern.typeOf(AggregationNode.class); @Override public Pattern getPattern() @@ -76,8 +75,12 @@ public Pattern getPattern() } @Override - public Optional apply(PlanNode node, Lookup lookup, PlanNodeIdAllocator idAllocator, SymbolAllocator symbolAllocator, Session session) + public Optional apply(PlanNode node, Context context) { + Lookup lookup = context.getLookup(); + PlanNodeIdAllocator idAllocator = context.getIdAllocator(); + Session session = context.getSession(); + if (!SystemSessionProperties.isEnableIntermediateAggregations(session)) { return Optional.empty(); } @@ -195,7 +198,7 @@ private static Map inputsAsOutputs(Map builder = ImmutableMap.builder(); for (Map.Entry entry : assignments.entrySet()) { // Should only have one input symbol - Symbol input = getOnlyElement(DependencyExtractor.extractAll(entry.getValue().getCall())); + Symbol input = getOnlyElement(SymbolsExtractor.extractAll(entry.getValue().getCall())); builder.put(input, entry.getValue()); } return builder.build(); diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/CanonicalizeFilterExpressions.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/CanonicalizeFilterExpressions.java new file mode 100644 index 000000000000..6081bf820a45 --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/CanonicalizeFilterExpressions.java @@ -0,0 +1,50 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.iterative.rule; + +import com.facebook.presto.matching.Pattern; +import com.facebook.presto.sql.planner.iterative.Rule; +import com.facebook.presto.sql.planner.plan.FilterNode; +import com.facebook.presto.sql.planner.plan.PlanNode; +import com.facebook.presto.sql.tree.Expression; + +import java.util.Optional; + +import static com.facebook.presto.sql.planner.optimizations.CanonicalizeExpressions.canonicalizeExpression; + +public class CanonicalizeFilterExpressions + implements Rule +{ + private static final Pattern PATTERN = Pattern.typeOf(FilterNode.class); + + @Override + public Pattern getPattern() + { + return PATTERN; + } + + @Override + public Optional apply(PlanNode node, Context context) + { + FilterNode filterNode = (FilterNode) node; + + Expression canonicalized = canonicalizeExpression(filterNode.getPredicate()); + + if (canonicalized.equals(filterNode.getPredicate())) { + return Optional.empty(); + } + + return Optional.of(new FilterNode(node.getId(), filterNode.getSource(), canonicalized)); + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/CanonicalizeJoinExpressions.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/CanonicalizeJoinExpressions.java new file mode 100644 index 000000000000..b201382ef1bc --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/CanonicalizeJoinExpressions.java @@ -0,0 +1,65 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.iterative.rule; + +import com.facebook.presto.matching.Pattern; +import com.facebook.presto.sql.planner.iterative.Rule; +import com.facebook.presto.sql.planner.plan.JoinNode; +import com.facebook.presto.sql.planner.plan.PlanNode; +import com.facebook.presto.sql.tree.Expression; + +import java.util.Optional; + +import static com.facebook.presto.sql.planner.optimizations.CanonicalizeExpressions.canonicalizeExpression; + +public class CanonicalizeJoinExpressions + implements Rule +{ + private static final Pattern PATTERN = Pattern.typeOf(JoinNode.class); + + @Override + public Pattern getPattern() + { + return PATTERN; + } + + @Override + public Optional apply(PlanNode node, Context context) + { + JoinNode joinNode = (JoinNode) node; + + if (!joinNode.getFilter().isPresent()) { + return Optional.empty(); + } + + Expression canonicalizedExpression = canonicalizeExpression(joinNode.getFilter().get()); + if (canonicalizedExpression.equals(joinNode.getFilter().get())) { + return Optional.empty(); + } + + JoinNode replacement = new JoinNode( + joinNode.getId(), + joinNode.getType(), + joinNode.getLeft(), + joinNode.getRight(), + joinNode.getCriteria(), + joinNode.getOutputSymbols(), + Optional.of(canonicalizedExpression), + joinNode.getLeftHashSymbol(), + joinNode.getRightHashSymbol(), + joinNode.getDistributionType()); + + return Optional.of(replacement); + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/CanonicalizeProjectExpressions.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/CanonicalizeProjectExpressions.java new file mode 100644 index 000000000000..8496b035e780 --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/CanonicalizeProjectExpressions.java @@ -0,0 +1,52 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.iterative.rule; + +import com.facebook.presto.matching.Pattern; +import com.facebook.presto.sql.planner.iterative.Rule; +import com.facebook.presto.sql.planner.optimizations.CanonicalizeExpressions; +import com.facebook.presto.sql.planner.plan.Assignments; +import com.facebook.presto.sql.planner.plan.PlanNode; +import com.facebook.presto.sql.planner.plan.ProjectNode; + +import java.util.Optional; + +public class CanonicalizeProjectExpressions + implements Rule +{ + private static final Pattern PATTERN = Pattern.typeOf(ProjectNode.class); + + @Override + public Pattern getPattern() + { + return PATTERN; + } + + @Override + public Optional apply(PlanNode node, Context context) + { + ProjectNode projectNode = (ProjectNode) node; + + Assignments assignments = projectNode.getAssignments() + .rewrite(CanonicalizeExpressions::canonicalizeExpression); + + if (assignments.equals(projectNode.getAssignments())) { + return Optional.empty(); + } + + PlanNode replacement = new ProjectNode(node.getId(), projectNode.getSource(), assignments); + + return Optional.of(replacement); + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/CanonicalizeTableScanExpressions.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/CanonicalizeTableScanExpressions.java new file mode 100644 index 000000000000..e49eeda3d4fd --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/CanonicalizeTableScanExpressions.java @@ -0,0 +1,63 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.iterative.rule; + +import com.facebook.presto.matching.Pattern; +import com.facebook.presto.sql.planner.iterative.Rule; +import com.facebook.presto.sql.planner.plan.PlanNode; +import com.facebook.presto.sql.planner.plan.TableScanNode; +import com.facebook.presto.sql.tree.Expression; + +import java.util.Objects; +import java.util.Optional; + +import static com.facebook.presto.sql.planner.optimizations.CanonicalizeExpressions.canonicalizeExpression; + +public class CanonicalizeTableScanExpressions + implements Rule +{ + private static final Pattern PATTERN = Pattern.typeOf(TableScanNode.class); + + @Override + public Pattern getPattern() + { + return PATTERN; + } + + @Override + public Optional apply(PlanNode node, Context context) + { + TableScanNode tableScanNode = (TableScanNode) node; + + Expression originalConstraint = null; + if (tableScanNode.getOriginalConstraint() != null) { + originalConstraint = canonicalizeExpression(tableScanNode.getOriginalConstraint()); + } + + if (Objects.equals(tableScanNode.getOriginalConstraint(), originalConstraint)) { + return Optional.empty(); + } + + TableScanNode replacement = new TableScanNode( + tableScanNode.getId(), + tableScanNode.getTable(), + tableScanNode.getOutputSymbols(), + tableScanNode.getAssignments(), + tableScanNode.getLayout(), + tableScanNode.getCurrentConstraint(), + originalConstraint); + + return Optional.of(replacement); + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/CreatePartialTopN.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/CreatePartialTopN.java index 82364de8c05c..4c4e122033cc 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/CreatePartialTopN.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/CreatePartialTopN.java @@ -13,11 +13,7 @@ */ package com.facebook.presto.sql.planner.iterative.rule; -import com.facebook.presto.Session; -import com.facebook.presto.sql.planner.PlanNodeIdAllocator; -import com.facebook.presto.sql.planner.SymbolAllocator; -import com.facebook.presto.sql.planner.iterative.Lookup; -import com.facebook.presto.sql.planner.iterative.Pattern; +import com.facebook.presto.matching.Pattern; import com.facebook.presto.sql.planner.iterative.Rule; import com.facebook.presto.sql.planner.plan.PlanNode; import com.facebook.presto.sql.planner.plan.TopNNode; @@ -31,7 +27,7 @@ public class CreatePartialTopN implements Rule { - private static final Pattern PATTERN = Pattern.node(TopNNode.class); + private static final Pattern PATTERN = Pattern.typeOf(TopNNode.class); @Override public Pattern getPattern() @@ -40,7 +36,7 @@ public Pattern getPattern() } @Override - public Optional apply(PlanNode node, Lookup lookup, PlanNodeIdAllocator idAllocator, SymbolAllocator symbolAllocator, Session session) + public Optional apply(PlanNode node, Context context) { if (!(node instanceof TopNNode)) { return Optional.empty(); @@ -52,10 +48,10 @@ public Optional apply(PlanNode node, Lookup lookup, PlanNodeIdAllocato return Optional.empty(); } - PlanNode source = lookup.resolve(single.getSource()); + PlanNode source = context.getLookup().resolve(single.getSource()); TopNNode partial = new TopNNode( - idAllocator.getNextId(), + context.getIdAllocator().getNextId(), source, single.getCount(), single.getOrderBy(), @@ -63,7 +59,7 @@ public Optional apply(PlanNode node, Lookup lookup, PlanNodeIdAllocato PARTIAL); return Optional.of(new TopNNode( - idAllocator.getNextId(), + context.getIdAllocator().getNextId(), partial, single.getCount(), single.getOrderBy(), diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/EliminateCrossJoins.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/EliminateCrossJoins.java index 10a44b0bc928..467feca97f23 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/EliminateCrossJoins.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/EliminateCrossJoins.java @@ -13,13 +13,10 @@ */ package com.facebook.presto.sql.planner.iterative.rule; -import com.facebook.presto.Session; import com.facebook.presto.SystemSessionProperties; +import com.facebook.presto.matching.Pattern; import com.facebook.presto.sql.planner.PlanNodeIdAllocator; import com.facebook.presto.sql.planner.Symbol; -import com.facebook.presto.sql.planner.SymbolAllocator; -import com.facebook.presto.sql.planner.iterative.Lookup; -import com.facebook.presto.sql.planner.iterative.Pattern; import com.facebook.presto.sql.planner.iterative.Rule; import com.facebook.presto.sql.planner.optimizations.joins.JoinGraph; import com.facebook.presto.sql.planner.plan.Assignments; @@ -30,8 +27,8 @@ import com.facebook.presto.sql.planner.plan.ProjectNode; import com.facebook.presto.sql.tree.Expression; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableSet; -import java.util.Comparator; import java.util.HashMap; import java.util.HashSet; import java.util.List; @@ -40,15 +37,17 @@ import java.util.PriorityQueue; import java.util.Set; +import static com.facebook.presto.sql.planner.iterative.rule.Util.restrictOutputs; import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkState; import static com.google.common.collect.ImmutableList.toImmutableList; +import static java.util.Comparator.comparing; import static java.util.Objects.requireNonNull; public class EliminateCrossJoins implements Rule { - private static final Pattern PATTERN = Pattern.node(JoinNode.class); + private static final Pattern PATTERN = Pattern.typeOf(JoinNode.class); @Override public Pattern getPattern() @@ -57,17 +56,17 @@ public Pattern getPattern() } @Override - public Optional apply(PlanNode node, Lookup lookup, PlanNodeIdAllocator idAllocator, SymbolAllocator symbolAllocator, Session session) + public Optional apply(PlanNode node, Context context) { if (!(node instanceof JoinNode)) { return Optional.empty(); } - if (!SystemSessionProperties.isJoinReorderingEnabled(session)) { + if (!SystemSessionProperties.isJoinReorderingEnabled(context.getSession())) { return Optional.empty(); } - JoinGraph joinGraph = JoinGraph.buildShallowFrom(node, lookup); + JoinGraph joinGraph = JoinGraph.buildShallowFrom(node, context.getLookup()); if (joinGraph.size() < 3) { return Optional.empty(); } @@ -77,7 +76,7 @@ public Optional apply(PlanNode node, Lookup lookup, PlanNodeIdAllocato return Optional.empty(); } - PlanNode replacement = buildJoinTree(node.getOutputSymbols(), joinGraph, joinOrder, idAllocator); + PlanNode replacement = buildJoinTree(node.getOutputSymbols(), joinGraph, joinOrder, context.getIdAllocator()); return Optional.of(replacement); } @@ -109,7 +108,7 @@ public static List getJoinOrder(JoinGraph graph) PriorityQueue nodesToVisit = new PriorityQueue<>( graph.size(), - (Comparator) (node1, node2) -> priorities.get(node1.getId()).compareTo(priorities.get(node2.getId()))); + comparing(node -> priorities.get(node.getId()))); Set visited = new HashSet<>(); nodesToVisit.add(graph.getNode(0)); @@ -200,15 +199,8 @@ public static PlanNode buildJoinTree(List expectedOutputSymbols, JoinGra Assignments.copyOf(graph.getAssignments().get())); } - if (!result.getOutputSymbols().equals(expectedOutputSymbols)) { - // Introduce a projection to constrain the outputs to what was originally expected - // Some nodes are sensitive to what's produced (e.g., DistinctLimit node) - result = new ProjectNode( - idAllocator.getNextId(), - result, - Assignments.identity(expectedOutputSymbols)); - } - - return result; + // If needed, introduce a projection to constrain the outputs to what was originally expected + // Some nodes are sensitive to what's produced (e.g., DistinctLimit node) + return restrictOutputs(idAllocator, result, ImmutableSet.copyOf(expectedOutputSymbols)).orElse(result); } } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/EvaluateZeroLimit.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/EvaluateZeroLimit.java index b379303dfe77..0340f9635e05 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/EvaluateZeroLimit.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/EvaluateZeroLimit.java @@ -13,11 +13,7 @@ */ package com.facebook.presto.sql.planner.iterative.rule; -import com.facebook.presto.Session; -import com.facebook.presto.sql.planner.PlanNodeIdAllocator; -import com.facebook.presto.sql.planner.SymbolAllocator; -import com.facebook.presto.sql.planner.iterative.Lookup; -import com.facebook.presto.sql.planner.iterative.Pattern; +import com.facebook.presto.matching.Pattern; import com.facebook.presto.sql.planner.iterative.Rule; import com.facebook.presto.sql.planner.plan.LimitNode; import com.facebook.presto.sql.planner.plan.PlanNode; @@ -29,7 +25,7 @@ public class EvaluateZeroLimit implements Rule { - private static final Pattern PATTERN = Pattern.node(LimitNode.class); + private static final Pattern PATTERN = Pattern.typeOf(LimitNode.class); @Override public Pattern getPattern() @@ -38,7 +34,7 @@ public Pattern getPattern() } @Override - public Optional apply(PlanNode node, Lookup lookup, PlanNodeIdAllocator idAllocator, SymbolAllocator symbolAllocator, Session session) + public Optional apply(PlanNode node, Context context) { LimitNode limit = (LimitNode) node; diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/EvaluateZeroSample.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/EvaluateZeroSample.java index ecf5b95424fc..898c10a6a969 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/EvaluateZeroSample.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/EvaluateZeroSample.java @@ -13,11 +13,7 @@ */ package com.facebook.presto.sql.planner.iterative.rule; -import com.facebook.presto.Session; -import com.facebook.presto.sql.planner.PlanNodeIdAllocator; -import com.facebook.presto.sql.planner.SymbolAllocator; -import com.facebook.presto.sql.planner.iterative.Lookup; -import com.facebook.presto.sql.planner.iterative.Pattern; +import com.facebook.presto.matching.Pattern; import com.facebook.presto.sql.planner.iterative.Rule; import com.facebook.presto.sql.planner.plan.PlanNode; import com.facebook.presto.sql.planner.plan.SampleNode; @@ -32,7 +28,7 @@ public class EvaluateZeroSample implements Rule { - private static final Pattern PATTERN = Pattern.node(SampleNode.class); + private static final Pattern PATTERN = Pattern.typeOf(SampleNode.class); @Override public Pattern getPattern() @@ -41,7 +37,7 @@ public Pattern getPattern() } @Override - public Optional apply(PlanNode node, Lookup lookup, PlanNodeIdAllocator idAllocator, SymbolAllocator symbolAllocator, Session session) + public Optional apply(PlanNode node, Context context) { SampleNode sample = (SampleNode) node; diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/ImplementBernoulliSampleAsFilter.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/ImplementBernoulliSampleAsFilter.java index 9705352c30c1..f4e065c3246d 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/ImplementBernoulliSampleAsFilter.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/ImplementBernoulliSampleAsFilter.java @@ -13,11 +13,7 @@ */ package com.facebook.presto.sql.planner.iterative.rule; -import com.facebook.presto.Session; -import com.facebook.presto.sql.planner.PlanNodeIdAllocator; -import com.facebook.presto.sql.planner.SymbolAllocator; -import com.facebook.presto.sql.planner.iterative.Lookup; -import com.facebook.presto.sql.planner.iterative.Pattern; +import com.facebook.presto.matching.Pattern; import com.facebook.presto.sql.planner.iterative.Rule; import com.facebook.presto.sql.planner.plan.FilterNode; import com.facebook.presto.sql.planner.plan.PlanNode; @@ -25,7 +21,6 @@ import com.facebook.presto.sql.tree.ComparisonExpression; import com.facebook.presto.sql.tree.ComparisonExpressionType; import com.facebook.presto.sql.tree.DoubleLiteral; -import com.facebook.presto.sql.tree.Expression; import com.facebook.presto.sql.tree.FunctionCall; import com.facebook.presto.sql.tree.QualifiedName; import com.google.common.collect.ImmutableList; @@ -50,7 +45,7 @@ public class ImplementBernoulliSampleAsFilter implements Rule { - private static final Pattern PATTERN = Pattern.node(SampleNode.class); + private static final Pattern PATTERN = Pattern.typeOf(SampleNode.class); @Override public Pattern getPattern() @@ -59,7 +54,7 @@ public Pattern getPattern() } @Override - public Optional apply(PlanNode node, Lookup lookup, PlanNodeIdAllocator idAllocator, SymbolAllocator symbolAllocator, Session session) + public Optional apply(PlanNode node, Context context) { SampleNode sample = (SampleNode) node; @@ -72,7 +67,7 @@ public Optional apply(PlanNode node, Lookup lookup, PlanNodeIdAllocato sample.getSource(), new ComparisonExpression( ComparisonExpressionType.LESS_THAN, - new FunctionCall(QualifiedName.of("rand"), ImmutableList.of()), + new FunctionCall(QualifiedName.of("rand"), ImmutableList.of()), new DoubleLiteral(Double.toString(sample.getSampleRatio()))))); } } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/ImplementFilteredAggregations.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/ImplementFilteredAggregations.java index e1e54119393c..661eb58fa8f3 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/ImplementFilteredAggregations.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/ImplementFilteredAggregations.java @@ -13,12 +13,8 @@ */ package com.facebook.presto.sql.planner.iterative.rule; -import com.facebook.presto.Session; -import com.facebook.presto.sql.planner.PlanNodeIdAllocator; +import com.facebook.presto.matching.Pattern; import com.facebook.presto.sql.planner.Symbol; -import com.facebook.presto.sql.planner.SymbolAllocator; -import com.facebook.presto.sql.planner.iterative.Lookup; -import com.facebook.presto.sql.planner.iterative.Pattern; import com.facebook.presto.sql.planner.iterative.Rule; import com.facebook.presto.sql.planner.plan.AggregationNode; import com.facebook.presto.sql.planner.plan.AggregationNode.Aggregation; @@ -57,7 +53,7 @@ public class ImplementFilteredAggregations implements Rule { - private static final Pattern PATTERN = Pattern.node(AggregationNode.class); + private static final Pattern PATTERN = Pattern.typeOf(AggregationNode.class); @Override public Pattern getPattern() @@ -66,7 +62,7 @@ public Pattern getPattern() } @Override - public Optional apply(PlanNode node, Lookup lookup, PlanNodeIdAllocator idAllocator, SymbolAllocator symbolAllocator, Session session) + public Optional apply(PlanNode node, Context context) { AggregationNode aggregation = (AggregationNode) node; @@ -91,7 +87,7 @@ public Optional apply(PlanNode node, Lookup lookup, PlanNodeIdAllocato if (call.getFilter().isPresent()) { Expression filter = call.getFilter().get(); - Symbol symbol = symbolAllocator.newSymbol(filter, BOOLEAN); + Symbol symbol = context.getSymbolAllocator().newSymbol(filter, BOOLEAN); newAssignments.put(symbol, filter); mask = Optional.of(symbol); } @@ -106,9 +102,9 @@ public Optional apply(PlanNode node, Lookup lookup, PlanNodeIdAllocato return Optional.of( new AggregationNode( - idAllocator.getNextId(), + context.getIdAllocator().getNextId(), new ProjectNode( - idAllocator.getNextId(), + context.getIdAllocator().getNextId(), aggregation.getSource(), newAssignments.build()), aggregations.build(), diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/InlineProjections.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/InlineProjections.java index d9207137a04f..2fd6206cd838 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/InlineProjections.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/InlineProjections.java @@ -13,14 +13,10 @@ */ package com.facebook.presto.sql.planner.iterative.rule; -import com.facebook.presto.Session; -import com.facebook.presto.sql.planner.DependencyExtractor; +import com.facebook.presto.matching.Pattern; import com.facebook.presto.sql.planner.ExpressionSymbolInliner; -import com.facebook.presto.sql.planner.PlanNodeIdAllocator; import com.facebook.presto.sql.planner.Symbol; -import com.facebook.presto.sql.planner.SymbolAllocator; -import com.facebook.presto.sql.planner.iterative.Lookup; -import com.facebook.presto.sql.planner.iterative.Pattern; +import com.facebook.presto.sql.planner.SymbolsExtractor; import com.facebook.presto.sql.planner.iterative.Rule; import com.facebook.presto.sql.planner.plan.Assignments; import com.facebook.presto.sql.planner.plan.PlanNode; @@ -29,6 +25,7 @@ import com.facebook.presto.sql.tree.Literal; import com.facebook.presto.sql.tree.TryExpression; import com.facebook.presto.sql.util.AstUtils; +import com.google.common.collect.ImmutableSet; import com.google.common.collect.Sets; import java.util.Map; @@ -48,7 +45,7 @@ public class InlineProjections implements Rule { - private static final Pattern PATTERN = Pattern.node(ProjectNode.class); + private static final Pattern PATTERN = Pattern.typeOf(ProjectNode.class); @Override public Pattern getPattern() @@ -57,11 +54,11 @@ public Pattern getPattern() } @Override - public Optional apply(PlanNode node, Lookup lookup, PlanNodeIdAllocator idAllocator, SymbolAllocator symbolAllocator, Session session) + public Optional apply(PlanNode node, Context context) { ProjectNode parent = (ProjectNode) node; - PlanNode source = lookup.resolve(parent.getSource()); + PlanNode source = context.getLookup().resolve(parent.getSource()); if (!(source instanceof ProjectNode)) { return Optional.empty(); } @@ -89,7 +86,7 @@ public Optional apply(PlanNode node, Lookup lookup, PlanNodeIdAllocato .entrySet().stream() .filter(entry -> targets.contains(entry.getKey())) .map(Map.Entry::getValue) - .flatMap(entry -> DependencyExtractor.extractAll(entry).stream()) + .flatMap(entry -> SymbolsExtractor.extractAll(entry).stream()) .collect(toSet()); Assignments.Builder childAssignments = Assignments.builder(); @@ -134,10 +131,14 @@ private Sets.SetView extractInliningTargets(ProjectNode parent, ProjectN // a. are not inputs to try() expressions // b. appear only once across all expressions // c. are not identity projections + // which come from the child, as opposed to an enclosing scope. + + Set childOutputSet = ImmutableSet.copyOf(child.getOutputSymbols()); Map dependencies = parent.getAssignments() .getExpressions().stream() - .flatMap(expression -> DependencyExtractor.extractAll(expression).stream()) + .flatMap(expression -> SymbolsExtractor.extractAll(expression).stream()) + .filter(childOutputSet::contains) .collect(Collectors.groupingBy(Function.identity(), Collectors.counting())); // find references to simple constants @@ -167,7 +168,7 @@ private Set extractTryArguments(Expression expression) return AstUtils.preOrder(expression) .filter(TryExpression.class::isInstance) .map(TryExpression.class::cast) - .flatMap(tryExpression -> DependencyExtractor.extractAll(tryExpression).stream()) + .flatMap(tryExpression -> SymbolsExtractor.extractAll(tryExpression).stream()) .collect(toSet()); } } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/MergeAdjacentWindows.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/MergeAdjacentWindows.java index 38d3724c176d..e136da0aa583 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/MergeAdjacentWindows.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/MergeAdjacentWindows.java @@ -13,12 +13,8 @@ */ package com.facebook.presto.sql.planner.iterative.rule; -import com.facebook.presto.Session; -import com.facebook.presto.sql.planner.PlanNodeIdAllocator; +import com.facebook.presto.matching.Pattern; import com.facebook.presto.sql.planner.Symbol; -import com.facebook.presto.sql.planner.SymbolAllocator; -import com.facebook.presto.sql.planner.iterative.Lookup; -import com.facebook.presto.sql.planner.iterative.Pattern; import com.facebook.presto.sql.planner.iterative.Rule; import com.facebook.presto.sql.planner.plan.PlanNode; import com.facebook.presto.sql.planner.plan.WindowNode; @@ -31,7 +27,7 @@ public class MergeAdjacentWindows implements Rule { - private static final Pattern PATTERN = Pattern.node(WindowNode.class); + private static final Pattern PATTERN = Pattern.typeOf(WindowNode.class); @Override public Pattern getPattern() @@ -40,7 +36,7 @@ public Pattern getPattern() } @Override - public Optional apply(PlanNode node, Lookup lookup, PlanNodeIdAllocator idAllocator, SymbolAllocator symbolAllocator, Session session) + public Optional apply(PlanNode node, Context context) { if (!(node instanceof WindowNode)) { return Optional.empty(); @@ -48,7 +44,7 @@ public Optional apply(PlanNode node, Lookup lookup, PlanNodeIdAllocato WindowNode parent = (WindowNode) node; - PlanNode source = lookup.resolve(parent.getSource()); + PlanNode source = context.getLookup().resolve(parent.getSource()); if (!(source instanceof WindowNode)) { return Optional.empty(); } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/MergeFilters.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/MergeFilters.java index c70cbed39298..0b63d30f8004 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/MergeFilters.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/MergeFilters.java @@ -13,11 +13,7 @@ */ package com.facebook.presto.sql.planner.iterative.rule; -import com.facebook.presto.Session; -import com.facebook.presto.sql.planner.PlanNodeIdAllocator; -import com.facebook.presto.sql.planner.SymbolAllocator; -import com.facebook.presto.sql.planner.iterative.Lookup; -import com.facebook.presto.sql.planner.iterative.Pattern; +import com.facebook.presto.matching.Pattern; import com.facebook.presto.sql.planner.iterative.Rule; import com.facebook.presto.sql.planner.plan.FilterNode; import com.facebook.presto.sql.planner.plan.PlanNode; @@ -29,7 +25,7 @@ public class MergeFilters implements Rule { - private static final Pattern PATTERN = Pattern.node(FilterNode.class); + private static final Pattern PATTERN = Pattern.typeOf(FilterNode.class); @Override public Pattern getPattern() @@ -38,11 +34,11 @@ public Pattern getPattern() } @Override - public Optional apply(PlanNode node, Lookup lookup, PlanNodeIdAllocator idAllocator, SymbolAllocator symbolAllocator, Session session) + public Optional apply(PlanNode node, Context context) { FilterNode parent = (FilterNode) node; - PlanNode source = lookup.resolve(parent.getSource()); + PlanNode source = context.getLookup().resolve(parent.getSource()); if (!(source instanceof FilterNode)) { return Optional.empty(); } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/MergeLimitWithDistinct.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/MergeLimitWithDistinct.java index 616dd88cdad4..a7f536374503 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/MergeLimitWithDistinct.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/MergeLimitWithDistinct.java @@ -13,11 +13,7 @@ */ package com.facebook.presto.sql.planner.iterative.rule; -import com.facebook.presto.Session; -import com.facebook.presto.sql.planner.PlanNodeIdAllocator; -import com.facebook.presto.sql.planner.SymbolAllocator; -import com.facebook.presto.sql.planner.iterative.Lookup; -import com.facebook.presto.sql.planner.iterative.Pattern; +import com.facebook.presto.matching.Pattern; import com.facebook.presto.sql.planner.iterative.Rule; import com.facebook.presto.sql.planner.plan.AggregationNode; import com.facebook.presto.sql.planner.plan.DistinctLimitNode; @@ -29,7 +25,7 @@ public class MergeLimitWithDistinct implements Rule { - private static final Pattern PATTERN = Pattern.node(LimitNode.class); + private static final Pattern PATTERN = Pattern.typeOf(LimitNode.class); @Override public Pattern getPattern() @@ -38,11 +34,11 @@ public Pattern getPattern() } @Override - public Optional apply(PlanNode node, Lookup lookup, PlanNodeIdAllocator idAllocator, SymbolAllocator symbolAllocator, Session session) + public Optional apply(PlanNode node, Context context) { LimitNode parent = (LimitNode) node; - PlanNode input = lookup.resolve(parent.getSource()); + PlanNode input = context.getLookup().resolve(parent.getSource()); if (!(input instanceof AggregationNode)) { return Optional.empty(); } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/MergeLimitWithSort.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/MergeLimitWithSort.java index f6d466acc564..cff354d15265 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/MergeLimitWithSort.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/MergeLimitWithSort.java @@ -13,11 +13,7 @@ */ package com.facebook.presto.sql.planner.iterative.rule; -import com.facebook.presto.Session; -import com.facebook.presto.sql.planner.PlanNodeIdAllocator; -import com.facebook.presto.sql.planner.SymbolAllocator; -import com.facebook.presto.sql.planner.iterative.Lookup; -import com.facebook.presto.sql.planner.iterative.Pattern; +import com.facebook.presto.matching.Pattern; import com.facebook.presto.sql.planner.iterative.Rule; import com.facebook.presto.sql.planner.plan.LimitNode; import com.facebook.presto.sql.planner.plan.PlanNode; @@ -29,7 +25,7 @@ public class MergeLimitWithSort implements Rule { - private static final Pattern PATTERN = Pattern.node(LimitNode.class); + private static final Pattern PATTERN = Pattern.typeOf(LimitNode.class); @Override public Pattern getPattern() @@ -38,11 +34,11 @@ public Pattern getPattern() } @Override - public Optional apply(PlanNode node, Lookup lookup, PlanNodeIdAllocator idAllocator, SymbolAllocator symbolAllocator, Session session) + public Optional apply(PlanNode node, Context context) { LimitNode parent = (LimitNode) node; - PlanNode source = lookup.resolve(parent.getSource()); + PlanNode source = context.getLookup().resolve(parent.getSource()); if (!(source instanceof SortNode)) { return Optional.empty(); } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/MergeLimitWithTopN.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/MergeLimitWithTopN.java index cf91fbacfd17..9b7689ec0811 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/MergeLimitWithTopN.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/MergeLimitWithTopN.java @@ -13,11 +13,7 @@ */ package com.facebook.presto.sql.planner.iterative.rule; -import com.facebook.presto.Session; -import com.facebook.presto.sql.planner.PlanNodeIdAllocator; -import com.facebook.presto.sql.planner.SymbolAllocator; -import com.facebook.presto.sql.planner.iterative.Lookup; -import com.facebook.presto.sql.planner.iterative.Pattern; +import com.facebook.presto.matching.Pattern; import com.facebook.presto.sql.planner.iterative.Rule; import com.facebook.presto.sql.planner.plan.LimitNode; import com.facebook.presto.sql.planner.plan.PlanNode; @@ -28,7 +24,7 @@ public class MergeLimitWithTopN implements Rule { - private static final Pattern PATTERN = Pattern.node(LimitNode.class); + private static final Pattern PATTERN = Pattern.typeOf(LimitNode.class); @Override public Pattern getPattern() @@ -37,11 +33,11 @@ public Pattern getPattern() } @Override - public Optional apply(PlanNode node, Lookup lookup, PlanNodeIdAllocator idAllocator, SymbolAllocator symbolAllocator, Session session) + public Optional apply(PlanNode node, Context context) { LimitNode parent = (LimitNode) node; - PlanNode source = lookup.resolve(parent.getSource()); + PlanNode source = context.getLookup().resolve(parent.getSource()); if (!(source instanceof TopNNode)) { return Optional.empty(); } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/MergeLimits.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/MergeLimits.java index 06e9ef9ccf65..ec3aeefd9dd2 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/MergeLimits.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/MergeLimits.java @@ -13,11 +13,7 @@ */ package com.facebook.presto.sql.planner.iterative.rule; -import com.facebook.presto.Session; -import com.facebook.presto.sql.planner.PlanNodeIdAllocator; -import com.facebook.presto.sql.planner.SymbolAllocator; -import com.facebook.presto.sql.planner.iterative.Lookup; -import com.facebook.presto.sql.planner.iterative.Pattern; +import com.facebook.presto.matching.Pattern; import com.facebook.presto.sql.planner.iterative.Rule; import com.facebook.presto.sql.planner.plan.LimitNode; import com.facebook.presto.sql.planner.plan.PlanNode; @@ -27,7 +23,7 @@ public class MergeLimits implements Rule { - private static final Pattern PATTERN = Pattern.node(LimitNode.class); + private static final Pattern PATTERN = Pattern.typeOf(LimitNode.class); @Override public Pattern getPattern() @@ -36,11 +32,11 @@ public Pattern getPattern() } @Override - public Optional apply(PlanNode node, Lookup lookup, PlanNodeIdAllocator idAllocator, SymbolAllocator symbolAllocator, Session session) + public Optional apply(PlanNode node, Context context) { LimitNode parent = (LimitNode) node; - PlanNode source = lookup.resolve(parent.getSource()); + PlanNode source = context.getLookup().resolve(parent.getSource()); if (!(source instanceof LimitNode)) { return Optional.empty(); } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/ProjectOffPushDownRule.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/ProjectOffPushDownRule.java new file mode 100644 index 000000000000..6c11773bb922 --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/ProjectOffPushDownRule.java @@ -0,0 +1,69 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.iterative.rule; + +import com.facebook.presto.matching.Pattern; +import com.facebook.presto.sql.planner.PlanNodeIdAllocator; +import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.sql.planner.iterative.Rule; +import com.facebook.presto.sql.planner.plan.PlanNode; +import com.facebook.presto.sql.planner.plan.ProjectNode; +import com.google.common.collect.ImmutableList; + +import java.util.Optional; +import java.util.Set; + +import static com.facebook.presto.sql.planner.iterative.rule.Util.pruneInputs; + +/** + * @param The node type to look for under the ProjectNode + * Looks for a Project parent over a N child, such that the parent doesn't use all the output columns of the child. + * Given that situation, invokes the pushDownProjectOff helper to possibly rewrite the child to produce fewer outputs. + */ +public abstract class ProjectOffPushDownRule + implements Rule +{ + private static final Pattern PATTERN = Pattern.typeOf(ProjectNode.class); + private final Class targetNodeClass; + + protected ProjectOffPushDownRule(Class targetNodeClass) + { + this.targetNodeClass = targetNodeClass; + } + + @Override + public Pattern getPattern() + { + return PATTERN; + } + + @Override + public Optional apply(PlanNode node, Context context) + { + ProjectNode parent = (ProjectNode) node; + + PlanNode child = context.getLookup().resolve(parent.getSource()); + if (!targetNodeClass.isInstance(child)) { + return Optional.empty(); + } + + N targetNode = targetNodeClass.cast(child); + + return pruneInputs(child.getOutputSymbols(), parent.getAssignments().getExpressions()) + .flatMap(prunedOutputs -> this.pushDownProjectOff(context.getIdAllocator(), targetNode, prunedOutputs)) + .map(newChild -> parent.replaceChildren(ImmutableList.of(newChild))); + } + + protected abstract Optional pushDownProjectOff(PlanNodeIdAllocator idAllocator, N targetNode, Set referencedOutputs); +} diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneCountAggregationOverScalar.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneCountAggregationOverScalar.java new file mode 100644 index 000000000000..4c5777ed0812 --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneCountAggregationOverScalar.java @@ -0,0 +1,70 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.iterative.rule; + +import com.facebook.presto.matching.Pattern; +import com.facebook.presto.metadata.Signature; +import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.sql.planner.iterative.Rule; +import com.facebook.presto.sql.planner.plan.AggregationNode; +import com.facebook.presto.sql.planner.plan.PlanNode; +import com.facebook.presto.sql.planner.plan.ValuesNode; +import com.facebook.presto.sql.tree.FunctionCall; +import com.facebook.presto.sql.tree.LongLiteral; +import com.google.common.collect.ImmutableList; + +import java.util.Map; +import java.util.Optional; + +import static com.facebook.presto.sql.planner.optimizations.QueryCardinalityUtil.isScalar; +import static java.util.Objects.requireNonNull; + +/** + * A count over a subquery can be reduced to a VALUES(1) provided + * the subquery is a scalar + */ +public class PruneCountAggregationOverScalar + implements Rule +{ + private static final Pattern PATTERN = Pattern.typeOf(AggregationNode.class); + + @Override + public Pattern getPattern() + { + return PATTERN; + } + + @Override + public Optional apply(PlanNode node, Context context) + { + AggregationNode parent = (AggregationNode) node; + Map assignments = parent.getAggregations(); + if (parent.hasDefaultOutput() && assignments.size() != 1) { + return Optional.empty(); + } + for (Map.Entry entry : assignments.entrySet()) { + AggregationNode.Aggregation aggregation = entry.getValue(); + requireNonNull(aggregation, "aggregation is null"); + Signature signature = aggregation.getSignature(); + FunctionCall functionCall = aggregation.getCall(); + if (!"count".equals(signature.getName()) || !functionCall.getArguments().isEmpty()) { + return Optional.empty(); + } + } + if (!assignments.isEmpty() && isScalar(parent.getSource(), context.getLookup())) { + return Optional.of(new ValuesNode(node.getId(), node.getOutputSymbols(), ImmutableList.of(ImmutableList.of(new LongLiteral("1"))))); + } + return Optional.empty(); + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneCrossJoinColumns.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneCrossJoinColumns.java new file mode 100644 index 000000000000..830af39a9baa --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneCrossJoinColumns.java @@ -0,0 +1,46 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.iterative.rule; + +import com.facebook.presto.sql.planner.PlanNodeIdAllocator; +import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.sql.planner.plan.JoinNode; +import com.facebook.presto.sql.planner.plan.PlanNode; + +import java.util.Optional; +import java.util.Set; + +import static com.facebook.presto.sql.planner.iterative.rule.Util.restrictChildOutputs; + +/** + * Cross joins don't support output symbol selection, so push the project-off through the node. + */ +public class PruneCrossJoinColumns + extends ProjectOffPushDownRule +{ + public PruneCrossJoinColumns() + { + super(JoinNode.class); + } + + @Override + protected Optional pushDownProjectOff(PlanNodeIdAllocator idAllocator, JoinNode joinNode, Set referencedOutputs) + { + if (!joinNode.isCrossJoin()) { + return Optional.empty(); + } + + return restrictChildOutputs(idAllocator, joinNode, referencedOutputs, referencedOutputs); + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneIndexSourceColumns.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneIndexSourceColumns.java new file mode 100644 index 000000000000..0ae1fa8771c2 --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneIndexSourceColumns.java @@ -0,0 +1,77 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.iterative.rule; + +import com.facebook.presto.spi.ColumnHandle; +import com.facebook.presto.spi.predicate.TupleDomain; +import com.facebook.presto.sql.planner.PlanNodeIdAllocator; +import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.sql.planner.plan.IndexSourceNode; +import com.facebook.presto.sql.planner.plan.PlanNode; +import com.google.common.collect.Maps; + +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; + +import static com.google.common.collect.ImmutableList.toImmutableList; +import static com.google.common.collect.ImmutableSet.toImmutableSet; + +public class PruneIndexSourceColumns + extends ProjectOffPushDownRule +{ + public PruneIndexSourceColumns() + { + super(IndexSourceNode.class); + } + + @Override + protected Optional pushDownProjectOff(PlanNodeIdAllocator idAllocator, IndexSourceNode indexSourceNode, Set referencedOutputs) + { + Set prunedLookupSymbols = indexSourceNode.getLookupSymbols().stream() + .filter(referencedOutputs::contains) + .collect(toImmutableSet()); + + Map prunedAssignments = Maps.filterEntries( + indexSourceNode.getAssignments(), + entry -> referencedOutputs.contains(entry.getKey()) || + tupleDomainReferencesColumnHandle(indexSourceNode.getEffectiveTupleDomain(), entry.getValue())); + + List prunedOutputList = + indexSourceNode.getOutputSymbols().stream() + .filter(referencedOutputs::contains) + .collect(toImmutableList()); + + return Optional.of( + new IndexSourceNode( + indexSourceNode.getId(), + indexSourceNode.getIndexHandle(), + indexSourceNode.getTableHandle(), + indexSourceNode.getLayout(), + prunedLookupSymbols, + prunedOutputList, + prunedAssignments, + indexSourceNode.getEffectiveTupleDomain())); + } + + private static boolean tupleDomainReferencesColumnHandle( + TupleDomain tupleDomain, + ColumnHandle columnHandle) + { + return tupleDomain.getDomains() + .map(domains -> domains.containsKey(columnHandle)) + .orElse(false); + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneJoinChildrenColumns.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneJoinChildrenColumns.java new file mode 100644 index 000000000000..e7d7e127c4a1 --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneJoinChildrenColumns.java @@ -0,0 +1,79 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.iterative.rule; + +import com.facebook.presto.matching.Pattern; +import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.sql.planner.SymbolsExtractor; +import com.facebook.presto.sql.planner.iterative.Rule; +import com.facebook.presto.sql.planner.plan.JoinNode; +import com.facebook.presto.sql.planner.plan.PlanNode; +import com.google.common.collect.ImmutableSet; + +import java.util.Optional; +import java.util.Set; + +import static com.facebook.presto.sql.planner.iterative.rule.Util.restrictChildOutputs; + +/** + * Non-Cross joins support output symbol selection, so make any project-off of child columns explicit in project nodes. + */ +public class PruneJoinChildrenColumns + implements Rule +{ + private static final Pattern PATTERN = Pattern.typeOf(JoinNode.class); + + @Override + public Pattern getPattern() + { + return PATTERN; + } + + @Override + public Optional apply(PlanNode node, Context context) + { + JoinNode joinNode = (JoinNode) node; + if (joinNode.isCrossJoin()) { + return Optional.empty(); + } + + Set globallyUsableInputs = ImmutableSet.builder() + .addAll(joinNode.getOutputSymbols()) + .addAll( + joinNode.getFilter() + .map(SymbolsExtractor::extractUnique) + .orElse(ImmutableSet.of())) + .build(); + + Set leftUsableInputs = ImmutableSet.builder() + .addAll(globallyUsableInputs) + .addAll( + joinNode.getCriteria().stream() + .map(JoinNode.EquiJoinClause::getLeft) + .iterator()) + .addAll(joinNode.getLeftHashSymbol().map(ImmutableSet::of).orElse(ImmutableSet.of())) + .build(); + + Set rightUsableInputs = ImmutableSet.builder() + .addAll(globallyUsableInputs) + .addAll( + joinNode.getCriteria().stream() + .map(JoinNode.EquiJoinClause::getRight) + .iterator()) + .addAll(joinNode.getRightHashSymbol().map(ImmutableSet::of).orElse(ImmutableSet.of())) + .build(); + + return restrictChildOutputs(context.getIdAllocator(), joinNode, leftUsableInputs, rightUsableInputs); + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneJoinColumns.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneJoinColumns.java new file mode 100644 index 000000000000..dc121e62b2ab --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneJoinColumns.java @@ -0,0 +1,57 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.iterative.rule; + +import com.facebook.presto.sql.planner.PlanNodeIdAllocator; +import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.sql.planner.plan.JoinNode; +import com.facebook.presto.sql.planner.plan.PlanNode; + +import java.util.Optional; +import java.util.Set; + +import static com.facebook.presto.util.MoreLists.filteredCopy; + +/** + * Non-cross joins support output symbol selection, so absorb any project-off into the node. + */ +public class PruneJoinColumns + extends ProjectOffPushDownRule +{ + public PruneJoinColumns() + { + super(JoinNode.class); + } + + @Override + protected Optional pushDownProjectOff(PlanNodeIdAllocator idAllocator, JoinNode joinNode, Set referencedOutputs) + { + if (joinNode.isCrossJoin()) { + return Optional.empty(); + } + + return Optional.of( + new JoinNode( + joinNode.getId(), + joinNode.getType(), + joinNode.getLeft(), + joinNode.getRight(), + joinNode.getCriteria(), + filteredCopy(joinNode.getOutputSymbols(), referencedOutputs::contains), + joinNode.getFilter(), + joinNode.getLeftHashSymbol(), + joinNode.getRightHashSymbol(), + joinNode.getDistributionType())); + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneMarkDistinctColumns.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneMarkDistinctColumns.java new file mode 100644 index 000000000000..04e954589d73 --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneMarkDistinctColumns.java @@ -0,0 +1,53 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.iterative.rule; + +import com.facebook.presto.sql.planner.PlanNodeIdAllocator; +import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.sql.planner.plan.MarkDistinctNode; +import com.facebook.presto.sql.planner.plan.PlanNode; +import com.google.common.collect.Streams; + +import java.util.Optional; +import java.util.Set; +import java.util.stream.Stream; + +import static com.facebook.presto.sql.planner.iterative.rule.Util.restrictChildOutputs; +import static com.google.common.collect.ImmutableSet.toImmutableSet; + +public class PruneMarkDistinctColumns + extends ProjectOffPushDownRule +{ + public PruneMarkDistinctColumns() + { + super(MarkDistinctNode.class); + } + + @Override + protected Optional pushDownProjectOff(PlanNodeIdAllocator idAllocator, MarkDistinctNode markDistinctNode, Set referencedOutputs) + { + if (!referencedOutputs.contains(markDistinctNode.getMarkerSymbol())) { + return Optional.of(markDistinctNode.getSource()); + } + + Set requiredInputs = Streams.concat( + referencedOutputs.stream() + .filter(symbol -> !symbol.equals(markDistinctNode.getMarkerSymbol())), + markDistinctNode.getDistinctSymbols().stream(), + markDistinctNode.getHashSymbol().map(Stream::of).orElse(Stream.empty())) + .collect(toImmutableSet()); + + return restrictChildOutputs(idAllocator, markDistinctNode, requiredInputs); + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneOutputColumns.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneOutputColumns.java new file mode 100644 index 000000000000..74784a8a3201 --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneOutputColumns.java @@ -0,0 +1,47 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.iterative.rule; + +import com.facebook.presto.matching.Pattern; +import com.facebook.presto.sql.planner.iterative.Rule; +import com.facebook.presto.sql.planner.plan.OutputNode; +import com.facebook.presto.sql.planner.plan.PlanNode; +import com.google.common.collect.ImmutableSet; + +import java.util.Optional; + +import static com.facebook.presto.sql.planner.iterative.rule.Util.restrictChildOutputs; + +public class PruneOutputColumns + implements Rule +{ + private static final Pattern PATTERN = Pattern.typeOf(OutputNode.class); + + @Override + public Pattern getPattern() + { + return PATTERN; + } + + @Override + public Optional apply(PlanNode node, Context context) + { + OutputNode outputNode = (OutputNode) node; + + return restrictChildOutputs( + context.getIdAllocator(), + outputNode, + ImmutableSet.copyOf(outputNode.getOutputSymbols())); + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneSemiJoinColumns.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneSemiJoinColumns.java new file mode 100644 index 000000000000..5337321c62a1 --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneSemiJoinColumns.java @@ -0,0 +1,57 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.iterative.rule; + +import com.facebook.presto.sql.planner.PlanNodeIdAllocator; +import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.sql.planner.plan.PlanNode; +import com.facebook.presto.sql.planner.plan.SemiJoinNode; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.Streams; + +import java.util.Optional; +import java.util.Set; +import java.util.stream.Stream; + +import static com.facebook.presto.sql.planner.iterative.rule.Util.restrictOutputs; +import static com.google.common.collect.ImmutableSet.toImmutableSet; + +public class PruneSemiJoinColumns + extends ProjectOffPushDownRule +{ + public PruneSemiJoinColumns() + { + super(SemiJoinNode.class); + } + + @Override + protected Optional pushDownProjectOff(PlanNodeIdAllocator idAllocator, SemiJoinNode semiJoinNode, Set referencedOutputs) + { + if (!referencedOutputs.contains(semiJoinNode.getSemiJoinOutput())) { + return Optional.of(semiJoinNode.getSource()); + } + + Set requiredSourceInputs = Streams.concat( + referencedOutputs.stream() + .filter(symbol -> !symbol.equals(semiJoinNode.getSemiJoinOutput())), + Stream.of(semiJoinNode.getSourceJoinSymbol()), + semiJoinNode.getSourceHashSymbol().map(Stream::of).orElse(Stream.empty())) + .collect(toImmutableSet()); + + return restrictOutputs(idAllocator, semiJoinNode.getSource(), requiredSourceInputs) + .map(newSource -> + semiJoinNode.replaceChildren(ImmutableList.of( + newSource, semiJoinNode.getFilteringSource()))); + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneSemiJoinFilteringSourceColumns.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneSemiJoinFilteringSourceColumns.java new file mode 100644 index 000000000000..8093917a1b3c --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneSemiJoinFilteringSourceColumns.java @@ -0,0 +1,57 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.iterative.rule; + +import com.facebook.presto.matching.Pattern; +import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.sql.planner.iterative.Rule; +import com.facebook.presto.sql.planner.plan.PlanNode; +import com.facebook.presto.sql.planner.plan.SemiJoinNode; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.Streams; + +import java.util.Optional; +import java.util.Set; +import java.util.stream.Stream; + +import static com.facebook.presto.sql.planner.iterative.rule.Util.restrictOutputs; +import static com.google.common.collect.ImmutableSet.toImmutableSet; + +public class PruneSemiJoinFilteringSourceColumns + implements Rule +{ + private static final Pattern PATTERN = Pattern.typeOf(SemiJoinNode.class); + + @Override + public Pattern getPattern() + { + return PATTERN; + } + + @Override + public Optional apply(PlanNode node, Context context) + { + SemiJoinNode semiJoinNode = (SemiJoinNode) node; + + Set requiredFilteringSourceInputs = Streams.concat( + Stream.of(semiJoinNode.getFilteringSourceJoinSymbol()), + semiJoinNode.getFilteringSourceHashSymbol().map(Stream::of).orElse(Stream.empty())) + .collect(toImmutableSet()); + + return restrictOutputs(context.getIdAllocator(), semiJoinNode.getFilteringSource(), requiredFilteringSourceInputs) + .map(newFilteringSource -> + semiJoinNode.replaceChildren(ImmutableList.of( + semiJoinNode.getSource(), newFilteringSource))); + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneTableScanColumns.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneTableScanColumns.java index d1c1b6dd4916..17a2cd9537d0 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneTableScanColumns.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneTableScanColumns.java @@ -13,65 +13,36 @@ */ package com.facebook.presto.sql.planner.iterative.rule; -import com.facebook.presto.Session; import com.facebook.presto.sql.planner.PlanNodeIdAllocator; import com.facebook.presto.sql.planner.Symbol; -import com.facebook.presto.sql.planner.SymbolAllocator; -import com.facebook.presto.sql.planner.iterative.Lookup; -import com.facebook.presto.sql.planner.iterative.Pattern; -import com.facebook.presto.sql.planner.iterative.Rule; import com.facebook.presto.sql.planner.plan.PlanNode; -import com.facebook.presto.sql.planner.plan.ProjectNode; import com.facebook.presto.sql.planner.plan.TableScanNode; -import java.util.List; import java.util.Optional; -import java.util.function.Function; -import java.util.stream.Collectors; +import java.util.Set; -import static com.facebook.presto.sql.planner.iterative.rule.Util.pruneInputs; +import static com.facebook.presto.util.MoreLists.filteredCopy; +import static com.google.common.collect.Maps.filterKeys; public class PruneTableScanColumns - implements Rule + extends ProjectOffPushDownRule { - private static final Pattern PATTERN = Pattern.node(ProjectNode.class); - - @Override - public Pattern getPattern() + public PruneTableScanColumns() { - return PATTERN; + super(TableScanNode.class); } @Override - public Optional apply(PlanNode node, Lookup lookup, PlanNodeIdAllocator idAllocator, SymbolAllocator symbolAllocator, Session session) + protected Optional pushDownProjectOff(PlanNodeIdAllocator idAllocator, TableScanNode tableScanNode, Set referencedOutputs) { - ProjectNode parent = (ProjectNode) node; - - PlanNode source = lookup.resolve(parent.getSource()); - if (!(source instanceof TableScanNode)) { - return Optional.empty(); - } - - TableScanNode child = (TableScanNode) source; - - Optional> dependencies = pruneInputs(child.getOutputSymbols(), parent.getAssignments().getExpressions()); - if (!dependencies.isPresent()) { - return Optional.empty(); - } - - List newOutputs = dependencies.get(); return Optional.of( - new ProjectNode( - parent.getId(), - new TableScanNode( - child.getId(), - child.getTable(), - newOutputs, - newOutputs.stream() - .collect(Collectors.toMap(Function.identity(), e -> child.getAssignments().get(e))), - child.getLayout(), - child.getCurrentConstraint(), - child.getOriginalConstraint()), - parent.getAssignments())); + new TableScanNode( + tableScanNode.getId(), + tableScanNode.getTable(), + filteredCopy(tableScanNode.getOutputSymbols(), referencedOutputs::contains), + filterKeys(tableScanNode.getAssignments(), referencedOutputs::contains), + tableScanNode.getLayout(), + tableScanNode.getCurrentConstraint(), + tableScanNode.getOriginalConstraint())); } } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneValuesColumns.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneValuesColumns.java index 67c4272b4dcb..c3fe56036508 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneValuesColumns.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneValuesColumns.java @@ -13,15 +13,9 @@ */ package com.facebook.presto.sql.planner.iterative.rule; -import com.facebook.presto.Session; import com.facebook.presto.sql.planner.PlanNodeIdAllocator; import com.facebook.presto.sql.planner.Symbol; -import com.facebook.presto.sql.planner.SymbolAllocator; -import com.facebook.presto.sql.planner.iterative.Lookup; -import com.facebook.presto.sql.planner.iterative.Pattern; -import com.facebook.presto.sql.planner.iterative.Rule; import com.facebook.presto.sql.planner.plan.PlanNode; -import com.facebook.presto.sql.planner.plan.ProjectNode; import com.facebook.presto.sql.planner.plan.ValuesNode; import com.facebook.presto.sql.tree.Expression; import com.google.common.collect.ImmutableList; @@ -29,57 +23,37 @@ import java.util.Arrays; import java.util.List; import java.util.Optional; +import java.util.Set; import java.util.stream.Collectors; -import static com.facebook.presto.sql.planner.iterative.rule.Util.pruneInputs; +import static com.facebook.presto.util.MoreLists.filteredCopy; public class PruneValuesColumns - implements Rule + extends ProjectOffPushDownRule { - private static final Pattern PATTERN = Pattern.node(ProjectNode.class); - - @Override - public Pattern getPattern() + public PruneValuesColumns() { - return PATTERN; + super(ValuesNode.class); } @Override - public Optional apply(PlanNode node, Lookup lookup, PlanNodeIdAllocator idAllocator, SymbolAllocator symbolAllocator, Session session) + protected Optional pushDownProjectOff(PlanNodeIdAllocator idAllocator, ValuesNode valuesNode, Set referencedOutputs) { - ProjectNode parent = (ProjectNode) node; - - PlanNode child = lookup.resolve(parent.getSource()); - if (!(child instanceof ValuesNode)) { - return Optional.empty(); - } - - ValuesNode values = (ValuesNode) child; - - Optional> dependencies = pruneInputs(child.getOutputSymbols(), parent.getAssignments().getExpressions()); - if (!dependencies.isPresent()) { - return Optional.empty(); - } - - List newOutputs = dependencies.get(); + List newOutputs = filteredCopy(valuesNode.getOutputSymbols(), referencedOutputs::contains); // for each output of project, the corresponding column in the values node int[] mapping = new int[newOutputs.size()]; for (int i = 0; i < mapping.length; i++) { - mapping[i] = values.getOutputSymbols().indexOf(newOutputs.get(i)); + mapping[i] = valuesNode.getOutputSymbols().indexOf(newOutputs.get(i)); } ImmutableList.Builder> rowsBuilder = ImmutableList.builder(); - for (List row : values.getRows()) { + for (List row : valuesNode.getRows()) { rowsBuilder.add(Arrays.stream(mapping) .mapToObj(row::get) .collect(Collectors.toList())); } - return Optional.of( - new ProjectNode( - parent.getId(), - new ValuesNode(values.getId(), newOutputs, rowsBuilder.build()), - parent.getAssignments())); + return Optional.of(new ValuesNode(valuesNode.getId(), newOutputs, rowsBuilder.build())); } } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PushAggregationThroughOuterJoin.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PushAggregationThroughOuterJoin.java index 073ebbb74f04..a7d7e33ee549 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PushAggregationThroughOuterJoin.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PushAggregationThroughOuterJoin.java @@ -13,13 +13,12 @@ */ package com.facebook.presto.sql.planner.iterative.rule; -import com.facebook.presto.Session; +import com.facebook.presto.matching.Pattern; import com.facebook.presto.sql.planner.ExpressionSymbolInliner; import com.facebook.presto.sql.planner.PlanNodeIdAllocator; import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.SymbolAllocator; import com.facebook.presto.sql.planner.iterative.Lookup; -import com.facebook.presto.sql.planner.iterative.Pattern; import com.facebook.presto.sql.planner.iterative.Rule; import com.facebook.presto.sql.planner.plan.AggregationNode; import com.facebook.presto.sql.planner.plan.Assignments; @@ -86,7 +85,7 @@ public class PushAggregationThroughOuterJoin implements Rule { - private static final Pattern PATTERN = Pattern.node(AggregationNode.class); + private static final Pattern PATTERN = Pattern.typeOf(AggregationNode.class); @Override public Pattern getPattern() @@ -95,9 +94,9 @@ public Pattern getPattern() } @Override - public Optional apply(PlanNode node, Lookup lookup, PlanNodeIdAllocator idAllocator, SymbolAllocator symbolAllocator, Session session) + public Optional apply(PlanNode node, Context context) { - if (!shouldPushAggregationThroughJoin(session)) { + if (!shouldPushAggregationThroughJoin(context.getSession())) { return Optional.empty(); } @@ -106,15 +105,15 @@ public Optional apply(PlanNode node, Lookup lookup, PlanNodeIdAllocato } AggregationNode aggregation = (AggregationNode) node; - PlanNode source = lookup.resolve(aggregation.getSource()); + PlanNode source = context.getLookup().resolve(aggregation.getSource()); if (!(source instanceof JoinNode)) { return Optional.empty(); } JoinNode join = (JoinNode) source; if (join.getFilter().isPresent() || !(join.getType() == JoinNode.Type.LEFT || join.getType() == JoinNode.Type.RIGHT) - || !groupsOnAllOuterTableColumns(aggregation, lookup.resolve(getOuterTable(join))) - || !isDistinct(lookup.resolve(getOuterTable(join)), lookup::resolve)) { + || !groupsOnAllOuterTableColumns(aggregation, context.getLookup().resolve(getOuterTable(join))) + || !isDistinct(context.getLookup().resolve(getOuterTable(join)), context.getLookup()::resolve)) { return Optional.empty(); } @@ -164,7 +163,7 @@ public Optional apply(PlanNode node, Lookup lookup, PlanNodeIdAllocato join.getDistributionType()); } - return Optional.of(coalesceWithNullAggregation(rewrittenAggregation, rewrittenJoin, symbolAllocator, idAllocator, lookup)); + return Optional.of(coalesceWithNullAggregation(rewrittenAggregation, rewrittenJoin, context.getSymbolAllocator(), context.getIdAllocator(), context.getLookup())); } private static PlanNode getInnerTable(JoinNode join) diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PushLimitThroughMarkDistinct.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PushLimitThroughMarkDistinct.java index 9f2120ac0693..4d12eba5495e 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PushLimitThroughMarkDistinct.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PushLimitThroughMarkDistinct.java @@ -13,11 +13,7 @@ */ package com.facebook.presto.sql.planner.iterative.rule; -import com.facebook.presto.Session; -import com.facebook.presto.sql.planner.PlanNodeIdAllocator; -import com.facebook.presto.sql.planner.SymbolAllocator; -import com.facebook.presto.sql.planner.iterative.Lookup; -import com.facebook.presto.sql.planner.iterative.Pattern; +import com.facebook.presto.matching.Pattern; import com.facebook.presto.sql.planner.iterative.Rule; import com.facebook.presto.sql.planner.plan.LimitNode; import com.facebook.presto.sql.planner.plan.MarkDistinctNode; @@ -30,7 +26,7 @@ public class PushLimitThroughMarkDistinct implements Rule { - private static final Pattern PATTERN = Pattern.node(LimitNode.class); + private static final Pattern PATTERN = Pattern.typeOf(LimitNode.class); @Override public Pattern getPattern() @@ -39,11 +35,11 @@ public Pattern getPattern() } @Override - public Optional apply(PlanNode node, Lookup lookup, PlanNodeIdAllocator idAllocator, SymbolAllocator symbolAllocator, Session session) + public Optional apply(PlanNode node, Context context) { LimitNode parent = (LimitNode) node; - PlanNode child = lookup.resolve(parent.getSource()); + PlanNode child = context.getLookup().resolve(parent.getSource()); if (!(child instanceof MarkDistinctNode)) { return Optional.empty(); } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PushLimitThroughProject.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PushLimitThroughProject.java index 562413ee3c2a..2c030ef030ac 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PushLimitThroughProject.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PushLimitThroughProject.java @@ -13,11 +13,7 @@ */ package com.facebook.presto.sql.planner.iterative.rule; -import com.facebook.presto.Session; -import com.facebook.presto.sql.planner.PlanNodeIdAllocator; -import com.facebook.presto.sql.planner.SymbolAllocator; -import com.facebook.presto.sql.planner.iterative.Lookup; -import com.facebook.presto.sql.planner.iterative.Pattern; +import com.facebook.presto.matching.Pattern; import com.facebook.presto.sql.planner.iterative.Rule; import com.facebook.presto.sql.planner.plan.LimitNode; import com.facebook.presto.sql.planner.plan.PlanNode; @@ -30,7 +26,7 @@ public class PushLimitThroughProject implements Rule { - private static final Pattern PATTERN = Pattern.node(LimitNode.class); + private static final Pattern PATTERN = Pattern.typeOf(LimitNode.class); @Override public Pattern getPattern() @@ -39,11 +35,11 @@ public Pattern getPattern() } @Override - public Optional apply(PlanNode node, Lookup lookup, PlanNodeIdAllocator idAllocator, SymbolAllocator symbolAllocator, Session session) + public Optional apply(PlanNode node, Context context) { LimitNode parent = (LimitNode) node; - PlanNode child = lookup.resolve(parent.getSource()); + PlanNode child = context.getLookup().resolve(parent.getSource()); if (!(child instanceof ProjectNode)) { return Optional.empty(); } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PushLimitThroughSemiJoin.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PushLimitThroughSemiJoin.java index 8953b38b0da0..9c5ff53507c3 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PushLimitThroughSemiJoin.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PushLimitThroughSemiJoin.java @@ -13,11 +13,7 @@ */ package com.facebook.presto.sql.planner.iterative.rule; -import com.facebook.presto.Session; -import com.facebook.presto.sql.planner.PlanNodeIdAllocator; -import com.facebook.presto.sql.planner.SymbolAllocator; -import com.facebook.presto.sql.planner.iterative.Lookup; -import com.facebook.presto.sql.planner.iterative.Pattern; +import com.facebook.presto.matching.Pattern; import com.facebook.presto.sql.planner.iterative.Rule; import com.facebook.presto.sql.planner.plan.LimitNode; import com.facebook.presto.sql.planner.plan.PlanNode; @@ -30,7 +26,7 @@ public class PushLimitThroughSemiJoin implements Rule { - private static final Pattern PATTERN = Pattern.node(LimitNode.class); + private static final Pattern PATTERN = Pattern.typeOf(LimitNode.class); @Override public Pattern getPattern() @@ -39,11 +35,11 @@ public Pattern getPattern() } @Override - public Optional apply(PlanNode node, Lookup lookup, PlanNodeIdAllocator idAllocator, SymbolAllocator symbolAllocator, Session session) + public Optional apply(PlanNode node, Context context) { LimitNode parent = (LimitNode) node; - PlanNode child = lookup.resolve(parent.getSource()); + PlanNode child = context.getLookup().resolve(parent.getSource()); if (!(child instanceof SemiJoinNode)) { return Optional.empty(); } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PushProjectionThroughExchange.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PushProjectionThroughExchange.java index 13badd6b5a0c..6f0c2b572ab1 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PushProjectionThroughExchange.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PushProjectionThroughExchange.java @@ -13,15 +13,11 @@ */ package com.facebook.presto.sql.planner.iterative.rule; -import com.facebook.presto.Session; +import com.facebook.presto.matching.Pattern; import com.facebook.presto.spi.type.Type; import com.facebook.presto.sql.planner.ExpressionSymbolInliner; import com.facebook.presto.sql.planner.PartitioningScheme; -import com.facebook.presto.sql.planner.PlanNodeIdAllocator; import com.facebook.presto.sql.planner.Symbol; -import com.facebook.presto.sql.planner.SymbolAllocator; -import com.facebook.presto.sql.planner.iterative.Lookup; -import com.facebook.presto.sql.planner.iterative.Pattern; import com.facebook.presto.sql.planner.iterative.Rule; import com.facebook.presto.sql.planner.plan.Assignments; import com.facebook.presto.sql.planner.plan.ExchangeNode; @@ -30,13 +26,14 @@ import com.facebook.presto.sql.tree.Expression; import com.facebook.presto.sql.tree.SymbolReference; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableSet; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Optional; -import static com.facebook.presto.sql.planner.plan.Assignments.identity; +import static com.facebook.presto.sql.planner.iterative.rule.Util.restrictOutputs; /** * Transforms: @@ -70,7 +67,7 @@ public class PushProjectionThroughExchange implements Rule { - private static final Pattern PATTERN = Pattern.node(ProjectNode.class); + private static final Pattern PATTERN = Pattern.typeOf(ProjectNode.class); @Override public Pattern getPattern() @@ -79,7 +76,7 @@ public Pattern getPattern() } @Override - public Optional apply(PlanNode node, Lookup lookup, PlanNodeIdAllocator idAllocator, SymbolAllocator symbolAllocator, Session session) + public Optional apply(PlanNode node, Context context) { if (!(node instanceof ProjectNode)) { return Optional.empty(); @@ -87,7 +84,7 @@ public Optional apply(PlanNode node, Lookup lookup, PlanNodeIdAllocato ProjectNode project = (ProjectNode) node; - PlanNode child = lookup.resolve(project.getSource()); + PlanNode child = context.getLookup().resolve(project.getSource()); if (!(child instanceof ExchangeNode)) { return Optional.empty(); } @@ -122,12 +119,12 @@ public Optional apply(PlanNode node, Lookup lookup, PlanNodeIdAllocato } for (Map.Entry projection : project.getAssignments().entrySet()) { Expression translatedExpression = translateExpression(projection.getValue(), outputToInputMap); - Type type = symbolAllocator.getTypes().get(projection.getKey()); - Symbol symbol = symbolAllocator.newSymbol(translatedExpression, type); + Type type = context.getSymbolAllocator().getTypes().get(projection.getKey()); + Symbol symbol = context.getSymbolAllocator().newSymbol(translatedExpression, type); projections.put(symbol, translatedExpression); inputs.add(symbol); } - newSourceBuilder.add(new ProjectNode(idAllocator.getNextId(), exchange.getSources().get(i), projections.build())); + newSourceBuilder.add(new ProjectNode(context.getIdAllocator().getNextId(), exchange.getSources().get(i), projections.build())); inputsBuilder.add(inputs.build()); } @@ -158,12 +155,8 @@ public Optional apply(PlanNode node, Lookup lookup, PlanNodeIdAllocato newSourceBuilder.build(), inputsBuilder.build()); - if (!result.getOutputSymbols().equals(project.getOutputSymbols())) { - // we need to strip unnecessary symbols (hash, partitioning columns). - result = new ProjectNode(idAllocator.getNextId(), result, identity(project.getOutputSymbols())); - } - - return Optional.of(result); + // we need to strip unnecessary symbols (hash, partitioning columns). + return Optional.of(restrictOutputs(context.getIdAllocator(), result, ImmutableSet.copyOf(project.getOutputSymbols())).orElse(result)); } private boolean isSymbolToSymbolProjection(ProjectNode project) diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PushProjectionThroughUnion.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PushProjectionThroughUnion.java index e203eff472a1..fdd4c1a460dd 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PushProjectionThroughUnion.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PushProjectionThroughUnion.java @@ -13,14 +13,10 @@ */ package com.facebook.presto.sql.planner.iterative.rule; -import com.facebook.presto.Session; +import com.facebook.presto.matching.Pattern; import com.facebook.presto.spi.type.Type; import com.facebook.presto.sql.planner.ExpressionSymbolInliner; -import com.facebook.presto.sql.planner.PlanNodeIdAllocator; import com.facebook.presto.sql.planner.Symbol; -import com.facebook.presto.sql.planner.SymbolAllocator; -import com.facebook.presto.sql.planner.iterative.Lookup; -import com.facebook.presto.sql.planner.iterative.Pattern; import com.facebook.presto.sql.planner.iterative.Rule; import com.facebook.presto.sql.planner.plan.Assignments; import com.facebook.presto.sql.planner.plan.PlanNode; @@ -39,7 +35,7 @@ public class PushProjectionThroughUnion implements Rule { - private static final Pattern PATTERN = Pattern.node(ProjectNode.class); + private static final Pattern PATTERN = Pattern.typeOf(ProjectNode.class); @Override public Pattern getPattern() @@ -48,7 +44,7 @@ public Pattern getPattern() } @Override - public Optional apply(PlanNode node, Lookup lookup, PlanNodeIdAllocator idAllocator, SymbolAllocator symbolAllocator, Session session) + public Optional apply(PlanNode node, Context context) { if (!(node instanceof ProjectNode)) { return Optional.empty(); @@ -56,7 +52,7 @@ public Optional apply(PlanNode node, Lookup lookup, PlanNodeIdAllocato ProjectNode parent = (ProjectNode) node; - PlanNode child = lookup.resolve(parent.getSource()); + PlanNode child = context.getLookup().resolve(parent.getSource()); if (!(child instanceof UnionNode)) { return Optional.empty(); } @@ -82,12 +78,12 @@ public Optional apply(PlanNode node, Lookup lookup, PlanNodeIdAllocato // Translate the assignments in the ProjectNode using symbols of the source of the UnionNode for (Map.Entry entry : parent.getAssignments().entrySet()) { Expression translatedExpression = translateExpression(entry.getValue(), outputToInput); - Type type = symbolAllocator.getTypes().get(entry.getKey()); - Symbol symbol = symbolAllocator.newSymbol(translatedExpression, type); + Type type = context.getSymbolAllocator().getTypes().get(entry.getKey()); + Symbol symbol = context.getSymbolAllocator().newSymbol(translatedExpression, type); assignments.put(symbol, translatedExpression); projectSymbolMapping.put(entry.getKey(), symbol); } - outputSources.add(new ProjectNode(idAllocator.getNextId(), source.getSources().get(i), assignments.build())); + outputSources.add(new ProjectNode(context.getIdAllocator().getNextId(), source.getSources().get(i), assignments.build())); outputLayout.forEach(symbol -> mappings.put(symbol, projectSymbolMapping.get(symbol))); } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PushTableWriteThroughUnion.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PushTableWriteThroughUnion.java new file mode 100644 index 000000000000..2bdb4ddd1195 --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PushTableWriteThroughUnion.java @@ -0,0 +1,83 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.iterative.rule; + +import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.sql.planner.iterative.Rule; +import com.facebook.presto.sql.planner.plan.PlanNode; +import com.facebook.presto.sql.planner.plan.TableWriterNode; +import com.facebook.presto.sql.planner.plan.UnionNode; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableListMultimap; + +import java.util.Optional; + +import static com.facebook.presto.SystemSessionProperties.isPushTableWriteThroughUnion; +import static com.google.common.collect.ImmutableList.toImmutableList; + +public class PushTableWriteThroughUnion + implements Rule +{ + @Override + public Optional apply(PlanNode node, Context context) + { + if (!isPushTableWriteThroughUnion(context.getSession())) { + return Optional.empty(); + } + + if (!(node instanceof TableWriterNode)) { + return Optional.empty(); + } + + TableWriterNode tableWriterNode = (TableWriterNode) node; + if (tableWriterNode.getPartitioningScheme().isPresent()) { + // The primary incentive of this optimizer is to increase the parallelism for table + // write. For a table with partitioning scheme, parallelism for table writing is + // guaranteed regardless of this optimizer. The level of local parallelism will be + // determined by LocalExecutionPlanner separately, and shouldn't be a concern of + // this optimizer. + return Optional.empty(); + } + + PlanNode child = context.getLookup().resolve(tableWriterNode.getSource()); + if (!(child instanceof UnionNode)) { + return Optional.empty(); + } + + UnionNode unionNode = (UnionNode) child; + ImmutableList.Builder rewrittenSources = ImmutableList.builder(); + ImmutableListMultimap.Builder mappings = ImmutableListMultimap.builder(); + for (int i = 0; i < unionNode.getSources().size(); i++) { + int index = i; + ImmutableList.Builder newSymbols = ImmutableList.builder(); + for (Symbol outputSymbol : node.getOutputSymbols()) { + Symbol newSymbol = context.getSymbolAllocator().newSymbol(outputSymbol); + newSymbols.add(newSymbol); + mappings.put(outputSymbol, newSymbol); + } + rewrittenSources.add(new TableWriterNode( + context.getIdAllocator().getNextId(), + unionNode.getSources().get(index), + tableWriterNode.getTarget(), + tableWriterNode.getColumns().stream() + .map(column -> unionNode.getSymbolMapping().get(column).get(index)) + .collect(toImmutableList()), + tableWriterNode.getColumnNames(), + newSymbols.build(), + tableWriterNode.getPartitioningScheme())); + } + + return Optional.of(new UnionNode(context.getIdAllocator().getNextId(), rewrittenSources.build(), mappings.build(), ImmutableList.copyOf(mappings.build().keySet()))); + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PushTopNThroughUnion.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PushTopNThroughUnion.java index 280cde9c19e9..2c3e90a681ac 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PushTopNThroughUnion.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PushTopNThroughUnion.java @@ -13,12 +13,8 @@ */ package com.facebook.presto.sql.planner.iterative.rule; -import com.facebook.presto.Session; -import com.facebook.presto.sql.planner.PlanNodeIdAllocator; +import com.facebook.presto.matching.Pattern; import com.facebook.presto.sql.planner.Symbol; -import com.facebook.presto.sql.planner.SymbolAllocator; -import com.facebook.presto.sql.planner.iterative.Lookup; -import com.facebook.presto.sql.planner.iterative.Pattern; import com.facebook.presto.sql.planner.iterative.Rule; import com.facebook.presto.sql.planner.optimizations.SymbolMapper; import com.facebook.presto.sql.planner.plan.PlanNode; @@ -37,7 +33,7 @@ public class PushTopNThroughUnion implements Rule { - private static final Pattern PATTERN = Pattern.node(TopNNode.class); + private static final Pattern PATTERN = Pattern.typeOf(TopNNode.class); @Override public Pattern getPattern() @@ -46,7 +42,7 @@ public Pattern getPattern() } @Override - public Optional apply(PlanNode node, Lookup lookup, PlanNodeIdAllocator idAllocator, SymbolAllocator symbolAllocator, Session session) + public Optional apply(PlanNode node, Context context) { if (!(node instanceof TopNNode)) { return Optional.empty(); @@ -58,7 +54,7 @@ public Optional apply(PlanNode node, Lookup lookup, PlanNodeIdAllocato return Optional.empty(); } - PlanNode child = lookup.resolve(topNNode.getSource()); + PlanNode child = context.getLookup().resolve(topNNode.getSource()); if (!(child instanceof UnionNode)) { return Optional.empty(); } @@ -75,7 +71,7 @@ public Optional apply(PlanNode node, Lookup lookup, PlanNodeIdAllocato Symbol unionInput = getLast(intersection(inputSymbols, sourceOutputSymbols)); symbolMapper.put(unionOutput, unionInput); } - sources.add(symbolMapper.build().map(topNNode, source, idAllocator.getNextId())); + sources.add(symbolMapper.build().map(topNNode, source, context.getIdAllocator().getNextId())); } return Optional.of(new UnionNode( diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/RemoveEmptyDelete.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/RemoveEmptyDelete.java index 6f6527da4fa6..27b2764aa3e6 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/RemoveEmptyDelete.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/RemoveEmptyDelete.java @@ -13,11 +13,7 @@ */ package com.facebook.presto.sql.planner.iterative.rule; -import com.facebook.presto.Session; -import com.facebook.presto.sql.planner.PlanNodeIdAllocator; -import com.facebook.presto.sql.planner.SymbolAllocator; -import com.facebook.presto.sql.planner.iterative.Lookup; -import com.facebook.presto.sql.planner.iterative.Pattern; +import com.facebook.presto.matching.Pattern; import com.facebook.presto.sql.planner.iterative.Rule; import com.facebook.presto.sql.planner.plan.DeleteNode; import com.facebook.presto.sql.planner.plan.ExchangeNode; @@ -52,7 +48,7 @@ public class RemoveEmptyDelete implements Rule { - private static final Pattern PATTERN = Pattern.node(TableFinishNode.class); + private static final Pattern PATTERN = Pattern.typeOf(TableFinishNode.class); @Override public Pattern getPattern() @@ -61,13 +57,13 @@ public Pattern getPattern() } @Override - public Optional apply(PlanNode node, Lookup lookup, PlanNodeIdAllocator idAllocator, SymbolAllocator symbolAllocator, Session session) + public Optional apply(PlanNode node, Context context) { // TODO split into multiple rules (https://github.com/prestodb/presto/issues/7292) TableFinishNode finish = (TableFinishNode) node; - PlanNode finishSource = lookup.resolve(finish.getSource()); + PlanNode finishSource = context.getLookup().resolve(finish.getSource()); if (!(finishSource instanceof ExchangeNode)) { return Optional.empty(); } @@ -77,13 +73,13 @@ public Optional apply(PlanNode node, Lookup lookup, PlanNodeIdAllocato return Optional.empty(); } - PlanNode exchangeSource = lookup.resolve(getOnlyElement(exchange.getSources())); + PlanNode exchangeSource = context.getLookup().resolve(getOnlyElement(exchange.getSources())); if (!(exchangeSource instanceof DeleteNode)) { return Optional.empty(); } DeleteNode delete = (DeleteNode) exchangeSource; - PlanNode deleteSource = lookup.resolve(delete.getSource()); + PlanNode deleteSource = context.getLookup().resolve(delete.getSource()); if (!(deleteSource instanceof ValuesNode)) { return Optional.empty(); } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/RemoveFullSample.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/RemoveFullSample.java index bc7d079f1cc1..52ce3714985b 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/RemoveFullSample.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/RemoveFullSample.java @@ -13,11 +13,7 @@ */ package com.facebook.presto.sql.planner.iterative.rule; -import com.facebook.presto.Session; -import com.facebook.presto.sql.planner.PlanNodeIdAllocator; -import com.facebook.presto.sql.planner.SymbolAllocator; -import com.facebook.presto.sql.planner.iterative.Lookup; -import com.facebook.presto.sql.planner.iterative.Pattern; +import com.facebook.presto.matching.Pattern; import com.facebook.presto.sql.planner.iterative.Rule; import com.facebook.presto.sql.planner.plan.PlanNode; import com.facebook.presto.sql.planner.plan.SampleNode; @@ -30,7 +26,7 @@ public class RemoveFullSample implements Rule { - private static final Pattern PATTERN = Pattern.node(SampleNode.class); + private static final Pattern PATTERN = Pattern.typeOf(SampleNode.class); @Override public Pattern getPattern() @@ -39,7 +35,7 @@ public Pattern getPattern() } @Override - public Optional apply(PlanNode node, Lookup lookup, PlanNodeIdAllocator idAllocator, SymbolAllocator symbolAllocator, Session session) + public Optional apply(PlanNode node, Context context) { SampleNode sample = (SampleNode) node; diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/RemoveRedundantIdentityProjections.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/RemoveRedundantIdentityProjections.java index 5b0605650aa0..a9eaf28157d1 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/RemoveRedundantIdentityProjections.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/RemoveRedundantIdentityProjections.java @@ -13,11 +13,7 @@ */ package com.facebook.presto.sql.planner.iterative.rule; -import com.facebook.presto.Session; -import com.facebook.presto.sql.planner.PlanNodeIdAllocator; -import com.facebook.presto.sql.planner.SymbolAllocator; -import com.facebook.presto.sql.planner.iterative.Lookup; -import com.facebook.presto.sql.planner.iterative.Pattern; +import com.facebook.presto.matching.Pattern; import com.facebook.presto.sql.planner.iterative.Rule; import com.facebook.presto.sql.planner.plan.PlanNode; import com.facebook.presto.sql.planner.plan.ProjectNode; @@ -31,7 +27,7 @@ public class RemoveRedundantIdentityProjections implements Rule { - private static final Pattern PATTERN = Pattern.node(ProjectNode.class); + private static final Pattern PATTERN = Pattern.typeOf(ProjectNode.class); @Override public Pattern getPattern() @@ -40,7 +36,7 @@ public Pattern getPattern() } @Override - public Optional apply(PlanNode node, Lookup lookup, PlanNodeIdAllocator idAllocator, SymbolAllocator symbolAllocator, Session session) + public Optional apply(PlanNode node, Context context) { ProjectNode project = (ProjectNode) node; diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/RemoveTrivialFilters.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/RemoveTrivialFilters.java new file mode 100644 index 000000000000..ada7e96497fe --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/RemoveTrivialFilters.java @@ -0,0 +1,58 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.iterative.rule; + +import com.facebook.presto.matching.Pattern; +import com.facebook.presto.sql.planner.iterative.Rule; +import com.facebook.presto.sql.planner.plan.FilterNode; +import com.facebook.presto.sql.planner.plan.PlanNode; +import com.facebook.presto.sql.planner.plan.ValuesNode; +import com.facebook.presto.sql.tree.Expression; + +import java.util.Optional; + +import static com.facebook.presto.sql.tree.BooleanLiteral.FALSE_LITERAL; +import static com.facebook.presto.sql.tree.BooleanLiteral.TRUE_LITERAL; +import static java.util.Collections.emptyList; +import static java.util.Optional.empty; + +public class RemoveTrivialFilters + implements Rule +{ + private static final Pattern PATTERN = Pattern.typeOf(FilterNode.class); + + @Override + public Pattern getPattern() + { + return PATTERN; + } + + @Override + public Optional apply(PlanNode node, Context context) + { + FilterNode filterNode = (FilterNode) node; + + Expression predicate = filterNode.getPredicate(); + + if (predicate.equals(TRUE_LITERAL)) { + return Optional.of(filterNode.getSource()); + } + + if (predicate.equals(FALSE_LITERAL)) { + return Optional.of(new ValuesNode(context.getIdAllocator().getNextId(), filterNode.getOutputSymbols(), emptyList())); + } + + return empty(); + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/RemoveUnreferencedScalarApplyNodes.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/RemoveUnreferencedScalarApplyNodes.java new file mode 100644 index 000000000000..557d37e17d70 --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/RemoveUnreferencedScalarApplyNodes.java @@ -0,0 +1,44 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.facebook.presto.sql.planner.iterative.rule; + +import com.facebook.presto.matching.Pattern; +import com.facebook.presto.sql.planner.iterative.Rule; +import com.facebook.presto.sql.planner.plan.ApplyNode; +import com.facebook.presto.sql.planner.plan.PlanNode; + +import java.util.Optional; + +public class RemoveUnreferencedScalarApplyNodes + implements Rule +{ + private static final Pattern PATTERN = Pattern.typeOf(ApplyNode.class); + + @Override + public Pattern getPattern() + { + return PATTERN; + } + + @Override + public Optional apply(PlanNode node, Context context) + { + ApplyNode applyNode = (ApplyNode) node; + if (applyNode.getSubqueryAssignments().isEmpty()) { + return Optional.of(applyNode.getInput()); + } + return Optional.empty(); + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/RemoveUnreferencedScalarLateralNodes.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/RemoveUnreferencedScalarLateralNodes.java new file mode 100644 index 000000000000..1b6de5d8325a --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/RemoveUnreferencedScalarLateralNodes.java @@ -0,0 +1,60 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.iterative.rule; + +import com.facebook.presto.matching.Pattern; +import com.facebook.presto.sql.planner.iterative.Lookup; +import com.facebook.presto.sql.planner.iterative.Rule; +import com.facebook.presto.sql.planner.plan.LateralJoinNode; +import com.facebook.presto.sql.planner.plan.PlanNode; + +import java.util.Optional; + +import static com.facebook.presto.sql.planner.optimizations.QueryCardinalityUtil.isScalar; +import static java.util.Optional.empty; + +public class RemoveUnreferencedScalarLateralNodes + implements Rule +{ + private static final Pattern PATTERN = Pattern.typeOf(LateralJoinNode.class); + + @Override + public Pattern getPattern() + { + return PATTERN; + } + + @Override + public Optional apply(PlanNode node, Context context) + { + LateralJoinNode lateralJoinNode = (LateralJoinNode) node; + PlanNode input = lateralJoinNode.getInput(); + PlanNode subquery = lateralJoinNode.getSubquery(); + + if (isUnreferencedScalar(input, context.getLookup())) { + return Optional.of(subquery); + } + + if (isUnreferencedScalar(subquery, context.getLookup())) { + return Optional.of(input); + } + + return empty(); + } + + private boolean isUnreferencedScalar(PlanNode input, Lookup lookup) + { + return input.getOutputSymbols().isEmpty() && isScalar(input, lookup); + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/SimplifyCountOverConstant.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/SimplifyCountOverConstant.java index 46d94c84adea..b0dd91426f99 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/SimplifyCountOverConstant.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/SimplifyCountOverConstant.java @@ -13,14 +13,10 @@ */ package com.facebook.presto.sql.planner.iterative.rule; -import com.facebook.presto.Session; +import com.facebook.presto.matching.Pattern; import com.facebook.presto.metadata.Signature; import com.facebook.presto.spi.type.StandardTypes; -import com.facebook.presto.sql.planner.PlanNodeIdAllocator; import com.facebook.presto.sql.planner.Symbol; -import com.facebook.presto.sql.planner.SymbolAllocator; -import com.facebook.presto.sql.planner.iterative.Lookup; -import com.facebook.presto.sql.planner.iterative.Pattern; import com.facebook.presto.sql.planner.iterative.Rule; import com.facebook.presto.sql.planner.plan.AggregationNode; import com.facebook.presto.sql.planner.plan.Assignments; @@ -45,7 +41,7 @@ public class SimplifyCountOverConstant implements Rule { - private static final Pattern PATTERN = Pattern.node(AggregationNode.class); + private static final Pattern PATTERN = Pattern.typeOf(AggregationNode.class); @Override public Pattern getPattern() @@ -54,11 +50,11 @@ public Pattern getPattern() } @Override - public Optional apply(PlanNode node, Lookup lookup, PlanNodeIdAllocator idAllocator, SymbolAllocator symbolAllocator, Session session) + public Optional apply(PlanNode node, Context context) { AggregationNode parent = (AggregationNode) node; - PlanNode input = lookup.resolve(parent.getSource()); + PlanNode input = context.getLookup().resolve(parent.getSource()); if (!(input instanceof ProjectNode)) { return Optional.empty(); } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/SingleMarkDistinctToGroupBy.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/SingleMarkDistinctToGroupBy.java index e841e849931a..cc0ffc6ca2e3 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/SingleMarkDistinctToGroupBy.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/SingleMarkDistinctToGroupBy.java @@ -13,12 +13,8 @@ */ package com.facebook.presto.sql.planner.iterative.rule; -import com.facebook.presto.Session; -import com.facebook.presto.sql.planner.PlanNodeIdAllocator; +import com.facebook.presto.matching.Pattern; import com.facebook.presto.sql.planner.Symbol; -import com.facebook.presto.sql.planner.SymbolAllocator; -import com.facebook.presto.sql.planner.iterative.Lookup; -import com.facebook.presto.sql.planner.iterative.Pattern; import com.facebook.presto.sql.planner.iterative.Rule; import com.facebook.presto.sql.planner.plan.AggregationNode; import com.facebook.presto.sql.planner.plan.AggregationNode.Aggregation; @@ -53,7 +49,7 @@ public class SingleMarkDistinctToGroupBy implements Rule { - private static final Pattern PATTERN = Pattern.node(AggregationNode.class); + private static final Pattern PATTERN = Pattern.typeOf(AggregationNode.class); @Override public Pattern getPattern() @@ -62,11 +58,11 @@ public Pattern getPattern() } @Override - public Optional apply(PlanNode node, Lookup lookup, PlanNodeIdAllocator idAllocator, SymbolAllocator symbolAllocator, Session session) + public Optional apply(PlanNode node, Context context) { AggregationNode parent = (AggregationNode) node; - PlanNode source = lookup.resolve(parent.getSource()); + PlanNode source = context.getLookup().resolve(parent.getSource()); if (!(source instanceof MarkDistinctNode)) { return Optional.empty(); } @@ -107,9 +103,9 @@ public Optional apply(PlanNode node, Lookup lookup, PlanNodeIdAllocato return Optional.of( new AggregationNode( - idAllocator.getNextId(), + context.getIdAllocator().getNextId(), new AggregationNode( - idAllocator.getNextId(), + context.getIdAllocator().getNextId(), child.getSource(), Collections.emptyMap(), ImmutableList.of(child.getDistinctSymbols()), diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/SwapAdjacentWindowsBySpecifications.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/SwapAdjacentWindowsBySpecifications.java index 460f60612200..4cb3ffdff682 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/SwapAdjacentWindowsBySpecifications.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/SwapAdjacentWindowsBySpecifications.java @@ -13,12 +13,8 @@ */ package com.facebook.presto.sql.planner.iterative.rule; -import com.facebook.presto.Session; -import com.facebook.presto.sql.planner.PlanNodeIdAllocator; +import com.facebook.presto.matching.Pattern; import com.facebook.presto.sql.planner.Symbol; -import com.facebook.presto.sql.planner.SymbolAllocator; -import com.facebook.presto.sql.planner.iterative.Lookup; -import com.facebook.presto.sql.planner.iterative.Pattern; import com.facebook.presto.sql.planner.iterative.Rule; import com.facebook.presto.sql.planner.plan.PlanNode; import com.facebook.presto.sql.planner.plan.WindowNode; @@ -32,7 +28,7 @@ public class SwapAdjacentWindowsBySpecifications implements Rule { - private static final Pattern PATTERN = Pattern.node(WindowNode.class); + private static final Pattern PATTERN = Pattern.typeOf(WindowNode.class); @Override public Pattern getPattern() @@ -41,11 +37,11 @@ public Pattern getPattern() } @Override - public Optional apply(PlanNode node, Lookup lookup, PlanNodeIdAllocator idAllocator, SymbolAllocator symbolAllocator, Session session) + public Optional apply(PlanNode node, Context context) { WindowNode parent = (WindowNode) node; - PlanNode child = lookup.resolve(parent.getSource()); + PlanNode child = context.getLookup().resolve(parent.getSource()); if (!(child instanceof WindowNode)) { return Optional.empty(); } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/TransformCorrelatedInPredicateToJoin.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/TransformCorrelatedInPredicateToJoin.java index 66bd22e3735a..342abb074775 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/TransformCorrelatedInPredicateToJoin.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/TransformCorrelatedInPredicateToJoin.java @@ -13,15 +13,14 @@ */ package com.facebook.presto.sql.planner.iterative.rule; -import com.facebook.presto.Session; +import com.facebook.presto.matching.Pattern; import com.facebook.presto.metadata.FunctionKind; import com.facebook.presto.metadata.Signature; -import com.facebook.presto.sql.planner.DependencyExtractor; import com.facebook.presto.sql.planner.PlanNodeIdAllocator; import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.SymbolAllocator; +import com.facebook.presto.sql.planner.SymbolsExtractor; import com.facebook.presto.sql.planner.iterative.Lookup; -import com.facebook.presto.sql.planner.iterative.Pattern; import com.facebook.presto.sql.planner.iterative.Rule; import com.facebook.presto.sql.planner.optimizations.TransformCorrelatedScalarAggregationToJoin; import com.facebook.presto.sql.planner.optimizations.TransformUncorrelatedInPredicateSubqueryToSemiJoin; @@ -95,7 +94,7 @@ public class TransformCorrelatedInPredicateToJoin implements Rule { - private static final Pattern PATTERN = Pattern.node(ApplyNode.class); + private static final Pattern PATTERN = Pattern.typeOf(ApplyNode.class); @Override public Pattern getPattern() @@ -104,7 +103,7 @@ public Pattern getPattern() } @Override - public Optional apply(PlanNode node, Lookup lookup, PlanNodeIdAllocator idAllocator, SymbolAllocator symbolAllocator, Session session) + public Optional apply(PlanNode node, Context context) { if (!(node instanceof ApplyNode)) { return Optional.empty(); @@ -128,7 +127,7 @@ public Optional apply(PlanNode node, Lookup lookup, PlanNodeIdAllocato InPredicate inPredicate = (InPredicate) assignmentExpression; Symbol inPredicateOutputSymbol = getOnlyElement(subqueryAssignments.getSymbols()); - return apply(apply, inPredicate, inPredicateOutputSymbol, lookup, idAllocator, symbolAllocator); + return apply(apply, inPredicate, inPredicateOutputSymbol, context.getLookup(), context.getIdAllocator(), context.getSymbolAllocator()); } private Optional apply( @@ -179,7 +178,7 @@ private PlanNode buildInPredicateEquivalent( idAllocator.getNextId(), decorrelatedBuildSource, Assignments.builder() - .putAll(Assignments.identity(decorrelatedBuildSource.getOutputSymbols())) + .putIdentities(decorrelatedBuildSource.getOutputSymbols()) .put(buildSideKnownNonNull, bigint(0)) .build() ); @@ -236,7 +235,7 @@ private PlanNode buildInPredicateEquivalent( idAllocator.getNextId(), aggregation, Assignments.builder() - .putAll(Assignments.identity(apply.getInput().getOutputSymbols())) + .putIdentities(apply.getInput().getOutputSymbols()) .put(inPredicateOutputSymbol, inPredicateEquivalent) .build() ); @@ -399,7 +398,7 @@ private boolean isCorrelatedRecursively(PlanNode node) private boolean isCorrelatedShallowly(PlanNode node) { - return DependencyExtractor.extractUniqueNonRecursive(node).stream().anyMatch(correlation::contains); + return SymbolsExtractor.extractUniqueNonRecursive(node).stream().anyMatch(correlation::contains); } } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/TransformCorrelatedScalarAggregationToJoin.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/TransformCorrelatedScalarAggregationToJoin.java index edcd8860a707..5a5710d39154 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/TransformCorrelatedScalarAggregationToJoin.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/TransformCorrelatedScalarAggregationToJoin.java @@ -13,12 +13,9 @@ */ package com.facebook.presto.sql.planner.iterative.rule; -import com.facebook.presto.Session; +import com.facebook.presto.matching.Pattern; import com.facebook.presto.metadata.FunctionRegistry; -import com.facebook.presto.sql.planner.PlanNodeIdAllocator; -import com.facebook.presto.sql.planner.SymbolAllocator; import com.facebook.presto.sql.planner.iterative.Lookup; -import com.facebook.presto.sql.planner.iterative.Pattern; import com.facebook.presto.sql.planner.iterative.Rule; import com.facebook.presto.sql.planner.optimizations.ScalarAggregationToJoinRewriter; import com.facebook.presto.sql.planner.plan.AggregationNode; @@ -30,14 +27,42 @@ import java.util.Optional; import static com.facebook.presto.sql.planner.optimizations.PlanNodeSearcher.searchFrom; -import static com.facebook.presto.sql.planner.optimizations.Predicates.isInstanceOfAny; -import static com.facebook.presto.sql.planner.optimizations.ScalarQueryUtil.isScalar; +import static com.facebook.presto.sql.planner.optimizations.QueryCardinalityUtil.isScalar; +import static com.facebook.presto.util.MorePredicates.isInstanceOfAny; import static java.util.Objects.requireNonNull; +/** + * Scalar aggregation is aggregation with GROUP BY 'a constant' (or empty GROUP BY). + * It always returns single row. + *

+ * This optimizer rewrites correlated scalar aggregation subquery to left outer join in a way described here: + * https://github.com/prestodb/presto/wiki/Correlated-subqueries + *

+ * From: + *

+ * - LateralJoin (with correlation list: [C])
+ *   - (input) plan which produces symbols: [A, B, C]
+ *   - (subquery) Aggregation(GROUP BY (); functions: [sum(F), count(), ...]
+ *     - Filter(D = C AND E > 5)
+ *       - plan which produces symbols: [D, E, F]
+ * 
+ * to: + *
+ * - Aggregation(GROUP BY A, B, C, U; functions: [sum(F), count(non_null), ...]
+ *   - Join(LEFT_OUTER, D = C)
+ *     - AssignUniqueId(adds symbol U)
+ *       - (input) plan which produces symbols: [A, B, C]
+ *     - Filter(E > 5)
+ *       - projection which adds non null symbol used for count() function
+ *         - plan which produces symbols: [D, E, F]
+ * 
+ *

+ * Note that only conjunction predicates in FilterNode are supported + */ public class TransformCorrelatedScalarAggregationToJoin implements Rule { - private static final Pattern PATTERN = Pattern.node(LateralJoinNode.class); + private static final Pattern PATTERN = Pattern.typeOf(LateralJoinNode.class); @Override public Pattern getPattern() @@ -53,25 +78,25 @@ public TransformCorrelatedScalarAggregationToJoin(FunctionRegistry functionRegis } @Override - public Optional apply(PlanNode node, Lookup lookup, PlanNodeIdAllocator idAllocator, SymbolAllocator symbolAllocator, Session session) + public Optional apply(PlanNode node, Context context) { if (!(node instanceof LateralJoinNode)) { return Optional.empty(); } LateralJoinNode lateralJoinNode = (LateralJoinNode) node; - PlanNode subquery = lookup.resolve(lateralJoinNode.getSubquery()); + PlanNode subquery = context.getLookup().resolve(lateralJoinNode.getSubquery()); - if (lateralJoinNode.getCorrelation().isEmpty() || !(isScalar(subquery, lookup))) { + if (lateralJoinNode.getCorrelation().isEmpty() || !(isScalar(subquery, context.getLookup()))) { return Optional.empty(); } - Optional aggregation = findAggregation(subquery, lookup); + Optional aggregation = findAggregation(subquery, context.getLookup()); if (!(aggregation.isPresent() && aggregation.get().getGroupingKeys().isEmpty())) { return Optional.empty(); } - ScalarAggregationToJoinRewriter rewriter = new ScalarAggregationToJoinRewriter(functionRegistry, symbolAllocator, idAllocator, lookup); + ScalarAggregationToJoinRewriter rewriter = new ScalarAggregationToJoinRewriter(functionRegistry, context.getSymbolAllocator(), context.getIdAllocator(), context.getLookup()); PlanNode rewrittenNode = rewriter.rewriteScalarAggregation(lateralJoinNode, aggregation.get()); @@ -86,7 +111,7 @@ private static Optional findAggregation(PlanNode rootNode, Look { return searchFrom(rootNode, lookup) .where(AggregationNode.class::isInstance) - .skipOnlyWhen(isInstanceOfAny(ProjectNode.class, EnforceSingleRowNode.class)) + .recurseOnlyWhen(isInstanceOfAny(ProjectNode.class, EnforceSingleRowNode.class)) .findFirst(); } } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/TransformExistsApplyToLateralNode.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/TransformExistsApplyToLateralNode.java index 3b0e82e8a423..4e783ea2fab3 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/TransformExistsApplyToLateralNode.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/TransformExistsApplyToLateralNode.java @@ -13,14 +13,10 @@ */ package com.facebook.presto.sql.planner.iterative.rule; -import com.facebook.presto.Session; +import com.facebook.presto.matching.Pattern; import com.facebook.presto.metadata.FunctionRegistry; import com.facebook.presto.metadata.Signature; -import com.facebook.presto.sql.planner.PlanNodeIdAllocator; import com.facebook.presto.sql.planner.Symbol; -import com.facebook.presto.sql.planner.SymbolAllocator; -import com.facebook.presto.sql.planner.iterative.Lookup; -import com.facebook.presto.sql.planner.iterative.Pattern; import com.facebook.presto.sql.planner.iterative.Rule; import com.facebook.presto.sql.planner.plan.AggregationNode; import com.facebook.presto.sql.planner.plan.AggregationNode.Aggregation; @@ -59,7 +55,7 @@ public class TransformExistsApplyToLateralNode implements Rule { - private static final Pattern PATTERN = Pattern.node(ApplyNode.class); + private static final Pattern PATTERN = Pattern.typeOf(ApplyNode.class); private static final QualifiedName COUNT = QualifiedName.of("count"); private static final FunctionCall COUNT_CALL = new FunctionCall(COUNT, ImmutableList.of()); private final Signature countSignature; @@ -77,7 +73,7 @@ public Pattern getPattern() } @Override - public Optional apply(PlanNode node, Lookup lookup, PlanNodeIdAllocator idAllocator, SymbolAllocator symbolAllocator, Session session) + public Optional apply(PlanNode node, Context context) { ApplyNode parent = (ApplyNode) node; @@ -90,7 +86,7 @@ public Optional apply(PlanNode node, Lookup lookup, PlanNodeIdAllocato return Optional.empty(); } - Symbol count = symbolAllocator.newSymbol(COUNT.toString(), BIGINT); + Symbol count = context.getSymbolAllocator().newSymbol(COUNT.toString(), BIGINT); Symbol exists = getOnlyElement(parent.getSubqueryAssignments().getSymbols()); return Optional.of( @@ -98,9 +94,9 @@ public Optional apply(PlanNode node, Lookup lookup, PlanNodeIdAllocato node.getId(), parent.getInput(), new ProjectNode( - idAllocator.getNextId(), + context.getIdAllocator().getNextId(), new AggregationNode( - idAllocator.getNextId(), + context.getIdAllocator().getNextId(), parent.getSubquery(), ImmutableMap.of(count, new Aggregation(COUNT_CALL, countSignature, Optional.empty())), ImmutableList.of(ImmutableList.of()), diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/TransformUncorrelatedInPredicateSubqueryToSemiJoin.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/TransformUncorrelatedInPredicateSubqueryToSemiJoin.java new file mode 100644 index 000000000000..aa0bbbd3afbb --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/TransformUncorrelatedInPredicateSubqueryToSemiJoin.java @@ -0,0 +1,95 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.iterative.rule; + +import com.facebook.presto.matching.Pattern; +import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.sql.planner.iterative.Rule; +import com.facebook.presto.sql.planner.plan.ApplyNode; +import com.facebook.presto.sql.planner.plan.PlanNode; +import com.facebook.presto.sql.planner.plan.SemiJoinNode; +import com.facebook.presto.sql.tree.Expression; +import com.facebook.presto.sql.tree.InPredicate; + +import java.util.Optional; + +import static com.google.common.collect.Iterables.getOnlyElement; + +/** + * This optimizers looks for InPredicate expressions in ApplyNodes and replaces the nodes with SemiJoin nodes. + *

+ * Plan before optimizer: + *

+ * Filter(a IN b):
+ *   Apply
+ *     - correlation: []  // empty
+ *     - input: some plan A producing symbol a
+ *     - subquery: some plan B producing symbol b
+ * 
+ *

+ * Plan after optimizer: + *

+ * Filter(semijoinresult):
+ *   SemiJoin
+ *     - source: plan A
+ *     - filteringSource: symbol a
+ *     - sourceJoinSymbol: plan B
+ *     - filteringSourceJoinSymbol: symbol b
+ *     - semiJoinOutput: semijoinresult
+ * 
+ */ +public class TransformUncorrelatedInPredicateSubqueryToSemiJoin + implements Rule +{ + @Override + public Pattern getPattern() + { + return Pattern.typeOf(ApplyNode.class); + } + + @Override + public Optional apply(PlanNode node, Context context) + { + ApplyNode applyNode = (ApplyNode) node; + + if (!applyNode.getCorrelation().isEmpty()) { + return Optional.empty(); + } + + if (applyNode.getSubqueryAssignments().size() != 1) { + return Optional.empty(); + } + + Expression expression = getOnlyElement(applyNode.getSubqueryAssignments().getExpressions()); + if (!(expression instanceof InPredicate)) { + return Optional.empty(); + } + + InPredicate inPredicate = (InPredicate) expression; + Symbol semiJoinSymbol = getOnlyElement(applyNode.getSubqueryAssignments().getSymbols()); + + SemiJoinNode replacement = new SemiJoinNode(context.getIdAllocator().getNextId(), + applyNode.getInput(), + applyNode.getSubquery(), + Symbol.from(inPredicate.getValue()), + Symbol.from(inPredicate.getValueList()), + semiJoinSymbol, + Optional.empty(), + Optional.empty(), + Optional.empty() + ); + + return Optional.of(replacement); + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/TransformUncorrelatedLateralToJoin.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/TransformUncorrelatedLateralToJoin.java new file mode 100644 index 000000000000..c5b5cf29c38f --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/TransformUncorrelatedLateralToJoin.java @@ -0,0 +1,61 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.iterative.rule; + +import com.facebook.presto.matching.Pattern; +import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.sql.planner.iterative.Rule; +import com.facebook.presto.sql.planner.plan.JoinNode; +import com.facebook.presto.sql.planner.plan.LateralJoinNode; +import com.facebook.presto.sql.planner.plan.PlanNode; +import com.google.common.collect.ImmutableList; + +import java.util.Optional; + +public class TransformUncorrelatedLateralToJoin + implements Rule +{ + private static final Pattern PATTERN = Pattern.typeOf(LateralJoinNode.class); + + @Override + public Pattern getPattern() + { + return PATTERN; + } + + @Override + public Optional apply(PlanNode node, Context context) + { + LateralJoinNode lateralJoinNode = (LateralJoinNode) node; + + if (!lateralJoinNode.getCorrelation().isEmpty()) { + return Optional.empty(); + } + + return Optional.of(new JoinNode( + context.getIdAllocator().getNextId(), + JoinNode.Type.INNER, + lateralJoinNode.getInput(), + lateralJoinNode.getSubquery(), + ImmutableList.of(), + ImmutableList.builder() + .addAll(lateralJoinNode.getInput().getOutputSymbols()) + .addAll(lateralJoinNode.getSubquery().getOutputSymbols()) + .build(), + Optional.empty(), + Optional.empty(), + Optional.empty(), + Optional.empty())); + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/Util.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/Util.java index 5c61227ca478..f6480bd0cc82 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/Util.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/Util.java @@ -13,19 +13,25 @@ */ package com.facebook.presto.sql.planner.iterative.rule; -import com.facebook.presto.sql.planner.DependencyExtractor; +import com.facebook.presto.sql.planner.PlanNodeIdAllocator; import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.sql.planner.SymbolsExtractor; +import com.facebook.presto.sql.planner.plan.Assignments; import com.facebook.presto.sql.planner.plan.PlanNode; +import com.facebook.presto.sql.planner.plan.ProjectNode; import com.facebook.presto.sql.tree.Expression; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableSet; import com.google.common.collect.Sets; import java.util.Collection; -import java.util.HashSet; import java.util.List; import java.util.Optional; import java.util.Set; +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.collect.ImmutableList.toImmutableList; + class Util { private Util() @@ -33,24 +39,20 @@ private Util() } /** - * Prune the list of available inputs to those required by the given expressions. + * Prune the set of available inputs to those required by the given expressions. * * If all inputs are used, return Optional.empty() to indicate that no pruning is necessary. */ - public static Optional> pruneInputs(Collection availableInputs, Collection expressions) + public static Optional> pruneInputs(Collection availableInputs, Collection expressions) { - Set available = new HashSet<>(availableInputs); - Set required = DependencyExtractor.extractUnique(expressions); - - // we need to compute the intersection in case some dependencies are symbols from - // the outer scope (i.e., correlated queries) - Set used = Sets.intersection(required, available); - if (used.size() == available.size()) { - // no need to prune... every available input is being used + Set availableInputsSet = ImmutableSet.copyOf(availableInputs); + Set prunedInputs = Sets.filter(availableInputsSet, SymbolsExtractor.extractUnique(expressions)::contains); + + if (prunedInputs.size() == availableInputsSet.size()) { return Optional.empty(); } - return Optional.of(ImmutableList.copyOf(used)); + return Optional.of(prunedInputs); } /** @@ -62,4 +64,55 @@ public static PlanNode transpose(PlanNode parent, PlanNode child) parent.replaceChildren( child.getSources()))); } + + /** + * @return If the node has outputs not in permittedOutputs, returns an identity projection containing only those node outputs also in permittedOutputs. + */ + public static Optional restrictOutputs(PlanNodeIdAllocator idAllocator, PlanNode node, Set permittedOutputs) + { + List restrictedOutputs = node.getOutputSymbols().stream() + .filter(permittedOutputs::contains) + .collect(toImmutableList()); + + if (restrictedOutputs.size() == node.getOutputSymbols().size()) { + return Optional.empty(); + } + + return Optional.of( + new ProjectNode( + idAllocator.getNextId(), + node, + Assignments.identity(restrictedOutputs))); + } + + /** + * @return The original node, with identity projections possibly inserted between node and each child, limiting the columns to those permitted. + * Returns a present Optional iff at least one child was rewritten. + */ + @SafeVarargs + public static Optional restrictChildOutputs(PlanNodeIdAllocator idAllocator, PlanNode node, Set... permittedChildOutputsArgs) + { + List> permittedChildOutputs = ImmutableList.copyOf(permittedChildOutputsArgs); + + checkArgument( + (node.getSources().size() == permittedChildOutputs.size()), + "Mismatched child (%d) and permitted outputs (%d) sizes", + node.getSources().size(), + permittedChildOutputs.size()); + + ImmutableList.Builder newChildrenBuilder = ImmutableList.builder(); + boolean rewroteChildren = false; + + for (int i = 0; i < node.getSources().size(); ++i) { + PlanNode oldChild = node.getSources().get(i); + Optional newChild = restrictOutputs(idAllocator, oldChild, permittedChildOutputs.get(i)); + rewroteChildren |= newChild.isPresent(); + newChildrenBuilder.add(newChild.orElse(oldChild)); + } + + if (!rewroteChildren) { + return Optional.empty(); + } + return Optional.of(node.replaceChildren(newChildrenBuilder.build())); + } } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/AddExchanges.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/AddExchanges.java index f97e3e8cdd46..1a80930db0d1 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/AddExchanges.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/AddExchanges.java @@ -26,7 +26,6 @@ import com.facebook.presto.spi.predicate.TupleDomain; import com.facebook.presto.spi.type.Type; import com.facebook.presto.sql.parser.SqlParser; -import com.facebook.presto.sql.planner.DependencyExtractor; import com.facebook.presto.sql.planner.DomainTranslator; import com.facebook.presto.sql.planner.ExpressionInterpreter; import com.facebook.presto.sql.planner.LookupSymbolResolver; @@ -35,6 +34,7 @@ import com.facebook.presto.sql.planner.PlanNodeIdAllocator; import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.SymbolAllocator; +import com.facebook.presto.sql.planner.SymbolsExtractor; import com.facebook.presto.sql.planner.plan.AggregationNode; import com.facebook.presto.sql.planner.plan.ApplyNode; import com.facebook.presto.sql.planner.plan.Assignments; @@ -674,7 +674,7 @@ private boolean shouldPrune(Expression predicate, Map assi // If any conjuncts evaluate to FALSE or null, then the whole predicate will never be true and so the partition should be pruned for (Expression expression : conjuncts) { - if (DependencyExtractor.extractUnique(expression).stream().anyMatch(correlations::contains)) { + if (SymbolsExtractor.extractUnique(expression).stream().anyMatch(correlations::contains)) { // expression contains correlated symbol with outer query continue; } @@ -1214,31 +1214,13 @@ private PlanWithProperties arbitraryDistributeUnion( @Override public PlanWithProperties visitApply(ApplyNode node, Context context) { - PlanWithProperties input = node.getInput().accept(this, context); - PlanWithProperties subquery = node.getSubquery().accept(this, context.withCorrelations(node.getCorrelation())); - - ApplyNode rewritten = new ApplyNode( - node.getId(), - input.getNode(), - subquery.getNode(), - node.getSubqueryAssignments(), - node.getCorrelation()); - return new PlanWithProperties(rewritten, deriveProperties(rewritten, ImmutableList.of(input.getProperties(), subquery.getProperties()))); + throw new IllegalStateException("Unexpected node: " + node.getClass().getName()); } @Override public PlanWithProperties visitLateralJoin(LateralJoinNode node, Context context) { - PlanWithProperties input = node.getInput().accept(this, context); - PlanWithProperties subquery = node.getSubquery().accept(this, context.withCorrelations(node.getCorrelation())); - - LateralJoinNode rewritten = new LateralJoinNode( - node.getId(), - input.getNode(), - subquery.getNode(), - node.getCorrelation(), - node.getType()); - return new PlanWithProperties(rewritten, deriveProperties(rewritten, ImmutableList.of(input.getProperties(), subquery.getProperties()))); + throw new IllegalStateException("Unexpected node: " + node.getClass().getName()); } private PlanWithProperties planChild(PlanNode node, Context context) diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/AddLocalExchanges.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/AddLocalExchanges.java index 7dc8e3ec0631..237944a7c8f6 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/AddLocalExchanges.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/AddLocalExchanges.java @@ -27,12 +27,14 @@ import com.facebook.presto.sql.planner.SymbolAllocator; import com.facebook.presto.sql.planner.optimizations.StreamPropertyDerivations.StreamProperties; import com.facebook.presto.sql.planner.plan.AggregationNode; +import com.facebook.presto.sql.planner.plan.ApplyNode; import com.facebook.presto.sql.planner.plan.DistinctLimitNode; import com.facebook.presto.sql.planner.plan.EnforceSingleRowNode; import com.facebook.presto.sql.planner.plan.ExchangeNode; import com.facebook.presto.sql.planner.plan.ExplainAnalyzeNode; import com.facebook.presto.sql.planner.plan.IndexJoinNode; import com.facebook.presto.sql.planner.plan.JoinNode; +import com.facebook.presto.sql.planner.plan.LateralJoinNode; import com.facebook.presto.sql.planner.plan.LimitNode; import com.facebook.presto.sql.planner.plan.MarkDistinctNode; import com.facebook.presto.sql.planner.plan.OutputNode; @@ -126,6 +128,18 @@ protected PlanWithProperties visitPlan(PlanNode node, StreamPreferredProperties parentPreferences.withDefaultParallelism(session)); } + @Override + public PlanWithProperties visitApply(ApplyNode node, StreamPreferredProperties parentPreferences) + { + throw new IllegalStateException("Unexpected node: " + node.getClass().getName()); + } + + @Override + public PlanWithProperties visitLateralJoin(LateralJoinNode node, StreamPreferredProperties parentPreferences) + { + throw new IllegalStateException("Unexpected node: " + node.getClass().getName()); + } + @Override public PlanWithProperties visitOutput(OutputNode node, StreamPreferredProperties parentPreferences) { diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/BeginTableWrite.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/BeginTableWrite.java index fd36318669fb..d3fc4a23150e 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/BeginTableWrite.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/BeginTableWrite.java @@ -44,7 +44,7 @@ import java.util.Optional; import java.util.Set; -import static com.facebook.presto.sql.planner.optimizations.ScalarQueryUtil.isScalar; +import static com.facebook.presto.sql.planner.optimizations.QueryCardinalityUtil.isAtMostScalar; import static com.facebook.presto.sql.planner.plan.ChildReplacer.replaceChildren; import static com.google.common.base.Preconditions.checkState; import static com.google.common.base.Verify.verify; @@ -197,9 +197,12 @@ private PlanNode rewriteDeleteTableScan(PlanNode node, TableHandle handle) PlanNode source = rewriteDeleteTableScan(((SemiJoinNode) node).getSource(), handle); return replaceChildren(node, ImmutableList.of(source, ((SemiJoinNode) node).getFilteringSource())); } - if (node instanceof JoinNode && (((JoinNode) node).getType() == JoinNode.Type.INNER) && isScalar(((JoinNode) node).getRight())) { - PlanNode source = rewriteDeleteTableScan(((JoinNode) node).getLeft(), handle); - return replaceChildren(node, ImmutableList.of(source, ((JoinNode) node).getRight())); + if (node instanceof JoinNode) { + JoinNode joinNode = (JoinNode) node; + if (joinNode.getType() == JoinNode.Type.INNER && isAtMostScalar(joinNode.getRight())) { + PlanNode source = rewriteDeleteTableScan(joinNode.getLeft(), handle); + return replaceChildren(node, ImmutableList.of(source, joinNode.getRight())); + } } throw new IllegalArgumentException("Invalid descendant for DeleteNode: " + node.getClass().getName()); } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/CanonicalizeExpressions.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/CanonicalizeExpressions.java index 6d3f0cb6fa35..2e69c1867cc9 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/CanonicalizeExpressions.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/CanonicalizeExpressions.java @@ -46,6 +46,7 @@ import static java.util.Objects.requireNonNull; +@Deprecated public class CanonicalizeExpressions implements PlanOptimizer { diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/DetermineJoinDistributionType.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/DetermineJoinDistributionType.java index cb912fbbb092..85d2b1f9a358 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/DetermineJoinDistributionType.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/DetermineJoinDistributionType.java @@ -28,7 +28,7 @@ import java.util.Optional; import static com.facebook.presto.SystemSessionProperties.isDistributedJoinEnabled; -import static com.facebook.presto.sql.planner.optimizations.ScalarQueryUtil.isScalar; +import static com.facebook.presto.sql.planner.optimizations.QueryCardinalityUtil.isAtMostScalar; import static com.facebook.presto.sql.planner.plan.JoinNode.Type.FULL; import static com.facebook.presto.sql.planner.plan.JoinNode.Type.INNER; import static com.facebook.presto.sql.planner.plan.JoinNode.Type.RIGHT; @@ -122,7 +122,7 @@ private JoinNode.DistributionType getTargetJoinDistributionType(JoinNode node) private static boolean mustBroadcastJoin(JoinNode node) { - return isScalar(node.getRight()) || isCrossJoin(node); + return isAtMostScalar(node.getRight()) || isCrossJoin(node); } private static boolean isCrossJoin(JoinNode node) diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/PartialAggregationPushDown.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/PartialAggregationPushDown.java index 4667d3752b8a..06460719c061 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/PartialAggregationPushDown.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/PartialAggregationPushDown.java @@ -18,12 +18,12 @@ import com.facebook.presto.metadata.Signature; import com.facebook.presto.operator.aggregation.InternalAggregationFunction; import com.facebook.presto.spi.type.Type; -import com.facebook.presto.sql.planner.DependencyExtractor; import com.facebook.presto.sql.planner.Partitioning; import com.facebook.presto.sql.planner.PartitioningScheme; import com.facebook.presto.sql.planner.PlanNodeIdAllocator; import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.SymbolAllocator; +import com.facebook.presto.sql.planner.SymbolsExtractor; import com.facebook.presto.sql.planner.plan.AggregationNode; import com.facebook.presto.sql.planner.plan.AggregationNode.Aggregation; import com.facebook.presto.sql.planner.plan.Assignments; @@ -261,7 +261,7 @@ private AggregationNode replaceAggregationSource( private boolean allAggregationsOn(Map aggregations, List outputSymbols) { - Set inputs = DependencyExtractor.extractUnique(aggregations.values().stream().map(Aggregation::getCall).collect(toImmutableList())); + Set inputs = SymbolsExtractor.extractUnique(aggregations.values().stream().map(Aggregation::getCall).collect(toImmutableList())); return outputSymbols.containsAll(inputs); } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/PlanNodeDecorrelator.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/PlanNodeDecorrelator.java new file mode 100644 index 000000000000..7c93f9ba1046 --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/PlanNodeDecorrelator.java @@ -0,0 +1,223 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.facebook.presto.sql.planner.optimizations; + +import com.facebook.presto.sql.ExpressionUtils; +import com.facebook.presto.sql.planner.PlanNodeIdAllocator; +import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.sql.planner.SymbolsExtractor; +import com.facebook.presto.sql.planner.iterative.Lookup; +import com.facebook.presto.sql.planner.plan.Assignments; +import com.facebook.presto.sql.planner.plan.FilterNode; +import com.facebook.presto.sql.planner.plan.LimitNode; +import com.facebook.presto.sql.planner.plan.PlanNode; +import com.facebook.presto.sql.planner.plan.ProjectNode; +import com.facebook.presto.sql.planner.plan.SimplePlanRewriter; +import com.facebook.presto.sql.tree.DefaultTraversalVisitor; +import com.facebook.presto.sql.tree.Expression; +import com.facebook.presto.sql.tree.LogicalBinaryExpression; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.Iterables; + +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.function.Predicate; +import java.util.stream.Collectors; + +import static com.facebook.presto.sql.planner.optimizations.PlanNodeSearcher.searchFrom; +import static com.facebook.presto.sql.planner.plan.SimplePlanRewriter.rewriteWith; +import static com.facebook.presto.util.MorePredicates.isInstanceOfAny; +import static com.google.common.collect.ImmutableList.toImmutableList; +import static java.util.Objects.requireNonNull; + +public class PlanNodeDecorrelator +{ + private final PlanNodeIdAllocator idAllocator; + private final Lookup lookup; + + public PlanNodeDecorrelator(PlanNodeIdAllocator idAllocator, Lookup lookup) + { + this.idAllocator = requireNonNull(idAllocator, "idAllocator is null"); + this.lookup = requireNonNull(lookup, "lookup is null"); + } + + public Optional decorrelateFilters(PlanNode node, List correlation) + { + PlanNodeSearcher filterNodeSearcher = searchFrom(node, lookup) + .where(FilterNode.class::isInstance) + .recurseOnlyWhen(isInstanceOfAny(ProjectNode.class, LimitNode.class)); + List filterNodes = filterNodeSearcher.findAll(); + + if (filterNodes.isEmpty()) { + return decorrelatedNode(ImmutableList.of(), node, correlation); + } + + if (filterNodes.size() > 1) { + return Optional.empty(); + } + + FilterNode filterNode = filterNodes.get(0); + Expression predicate = filterNode.getPredicate(); + + if (!isSupportedPredicate(predicate)) { + return Optional.empty(); + } + + if (!SymbolsExtractor.extractUnique(predicate).containsAll(correlation)) { + return Optional.empty(); + } + + Map> predicates = ExpressionUtils.extractConjuncts(predicate).stream() + .collect(Collectors.partitioningBy(isUsingPredicate(correlation))); + List correlatedPredicates = ImmutableList.copyOf(predicates.get(true)); + List uncorrelatedPredicates = ImmutableList.copyOf(predicates.get(false)); + + node = updateFilterNode(filterNodeSearcher, uncorrelatedPredicates); + + if (!correlatedPredicates.isEmpty()) { + // filterNodes condition has changed so Limit node no longer applies for EXISTS subquery + node = removeLimitNode(node); + } + + node = ensureJoinSymbolsAreReturned(node, correlatedPredicates); + + return decorrelatedNode(correlatedPredicates, node, correlation); + } + + private static boolean isSupportedPredicate(Expression predicate) + { + AtomicBoolean isSupported = new AtomicBoolean(true); + new DefaultTraversalVisitor() + { + @Override + protected Void visitLogicalBinaryExpression(LogicalBinaryExpression node, AtomicBoolean context) + { + if (node.getType() != LogicalBinaryExpression.Type.AND) { + context.set(false); + } + return null; + } + }.process(predicate, isSupported); + return isSupported.get(); + } + + private Predicate isUsingPredicate(List symbols) + { + return expression -> symbols.stream().anyMatch(SymbolsExtractor.extractUnique(expression)::contains); + } + + private PlanNode updateFilterNode(PlanNodeSearcher filterNodeSearcher, List newPredicates) + { + if (newPredicates.isEmpty()) { + return filterNodeSearcher.removeAll(); + } + FilterNode oldFilterNode = Iterables.getOnlyElement(filterNodeSearcher.findAll()); + FilterNode newFilterNode = new FilterNode( + idAllocator.getNextId(), + oldFilterNode.getSource(), + ExpressionUtils.combineConjuncts(newPredicates)); + return filterNodeSearcher.replaceAll(newFilterNode); + } + + private PlanNode removeLimitNode(PlanNode node) + { + node = searchFrom(node, lookup) + .where(LimitNode.class::isInstance) + .recurseOnlyWhen(ProjectNode.class::isInstance) + .removeFirst(); + return node; + } + + private PlanNode ensureJoinSymbolsAreReturned(PlanNode scalarAggregationSource, List joinPredicate) + { + Set joinExpressionSymbols = SymbolsExtractor.extractUnique(joinPredicate); + ExtendProjectionRewriter extendProjectionRewriter = new ExtendProjectionRewriter( + idAllocator, + joinExpressionSymbols); + return rewriteWith(extendProjectionRewriter, scalarAggregationSource); + } + + private Optional decorrelatedNode( + List correlatedPredicates, + PlanNode node, + List correlation) + { + if (SymbolsExtractor.extractUnique(node, lookup).stream().anyMatch(correlation::contains)) { + // node is still correlated ; / + return Optional.empty(); + } + return Optional.of(new DecorrelatedNode(correlatedPredicates, node)); + } + + public static class DecorrelatedNode + { + private final List correlatedPredicates; + private final PlanNode node; + + public DecorrelatedNode(List correlatedPredicates, PlanNode node) + { + requireNonNull(correlatedPredicates, "correlatedPredicates is null"); + this.correlatedPredicates = ImmutableList.copyOf(correlatedPredicates); + this.node = requireNonNull(node, "node is null"); + } + + Optional getCorrelatedPredicates() + { + if (correlatedPredicates.isEmpty()) { + return Optional.empty(); + } + return Optional.of(ExpressionUtils.and(correlatedPredicates)); + } + + public PlanNode getNode() + { + return node; + } + } + + private static class ExtendProjectionRewriter + extends SimplePlanRewriter + { + private final PlanNodeIdAllocator idAllocator; + private final Set symbols; + + ExtendProjectionRewriter(PlanNodeIdAllocator idAllocator, Set symbols) + { + this.idAllocator = requireNonNull(idAllocator, "idAllocator is null"); + this.symbols = requireNonNull(symbols, "symbols is null"); + } + + @Override + public PlanNode visitProject(ProjectNode node, RewriteContext context) + { + ProjectNode rewrittenNode = (ProjectNode) context.defaultRewrite(node, context.get()); + + List symbolsToAdd = symbols.stream() + .filter(rewrittenNode.getSource().getOutputSymbols()::contains) + .filter(symbol -> !rewrittenNode.getOutputSymbols().contains(symbol)) + .collect(toImmutableList()); + + Assignments assignments = Assignments.builder() + .putAll(rewrittenNode.getAssignments()) + .putIdentities(symbolsToAdd) + .build(); + + return new ProjectNode(idAllocator.getNextId(), rewrittenNode.getSource(), assignments); + } + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/PlanNodeSearcher.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/PlanNodeSearcher.java index 2eca4cd7d399..4fadc572b767 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/PlanNodeSearcher.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/PlanNodeSearcher.java @@ -22,21 +22,25 @@ import java.util.function.Predicate; import static com.facebook.presto.sql.planner.iterative.Lookup.noLookup; -import static com.facebook.presto.sql.planner.optimizations.Predicates.alwaysTrue; import static com.facebook.presto.sql.planner.plan.ChildReplacer.replaceChildren; import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Predicates.alwaysTrue; import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.collect.Iterables.getOnlyElement; import static java.util.Objects.requireNonNull; public class PlanNodeSearcher { - @Deprecated public static PlanNodeSearcher searchFrom(PlanNode node) { return searchFrom(node, noLookup()); } + /** + * Use it in optimizer {@link com.facebook.presto.sql.planner.iterative.Rule} only if you truly do not have a better option + * + * TODO: replace it with a support for plan (physical) properties in rules pattern matching + */ public static PlanNodeSearcher searchFrom(PlanNode node, Lookup lookup) { return new PlanNodeSearcher(node, lookup); @@ -45,9 +49,9 @@ public static PlanNodeSearcher searchFrom(PlanNode node, Lookup lookup) private final PlanNode node; private final Lookup lookup; private Predicate where = alwaysTrue(); - private Predicate skipOnly = alwaysTrue(); + private Predicate recurseOnlyWhen = alwaysTrue(); - public PlanNodeSearcher(PlanNode node, Lookup lookup) + private PlanNodeSearcher(PlanNode node, Lookup lookup) { this.node = requireNonNull(node, "node is null"); this.lookup = requireNonNull(lookup, "lookup is null"); @@ -59,9 +63,9 @@ public PlanNodeSearcher where(Predicate where) return this; } - public PlanNodeSearcher skipOnlyWhen(Predicate skipOnly) + public PlanNodeSearcher recurseOnlyWhen(Predicate skipOnly) { - this.skipOnly = requireNonNull(skipOnly, "skipOnly is null"); + this.recurseOnlyWhen = requireNonNull(skipOnly, "recurseOnlyWhen is null"); return this; } @@ -77,7 +81,7 @@ private Optional findFirstRecursive(PlanNode node) if (where.test(node)) { return Optional.of((T) node); } - if (skipOnly.test(node)) { + if (recurseOnlyWhen.test(node)) { for (PlanNode source : node.getSources()) { Optional found = findFirstRecursive(source); if (found.isPresent()) { @@ -88,6 +92,15 @@ private Optional findFirstRecursive(PlanNode node) return Optional.empty(); } + public Optional findSingle() + { + List all = findAll(); + if (all.size() == 1) { + return Optional.of(all.get(0)); + } + return Optional.empty(); + } + public List findAll() { ImmutableList.Builder nodes = ImmutableList.builder(); @@ -116,7 +129,7 @@ private void findAllRecursive(PlanNode node, ImmutableList. if (where.test(node)) { nodes.add((T) node); } - if (skipOnly.test(node)) { + if (recurseOnlyWhen.test(node)) { for (PlanNode source : node.getSources()) { findAllRecursive(source, nodes); } @@ -138,7 +151,7 @@ private PlanNode removeAllRecursive(PlanNode node) "Unable to remove plan node as it contains 0 or more than 1 children"); return node.getSources().get(0); } - if (skipOnly.test(node)) { + if (recurseOnlyWhen.test(node)) { List sources = node.getSources().stream() .map(source -> removeAllRecursive(source)) .collect(toImmutableList()); @@ -162,7 +175,7 @@ private PlanNode removeFirstRecursive(PlanNode node) "Unable to remove plan node as it contains 0 or more than 1 children"); return node.getSources().get(0); } - if (skipOnly.test(node)) { + if (recurseOnlyWhen.test(node)) { List sources = node.getSources(); if (sources.isEmpty()) { return node; @@ -189,7 +202,7 @@ private PlanNode replaceAllRecursive(PlanNode node, PlanNode nodeToReplace) if (where.test(node)) { return nodeToReplace; } - if (skipOnly.test(node)) { + if (recurseOnlyWhen.test(node)) { List sources = node.getSources().stream() .map(source -> replaceAllRecursive(source, nodeToReplace)) .collect(toImmutableList()); diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/PredicatePushDown.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/PredicatePushDown.java index 7ac41da508fd..e726a6aace18 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/PredicatePushDown.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/PredicatePushDown.java @@ -17,7 +17,6 @@ import com.facebook.presto.metadata.Metadata; import com.facebook.presto.spi.type.Type; import com.facebook.presto.sql.parser.SqlParser; -import com.facebook.presto.sql.planner.DependencyExtractor; import com.facebook.presto.sql.planner.DeterminismEvaluator; import com.facebook.presto.sql.planner.EffectivePredicateExtractor; import com.facebook.presto.sql.planner.EqualityInference; @@ -28,6 +27,7 @@ import com.facebook.presto.sql.planner.PlanNodeIdAllocator; import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.SymbolAllocator; +import com.facebook.presto.sql.planner.SymbolsExtractor; import com.facebook.presto.sql.planner.plan.AggregationNode; import com.facebook.presto.sql.planner.plan.AssignUniqueId; import com.facebook.presto.sql.planner.plan.Assignments; @@ -194,7 +194,7 @@ public PlanNode visitProject(ProjectNode node, RewriteContext contex .map(Map.Entry::getKey) .collect(Collectors.toSet()); - Predicate deterministic = conjunct -> DependencyExtractor.extractUnique(conjunct).stream() + Predicate deterministic = conjunct -> SymbolsExtractor.extractUnique(conjunct).stream() .allMatch(deterministicSymbols::contains); Map> conjuncts = extractConjuncts(context.get()).stream().collect(Collectors.partitioningBy(deterministic)); @@ -214,13 +214,13 @@ public PlanNode visitProject(ProjectNode node, RewriteContext contex @Override public PlanNode visitGroupId(GroupIdNode node, RewriteContext context) { - checkState(!DependencyExtractor.extractUnique(context.get()).contains(node.getGroupIdSymbol()), "groupId symbol cannot be referenced in predicate"); + checkState(!SymbolsExtractor.extractUnique(context.get()).contains(node.getGroupIdSymbol()), "groupId symbol cannot be referenced in predicate"); Map commonGroupingSymbolMapping = node.getGroupingSetMappings().entrySet().stream() .filter(entry -> node.getCommonGroupingColumns().contains(entry.getKey())) .collect(Collectors.toMap(Map.Entry::getKey, entry -> entry.getValue().toSymbolReference())); - Predicate pushdownEligiblePredicate = conjunct -> DependencyExtractor.extractUnique(conjunct).stream() + Predicate pushdownEligiblePredicate = conjunct -> SymbolsExtractor.extractUnique(conjunct).stream() .allMatch(commonGroupingSymbolMapping.keySet()::contains); Map> conjuncts = extractConjuncts(context.get()).stream().collect(Collectors.partitioningBy(pushdownEligiblePredicate)); @@ -240,7 +240,7 @@ public PlanNode visitGroupId(GroupIdNode node, RewriteContext contex @Override public PlanNode visitMarkDistinct(MarkDistinctNode node, RewriteContext context) { - checkState(!DependencyExtractor.extractUnique(context.get()).contains(node.getMarkerSymbol()), "predicate depends on marker symbol"); + checkState(!SymbolsExtractor.extractUnique(context.get()).contains(node.getMarkerSymbol()), "predicate depends on marker symbol"); return context.defaultRewrite(node, context.get()); } @@ -372,7 +372,7 @@ public PlanNode visitJoin(JoinNode node, RewriteContext context) if (joinEqualityExpression(node.getLeft().getOutputSymbols()).test(conjunct)) { ComparisonExpression equality = (ComparisonExpression) conjunct; - boolean alignedComparison = Iterables.all(DependencyExtractor.extractUnique(equality.getLeft()), in(node.getLeft().getOutputSymbols())); + boolean alignedComparison = Iterables.all(SymbolsExtractor.extractUnique(equality.getLeft()), in(node.getLeft().getOutputSymbols())); Expression leftExpression = (alignedComparison) ? equality.getLeft() : equality.getRight(); Expression rightExpression = (alignedComparison) ? equality.getRight() : equality.getLeft(); @@ -473,8 +473,8 @@ private static PlanNode createJoinNodeWithExpectedOutputs( private static OuterJoinPushDownResult processLimitedOuterJoin(Expression inheritedPredicate, Expression outerEffectivePredicate, Expression innerEffectivePredicate, Expression joinPredicate, Collection outerSymbols) { - checkArgument(Iterables.all(DependencyExtractor.extractUnique(outerEffectivePredicate), in(outerSymbols)), "outerEffectivePredicate must only contain symbols from outerSymbols"); - checkArgument(Iterables.all(DependencyExtractor.extractUnique(innerEffectivePredicate), not(in(outerSymbols))), "innerEffectivePredicate must not contain symbols from outerSymbols"); + checkArgument(Iterables.all(SymbolsExtractor.extractUnique(outerEffectivePredicate), in(outerSymbols)), "outerEffectivePredicate must only contain symbols from outerSymbols"); + checkArgument(Iterables.all(SymbolsExtractor.extractUnique(innerEffectivePredicate), not(in(outerSymbols))), "innerEffectivePredicate must not contain symbols from outerSymbols"); ImmutableList.Builder outerPushdownConjuncts = ImmutableList.builder(); ImmutableList.Builder innerPushdownConjuncts = ImmutableList.builder(); @@ -594,8 +594,8 @@ private Expression getPostJoinPredicate() private static InnerJoinPushDownResult processInnerJoin(Expression inheritedPredicate, Expression leftEffectivePredicate, Expression rightEffectivePredicate, Expression joinPredicate, Collection leftSymbols) { - checkArgument(Iterables.all(DependencyExtractor.extractUnique(leftEffectivePredicate), in(leftSymbols)), "leftEffectivePredicate must only contain symbols from leftSymbols"); - checkArgument(Iterables.all(DependencyExtractor.extractUnique(rightEffectivePredicate), not(in(leftSymbols))), "rightEffectivePredicate must not contain symbols from leftSymbols"); + checkArgument(Iterables.all(SymbolsExtractor.extractUnique(leftEffectivePredicate), in(leftSymbols)), "leftEffectivePredicate must only contain symbols from leftSymbols"); + checkArgument(Iterables.all(SymbolsExtractor.extractUnique(rightEffectivePredicate), not(in(leftSymbols))), "rightEffectivePredicate must not contain symbols from leftSymbols"); ImmutableList.Builder leftPushDownConjuncts = ImmutableList.builder(); ImmutableList.Builder rightPushDownConjuncts = ImmutableList.builder(); @@ -812,8 +812,8 @@ private static Predicate joinEqualityExpression(final Collection symbols1 = DependencyExtractor.extractUnique(comparison.getLeft()); - Set symbols2 = DependencyExtractor.extractUnique(comparison.getRight()); + Set symbols1 = SymbolsExtractor.extractUnique(comparison.getLeft()); + Set symbols2 = SymbolsExtractor.extractUnique(comparison.getRight()); if (symbols1.isEmpty() || symbols2.isEmpty()) { return false; } @@ -986,7 +986,7 @@ public PlanNode visitTableScan(TableScanNode node, RewriteContext co @Override public PlanNode visitAssignUniqueId(AssignUniqueId node, RewriteContext context) { - Set predicateSymbols = DependencyExtractor.extractUnique(context.get()); + Set predicateSymbols = SymbolsExtractor.extractUnique(context.get()); checkState(!predicateSymbols.contains(node.getIdColumn()), "UniqueId in predicate is not yet supported"); return context.defaultRewrite(node, context.get()); } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/PruneUnreferencedOutputs.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/PruneUnreferencedOutputs.java index f0418b07b8ee..db0c943df031 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/PruneUnreferencedOutputs.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/PruneUnreferencedOutputs.java @@ -16,11 +16,11 @@ import com.facebook.presto.Session; import com.facebook.presto.spi.ColumnHandle; import com.facebook.presto.spi.type.Type; -import com.facebook.presto.sql.planner.DependencyExtractor; import com.facebook.presto.sql.planner.PartitioningScheme; import com.facebook.presto.sql.planner.PlanNodeIdAllocator; import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.SymbolAllocator; +import com.facebook.presto.sql.planner.SymbolsExtractor; import com.facebook.presto.sql.planner.plan.AggregationNode; import com.facebook.presto.sql.planner.plan.AggregationNode.Aggregation; import com.facebook.presto.sql.planner.plan.ApplyNode; @@ -78,6 +78,7 @@ import java.util.Set; import java.util.stream.Collectors; +import static com.facebook.presto.sql.planner.optimizations.QueryCardinalityUtil.isScalar; import static com.google.common.base.Predicates.in; import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.collect.ImmutableSet.toImmutableSet; @@ -177,7 +178,7 @@ public PlanNode visitJoin(JoinNode node, RewriteContext> context) Set expectedFilterInputs = new HashSet<>(); if (node.getFilter().isPresent()) { expectedFilterInputs = ImmutableSet.builder() - .addAll(DependencyExtractor.extractUnique(node.getFilter().get())) + .addAll(SymbolsExtractor.extractUnique(node.getFilter().get())) .addAll(context.get()) .build(); } @@ -314,7 +315,7 @@ public PlanNode visitAggregation(AggregationNode node, RewriteContext> context if (context.get().contains(symbol)) { FunctionCall call = function.getFunctionCall(); - expectedInputs.addAll(DependencyExtractor.extractUnique(call)); + expectedInputs.addAll(SymbolsExtractor.extractUnique(call)); functionsBuilder.put(symbol, entry.getValue()); } @@ -412,7 +413,7 @@ public PlanNode visitTableScan(TableScanNode node, RewriteContext> c public PlanNode visitFilter(FilterNode node, RewriteContext> context) { Set expectedInputs = ImmutableSet.builder() - .addAll(DependencyExtractor.extractUnique(node.getPredicate())) + .addAll(SymbolsExtractor.extractUnique(node.getPredicate())) .addAll(context.get()) .build(); @@ -503,7 +504,7 @@ public PlanNode visitProject(ProjectNode node, RewriteContext> conte Expression expression = node.getAssignments().get(output); if (context.get().contains(output)) { - expectedInputs.addAll(DependencyExtractor.extractUnique(expression)); + expectedInputs.addAll(SymbolsExtractor.extractUnique(expression)); builder.put(output, expression); } } @@ -734,7 +735,7 @@ public PlanNode visitApply(ApplyNode node, RewriteContext> context) Symbol output = entry.getKey(); Expression expression = entry.getValue(); if (context.get().contains(output)) { - subqueryAssignmentsSymbolsBuilder.addAll(DependencyExtractor.extractUnique(expression)); + subqueryAssignmentsSymbolsBuilder.addAll(SymbolsExtractor.extractUnique(expression)); subqueryAssignments.put(output, expression); } } @@ -743,7 +744,7 @@ public PlanNode visitApply(ApplyNode node, RewriteContext> context) PlanNode subquery = context.rewrite(node.getSubquery(), subqueryAssignmentsSymbols); // prune not used correlation symbols - Set subquerySymbols = DependencyExtractor.extractUnique(subquery); + Set subquerySymbols = SymbolsExtractor.extractUnique(subquery); List newCorrelation = node.getCorrelation().stream() .filter(subquerySymbols::contains) .collect(toImmutableList()); @@ -772,12 +773,12 @@ public PlanNode visitLateralJoin(LateralJoinNode node, RewriteContext subquerySymbols = DependencyExtractor.extractUnique(subquery); + Set subquerySymbols = SymbolsExtractor.extractUnique(subquery); List newCorrelation = node.getCorrelation().stream() .filter(subquerySymbols::contains) .collect(toImmutableList()); @@ -787,6 +788,12 @@ public PlanNode visitLateralJoin(LateralJoinNode node, RewriteContext extractCardinality(PlanNode node) + { + return extractCardinality(node, noLookup()); + } + + private static Range extractCardinality(PlanNode node, Lookup lookup) + { + return node.accept(new CardinalityExtractorPlanVisitor(lookup), null); + } + + private static final class CardinalityExtractorPlanVisitor + extends PlanVisitor, Void> + { + private final Lookup lookup; + + public CardinalityExtractorPlanVisitor(Lookup lookup) + { + this.lookup = requireNonNull(lookup, "lookup is null"); + } + + @Override + protected Range visitPlan(PlanNode node, Void context) + { + return Range.atLeast(0L); + } + + @Override + public Range visitGroupReference(GroupReference node, Void context) + { + return lookup.resolve(node).accept(this, context); + } + + @Override + public Range visitEnforceSingleRow(EnforceSingleRowNode node, Void context) + { + return Range.singleton(1L); + } + + @Override + public Range visitAggregation(AggregationNode node, Void context) + { + if (node.hasEmptyGroupingSet()) { + return Range.singleton(1L); + } + return Range.atLeast(0L); + } + + @Override + public Range visitExchange(ExchangeNode node, Void context) + { + if (node.getSources().size() == 1) { + return getOnlyElement(node.getSources()).accept(this, null); + } + return Range.atLeast(0L); + } + + @Override + public Range visitProject(ProjectNode node, Void context) + { + return node.getSource().accept(this, null); + } + + @Override + public Range visitFilter(FilterNode node, Void context) + { + Range sourceCardinalityRange = node.getSource().accept(this, null); + if (sourceCardinalityRange.hasUpperBound()) { + return Range.closed(0L, sourceCardinalityRange.upperEndpoint()); + } + return Range.atLeast(0L); + } + + public Range visitValues(ValuesNode node, Void context) + { + return Range.singleton((long) node.getRows().size()); + } + + @Override + public Range visitLimit(LimitNode node, Void context) + { + Range sourceCardinalityRange = node.getSource().accept(this, null); + long upper = node.getCount(); + if (sourceCardinalityRange.hasUpperBound()) { + upper = min(sourceCardinalityRange.upperEndpoint(), node.getCount()); + } + long lower = min(upper, sourceCardinalityRange.lowerEndpoint()); + return Range.closed(lower, upper); + } + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/RemoveUnreferencedScalarLateralNodes.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/RemoveUnreferencedScalarLateralNodes.java index f354cc4fb44c..30c8de828f94 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/RemoveUnreferencedScalarLateralNodes.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/RemoveUnreferencedScalarLateralNodes.java @@ -25,12 +25,13 @@ import java.util.Map; -import static com.facebook.presto.sql.planner.optimizations.ScalarQueryUtil.isScalar; +import static com.facebook.presto.sql.planner.optimizations.QueryCardinalityUtil.isScalar; import static com.facebook.presto.sql.planner.plan.SimplePlanRewriter.rewriteWith; /** * Remove LateralJoinNodes with unreferenced scalar input, e.g: "SELECT (SELECT 1)". */ +@Deprecated public class RemoveUnreferencedScalarLateralNodes implements PlanOptimizer { @@ -46,15 +47,23 @@ private static class Rewriter @Override public PlanNode visitLateralJoin(LateralJoinNode node, RewriteContext context) { - if (node.getInput().getOutputSymbols().isEmpty() && isScalar(node.getInput())) { - return context.rewrite(node.getSubquery()); + PlanNode input = node.getInput(); + PlanNode subquery = node.getSubquery(); + + if (isUnreferencedScalar(input)) { + return context.rewrite(subquery); } - if (node.getSubquery().getOutputSymbols().isEmpty() && isScalar(node.getSubquery())) { - return context.rewrite(node.getInput()); + if (isUnreferencedScalar(subquery)) { + return context.rewrite(input); } return context.defaultRewrite(node); } + + private boolean isUnreferencedScalar(PlanNode input) + { + return input.getOutputSymbols().isEmpty() && isScalar(input); + } } } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/ScalarAggregationToJoinRewriter.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/ScalarAggregationToJoinRewriter.java index ac8f45293ef3..09daf7b3dade 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/ScalarAggregationToJoinRewriter.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/ScalarAggregationToJoinRewriter.java @@ -14,49 +14,37 @@ package com.facebook.presto.sql.planner.optimizations; import com.facebook.presto.metadata.FunctionRegistry; -import com.facebook.presto.metadata.Signature; import com.facebook.presto.spi.type.BigintType; import com.facebook.presto.spi.type.BooleanType; import com.facebook.presto.spi.type.TypeSignature; -import com.facebook.presto.sql.ExpressionUtils; -import com.facebook.presto.sql.planner.DependencyExtractor; import com.facebook.presto.sql.planner.PlanNodeIdAllocator; import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.SymbolAllocator; import com.facebook.presto.sql.planner.iterative.Lookup; +import com.facebook.presto.sql.planner.optimizations.PlanNodeDecorrelator.DecorrelatedNode; import com.facebook.presto.sql.planner.plan.AggregationNode; import com.facebook.presto.sql.planner.plan.AggregationNode.Aggregation; import com.facebook.presto.sql.planner.plan.AssignUniqueId; import com.facebook.presto.sql.planner.plan.Assignments; import com.facebook.presto.sql.planner.plan.EnforceSingleRowNode; -import com.facebook.presto.sql.planner.plan.FilterNode; import com.facebook.presto.sql.planner.plan.JoinNode; import com.facebook.presto.sql.planner.plan.LateralJoinNode; import com.facebook.presto.sql.planner.plan.PlanNode; import com.facebook.presto.sql.planner.plan.ProjectNode; -import com.facebook.presto.sql.planner.plan.SimplePlanRewriter; -import com.facebook.presto.sql.tree.DefaultTraversalVisitor; import com.facebook.presto.sql.tree.Expression; import com.facebook.presto.sql.tree.FunctionCall; -import com.facebook.presto.sql.tree.LogicalBinaryExpression; import com.facebook.presto.sql.tree.QualifiedName; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; -import com.google.common.collect.Iterables; import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Optional; import java.util.Set; -import java.util.concurrent.atomic.AtomicBoolean; -import java.util.function.Predicate; -import java.util.stream.Collectors; import static com.facebook.presto.sql.analyzer.TypeSignatureProvider.fromTypeSignatures; import static com.facebook.presto.sql.planner.optimizations.PlanNodeSearcher.searchFrom; -import static com.facebook.presto.sql.planner.optimizations.Predicates.isInstanceOfAny; -import static com.facebook.presto.sql.planner.plan.SimplePlanRewriter.rewriteWith; import static com.facebook.presto.sql.tree.BooleanLiteral.TRUE_LITERAL; import static com.google.common.collect.ImmutableList.toImmutableList; import static java.util.Objects.requireNonNull; @@ -70,6 +58,7 @@ public class ScalarAggregationToJoinRewriter private final SymbolAllocator symbolAllocator; private final PlanNodeIdAllocator idAllocator; private final Lookup lookup; + private final PlanNodeDecorrelator planNodeDecorrelator; public ScalarAggregationToJoinRewriter(FunctionRegistry functionRegistry, SymbolAllocator symbolAllocator, PlanNodeIdAllocator idAllocator, Lookup lookup) { @@ -77,19 +66,20 @@ public ScalarAggregationToJoinRewriter(FunctionRegistry functionRegistry, Symbol this.symbolAllocator = requireNonNull(symbolAllocator, "symbolAllocator is null"); this.idAllocator = requireNonNull(idAllocator, "idAllocator is null"); this.lookup = requireNonNull(lookup, "lookup is null"); + this.planNodeDecorrelator = new PlanNodeDecorrelator(idAllocator, lookup); } public PlanNode rewriteScalarAggregation(LateralJoinNode lateralJoinNode, AggregationNode aggregation) { List correlation = lateralJoinNode.getCorrelation(); - Optional source = decorrelateFilters(lookup.resolve(aggregation.getSource()), correlation); + Optional source = planNodeDecorrelator.decorrelateFilters(lookup.resolve(aggregation.getSource()), correlation); if (!source.isPresent()) { return lateralJoinNode; } Symbol nonNull = symbolAllocator.newSymbol("non_null", BooleanType.BOOLEAN); Assignments scalarAggregationSourceAssignments = Assignments.builder() - .putAll(Assignments.identity(source.get().getNode().getOutputSymbols())) + .putIdentities(source.get().getNode().getOutputSymbols()) .put(nonNull, TRUE_LITERAL) .build(); ProjectNode scalarAggregationSourceWithNonNullableSymbol = new ProjectNode( @@ -143,14 +133,14 @@ private PlanNode rewriteScalarAggregation( Optional subqueryProjection = searchFrom(lateralJoinNode.getSubquery(), lookup) .where(ProjectNode.class::isInstance) - .skipOnlyWhen(EnforceSingleRowNode.class::isInstance) + .recurseOnlyWhen(EnforceSingleRowNode.class::isInstance) .findFirst(); List aggregationOutputSymbols = getTruncatedAggregationSymbols(lateralJoinNode, aggregationNode.get()); if (subqueryProjection.isPresent()) { Assignments assignments = Assignments.builder() - .putAll(Assignments.identity(aggregationOutputSymbols)) + .putIdentities(aggregationOutputSymbols) .putAll(subqueryProjection.get().getAssignments()) .build(); @@ -160,14 +150,10 @@ private PlanNode rewriteScalarAggregation( assignments); } else { - Assignments assignments = Assignments.builder() - .putAll(Assignments.identity(aggregationOutputSymbols)) - .build(); - return new ProjectNode( idAllocator.getNextId(), aggregationNode.get(), - assignments); + Assignments.identity(aggregationOutputSymbols)); } } @@ -185,7 +171,6 @@ private Optional createAggregationNode( Symbol nonNullableAggregationSourceSymbol) { ImmutableMap.Builder aggregations = ImmutableMap.builder(); - ImmutableMap.Builder functions = ImmutableMap.builder(); for (Map.Entry entry : scalarAggregation.getAggregations().entrySet()) { FunctionCall call = entry.getValue().getCall(); Symbol symbol = entry.getKey(); @@ -216,155 +201,4 @@ private Optional createAggregationNode( scalarAggregation.getHashSymbol(), Optional.empty())); } - - private Optional decorrelateFilters(PlanNode node, List correlation) - { - PlanNodeSearcher filterNodeSearcher = searchFrom(node, lookup) - .where(FilterNode.class::isInstance) - .skipOnlyWhen(isInstanceOfAny(ProjectNode.class)); - List filterNodes = filterNodeSearcher.findAll(); - - if (filterNodes.isEmpty()) { - return decorrelatedNode(ImmutableList.of(), node, correlation); - } - - if (filterNodes.size() > 1) { - return Optional.empty(); - } - - FilterNode filterNode = filterNodes.get(0); - Expression predicate = filterNode.getPredicate(); - - if (!isSupportedPredicate(predicate)) { - return Optional.empty(); - } - - if (!DependencyExtractor.extractUnique(predicate).containsAll(correlation)) { - return Optional.empty(); - } - - Map> predicates = ExpressionUtils.extractConjuncts(predicate).stream() - .collect(Collectors.partitioningBy(isUsingPredicate(correlation))); - List correlatedPredicates = ImmutableList.copyOf(predicates.get(true)); - List uncorrelatedPredicates = ImmutableList.copyOf(predicates.get(false)); - - node = updateFilterNode(filterNodeSearcher, uncorrelatedPredicates); - node = ensureJoinSymbolsAreReturned(node, correlatedPredicates); - - return decorrelatedNode(correlatedPredicates, node, correlation); - } - - private Optional decorrelatedNode( - List correlatedPredicates, - PlanNode node, - List correlation) - { - Set uniqueSymbols = DependencyExtractor.extractUnique(node, lookup); - if (uniqueSymbols.stream().anyMatch(correlation::contains)) { - // node is still correlated ; / - return Optional.empty(); - } - return Optional.of(new DecorrelatedNode(correlatedPredicates, node)); - } - - private static Predicate isUsingPredicate(List symbols) - { - return expression -> symbols.stream().anyMatch(DependencyExtractor.extractUnique(expression)::contains); - } - - private PlanNode updateFilterNode(PlanNodeSearcher filterNodeSearcher, List newPredicates) - { - if (newPredicates.isEmpty()) { - return filterNodeSearcher.removeAll(); - } - FilterNode oldFilterNode = Iterables.getOnlyElement(filterNodeSearcher.findAll()); - FilterNode newFilterNode = new FilterNode( - idAllocator.getNextId(), - oldFilterNode.getSource(), - ExpressionUtils.combineConjuncts(newPredicates)); - return filterNodeSearcher.replaceAll(newFilterNode); - } - - private PlanNode ensureJoinSymbolsAreReturned(PlanNode scalarAggregationSource, List joinPredicate) - { - Set joinExpressionSymbols = DependencyExtractor.extractUnique(joinPredicate); - ExtendProjectionRewriter extendProjectionRewriter = new ExtendProjectionRewriter( - idAllocator, - joinExpressionSymbols); - return rewriteWith(extendProjectionRewriter, scalarAggregationSource); - } - - private static boolean isSupportedPredicate(Expression predicate) - { - AtomicBoolean isSupported = new AtomicBoolean(true); - new DefaultTraversalVisitor() - { - @Override - protected Void visitLogicalBinaryExpression(LogicalBinaryExpression node, AtomicBoolean context) - { - if (node.getType() != LogicalBinaryExpression.Type.AND) { - context.set(false); - } - return null; - } - }.process(predicate, isSupported); - return isSupported.get(); - } - - private static class DecorrelatedNode - { - private final List correlatedPredicates; - private final PlanNode node; - - public DecorrelatedNode(List correlatedPredicates, PlanNode node) - { - requireNonNull(correlatedPredicates, "correlatedPredicates is null"); - this.correlatedPredicates = ImmutableList.copyOf(correlatedPredicates); - this.node = requireNonNull(node, "node is null"); - } - - public Optional getCorrelatedPredicates() - { - if (correlatedPredicates.isEmpty()) { - return Optional.empty(); - } - return Optional.of(ExpressionUtils.and(correlatedPredicates)); - } - - public PlanNode getNode() - { - return node; - } - } - - private static class ExtendProjectionRewriter - extends SimplePlanRewriter - { - private final PlanNodeIdAllocator idAllocator; - private final Set symbols; - - ExtendProjectionRewriter(PlanNodeIdAllocator idAllocator, Set symbols) - { - this.idAllocator = requireNonNull(idAllocator, "idAllocator is null"); - this.symbols = requireNonNull(symbols, "symbols is null"); - } - - @Override - public PlanNode visitProject(ProjectNode node, RewriteContext context) - { - ProjectNode rewrittenNode = (ProjectNode) context.defaultRewrite(node, context.get()); - - List symbolsToAdd = symbols.stream() - .filter(rewrittenNode.getSource().getOutputSymbols()::contains) - .filter(symbol -> !rewrittenNode.getOutputSymbols().contains(symbol)) - .collect(toImmutableList()); - - Assignments assignments = Assignments.builder() - .putAll(rewrittenNode.getAssignments()) - .putAll(Assignments.identity(symbolsToAdd)) - .build(); - - return new ProjectNode(idAllocator.getNextId(), rewrittenNode.getSource(), assignments); - } - } } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/ScalarQueryUtil.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/ScalarQueryUtil.java deleted file mode 100644 index 5988b8a2b131..000000000000 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/ScalarQueryUtil.java +++ /dev/null @@ -1,104 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package com.facebook.presto.sql.planner.optimizations; - -import com.facebook.presto.sql.planner.iterative.GroupReference; -import com.facebook.presto.sql.planner.iterative.Lookup; -import com.facebook.presto.sql.planner.plan.AggregationNode; -import com.facebook.presto.sql.planner.plan.EnforceSingleRowNode; -import com.facebook.presto.sql.planner.plan.ExchangeNode; -import com.facebook.presto.sql.planner.plan.FilterNode; -import com.facebook.presto.sql.planner.plan.PlanNode; -import com.facebook.presto.sql.planner.plan.PlanVisitor; -import com.facebook.presto.sql.planner.plan.ProjectNode; -import com.facebook.presto.sql.planner.plan.ValuesNode; -import com.google.common.collect.ImmutableList; - -import static com.facebook.presto.sql.planner.iterative.Lookup.noLookup; -import static com.google.common.collect.Iterables.getOnlyElement; -import static java.util.Objects.requireNonNull; - -public final class ScalarQueryUtil -{ - private ScalarQueryUtil() {} - - public static boolean isScalar(PlanNode node, Lookup lookup) - { - return node.accept(new IsScalarPlanVisitor(lookup), null); - } - - public static boolean isScalar(PlanNode node) - { - return isScalar(node, noLookup()); - } - - private static final class IsScalarPlanVisitor - extends PlanVisitor - { - private final Lookup lookup; - - public IsScalarPlanVisitor(Lookup lookup) - { - this.lookup = requireNonNull(lookup, "lookup is null"); - } - - @Override - protected Boolean visitPlan(PlanNode node, Void context) - { - return false; - } - - @Override - public Boolean visitGroupReference(GroupReference node, Void context) - { - return lookup.resolve(node).accept(this, context); - } - - @Override - public Boolean visitEnforceSingleRow(EnforceSingleRowNode node, Void context) - { - return true; - } - - @Override - public Boolean visitAggregation(AggregationNode node, Void context) - { - return node.getGroupingSets().equals(ImmutableList.of(ImmutableList.of())); - } - - @Override - public Boolean visitExchange(ExchangeNode node, Void context) - { - return (node.getSources().size() == 1) && - getOnlyElement(node.getSources()).accept(this, null); - } - - @Override - public Boolean visitProject(ProjectNode node, Void context) - { - return node.getSource().accept(this, null); - } - - @Override - public Boolean visitFilter(FilterNode node, Void context) - { - return node.getSource().accept(this, null); - } - - public Boolean visitValues(ValuesNode node, Void context) - { - return node.getRows().size() == 1; - } - } -} diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/SimplifyExpressions.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/SimplifyExpressions.java index de5b5d8cd58b..b7eb14ac7be0 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/SimplifyExpressions.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/SimplifyExpressions.java @@ -115,6 +115,7 @@ public PlanNode visitFilter(FilterNode node, RewriteContext context) { PlanNode source = context.rewrite(node.getSource()); Expression simplified = simplifyExpression(node.getPredicate()); + //When porting this to Rule(s), keep in mind the following logic is already implemented in RemoveTrivialFilters rule if (simplified.equals(TRUE_LITERAL)) { return source; } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/TransformCorrelatedNoAggregationSubqueryToJoin.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/TransformCorrelatedNoAggregationSubqueryToJoin.java new file mode 100644 index 000000000000..154fc6ab51f9 --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/TransformCorrelatedNoAggregationSubqueryToJoin.java @@ -0,0 +1,113 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.facebook.presto.sql.planner.optimizations; + +import com.facebook.presto.Session; +import com.facebook.presto.spi.type.Type; +import com.facebook.presto.sql.planner.PlanNodeIdAllocator; +import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.sql.planner.SymbolAllocator; +import com.facebook.presto.sql.planner.optimizations.PlanNodeDecorrelator.DecorrelatedNode; +import com.facebook.presto.sql.planner.plan.JoinNode; +import com.facebook.presto.sql.planner.plan.LateralJoinNode; +import com.facebook.presto.sql.planner.plan.PlanNode; +import com.facebook.presto.sql.planner.plan.SimplePlanRewriter; +import com.google.common.collect.ImmutableList; + +import java.util.List; +import java.util.Map; +import java.util.Optional; + +import static com.facebook.presto.sql.planner.iterative.Lookup.noLookup; +import static com.facebook.presto.sql.planner.plan.SimplePlanRewriter.rewriteWith; +import static java.util.Objects.requireNonNull; + +/** + * This optimizer can rewrite correlated no aggregation subquery to inner join in a way described here: + * From: + *
+ * - Lateral (with correlation list: [B])
+ *   - (input) plan which produces symbols: [A, B]
+ *   - (subquery)
+ *     - Filter(B = C AND D < 5)
+ *       - plan which produces symbols: [C, D]
+ * 
+ * to: + *
+ *   - Join(INNER, B = C)
+ *       - (input) plan which produces symbols: [A, B]
+ *       - Filter(D < 5)
+ *          - plan which produces symbols: [C, D]
+ * 
+ *

+ * Note only conjunction predicates in FilterNode are supported + */ +public class TransformCorrelatedNoAggregationSubqueryToJoin + implements PlanOptimizer +{ + @Override + public PlanNode optimize( + PlanNode plan, + Session session, + Map types, + SymbolAllocator symbolAllocator, + PlanNodeIdAllocator idAllocator) + { + return rewriteWith(new Rewriter(idAllocator), plan, null); + } + + private static class Rewriter + extends SimplePlanRewriter + { + private final PlanNodeIdAllocator idAllocator; + + public Rewriter(PlanNodeIdAllocator idAllocator) + { + this.idAllocator = requireNonNull(idAllocator, "idAllocator is null"); + } + + @Override + public PlanNode visitLateralJoin(LateralJoinNode node, RewriteContext context) + { + LateralJoinNode rewrittenNode = (LateralJoinNode) context.defaultRewrite(node, context.get()); + if (!rewrittenNode.getCorrelation().isEmpty()) { + return rewriteNoAggregationSubquery(rewrittenNode); + } + return rewrittenNode; + } + + private PlanNode rewriteNoAggregationSubquery(LateralJoinNode lateral) + { + List correlation = lateral.getCorrelation(); + PlanNodeDecorrelator decorrelator = new PlanNodeDecorrelator(idAllocator, noLookup()); + Optional source = decorrelator.decorrelateFilters(lateral.getSubquery(), correlation); + if (!source.isPresent()) { + return lateral; + } + + return new JoinNode( + idAllocator.getNextId(), + JoinNode.Type.INNER, + lateral.getInput(), + source.get().getNode(), + ImmutableList.of(), + lateral.getOutputSymbols(), + source.get().getCorrelatedPredicates(), + Optional.empty(), + Optional.empty(), + Optional.empty()); + } + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/TransformCorrelatedScalarAggregationToJoin.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/TransformCorrelatedScalarAggregationToJoin.java index 5a5d6c2bce29..9790a137e5e4 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/TransformCorrelatedScalarAggregationToJoin.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/TransformCorrelatedScalarAggregationToJoin.java @@ -31,8 +31,8 @@ import static com.facebook.presto.sql.planner.iterative.Lookup.noLookup; import static com.facebook.presto.sql.planner.optimizations.PlanNodeSearcher.searchFrom; -import static com.facebook.presto.sql.planner.optimizations.Predicates.isInstanceOfAny; import static com.facebook.presto.sql.planner.plan.SimplePlanRewriter.rewriteWith; +import static com.facebook.presto.util.MorePredicates.isInstanceOfAny; import static java.util.Objects.requireNonNull; /** @@ -106,7 +106,7 @@ public PlanNode visitLateralJoin(LateralJoinNode node, RewriteContext if (!rewrittenNode.getCorrelation().isEmpty()) { Optional aggregation = searchFrom(rewrittenNode.getSubquery()) .where(AggregationNode.class::isInstance) - .skipOnlyWhen(isInstanceOfAny(ProjectNode.class, EnforceSingleRowNode.class)) + .recurseOnlyWhen(isInstanceOfAny(ProjectNode.class, EnforceSingleRowNode.class)) .findFirst(); if (aggregation.isPresent() && aggregation.get().getGroupingKeys().isEmpty()) { ScalarAggregationToJoinRewriter scalarAggregationToJoinRewriter = new ScalarAggregationToJoinRewriter(functionRegistry, symbolAllocator, idAllocator, noLookup()); diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/TransformCorrelatedSingleRowSubqueryToProject.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/TransformCorrelatedSingleRowSubqueryToProject.java new file mode 100644 index 000000000000..585def2f2768 --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/TransformCorrelatedSingleRowSubqueryToProject.java @@ -0,0 +1,121 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.facebook.presto.sql.planner.optimizations; + +import com.facebook.presto.Session; +import com.facebook.presto.spi.type.Type; +import com.facebook.presto.sql.planner.PlanNodeIdAllocator; +import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.sql.planner.SymbolAllocator; +import com.facebook.presto.sql.planner.plan.Assignments; +import com.facebook.presto.sql.planner.plan.LateralJoinNode; +import com.facebook.presto.sql.planner.plan.PlanNode; +import com.facebook.presto.sql.planner.plan.ProjectNode; +import com.facebook.presto.sql.planner.plan.SimplePlanRewriter; +import com.facebook.presto.sql.planner.plan.ValuesNode; + +import java.util.List; +import java.util.Map; +import java.util.Optional; + +import static com.facebook.presto.sql.planner.optimizations.PlanNodeSearcher.searchFrom; +import static com.facebook.presto.sql.planner.plan.SimplePlanRewriter.rewriteWith; +import static java.util.Objects.requireNonNull; + +/** + * This optimizer can rewrite correlated single row subquery to projection in a way described here: + * From: + *

+ * - Lateral(with correlation list: [A, C])
+ *   - (input) plan which produces symbols: [A, B, C]
+ *   - (subquery)
+ *     - Project (A + C)
+ *       - single row VALUES()
+ * 
+ * to: + *
+ *   - Project(A, B, C, A + C)
+ *       - (input) plan which produces symbols: [A, B, C]
+ * 
+ */ +public class TransformCorrelatedSingleRowSubqueryToProject + implements PlanOptimizer +{ + @Override + public PlanNode optimize( + PlanNode plan, + Session session, + Map types, + SymbolAllocator symbolAllocator, + PlanNodeIdAllocator idAllocator) + { + return rewriteWith(new Rewriter(idAllocator), plan, null); + } + + private static class Rewriter + extends SimplePlanRewriter + { + private final PlanNodeIdAllocator idAllocator; + + public Rewriter(PlanNodeIdAllocator idAllocator) + { + this.idAllocator = requireNonNull(idAllocator, "idAllocator is null"); + } + + @Override + public PlanNode visitLateralJoin(LateralJoinNode lateral, RewriteContext context) + { + LateralJoinNode rewrittenLateral = (LateralJoinNode) context.defaultRewrite(lateral, context.get()); + if (rewrittenLateral.getCorrelation().isEmpty()) { + return rewrittenLateral; + } + + Optional values = searchFrom(lateral.getSubquery()) + .recurseOnlyWhen(ProjectNode.class::isInstance) + .where(ValuesNode.class::isInstance) + .findSingle(); + + if (!values.isPresent() || !isSingleRowValuesWithNoColumns(values.get())) { + return rewrittenLateral; + } + + List subqueryProjections = searchFrom(lateral.getSubquery()) + .where(ProjectNode.class::isInstance) + .findAll(); + + if (subqueryProjections.size() == 0) { + return rewrittenLateral.getInput(); + } + else if (subqueryProjections.size() == 1) { + Assignments assignments = Assignments.builder() + .putIdentities(rewrittenLateral.getInput().getOutputSymbols()) + .putAll(subqueryProjections.get(0).getAssignments()) + .build(); + return projectNode(rewrittenLateral.getInput(), assignments); + } + return rewrittenLateral; + } + + private ProjectNode projectNode(PlanNode source, Assignments assignments) + { + return new ProjectNode(idAllocator.getNextId(), source, assignments); + } + + private static boolean isSingleRowValuesWithNoColumns(ValuesNode values) + { + return values.getRows().size() == 1 && values.getRows().get(0).size() == 0; + } + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/TransformUncorrelatedInPredicateSubqueryToSemiJoin.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/TransformUncorrelatedInPredicateSubqueryToSemiJoin.java index 9db6ecddb875..f2efb629e2a1 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/TransformUncorrelatedInPredicateSubqueryToSemiJoin.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/TransformUncorrelatedInPredicateSubqueryToSemiJoin.java @@ -55,6 +55,7 @@ * - semiJoinOutput: semijoinresult * */ +@Deprecated public class TransformUncorrelatedInPredicateSubqueryToSemiJoin implements PlanOptimizer { diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/TransformUncorrelatedLateralToJoin.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/TransformUncorrelatedLateralToJoin.java index bf3ce72a1cec..f9fe5e1010cc 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/TransformUncorrelatedLateralToJoin.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/TransformUncorrelatedLateralToJoin.java @@ -29,6 +29,7 @@ import static java.util.Objects.requireNonNull; +@Deprecated public class TransformUncorrelatedLateralToJoin implements PlanOptimizer { diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/UnaliasSymbolReferences.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/UnaliasSymbolReferences.java index 78b6711c3292..9b36d768ba0d 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/UnaliasSymbolReferences.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/UnaliasSymbolReferences.java @@ -170,7 +170,7 @@ public PlanNode visitGroupId(GroupIdNode node, RewriteContext context) public PlanNode visitExplainAnalyze(ExplainAnalyzeNode node, RewriteContext context) { PlanNode source = context.rewrite(node.getSource()); - return new ExplainAnalyzeNode(node.getId(), source, canonicalize(node.getOutputSymbol())); + return new ExplainAnalyzeNode(node.getId(), source, canonicalize(node.getOutputSymbol()), node.isVerbose()); } @Override diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/WindowNodeUtil.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/WindowNodeUtil.java index 742a7501395c..e032b5f525a7 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/WindowNodeUtil.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/WindowNodeUtil.java @@ -13,7 +13,7 @@ */ package com.facebook.presto.sql.planner.optimizations; -import com.facebook.presto.sql.planner.DependencyExtractor; +import com.facebook.presto.sql.planner.SymbolsExtractor; import com.facebook.presto.sql.planner.plan.WindowNode; public final class WindowNodeUtil @@ -26,7 +26,7 @@ public static boolean dependsOn(WindowNode parent, WindowNode child) || parent.getOrderBy().stream().anyMatch(child.getCreatedSymbols()::contains) || parent.getWindowFunctions().values().stream() .map(WindowNode.Function::getFunctionCall) - .map(DependencyExtractor::extractUnique) + .map(SymbolsExtractor::extractUnique) .flatMap(symbols -> symbols.stream()) .anyMatch(child.getCreatedSymbols()::contains); } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/Assignments.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/Assignments.java index 45f28476bc69..8b6eeafa1812 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/Assignments.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/Assignments.java @@ -78,18 +78,16 @@ public static Assignments of(Symbol symbol1, Expression expression1, Symbol symb } private final Map assignments; - private final List outputs; @JsonCreator public Assignments(@JsonProperty("assignments") Map assignments) { this.assignments = ImmutableMap.copyOf(requireNonNull(assignments, "assignments is null")); - this.outputs = ImmutableList.copyOf(assignments.keySet()); } public List getOutputs() { - return outputs; + return ImmutableList.copyOf(assignments.keySet()); } @JsonProperty("assignments") @@ -166,6 +164,32 @@ public int size() return assignments.size(); } + public boolean isEmpty() + { + return size() == 0; + } + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + + Assignments that = (Assignments) o; + + return assignments.equals(that.assignments); + } + + @Override + public int hashCode() + { + return assignments.hashCode(); + } + public static class Builder { private final Map assignments = new LinkedHashMap<>(); diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/ExplainAnalyzeNode.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/ExplainAnalyzeNode.java index 5d5da31e2954..bd6315935f3d 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/ExplainAnalyzeNode.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/ExplainAnalyzeNode.java @@ -31,16 +31,19 @@ public class ExplainAnalyzeNode { private final PlanNode source; private final Symbol outputSymbol; + private final boolean verbose; @JsonCreator public ExplainAnalyzeNode( @JsonProperty("id") PlanNodeId id, @JsonProperty("source") PlanNode source, - @JsonProperty("outputSymbol") Symbol outputSymbol) + @JsonProperty("outputSymbol") Symbol outputSymbol, + @JsonProperty("verbose") boolean verbose) { super(id); this.source = requireNonNull(source, "source is null"); this.outputSymbol = requireNonNull(outputSymbol, "outputSymbol is null"); + this.verbose = verbose; } @JsonProperty("outputSymbol") @@ -55,6 +58,12 @@ public PlanNode getSource() return source; } + @JsonProperty("verbose") + public boolean isVerbose() + { + return verbose; + } + @Override public List getOutputSymbols() { @@ -76,6 +85,6 @@ public R accept(PlanVisitor visitor, C context) @Override public PlanNode replaceChildren(List newChildren) { - return new ExplainAnalyzeNode(getId(), Iterables.getOnlyElement(newChildren), outputSymbol); + return new ExplainAnalyzeNode(getId(), Iterables.getOnlyElement(newChildren), outputSymbol, isVerbose()); } } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/IndexSourceNode.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/IndexSourceNode.java index e1c7eedff9b4..61268f60b4ad 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/IndexSourceNode.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/IndexSourceNode.java @@ -67,6 +67,11 @@ public IndexSourceNode( checkArgument(!outputSymbols.isEmpty(), "outputSymbols is empty"); checkArgument(assignments.keySet().containsAll(lookupSymbols), "Assignments do not include all lookup symbols"); checkArgument(outputSymbols.containsAll(lookupSymbols), "Lookup symbols need to be part of the output symbols"); + Set assignedColumnHandles = ImmutableSet.copyOf(assignments.values()); + effectiveTupleDomain.getDomains().ifPresent(handleToDomain -> + checkArgument( + assignedColumnHandles.containsAll(handleToDomain.keySet()), + "Tuple domain handles must have assigned symbols")); } @JsonProperty diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/JoinNode.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/JoinNode.java index f9d867800d66..3ecc8862ae35 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/JoinNode.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/JoinNode.java @@ -95,6 +95,9 @@ public JoinNode(@JsonProperty("id") PlanNodeId id, .build(); checkArgument(inputSymbols.containsAll(outputSymbols), "Left and right join inputs do not contain all output symbols"); checkArgument(!isCrossJoin() || inputSymbols.equals(outputSymbols), "Cross join does not support output symbols pruning or reordering"); + + checkArgument(!(criteria.isEmpty() && leftHashSymbol.isPresent()), "Left hash symbol is only valid in an equijoin"); + checkArgument(!(criteria.isEmpty() && rightHashSymbol.isPresent()), "Right hash symbol is only valid in an equijoin"); } public enum DistributionType diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/planPrinter/PlanNodeStats.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/planPrinter/PlanNodeStats.java index 3356b2d5a39d..bafe8a2f8947 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/planPrinter/PlanNodeStats.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/planPrinter/PlanNodeStats.java @@ -14,10 +14,12 @@ package com.facebook.presto.sql.planner.planPrinter; import com.facebook.presto.sql.planner.plan.PlanNodeId; +import com.facebook.presto.util.Mergeable; import io.airlift.units.DataSize; import io.airlift.units.Duration; import java.util.Map; +import java.util.Optional; import java.util.Set; import static com.facebook.presto.util.MoreMaps.mergeMaps; @@ -29,7 +31,10 @@ import static java.util.concurrent.TimeUnit.MILLISECONDS; import static java.util.stream.Collectors.toMap; +// TODO: break into operator-specific stats classes instead of having a big union-class aggregating all stats together +@Deprecated public class PlanNodeStats + implements Mergeable { private final PlanNodeId planNodeId; @@ -41,6 +46,7 @@ public class PlanNodeStats private final Map operatorInputStats; private final Map operatorHashCollisionsStats; + private final Optional windowOperatorStats; PlanNodeStats( PlanNodeId planNodeId, @@ -50,7 +56,8 @@ public class PlanNodeStats long planNodeOutputPositions, DataSize planNodeOutputDataSize, Map operatorInputStats, - Map operatorHashCollisionsStats) + Map operatorHashCollisionsStats, + Optional windowOperatorStats) { this.planNodeId = requireNonNull(planNodeId, "planNodeId is null"); @@ -62,6 +69,7 @@ public class PlanNodeStats this.operatorInputStats = requireNonNull(operatorInputStats, "operatorInputStats is null"); this.operatorHashCollisionsStats = requireNonNull(operatorHashCollisionsStats, "operatorHashCollisionsStats is null"); + this.windowOperatorStats = requireNonNull(windowOperatorStats, "windowOperatorStats is null"); } private static double computedStdDev(double sumSquared, double sum, long n) @@ -161,24 +169,32 @@ public Map getOperatorExpectedCollisionsAverages() entry -> entry.getValue().getWeightedExpectedHashCollisions() / operatorInputStats.get(entry.getKey()).getInputPositions())); } - public static PlanNodeStats merge(PlanNodeStats left, PlanNodeStats right) + public Optional getWindowOperatorStats() { - checkArgument(left.getPlanNodeId().equals(right.getPlanNodeId()), "planNodeIds do not match. %s != %s", left.getPlanNodeId(), right.getPlanNodeId()); + return windowOperatorStats; + } + + @Override + public PlanNodeStats mergeWith(PlanNodeStats other) + { + checkArgument(planNodeId.equals(other.getPlanNodeId()), "planNodeIds do not match. %s != %s", planNodeId, other.getPlanNodeId()); - long planNodeInputPositions = left.planNodeInputPositions + right.planNodeInputPositions; - DataSize planNodeInputDataSize = succinctBytes(left.planNodeInputDataSize.toBytes() + right.planNodeInputDataSize.toBytes()); - long planNodeOutputPositions = left.planNodeOutputPositions + right.planNodeOutputPositions; - DataSize planNodeOutputDataSize = succinctBytes(left.planNodeOutputDataSize.toBytes() + right.planNodeOutputDataSize.toBytes()); + long planNodeInputPositions = this.planNodeInputPositions + other.planNodeInputPositions; + DataSize planNodeInputDataSize = succinctBytes(this.planNodeInputDataSize.toBytes() + other.planNodeInputDataSize.toBytes()); + long planNodeOutputPositions = this.planNodeOutputPositions + other.planNodeOutputPositions; + DataSize planNodeOutputDataSize = succinctBytes(this.planNodeOutputDataSize.toBytes() + other.planNodeOutputDataSize.toBytes()); - Map operatorInputStats = mergeMaps(left.operatorInputStats, right.operatorInputStats, OperatorInputStats::merge); - Map operatorHashCollisionsStats = mergeMaps(left.operatorHashCollisionsStats, right.operatorHashCollisionsStats, OperatorHashCollisionsStats::merge); + Map operatorInputStats = mergeMaps(this.operatorInputStats, other.operatorInputStats, OperatorInputStats::merge); + Map operatorHashCollisionsStats = mergeMaps(this.operatorHashCollisionsStats, other.operatorHashCollisionsStats, OperatorHashCollisionsStats::merge); + Optional windowNodeStats = Mergeable.merge(this.windowOperatorStats, other.windowOperatorStats); return new PlanNodeStats( - left.getPlanNodeId(), - new Duration(left.getPlanNodeWallTime().toMillis() + right.getPlanNodeWallTime().toMillis(), MILLISECONDS), + planNodeId, + new Duration(planNodeWallTime.toMillis() + other.getPlanNodeWallTime().toMillis(), MILLISECONDS), planNodeInputPositions, planNodeInputDataSize, planNodeOutputPositions, planNodeOutputDataSize, operatorInputStats, - operatorHashCollisionsStats); + operatorHashCollisionsStats, + windowNodeStats); } } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/planPrinter/PlanNodeStatsSummarizer.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/planPrinter/PlanNodeStatsSummarizer.java index 382e7e4a30b6..bc3d32cdfcdb 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/planPrinter/PlanNodeStatsSummarizer.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/planPrinter/PlanNodeStatsSummarizer.java @@ -19,6 +19,7 @@ import com.facebook.presto.operator.OperatorStats; import com.facebook.presto.operator.PipelineStats; import com.facebook.presto.operator.TaskStats; +import com.facebook.presto.operator.WindowInfo; import com.facebook.presto.sql.planner.plan.PlanNodeId; import com.google.common.collect.ImmutableMap; import io.airlift.units.Duration; @@ -28,6 +29,7 @@ import java.util.HashSet; import java.util.List; import java.util.Map; +import java.util.Optional; import java.util.Set; import static com.facebook.presto.util.MoreMaps.mergeMaps; @@ -51,7 +53,7 @@ public static Map aggregatePlanNodeStats(StageInfo st .flatMap(taskStats -> getPlanNodeStats(taskStats).stream()) .collect(toList()); for (PlanNodeStats stats : planNodeStats) { - aggregatedStats.merge(stats.getPlanNodeId(), stats, PlanNodeStats::merge); + aggregatedStats.merge(stats.getPlanNodeId(), stats, (left, right) -> left.mergeWith(right)); } return aggregatedStats; } @@ -71,6 +73,7 @@ private static List getPlanNodeStats(TaskStats taskStats) Map> operatorInputStats = new HashMap<>(); Map> operatorHashCollisionsStats = new HashMap<>(); + Map windowNodeStats = new HashMap<>(); for (PipelineStats pipelineStats : taskStats.getPipelines()) { // Due to eventual consistently collected stats, these could be empty @@ -118,6 +121,12 @@ private static List getPlanNodeStats(TaskStats taskStats) (map1, map2) -> mergeMaps(map1, map2, OperatorHashCollisionsStats::merge)); } + // The only statistics we have for Window Functions are very low level, thus displayed only in VERBOSE mode + if (operatorStats.getInfo() instanceof WindowInfo) { + WindowInfo windowInfo = (WindowInfo) operatorStats.getInfo(); + windowNodeStats.merge(planNodeId, WindowOperatorStats.create(windowInfo), (left, right) -> left.mergeWith(right)); + } + planNodeInputPositions.merge(planNodeId, operatorStats.getInputPositions(), Long::sum); planNodeInputBytes.merge(planNodeId, operatorStats.getInputDataSize().toBytes(), Long::sum); processedNodes.add(planNodeId); @@ -158,7 +167,8 @@ private static List getPlanNodeStats(TaskStats taskStats) succinctDataSize(planNodeOutputBytes.getOrDefault(planNodeId, 0L), BYTE), operatorInputStats.get(planNodeId), // Only some operators emit hash collisions statistics - operatorHashCollisionsStats.getOrDefault(planNodeId, emptyMap()))); + operatorHashCollisionsStats.getOrDefault(planNodeId, emptyMap()), + Optional.ofNullable(windowNodeStats.get(planNodeId)))); } return stats; } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/planPrinter/PlanPrinter.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/planPrinter/PlanPrinter.java index bcc7d064a13e..b2ccb2040e0a 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/planPrinter/PlanPrinter.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/planPrinter/PlanPrinter.java @@ -135,13 +135,14 @@ public class PlanPrinter private final StringBuilder output = new StringBuilder(); private final Metadata metadata; private final Optional> stats; + private final boolean verbose; private PlanPrinter(PlanNode plan, Map types, Metadata metadata, CostCalculator costCalculator, Session sesion) { - this(plan, types, metadata, costCalculator, sesion, 0); + this(plan, types, metadata, costCalculator, sesion, 0, false); } - private PlanPrinter(PlanNode plan, Map types, Metadata metadata, CostCalculator costCalculator, Session session, int indent) + private PlanPrinter(PlanNode plan, Map types, Metadata metadata, CostCalculator costCalculator, Session session, int indent, boolean verbose) { requireNonNull(plan, "plan is null"); requireNonNull(types, "types is null"); @@ -150,13 +151,14 @@ private PlanPrinter(PlanNode plan, Map types, Metadata metadata, C this.metadata = metadata; this.stats = Optional.empty(); + this.verbose = verbose; Map costs = costCalculator.calculateCostForPlan(session, types, plan); Visitor visitor = new Visitor(types, costs, session); plan.accept(visitor, indent); } - private PlanPrinter(PlanNode plan, Map types, Metadata metadata, CostCalculator costCalculator, Session session, Map stats, int indent) + private PlanPrinter(PlanNode plan, Map types, Metadata metadata, CostCalculator costCalculator, Session session, Map stats, int indent, boolean verbose) { requireNonNull(plan, "plan is null"); requireNonNull(types, "types is null"); @@ -165,6 +167,7 @@ private PlanPrinter(PlanNode plan, Map types, Metadata metadata, C this.metadata = metadata; this.stats = Optional.of(stats); + this.verbose = verbose; Map costs = costCalculator.calculateCostForPlan(session, types, plan); Visitor visitor = new Visitor(types, costs, session); @@ -184,15 +187,25 @@ public static String textLogicalPlan(PlanNode plan, Map types, Met public static String textLogicalPlan(PlanNode plan, Map types, Metadata metadata, CostCalculator costCalculator, Session session, int indent) { - return new PlanPrinter(plan, types, metadata, costCalculator, session, indent).toString(); + return textLogicalPlan(plan, types, metadata, costCalculator, session, indent, false); } - public static String textLogicalPlan(PlanNode plan, Map types, Metadata metadata, CostCalculator costCalculator, Session session, Map stats, int indent) + public static String textLogicalPlan(PlanNode plan, Map types, Metadata metadata, CostCalculator costCalculator, Session session, int indent, boolean verbose) { - return new PlanPrinter(plan, types, metadata, costCalculator, session, stats, indent).toString(); + return new PlanPrinter(plan, types, metadata, costCalculator, session, indent, verbose).toString(); + } + + public static String textLogicalPlan(PlanNode plan, Map types, Metadata metadata, CostCalculator costCalculator, Session session, Map stats, int indent, boolean verbose) + { + return new PlanPrinter(plan, types, metadata, costCalculator, session, stats, indent, verbose).toString(); } public static String textDistributedPlan(StageInfo outputStageInfo, Metadata metadata, CostCalculator costCalculator, Session session) + { + return textDistributedPlan(outputStageInfo, metadata, costCalculator, session, false); + } + + public static String textDistributedPlan(StageInfo outputStageInfo, Metadata metadata, CostCalculator costCalculator, Session session, boolean verbose) { StringBuilder builder = new StringBuilder(); List allStages = outputStageInfo.getSubStages().stream() @@ -200,37 +213,50 @@ public static String textDistributedPlan(StageInfo outputStageInfo, Metadata met .collect(toImmutableList()); for (StageInfo stageInfo : allStages) { Map aggregatedStats = aggregatePlanNodeStats(stageInfo); - builder.append(formatFragment(metadata, costCalculator, session, stageInfo.getPlan(), Optional.of(stageInfo.getStageStats()), Optional.of(aggregatedStats))); + builder.append(formatFragment(metadata, costCalculator, session, stageInfo.getPlan(), Optional.of(stageInfo), Optional.of(aggregatedStats), verbose)); } return builder.toString(); } public static String textDistributedPlan(SubPlan plan, Metadata metadata, CostCalculator costCalculator, Session session) + { + return textDistributedPlan(plan, metadata, costCalculator, session, false); + } + + public static String textDistributedPlan(SubPlan plan, Metadata metadata, CostCalculator costCalculator, Session session, boolean verbose) { StringBuilder builder = new StringBuilder(); for (PlanFragment fragment : plan.getAllFragments()) { - builder.append(formatFragment(metadata, costCalculator, session, fragment, Optional.empty(), Optional.empty())); + builder.append(formatFragment(metadata, costCalculator, session, fragment, Optional.empty(), Optional.empty(), verbose)); } return builder.toString(); } - private static String formatFragment(Metadata metadata, CostCalculator costCalculator, Session session, PlanFragment fragment, Optional stageStats, Optional> planNodeStats) + private static String formatFragment(Metadata metadata, CostCalculator costCalculator, Session session, PlanFragment fragment, Optional stageInfo, Optional> planNodeStats, boolean verbose) { StringBuilder builder = new StringBuilder(); builder.append(format("Fragment %s [%s]\n", fragment.getId(), fragment.getPartitioning())); - if (stageStats.isPresent()) { + if (stageInfo.isPresent()) { + StageStats stageStats = stageInfo.get().getStageStats(); + + double avgPositionsPerTask = stageInfo.get().getTasks().stream().mapToLong(task -> task.getStats().getProcessedInputPositions()).average().orElse(Double.NaN); + double squaredDifferences = stageInfo.get().getTasks().stream().mapToDouble(task -> Math.pow(task.getStats().getProcessedInputPositions() - avgPositionsPerTask, 2)).sum(); + double sdAmongTasks = Math.sqrt(squaredDifferences / stageInfo.get().getTasks().size()); + builder.append(indentString(1)) - .append(format("CPU: %s, Input: %s (%s), Output: %s (%s)\n", - stageStats.get().getTotalCpuTime(), - formatPositions(stageStats.get().getProcessedInputPositions()), - stageStats.get().getProcessedInputDataSize(), - formatPositions(stageStats.get().getOutputPositions()), - stageStats.get().getOutputDataSize())); + .append(format("CPU: %s, Input: %s (%s); per task: avg.: %s std.dev.: %s, Output: %s (%s)\n", + stageStats.getTotalCpuTime(), + formatPositions(stageStats.getProcessedInputPositions()), + stageStats.getProcessedInputDataSize(), + avgPositionsPerTask, + sdAmongTasks, + formatPositions(stageStats.getOutputPositions()), + stageStats.getOutputDataSize())); } PartitioningScheme partitioningScheme = fragment.getPartitioningScheme(); @@ -263,12 +289,12 @@ private static String formatFragment(Metadata metadata, CostCalculator costCalcu formatHash(partitioningScheme.getHashColumn()))); } - if (stageStats.isPresent()) { - builder.append(textLogicalPlan(fragment.getRoot(), fragment.getSymbols(), metadata, costCalculator, session, planNodeStats.get(), 1)) + if (stageInfo.isPresent()) { + builder.append(textLogicalPlan(fragment.getRoot(), fragment.getSymbols(), metadata, costCalculator, session, planNodeStats.get(), 1, verbose)) .append("\n"); } else { - builder.append(textLogicalPlan(fragment.getRoot(), fragment.getSymbols(), metadata, costCalculator, session, 1)) + builder.append(textLogicalPlan(fragment.getRoot(), fragment.getSymbols(), metadata, costCalculator, session, 1, verbose)) .append("\n"); } @@ -359,6 +385,11 @@ private void printStats(int indent, PlanNodeId planNodeId, boolean printInput, b output.append('\n'); printDistributions(indent, nodeStats); + + if (nodeStats.getWindowOperatorStats().isPresent()) { + // TODO: Once PlanNodeStats becomes broken into smaller classes, we should rely on toString() method of WindowOperatorStats here + printWindowOperatorStats(indent, nodeStats.getWindowOperatorStats().get()); + } } private void printDistributions(int indent, PlanNodeStats nodeStats) @@ -426,6 +457,34 @@ private static Map translateOperatorTypes(Set operators) return ImmutableMap.of(); } + private void printWindowOperatorStats(int indent, WindowOperatorStats stats) + { + if (!verbose) { + // these stats are too detailed for non-verbose mode + return; + } + + output.append(indentString(indent)); + output.append(format("Active Drivers: [ %d / %d ]", stats.getActiveDrivers(), stats.getTotalDrivers())); + output.append('\n'); + + output.append(indentString(indent)); + output.append(format("Index size: std.dev.: %s bytes , %s rows", formatDouble(stats.getIndexSizeStdDev()), formatDouble(stats.getIndexPositionsStdDev()))); + output.append('\n'); + + output.append(indentString(indent)); + output.append(format("Index count per driver: std.dev.: %s", formatDouble(stats.getIndexCountPerDriverStdDev()))); + output.append('\n'); + + output.append(indentString(indent)); + output.append(format("Rows per driver: std.dev.: %s", formatDouble(stats.getRowsPerDriverStdDev()))); + output.append('\n'); + + output.append(indentString(indent)); + output.append(format("Size of partition: std.dev.: %s", formatDouble(stats.getPartitionRowsStdDev()))); + output.append('\n'); + } + private static String formatDouble(double value) { if (isFinite(value)) { @@ -1223,10 +1282,17 @@ private String formatDomain(Domain domain) private void printCost(int indent, PlanNode... nodes) { - String costString = Joiner.on("/").join(Arrays.stream(nodes) - .map(this::formatCost) - .collect(toImmutableList())); - print(indent, "Cost: %s", costString); + if (Arrays.stream(nodes).anyMatch(this::isKnownCost)) { + String costString = Joiner.on("/").join(Arrays.stream(nodes) + .map(this::formatCost) + .collect(toImmutableList())); + print(indent, "Cost: %s", costString); + } + } + + private boolean isKnownCost(PlanNode node) + { + return !UNKNOWN_COST.equals(costs.getOrDefault(node.getId(), UNKNOWN_COST)); } private String formatCost(PlanNode node) diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/planPrinter/WindowOperatorStats.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/planPrinter/WindowOperatorStats.java new file mode 100644 index 000000000000..7faba0031ff8 --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/planPrinter/WindowOperatorStats.java @@ -0,0 +1,171 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.planPrinter; + +import com.facebook.presto.operator.WindowInfo; +import com.facebook.presto.operator.WindowInfo.DriverWindowInfo; +import com.facebook.presto.util.Mergeable; + +import static com.google.common.base.Preconditions.checkArgument; + +class WindowOperatorStats + implements Mergeable +{ + private final int activeDrivers; + private final int totalDrivers; + private final double positionsInIndexesSumSquaredDiffs; + private final double sizeOfIndexesSumSquaredDiffs; + private final double indexCountPerDriverSumSquaredDiffs; + private final double partitionRowsSumSquaredDiffs; + private final double rowCountPerDriverSumSquaredDiffs; + private final long totalRowCount; + private final long totalIndexesCount; + private final long totalPartitionsCount; + + public static WindowOperatorStats create(WindowInfo info) + { + checkArgument(info.getWindowInfos().size() > 0, "WindowInfo cannot have empty list of DriverWindowInfos"); + + int activeDrivers = 0; + int totalDrivers = 0; + + double partitionRowsSumSquaredDiffs = 0.0; + double positionsInIndexesSumSquaredDiffs = 0.0; + double sizeOfIndexesSumSquaredDiffs = 0.0; + double indexCountPerDriverSumSquaredDiffs = 0.0; + double rowCountPerDriverSumSquaredDiffs = 0.0; + long totalRowCount = 0; + long totalIndexesCount = 0; + long totalPartitionsCount = 0; + + double averageNumberOfIndexes = info.getWindowInfos().stream() + .filter(windowInfo -> windowInfo.getTotalRowsCount() > 0) + .mapToLong(DriverWindowInfo::getNumberOfIndexes) + .average() + .getAsDouble(); + + double averageNumberOfRows = info.getWindowInfos().stream() + .filter(windowInfo -> windowInfo.getTotalRowsCount() > 0) + .mapToLong(DriverWindowInfo::getTotalRowsCount) + .average() + .getAsDouble(); + + for (DriverWindowInfo driverWindowInfo : info.getWindowInfos()) { + long driverTotalRowsCount = driverWindowInfo.getTotalRowsCount(); + totalDrivers++; + if (driverTotalRowsCount > 0) { + long numberOfIndexes = driverWindowInfo.getNumberOfIndexes(); + + partitionRowsSumSquaredDiffs += driverWindowInfo.getSumSquaredDifferencesSizeInPartition(); + totalPartitionsCount += driverWindowInfo.getTotalPartitionsCount(); + + totalRowCount += driverWindowInfo.getTotalRowsCount(); + + positionsInIndexesSumSquaredDiffs += driverWindowInfo.getSumSquaredDifferencesPositionsOfIndex(); + sizeOfIndexesSumSquaredDiffs += driverWindowInfo.getSumSquaredDifferencesSizeOfIndex(); + totalIndexesCount += numberOfIndexes; + + indexCountPerDriverSumSquaredDiffs += (Math.pow(numberOfIndexes - averageNumberOfIndexes, 2)); + rowCountPerDriverSumSquaredDiffs += (Math.pow(driverTotalRowsCount - averageNumberOfRows, 2)); + activeDrivers++; + } + } + + return new WindowOperatorStats(partitionRowsSumSquaredDiffs, + positionsInIndexesSumSquaredDiffs, + sizeOfIndexesSumSquaredDiffs, + indexCountPerDriverSumSquaredDiffs, + rowCountPerDriverSumSquaredDiffs, + totalRowCount, + totalIndexesCount, + totalPartitionsCount, + activeDrivers, + totalDrivers); + } + + private WindowOperatorStats( + double partitionRowsSumSquaredDiffs, + double positionsInIndexesSumSquaredDiffs, + double sizeOfIndexesSumSquaredDiffs, + double indexCountPerDriverSumSquaredDiffs, + double rowCountPerDriverSumSquaredDiffs, + long totalRowCount, + long totalIndexesCount, + long totalPartitionsCount, + int activeDrivers, + int totalDrivers) + { + this.partitionRowsSumSquaredDiffs = partitionRowsSumSquaredDiffs; + this.positionsInIndexesSumSquaredDiffs = positionsInIndexesSumSquaredDiffs; + this.sizeOfIndexesSumSquaredDiffs = sizeOfIndexesSumSquaredDiffs; + this.indexCountPerDriverSumSquaredDiffs = indexCountPerDriverSumSquaredDiffs; + this.rowCountPerDriverSumSquaredDiffs = rowCountPerDriverSumSquaredDiffs; + this.totalRowCount = totalRowCount; + this.totalIndexesCount = totalIndexesCount; + this.totalPartitionsCount = totalPartitionsCount; + this.activeDrivers = activeDrivers; + this.totalDrivers = totalDrivers; + } + + @Override + public WindowOperatorStats mergeWith(WindowOperatorStats other) + { + return new WindowOperatorStats( + partitionRowsSumSquaredDiffs + other.partitionRowsSumSquaredDiffs, + positionsInIndexesSumSquaredDiffs + other.positionsInIndexesSumSquaredDiffs, + sizeOfIndexesSumSquaredDiffs + other.sizeOfIndexesSumSquaredDiffs, + indexCountPerDriverSumSquaredDiffs + other.indexCountPerDriverSumSquaredDiffs, + rowCountPerDriverSumSquaredDiffs + other.rowCountPerDriverSumSquaredDiffs, + totalRowCount + other.totalRowCount, + totalIndexesCount + other.totalIndexesCount, + totalPartitionsCount + other.totalPartitionsCount, + activeDrivers + other.activeDrivers, + totalDrivers + other.totalDrivers); + } + + public double getIndexSizeStdDev() + { + return Math.sqrt(sizeOfIndexesSumSquaredDiffs / totalIndexesCount); + } + + public double getIndexPositionsStdDev() + { + return Math.sqrt(positionsInIndexesSumSquaredDiffs / totalIndexesCount); + } + + public double getIndexCountPerDriverStdDev() + { + return Math.sqrt(indexCountPerDriverSumSquaredDiffs / activeDrivers); + } + + public double getPartitionRowsStdDev() + { + return Math.sqrt(partitionRowsSumSquaredDiffs / totalPartitionsCount); + } + + public double getRowsPerDriverStdDev() + { + return Math.sqrt(rowCountPerDriverSumSquaredDiffs / activeDrivers); + } + + public int getActiveDrivers() + { + return activeDrivers; + } + + public int getTotalDrivers() + { + return totalDrivers; + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/sanity/NoSubqueryRelatedNodeLeftChecker.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/sanity/NoSubqueryRelatedNodeLeftChecker.java deleted file mode 100644 index 576c00e96974..000000000000 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/sanity/NoSubqueryRelatedNodeLeftChecker.java +++ /dev/null @@ -1,60 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package com.facebook.presto.sql.planner.sanity; - -import com.facebook.presto.Session; -import com.facebook.presto.metadata.Metadata; -import com.facebook.presto.spi.type.Type; -import com.facebook.presto.sql.parser.SqlParser; -import com.facebook.presto.sql.planner.SimplePlanVisitor; -import com.facebook.presto.sql.planner.Symbol; -import com.facebook.presto.sql.planner.plan.ApplyNode; -import com.facebook.presto.sql.planner.plan.LateralJoinNode; -import com.facebook.presto.sql.planner.plan.PlanNode; - -import java.util.List; -import java.util.Map; - -public class NoSubqueryRelatedNodeLeftChecker - implements PlanSanityChecker.Checker -{ - @Override - public void validate(PlanNode plan, Session session, Metadata metadata, SqlParser sqlParser, Map types) - { - plan.accept(new SimplePlanVisitor() - { - @Override - public Object visitApply(ApplyNode node, Object context) - { - throw subqueryLeftException(node.getCorrelation()); - } - - @Override - public Object visitLateralJoin(LateralJoinNode node, Object context) - { - throw subqueryLeftException(node.getCorrelation()); - } - - private IllegalArgumentException subqueryLeftException(List correlation) - { - if (correlation.isEmpty()) { - return new IllegalArgumentException("Unsupported subquery type"); - } - else { - return new IllegalArgumentException("Unsupported correlated subquery type"); - } - } - }, null); - } -} diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/sanity/PlanSanityChecker.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/sanity/PlanSanityChecker.java index 6b13da104863..4e3cb8c514ef 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/sanity/PlanSanityChecker.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/sanity/PlanSanityChecker.java @@ -19,34 +19,53 @@ import com.facebook.presto.sql.parser.SqlParser; import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.plan.PlanNode; -import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableListMultimap; +import com.google.common.collect.Multimap; -import java.util.List; import java.util.Map; /** - * It is going to be executed at the end of logical planner, to verify its correctness + * It is going to be executed to verify logical planner correctness */ public final class PlanSanityChecker { - private static final List CHECKERS = ImmutableList.of( - new ValidateDependenciesChecker(), - new TypeValidator(), - new NoSubqueryExpressionLeftChecker(), - new NoDuplicatePlanNodeIdsChecker(), - new NoSubqueryRelatedNodeLeftChecker(), - new VerifyNoFilteredAggregations(), - new VerifyOnlyOneOutputNode()); + private static final Multimap CHECKERS = ImmutableListMultimap.builder() + .putAll( + Stage.INTERMEDIATE, + new ValidateDependenciesChecker(), + new NoDuplicatePlanNodeIdsChecker(), + new TypeValidator(), + new NoSubqueryExpressionLeftChecker(), + new VerifyOnlyOneOutputNode()) + .putAll( + Stage.FINAL, + new ValidateDependenciesChecker(), + new NoDuplicatePlanNodeIdsChecker(), + new TypeValidator(), + new NoSubqueryExpressionLeftChecker(), + new VerifyOnlyOneOutputNode(), + new VerifyNoFilteredAggregations()) + .build(); private PlanSanityChecker() {} - public static void validate(PlanNode planNode, Session session, Metadata metadata, SqlParser sqlParser, Map types) + public static void validateFinalPlan(PlanNode planNode, Session session, Metadata metadata, SqlParser sqlParser, Map types) { - CHECKERS.forEach(checker -> checker.validate(planNode, session, metadata, sqlParser, types)); + CHECKERS.get(Stage.FINAL).forEach(checker -> checker.validate(planNode, session, metadata, sqlParser, types)); + } + + public static void validateIntermediatePlan(PlanNode planNode, Session session, Metadata metadata, SqlParser sqlParser, Map types) + { + CHECKERS.get(Stage.INTERMEDIATE).forEach(checker -> checker.validate(planNode, session, metadata, sqlParser, types)); } public interface Checker { void validate(PlanNode planNode, Session session, Metadata metadata, SqlParser sqlParser, Map types); } + + private enum Stage + { + INTERMEDIATE, FINAL + }; } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/sanity/ValidateDependenciesChecker.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/sanity/ValidateDependenciesChecker.java index 82495ea50e5d..c75a1e9d3c46 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/sanity/ValidateDependenciesChecker.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/sanity/ValidateDependenciesChecker.java @@ -17,8 +17,8 @@ import com.facebook.presto.metadata.Metadata; import com.facebook.presto.spi.type.Type; import com.facebook.presto.sql.parser.SqlParser; -import com.facebook.presto.sql.planner.DependencyExtractor; import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.sql.planner.SymbolsExtractor; import com.facebook.presto.sql.planner.plan.AggregationNode; import com.facebook.presto.sql.planner.plan.AggregationNode.Aggregation; import com.facebook.presto.sql.planner.plan.ApplyNode; @@ -118,7 +118,7 @@ public Void visitAggregation(AggregationNode node, Set boundSymbols) checkDependencies(inputs, node.getGroupingKeys(), "Invalid node. Grouping key symbols (%s) not in source plan output (%s)", node.getGroupingKeys(), node.getSource().getOutputSymbols()); for (Aggregation aggregation : node.getAggregations().values()) { - Set dependencies = DependencyExtractor.extractUnique(aggregation.getCall()); + Set dependencies = SymbolsExtractor.extractUnique(aggregation.getCall()); checkDependencies(inputs, dependencies, "Invalid node. Aggregation dependencies (%s) not in source plan output (%s)", dependencies, node.getSource().getOutputSymbols()); } @@ -170,7 +170,7 @@ public Void visitWindow(WindowNode node, Set boundSymbols) checkDependencies(inputs, bounds.build(), "Invalid node. Frame bounds (%s) not in source plan output (%s)", bounds.build(), node.getSource().getOutputSymbols()); for (WindowNode.Function function : node.getWindowFunctions().values()) { - Set dependencies = DependencyExtractor.extractUnique(function.getFunctionCall()); + Set dependencies = SymbolsExtractor.extractUnique(function.getFunctionCall()); checkDependencies(inputs, dependencies, "Invalid node. Window function dependencies (%s) not in source plan output (%s)", dependencies, node.getSource().getOutputSymbols()); } @@ -210,7 +210,7 @@ public Void visitFilter(FilterNode node, Set boundSymbols) Set inputs = createInputs(source, boundSymbols); checkDependencies(inputs, node.getOutputSymbols(), "Invalid node. Output symbols (%s) not in source plan output (%s)", node.getOutputSymbols(), node.getSource().getOutputSymbols()); - Set dependencies = DependencyExtractor.extractUnique(node.getPredicate()); + Set dependencies = SymbolsExtractor.extractUnique(node.getPredicate()); checkDependencies(inputs, dependencies, "Invalid node. Predicate dependencies (%s) not in source plan output (%s)", dependencies, node.getSource().getOutputSymbols()); return null; @@ -233,7 +233,7 @@ public Void visitProject(ProjectNode node, Set boundSymbols) Set inputs = createInputs(source, boundSymbols); for (Expression expression : node.getAssignments().getExpressions()) { - Set dependencies = DependencyExtractor.extractUnique(expression); + Set dependencies = SymbolsExtractor.extractUnique(expression); checkDependencies(inputs, dependencies, "Invalid node. Expression dependencies (%s) not in source plan output (%s)", dependencies, inputs); } @@ -319,7 +319,7 @@ public Void visitJoin(JoinNode node, Set boundSymbols) } node.getFilter().ifPresent(predicate -> { - Set predicateSymbols = DependencyExtractor.extractUnique(predicate); + Set predicateSymbols = SymbolsExtractor.extractUnique(predicate); checkArgument( allInputs.containsAll(predicateSymbols), "Symbol from filter (%s) not in sources (%s)", @@ -545,7 +545,7 @@ public Void visitApply(ApplyNode node, Set boundSymbols) node.getSubquery().accept(this, subqueryCorrelation); // visit child checkDependencies(node.getInput().getOutputSymbols(), node.getCorrelation(), "APPLY input must provide all the necessary correlation symbols for subquery"); - checkDependencies(DependencyExtractor.extractUnique(node.getSubquery()), node.getCorrelation(), "not all APPLY correlation symbols are used in subquery"); + checkDependencies(SymbolsExtractor.extractUnique(node.getSubquery()), node.getCorrelation(), "not all APPLY correlation symbols are used in subquery"); ImmutableSet inputs = ImmutableSet.builder() .addAll(createInputs(node.getSubquery(), boundSymbols)) @@ -553,7 +553,7 @@ public Void visitApply(ApplyNode node, Set boundSymbols) .build(); for (Expression expression : node.getSubqueryAssignments().getExpressions()) { - Set dependencies = DependencyExtractor.extractUnique(expression); + Set dependencies = SymbolsExtractor.extractUnique(expression); checkDependencies(inputs, dependencies, "Invalid node. Expression dependencies (%s) not in source plan output (%s)", dependencies, inputs); } @@ -576,7 +576,7 @@ public Void visitLateralJoin(LateralJoinNode node, Set boundSymbols) node.getCorrelation(), "LATERAL input must provide all the necessary correlation symbols for subquery"); checkDependencies( - DependencyExtractor.extractUnique(node.getSubquery()), + SymbolsExtractor.extractUnique(node.getSubquery()), node.getCorrelation(), "not all LATERAL correlation symbols are used in subquery"); diff --git a/presto-main/src/main/java/com/facebook/presto/sql/relational/Signatures.java b/presto-main/src/main/java/com/facebook/presto/sql/relational/Signatures.java index 02f4740e7db9..10ff4b3250e7 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/relational/Signatures.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/relational/Signatures.java @@ -16,6 +16,7 @@ import com.facebook.presto.metadata.Signature; import com.facebook.presto.spi.function.OperatorType; import com.facebook.presto.spi.type.BigintType; +import com.facebook.presto.spi.type.RowType; import com.facebook.presto.spi.type.StandardTypes; import com.facebook.presto.spi.type.Type; import com.facebook.presto.spi.type.TypeSignature; @@ -23,7 +24,6 @@ import com.facebook.presto.sql.tree.ComparisonExpressionType; import com.facebook.presto.sql.tree.LogicalBinaryExpression; import com.facebook.presto.type.LikePatternType; -import com.facebook.presto.type.RowType; import com.google.common.collect.ImmutableList; import com.google.common.collect.Lists; diff --git a/presto-main/src/main/java/com/facebook/presto/sql/relational/SqlToRowExpressionTranslator.java b/presto-main/src/main/java/com/facebook/presto/sql/relational/SqlToRowExpressionTranslator.java index 2648c9e3e8e0..9968d5ec16f1 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/relational/SqlToRowExpressionTranslator.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/relational/SqlToRowExpressionTranslator.java @@ -19,6 +19,8 @@ import com.facebook.presto.metadata.Signature; import com.facebook.presto.spi.type.DecimalParseResult; import com.facebook.presto.spi.type.Decimals; +import com.facebook.presto.spi.type.RowType; +import com.facebook.presto.spi.type.RowType.RowField; import com.facebook.presto.spi.type.TimeZoneKey; import com.facebook.presto.spi.type.Type; import com.facebook.presto.spi.type.TypeManager; @@ -68,8 +70,6 @@ import com.facebook.presto.sql.tree.TimestampLiteral; import com.facebook.presto.sql.tree.TryExpression; import com.facebook.presto.sql.tree.WhenClause; -import com.facebook.presto.type.RowType; -import com.facebook.presto.type.RowType.RowField; import com.facebook.presto.type.UnknownType; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; diff --git a/presto-main/src/main/java/com/facebook/presto/sql/rewrite/ExplainRewrite.java b/presto-main/src/main/java/com/facebook/presto/sql/rewrite/ExplainRewrite.java index cabaa3541068..36d3b083d9ca 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/rewrite/ExplainRewrite.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/rewrite/ExplainRewrite.java @@ -80,7 +80,7 @@ protected Node visitExplain(Explain node, Void context) { if (node.isAnalyze()) { Statement statement = (Statement) process(node.getStatement(), context); - return new Explain(statement, node.isAnalyze(), node.getOptions()); + return new Explain(statement, node.isAnalyze(), node.isVerbose(), node.getOptions()); } ExplainType.Type planType = LOGICAL; diff --git a/presto-main/src/main/java/com/facebook/presto/sql/rewrite/ShowQueriesRewrite.java b/presto-main/src/main/java/com/facebook/presto/sql/rewrite/ShowQueriesRewrite.java index 56eb988648a3..4f82830d1624 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/rewrite/ShowQueriesRewrite.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/rewrite/ShowQueriesRewrite.java @@ -174,6 +174,7 @@ protected Node visitExplain(Explain node, Void context) return new Explain( node.getLocation().get(), node.isAnalyze(), + node.isVerbose(), statement, node.getOptions()); } diff --git a/presto-main/src/main/java/com/facebook/presto/testing/MaterializedResult.java b/presto-main/src/main/java/com/facebook/presto/testing/MaterializedResult.java index 5e1532fbb2b0..d40c6b957470 100644 --- a/presto-main/src/main/java/com/facebook/presto/testing/MaterializedResult.java +++ b/presto-main/src/main/java/com/facebook/presto/testing/MaterializedResult.java @@ -20,7 +20,10 @@ import com.facebook.presto.spi.PageBuilder; import com.facebook.presto.spi.block.Block; import com.facebook.presto.spi.block.BlockBuilder; +import com.facebook.presto.spi.type.ArrayType; import com.facebook.presto.spi.type.CharType; +import com.facebook.presto.spi.type.MapType; +import com.facebook.presto.spi.type.RowType; import com.facebook.presto.spi.type.SqlDate; import com.facebook.presto.spi.type.SqlDecimal; import com.facebook.presto.spi.type.SqlTime; @@ -30,9 +33,6 @@ import com.facebook.presto.spi.type.TimeZoneKey; import com.facebook.presto.spi.type.Type; import com.facebook.presto.spi.type.VarcharType; -import com.facebook.presto.type.ArrayType; -import com.facebook.presto.type.MapType; -import com.facebook.presto.type.RowType; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; diff --git a/presto-main/src/main/java/com/facebook/presto/testing/TestingAccessControlManager.java b/presto-main/src/main/java/com/facebook/presto/testing/TestingAccessControlManager.java index f11413872ec4..7212085437b2 100644 --- a/presto-main/src/main/java/com/facebook/presto/testing/TestingAccessControlManager.java +++ b/presto-main/src/main/java/com/facebook/presto/testing/TestingAccessControlManager.java @@ -36,6 +36,7 @@ import static com.facebook.presto.spi.security.AccessDeniedException.denyCreateTable; import static com.facebook.presto.spi.security.AccessDeniedException.denyCreateView; import static com.facebook.presto.spi.security.AccessDeniedException.denyDeleteTable; +import static com.facebook.presto.spi.security.AccessDeniedException.denyDropColumn; import static com.facebook.presto.spi.security.AccessDeniedException.denyDropSchema; import static com.facebook.presto.spi.security.AccessDeniedException.denyDropTable; import static com.facebook.presto.spi.security.AccessDeniedException.denyDropView; @@ -55,6 +56,7 @@ import static com.facebook.presto.testing.TestingAccessControlManager.TestingPrivilegeType.CREATE_VIEW_WITH_SELECT_TABLE; import static com.facebook.presto.testing.TestingAccessControlManager.TestingPrivilegeType.CREATE_VIEW_WITH_SELECT_VIEW; import static com.facebook.presto.testing.TestingAccessControlManager.TestingPrivilegeType.DELETE_TABLE; +import static com.facebook.presto.testing.TestingAccessControlManager.TestingPrivilegeType.DROP_COLUMN; import static com.facebook.presto.testing.TestingAccessControlManager.TestingPrivilegeType.DROP_SCHEMA; import static com.facebook.presto.testing.TestingAccessControlManager.TestingPrivilegeType.DROP_TABLE; import static com.facebook.presto.testing.TestingAccessControlManager.TestingPrivilegeType.DROP_VIEW; @@ -187,6 +189,15 @@ public void checkCanAddColumns(TransactionId transactionId, Identity identity, Q super.checkCanAddColumns(transactionId, identity, tableName); } + @Override + public void checkCanDropColumn(TransactionId transactionId, Identity identity, QualifiedObjectName tableName) + { + if (shouldDenyPrivilege(identity.getUser(), tableName.getObjectName(), DROP_COLUMN)) { + denyDropColumn(tableName.toString()); + } + super.checkCanDropColumn(transactionId, identity, tableName); + } + @Override public void checkCanRenameColumn(TransactionId transactionId, Identity identity, QualifiedObjectName tableName) { @@ -322,7 +333,7 @@ public enum TestingPrivilegeType SET_USER, CREATE_SCHEMA, DROP_SCHEMA, RENAME_SCHEMA, CREATE_TABLE, DROP_TABLE, RENAME_TABLE, SELECT_TABLE, INSERT_TABLE, DELETE_TABLE, - ADD_COLUMN, RENAME_COLUMN, + ADD_COLUMN, DROP_COLUMN, RENAME_COLUMN, CREATE_VIEW, DROP_VIEW, SELECT_VIEW, CREATE_VIEW_WITH_SELECT_TABLE, CREATE_VIEW_WITH_SELECT_VIEW, SET_SESSION diff --git a/presto-main/src/main/java/com/facebook/presto/type/ArrayParametricType.java b/presto-main/src/main/java/com/facebook/presto/type/ArrayParametricType.java index b67a6aeaec98..6f29a974ba0d 100644 --- a/presto-main/src/main/java/com/facebook/presto/type/ArrayParametricType.java +++ b/presto-main/src/main/java/com/facebook/presto/type/ArrayParametricType.java @@ -13,6 +13,7 @@ */ package com.facebook.presto.type; +import com.facebook.presto.spi.type.ArrayType; import com.facebook.presto.spi.type.ParameterKind; import com.facebook.presto.spi.type.ParametricType; import com.facebook.presto.spi.type.StandardTypes; diff --git a/presto-main/src/main/java/com/facebook/presto/type/DecimalInequalityOperators.java b/presto-main/src/main/java/com/facebook/presto/type/DecimalInequalityOperators.java index c34792dd868f..738fd3c8f665 100644 --- a/presto-main/src/main/java/com/facebook/presto/type/DecimalInequalityOperators.java +++ b/presto-main/src/main/java/com/facebook/presto/type/DecimalInequalityOperators.java @@ -20,7 +20,6 @@ import com.facebook.presto.spi.PrestoException; import com.facebook.presto.spi.function.OperatorType; import com.facebook.presto.spi.type.TypeSignature; -import com.google.common.base.Throwables; import com.google.common.collect.ImmutableSet; import io.airlift.slice.Slice; @@ -41,6 +40,7 @@ import static com.facebook.presto.spi.type.TypeSignature.parseTypeSignature; import static com.facebook.presto.spi.type.UnscaledDecimal128Arithmetic.compare; import static com.facebook.presto.util.Reflection.methodHandle; +import static com.google.common.base.Throwables.throwIfInstanceOf; public class DecimalInequalityOperators { @@ -175,8 +175,8 @@ private static boolean invokeGetResult(MethodHandle getResultMethodHandle, int c return (boolean) getResultMethodHandle.invokeExact(comparisonResult); } catch (Throwable t) { - Throwables.propagateIfInstanceOf(t, Error.class); - Throwables.propagateIfInstanceOf(t, PrestoException.class); + throwIfInstanceOf(t, Error.class); + throwIfInstanceOf(t, PrestoException.class); throw new PrestoException(GENERIC_INTERNAL_ERROR, t); } } diff --git a/presto-main/src/main/java/com/facebook/presto/type/DecimalSaturatedFloorCasts.java b/presto-main/src/main/java/com/facebook/presto/type/DecimalSaturatedFloorCasts.java index b9fffa970bd8..9ec747753135 100644 --- a/presto-main/src/main/java/com/facebook/presto/type/DecimalSaturatedFloorCasts.java +++ b/presto-main/src/main/java/com/facebook/presto/type/DecimalSaturatedFloorCasts.java @@ -30,13 +30,10 @@ import static com.facebook.presto.spi.type.Decimals.bigIntegerTenToNth; import static com.facebook.presto.spi.type.Decimals.decodeUnscaledValue; import static com.facebook.presto.spi.type.Decimals.encodeUnscaledValue; -import static com.facebook.presto.spi.type.DoubleType.DOUBLE; import static com.facebook.presto.spi.type.IntegerType.INTEGER; -import static com.facebook.presto.spi.type.RealType.REAL; import static com.facebook.presto.spi.type.SmallintType.SMALLINT; import static com.facebook.presto.spi.type.TinyintType.TINYINT; import static com.facebook.presto.spi.type.TypeSignature.parseTypeSignature; -import static java.lang.Float.intBitsToFloat; import static java.lang.Math.toIntExact; import static java.math.BigInteger.ONE; import static java.math.RoundingMode.FLOOR; @@ -108,64 +105,6 @@ private static BigInteger bigDecimalToBigintFloorSaturatedCast(BigDecimal bigDec return unscaledValue; } - public static final SqlScalarFunction DOUBLE_TO_DECIMAL_SATURATED_FLOOR_CAST = SqlScalarFunction.builder(DecimalSaturatedFloorCasts.class) - .signature(Signature.builder() - .kind(SCALAR) - .operatorType(SATURATED_FLOOR_CAST) - .argumentTypes(DOUBLE.getTypeSignature()) - .returnType(parseTypeSignature("decimal(result_precision,result_scale)", ImmutableSet.of("result_precision", "result_scale"))) - .build() - ) - .implementation(b -> b - .methods("doubleToShortDecimal", "doubleToLongDecimal") - .withExtraParameters((context) -> { - int resultPrecision = toIntExact(context.getLiteral("result_precision")); - int resultScale = toIntExact(context.getLiteral("result_scale")); - return ImmutableList.of(resultPrecision, resultScale); - }) - ).build(); - - @UsedByGeneratedCode - public static long doubleToShortDecimal(double value, int resultPrecision, int resultScale) - { - return bigDecimalToBigintFloorSaturatedCast(new BigDecimal(value), resultPrecision, resultScale).longValueExact(); - } - - @UsedByGeneratedCode - public static Slice doubleToLongDecimal(double value, int resultPrecision, int resultScale) - { - return encodeUnscaledValue(bigDecimalToBigintFloorSaturatedCast(new BigDecimal(value), resultPrecision, resultScale)); - } - - public static final SqlScalarFunction REAL_TO_DECIMAL_SATURATED_FLOOR_CAST = SqlScalarFunction.builder(DecimalSaturatedFloorCasts.class) - .signature(Signature.builder() - .kind(SCALAR) - .operatorType(SATURATED_FLOOR_CAST) - .argumentTypes(REAL.getTypeSignature()) - .returnType(parseTypeSignature("decimal(result_precision,result_scale)", ImmutableSet.of("result_precision", "result_scale"))) - .build() - ) - .implementation(b -> b - .methods("realToShortDecimal", "realToLongDecimal") - .withExtraParameters((context) -> { - int resultPrecision = toIntExact(context.getLiteral("result_precision")); - int resultScale = toIntExact(context.getLiteral("result_scale")); - return ImmutableList.of(resultPrecision, resultScale); - }) - ).build(); - - @UsedByGeneratedCode - public static long realToShortDecimal(long value, int resultPrecision, int resultScale) - { - return bigDecimalToBigintFloorSaturatedCast(new BigDecimal(intBitsToFloat((int) value)), resultPrecision, resultScale).longValueExact(); - } - - @UsedByGeneratedCode - public static Slice realToLongDecimal(long value, int resultPrecision, int resultScale) - { - return encodeUnscaledValue(bigDecimalToBigintFloorSaturatedCast(new BigDecimal(intBitsToFloat((int) value)), resultPrecision, resultScale)); - } - public static final SqlScalarFunction DECIMAL_TO_BIGINT_SATURATED_FLOOR_CAST = decimalToGenericIntegerTypeSaturatedFloorCast(BIGINT, Long.MIN_VALUE, Long.MAX_VALUE); public static final SqlScalarFunction DECIMAL_TO_INTEGER_SATURATED_FLOOR_CAST = decimalToGenericIntegerTypeSaturatedFloorCast(INTEGER, Integer.MIN_VALUE, Integer.MAX_VALUE); public static final SqlScalarFunction DECIMAL_TO_SMALLINT_SATURATED_FLOOR_CAST = decimalToGenericIntegerTypeSaturatedFloorCast(SMALLINT, Short.MIN_VALUE, Short.MAX_VALUE); diff --git a/presto-main/src/main/java/com/facebook/presto/type/DoubleOperators.java b/presto-main/src/main/java/com/facebook/presto/type/DoubleOperators.java index 946158431bf2..6a5ec53dd523 100644 --- a/presto-main/src/main/java/com/facebook/presto/type/DoubleOperators.java +++ b/presto-main/src/main/java/com/facebook/presto/type/DoubleOperators.java @@ -265,13 +265,6 @@ else if (value >= Float.MAX_VALUE) { return floatToRawIntBits(result); } - @ScalarOperator(SATURATED_FLOOR_CAST) - @SqlType(StandardTypes.BIGINT) - public static long saturatedFloorCastToBigint(@SqlType(StandardTypes.DOUBLE) double value) - { - return saturatedFloorCastToLong(value, Long.MIN_VALUE, MIN_LONG_AS_DOUBLE, Long.MAX_VALUE, MAX_LONG_PLUS_ONE_AS_DOUBLE); - } - @ScalarOperator(SATURATED_FLOOR_CAST) @SqlType(StandardTypes.INTEGER) public static long saturatedFloorCastToInteger(@SqlType(StandardTypes.DOUBLE) double value) diff --git a/presto-main/src/main/java/com/facebook/presto/type/MapParametricType.java b/presto-main/src/main/java/com/facebook/presto/type/MapParametricType.java index c3ca2e432665..9e7a3635b224 100644 --- a/presto-main/src/main/java/com/facebook/presto/type/MapParametricType.java +++ b/presto-main/src/main/java/com/facebook/presto/type/MapParametricType.java @@ -14,6 +14,7 @@ package com.facebook.presto.type; import com.facebook.presto.spi.function.OperatorType; +import com.facebook.presto.spi.type.MapType; import com.facebook.presto.spi.type.ParameterKind; import com.facebook.presto.spi.type.ParametricType; import com.facebook.presto.spi.type.StandardTypes; @@ -32,10 +33,11 @@ public final class MapParametricType implements ParametricType { - public static final MapParametricType MAP = new MapParametricType(); + private final boolean useNewMapBlock; - private MapParametricType() + public MapParametricType(boolean useNewMapBlock) { + this.useNewMapBlock = useNewMapBlock; } @Override @@ -61,6 +63,12 @@ public Type createType(TypeManager typeManager, List parameters) MethodHandle keyBlockNativeEquals = compose(keyNativeEquals, nativeValueGetter(keyType)); MethodHandle keyNativeHashCode = typeManager.resolveOperator(OperatorType.HASH_CODE, ImmutableList.of(keyType)); MethodHandle keyBlockHashCode = compose(keyNativeHashCode, nativeValueGetter(keyType)); - return new MapType(keyType, valueType, keyBlockNativeEquals, keyNativeHashCode, keyBlockHashCode); + return new MapType( + useNewMapBlock, + keyType, + valueType, + useNewMapBlock ? keyBlockNativeEquals : null, + useNewMapBlock ? keyNativeHashCode : null, + useNewMapBlock ? keyBlockHashCode : null); } } diff --git a/presto-main/src/main/java/com/facebook/presto/type/RealOperators.java b/presto-main/src/main/java/com/facebook/presto/type/RealOperators.java index 80c47c946e31..f9359a762966 100644 --- a/presto-main/src/main/java/com/facebook/presto/type/RealOperators.java +++ b/presto-main/src/main/java/com/facebook/presto/type/RealOperators.java @@ -247,20 +247,6 @@ public static boolean isDistinctFrom( return notEqual(left, right); } - @ScalarOperator(SATURATED_FLOOR_CAST) - @SqlType(StandardTypes.BIGINT) - public static long saturatedFloorCastToBigint(@SqlType(StandardTypes.REAL) long value) - { - return saturatedFloorCastToLong(value, Long.MIN_VALUE, MIN_LONG_AS_FLOAT, Long.MAX_VALUE, MAX_LONG_PLUS_ONE_AS_FLOAT); - } - - @ScalarOperator(SATURATED_FLOOR_CAST) - @SqlType(StandardTypes.INTEGER) - public static long saturatedFloorCastToInteger(@SqlType(StandardTypes.REAL) long value) - { - return saturatedFloorCastToLong(value, Integer.MIN_VALUE, MIN_INTEGER_AS_FLOAT, Integer.MAX_VALUE, MAX_INTEGER_PLUS_ONE_AS_FLOAT); - } - @ScalarOperator(SATURATED_FLOOR_CAST) @SqlType(StandardTypes.SMALLINT) public static long saturatedFloorCastToSmallint(@SqlType(StandardTypes.REAL) long value) diff --git a/presto-main/src/main/java/com/facebook/presto/type/RowParametricType.java b/presto-main/src/main/java/com/facebook/presto/type/RowParametricType.java index 51d5399c263d..173983ee6fde 100644 --- a/presto-main/src/main/java/com/facebook/presto/type/RowParametricType.java +++ b/presto-main/src/main/java/com/facebook/presto/type/RowParametricType.java @@ -16,6 +16,7 @@ import com.facebook.presto.spi.type.NamedType; import com.facebook.presto.spi.type.ParameterKind; import com.facebook.presto.spi.type.ParametricType; +import com.facebook.presto.spi.type.RowType; import com.facebook.presto.spi.type.StandardTypes; import com.facebook.presto.spi.type.Type; import com.facebook.presto.spi.type.TypeManager; diff --git a/presto-main/src/main/java/com/facebook/presto/type/TypeJsonUtils.java b/presto-main/src/main/java/com/facebook/presto/type/TypeJsonUtils.java index e08a5c7cd6b3..cc6484979806 100644 --- a/presto-main/src/main/java/com/facebook/presto/type/TypeJsonUtils.java +++ b/presto-main/src/main/java/com/facebook/presto/type/TypeJsonUtils.java @@ -18,9 +18,12 @@ import com.facebook.presto.spi.StandardErrorCode; import com.facebook.presto.spi.block.BlockBuilder; import com.facebook.presto.spi.block.BlockBuilderStatus; +import com.facebook.presto.spi.type.ArrayType; import com.facebook.presto.spi.type.DecimalType; import com.facebook.presto.spi.type.Decimals; import com.facebook.presto.spi.type.FixedWidthType; +import com.facebook.presto.spi.type.MapType; +import com.facebook.presto.spi.type.RowType; import com.facebook.presto.spi.type.SqlDecimal; import com.facebook.presto.spi.type.StandardTypes; import com.facebook.presto.spi.type.Type; diff --git a/presto-main/src/main/java/com/facebook/presto/type/TypeRegistry.java b/presto-main/src/main/java/com/facebook/presto/type/TypeRegistry.java index 918a123f37b4..4c6ba635494b 100644 --- a/presto-main/src/main/java/com/facebook/presto/type/TypeRegistry.java +++ b/presto-main/src/main/java/com/facebook/presto/type/TypeRegistry.java @@ -15,8 +15,10 @@ import com.facebook.presto.metadata.FunctionRegistry; import com.facebook.presto.spi.function.OperatorType; +import com.facebook.presto.spi.type.ArrayType; import com.facebook.presto.spi.type.CharType; import com.facebook.presto.spi.type.DecimalType; +import com.facebook.presto.spi.type.MapType; import com.facebook.presto.spi.type.ParametricType; import com.facebook.presto.spi.type.StandardTypes; import com.facebook.presto.spi.type.Type; @@ -25,6 +27,8 @@ import com.facebook.presto.spi.type.TypeSignature; import com.facebook.presto.spi.type.TypeSignatureParameter; import com.facebook.presto.spi.type.VarcharType; +import com.facebook.presto.sql.analyzer.FeaturesConfig; +import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; @@ -71,7 +75,6 @@ import static com.facebook.presto.type.JsonType.JSON; import static com.facebook.presto.type.LikePatternType.LIKE_PATTERN; import static com.facebook.presto.type.ListLiteralType.LIST_LITERAL; -import static com.facebook.presto.type.MapParametricType.MAP; import static com.facebook.presto.type.Re2JRegexpType.RE2J_REGEXP; import static com.facebook.presto.type.RowParametricType.ROW; import static com.facebook.presto.type.UnknownType.UNKNOWN; @@ -88,13 +91,20 @@ public final class TypeRegistry private FunctionRegistry functionRegistry; + @VisibleForTesting public TypeRegistry() { - this(ImmutableSet.of()); + this(ImmutableSet.of(), new FeaturesConfig()); } - @Inject + @VisibleForTesting public TypeRegistry(Set types) + { + this(ImmutableSet.of(), new FeaturesConfig()); + } + + @Inject + public TypeRegistry(Set types, FeaturesConfig featuresConfig) { requireNonNull(types, "types is null"); @@ -132,7 +142,7 @@ public TypeRegistry(Set types) addParametricType(DecimalParametricType.DECIMAL); addParametricType(ROW); addParametricType(ARRAY); - addParametricType(MAP); + addParametricType(new MapParametricType(featuresConfig.isNewMapBlock())); addParametricType(FUNCTION); for (Type type : types) { diff --git a/presto-main/src/main/java/com/facebook/presto/util/Failures.java b/presto-main/src/main/java/com/facebook/presto/util/Failures.java index c9dc14c0eb48..b291900455df 100644 --- a/presto-main/src/main/java/com/facebook/presto/util/Failures.java +++ b/presto-main/src/main/java/com/facebook/presto/util/Failures.java @@ -21,6 +21,7 @@ import com.facebook.presto.spi.HostAddress; import com.facebook.presto.spi.PrestoException; import com.facebook.presto.spi.PrestoTransportException; +import com.facebook.presto.spi.StandardErrorCode; import com.facebook.presto.sql.analyzer.SemanticException; import com.facebook.presto.sql.parser.ParsingException; import com.facebook.presto.sql.tree.NodeLocation; @@ -35,6 +36,7 @@ import static com.facebook.presto.spi.StandardErrorCode.SYNTAX_ERROR; import static com.google.common.base.Functions.toStringFunction; import static com.google.common.base.MoreObjects.firstNonNull; +import static com.google.common.base.Throwables.throwIfInstanceOf; import static com.google.common.collect.ImmutableList.toImmutableList; import static java.lang.String.format; import static java.util.Arrays.asList; @@ -132,4 +134,11 @@ private static ErrorCode toErrorCode(@Nullable Throwable throwable) } return GENERIC_INTERNAL_ERROR.toErrorCode(); } + + public static PrestoException internalError(Throwable t) + { + throwIfInstanceOf(t, Error.class); + throwIfInstanceOf(t, PrestoException.class); + return new PrestoException(StandardErrorCode.GENERIC_INTERNAL_ERROR, t); + } } diff --git a/presto-main/src/main/java/com/facebook/presto/util/FastutilSetHelper.java b/presto-main/src/main/java/com/facebook/presto/util/FastutilSetHelper.java index 1c1bf32c66a0..fd5b354a85fd 100644 --- a/presto-main/src/main/java/com/facebook/presto/util/FastutilSetHelper.java +++ b/presto-main/src/main/java/com/facebook/presto/util/FastutilSetHelper.java @@ -16,7 +16,6 @@ import com.facebook.presto.metadata.FunctionRegistry; import com.facebook.presto.spi.PrestoException; import com.facebook.presto.spi.type.Type; -import com.google.common.base.Throwables; import com.google.common.collect.ImmutableList; import it.unimi.dsi.fastutil.Hash; import it.unimi.dsi.fastutil.booleans.BooleanOpenHashSet; @@ -34,6 +33,7 @@ import static com.facebook.presto.spi.StandardErrorCode.GENERIC_INTERNAL_ERROR; import static com.facebook.presto.spi.function.OperatorType.EQUAL; import static com.facebook.presto.spi.function.OperatorType.HASH_CODE; +import static com.google.common.base.Throwables.throwIfInstanceOf; import static java.lang.Math.toIntExact; public final class FastutilSetHelper @@ -103,8 +103,8 @@ public int hashCode(long value) return Long.hashCode((long) hashCodeHandle.invokeExact(value)); } catch (Throwable t) { - Throwables.propagateIfInstanceOf(t, Error.class); - Throwables.propagateIfInstanceOf(t, PrestoException.class); + throwIfInstanceOf(t, Error.class); + throwIfInstanceOf(t, PrestoException.class); throw new PrestoException(GENERIC_INTERNAL_ERROR, t); } } @@ -116,8 +116,8 @@ public boolean equals(long a, long b) return (boolean) equalsHandle.invokeExact(a, b); } catch (Throwable t) { - Throwables.propagateIfInstanceOf(t, Error.class); - Throwables.propagateIfInstanceOf(t, PrestoException.class); + throwIfInstanceOf(t, Error.class); + throwIfInstanceOf(t, PrestoException.class); throw new PrestoException(GENERIC_INTERNAL_ERROR, t); } } @@ -142,8 +142,8 @@ public int hashCode(double value) return Long.hashCode((long) hashCodeHandle.invokeExact(value)); } catch (Throwable t) { - Throwables.propagateIfInstanceOf(t, Error.class); - Throwables.propagateIfInstanceOf(t, PrestoException.class); + throwIfInstanceOf(t, Error.class); + throwIfInstanceOf(t, PrestoException.class); throw new PrestoException(GENERIC_INTERNAL_ERROR, t); } } @@ -155,8 +155,8 @@ public boolean equals(double a, double b) return (boolean) equalsHandle.invokeExact(a, b); } catch (Throwable t) { - Throwables.propagateIfInstanceOf(t, Error.class); - Throwables.propagateIfInstanceOf(t, PrestoException.class); + throwIfInstanceOf(t, Error.class); + throwIfInstanceOf(t, PrestoException.class); throw new PrestoException(GENERIC_INTERNAL_ERROR, t); } } @@ -185,8 +185,8 @@ public int hashCode(Object value) return toIntExact(Long.hashCode((long) hashCodeHandle.invokeExact(value))); } catch (Throwable t) { - Throwables.propagateIfInstanceOf(t, Error.class); - Throwables.propagateIfInstanceOf(t, PrestoException.class); + throwIfInstanceOf(t, Error.class); + throwIfInstanceOf(t, PrestoException.class); throw new PrestoException(GENERIC_INTERNAL_ERROR, t); } } @@ -198,8 +198,8 @@ public boolean equals(Object a, Object b) return (boolean) equalsHandle.invokeExact(a, b); } catch (Throwable t) { - Throwables.propagateIfInstanceOf(t, Error.class); - Throwables.propagateIfInstanceOf(t, PrestoException.class); + throwIfInstanceOf(t, Error.class); + throwIfInstanceOf(t, PrestoException.class); throw new PrestoException(GENERIC_INTERNAL_ERROR, t); } } diff --git a/presto-main/src/main/java/com/facebook/presto/util/JsonUtil.java b/presto-main/src/main/java/com/facebook/presto/util/JsonUtil.java index d3aa40100f70..d2d7e0ec9cf3 100644 --- a/presto-main/src/main/java/com/facebook/presto/util/JsonUtil.java +++ b/presto-main/src/main/java/com/facebook/presto/util/JsonUtil.java @@ -16,13 +16,13 @@ import com.facebook.presto.spi.ConnectorSession; import com.facebook.presto.spi.PrestoException; import com.facebook.presto.spi.block.Block; +import com.facebook.presto.spi.type.ArrayType; import com.facebook.presto.spi.type.DecimalType; import com.facebook.presto.spi.type.Decimals; +import com.facebook.presto.spi.type.MapType; +import com.facebook.presto.spi.type.RowType; import com.facebook.presto.spi.type.StandardTypes; import com.facebook.presto.spi.type.Type; -import com.facebook.presto.type.ArrayType; -import com.facebook.presto.type.MapType; -import com.facebook.presto.type.RowType; import com.facebook.presto.type.UnknownType; import com.fasterxml.jackson.core.JsonFactory; import com.fasterxml.jackson.core.JsonGenerator; diff --git a/presto-main/src/main/java/com/facebook/presto/util/Mergeable.java b/presto-main/src/main/java/com/facebook/presto/util/Mergeable.java new file mode 100644 index 000000000000..d5dc3b921a46 --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/util/Mergeable.java @@ -0,0 +1,33 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.util; + +import java.util.Optional; + +public interface Mergeable +{ + T mergeWith(T other); + + static > Optional merge(Optional first, Optional second) + { + if (first.isPresent() && second.isPresent()) { + return Optional.of(first.get().mergeWith(second.get())); + } + else if (first.isPresent()) { + return first; + } + + return second; + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/util/MoreLists.java b/presto-main/src/main/java/com/facebook/presto/util/MoreLists.java index e8f75f5aaa45..e2a54fd35c31 100644 --- a/presto-main/src/main/java/com/facebook/presto/util/MoreLists.java +++ b/presto-main/src/main/java/com/facebook/presto/util/MoreLists.java @@ -16,6 +16,7 @@ import com.google.common.collect.ImmutableList; import java.util.List; +import java.util.function.Predicate; import static com.google.common.collect.ImmutableList.toImmutableList; import static java.util.Objects.requireNonNull; @@ -29,5 +30,14 @@ public static List> listOfListsCopy(List> lists) .collect(toImmutableList()); } + public static List filteredCopy(List elements, Predicate predicate) + { + requireNonNull(elements, "elements is null"); + requireNonNull(predicate, "predicate is null"); + return elements.stream() + .filter(predicate) + .collect(toImmutableList()); + } + private MoreLists() {} } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/Predicates.java b/presto-main/src/main/java/com/facebook/presto/util/MorePredicates.java similarity index 71% rename from presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/Predicates.java rename to presto-main/src/main/java/com/facebook/presto/util/MorePredicates.java index 53d21d881f4b..5e8cc3345248 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/Predicates.java +++ b/presto-main/src/main/java/com/facebook/presto/util/MorePredicates.java @@ -11,30 +11,22 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.facebook.presto.sql.planner.optimizations; +package com.facebook.presto.util; import java.util.function.Predicate; -public class Predicates +import static com.google.common.base.Predicates.alwaysFalse; + +public class MorePredicates { - private Predicates() {} + private MorePredicates() {} public static Predicate isInstanceOfAny(Class... classes) { - Predicate predicate = alwaysFalse(); + Predicate predicate = alwaysFalse(); for (Class clazz : classes) { predicate = predicate.or(clazz::isInstance); } return predicate; } - - public static Predicate alwaysTrue() - { - return x -> true; - } - - public static Predicate alwaysFalse() - { - return x -> false; - } } diff --git a/presto-main/src/test/java/com/facebook/presto/block/AbstractTestBlock.java b/presto-main/src/test/java/com/facebook/presto/block/AbstractTestBlock.java index 39cd45770c7b..021416c18928 100644 --- a/presto-main/src/test/java/com/facebook/presto/block/AbstractTestBlock.java +++ b/presto-main/src/test/java/com/facebook/presto/block/AbstractTestBlock.java @@ -17,6 +17,7 @@ import com.facebook.presto.spi.block.BlockBuilder; import com.facebook.presto.spi.block.BlockBuilderStatus; import com.facebook.presto.spi.block.BlockEncoding; +import com.facebook.presto.spi.block.DictionaryId; import com.google.common.collect.ImmutableList; import com.google.common.primitives.Ints; import io.airlift.slice.DynamicSliceOutput; @@ -26,11 +27,13 @@ import org.openjdk.jol.info.ClassLayout; import org.testng.annotations.Test; +import java.lang.invoke.MethodHandle; import java.lang.reflect.Array; import java.lang.reflect.Field; import java.util.IdentityHashMap; import java.util.List; import java.util.Map; +import java.util.concurrent.atomic.AtomicLong; import static com.facebook.presto.spi.type.BigintType.BIGINT; import static com.facebook.presto.spi.type.VarbinaryType.VARBINARY; @@ -41,6 +44,7 @@ import static io.airlift.slice.SizeOf.SIZE_OF_SHORT; import static io.airlift.slice.SizeOf.sizeOf; import static java.lang.Math.toIntExact; +import static java.lang.String.format; import static org.testng.Assert.assertEquals; import static org.testng.Assert.assertFalse; import static org.testng.Assert.assertTrue; @@ -93,53 +97,60 @@ private void assertRetainedSize(Block block) Field[] fields = block.getClass().getDeclaredFields(); try { for (Field field : fields) { - Class type = field.getType(); + Class type = field.getType(); if (type.isPrimitive()) { continue; } field.setAccessible(true); - if (type.equals(Slice.class)) { + if (type == Slice.class) { retainedSize += ((Slice) field.get(block)).getRetainedSize(); } - else if (type.equals(BlockBuilderStatus.class)) { + else if (type == BlockBuilderStatus.class) { retainedSize += BlockBuilderStatus.INSTANCE_SIZE; } - else if (type.equals(BlockBuilder.class) || type.equals(Block.class)) { + else if (type == BlockBuilder.class || type == Block.class) { retainedSize += ((Block) field.get(block)).getRetainedSizeInBytes(); } - else if (type.equals(Slice[].class)) { + else if (type == Slice[].class) { retainedSize += getSliceArrayRetainedSizeInBytes((Slice[]) field.get(block)); } - else if (type.equals(BlockBuilder[].class) || type.equals(Block[].class)) { + else if (type == BlockBuilder[].class || type == Block[].class) { Block[] blocks = (Block[]) field.get(block); for (Block innerBlock : blocks) { assertRetainedSize(innerBlock); retainedSize += innerBlock.getRetainedSizeInBytes(); } } - else if (type.equals(SliceOutput.class)) { + else if (type == SliceOutput.class) { retainedSize += ((SliceOutput) field.get(block)).getRetainedSize(); } - else if (type.equals(int[].class)) { + else if (type == int[].class) { retainedSize += sizeOf((int[]) field.get(block)); } - else if (type.equals(boolean[].class)) { + else if (type == boolean[].class) { retainedSize += sizeOf((boolean[]) field.get(block)); } - else if (type.equals(byte[].class)) { + else if (type == byte[].class) { retainedSize += sizeOf((byte[]) field.get(block)); } - else if (type.equals(long[].class)) { + else if (type == long[].class) { retainedSize += sizeOf((long[]) field.get(block)); } - else if (type.equals(short[].class)) { + else if (type == short[].class) { retainedSize += sizeOf((short[]) field.get(block)); } + else if (type == DictionaryId.class || BlockEncoding.class.isAssignableFrom(type) || type == AtomicLong.class || type == MethodHandle.class) { + // TODO: Some of these should be accounted in retainedSize + // do nothing + } + else { + throw new IllegalArgumentException(format("Unknown type encountered: %s", type)); + } } } - catch (IllegalAccessException | IllegalArgumentException t) { + catch (IllegalAccessException t) { throw new RuntimeException(t); } assertEquals(block.getRetainedSizeInBytes(), retainedSize); @@ -187,17 +198,17 @@ private void assertBlockSize(Block block) { // Asserting on `block` is not very effective because most blocks passed to this method is compact. // Therefore, we split the `block` into two and assert again. - int expectedBlockSize = copyBlock(block).getSizeInBytes(); + long expectedBlockSize = copyBlock(block).getSizeInBytes(); assertEquals(block.getSizeInBytes(), expectedBlockSize); assertEquals(block.getRegionSizeInBytes(0, block.getPositionCount()), expectedBlockSize); List splitBlock = splitBlock(block, 2); Block firstHalf = splitBlock.get(0); - int expectedFirstHalfSize = copyBlock(firstHalf).getSizeInBytes(); + long expectedFirstHalfSize = copyBlock(firstHalf).getSizeInBytes(); assertEquals(firstHalf.getSizeInBytes(), expectedFirstHalfSize); assertEquals(block.getRegionSizeInBytes(0, firstHalf.getPositionCount()), expectedFirstHalfSize); Block secondHalf = splitBlock.get(1); - int expectedSecondHalfSize = copyBlock(secondHalf).getSizeInBytes(); + long expectedSecondHalfSize = copyBlock(secondHalf).getSizeInBytes(); assertEquals(secondHalf.getSizeInBytes(), expectedSecondHalfSize); assertEquals(block.getRegionSizeInBytes(firstHalf.getPositionCount(), secondHalf.getPositionCount()), expectedSecondHalfSize); } diff --git a/presto-main/src/test/java/com/facebook/presto/block/BlockAssertions.java b/presto-main/src/test/java/com/facebook/presto/block/BlockAssertions.java index 19fa3f9b4f73..c833bf51248d 100644 --- a/presto-main/src/test/java/com/facebook/presto/block/BlockAssertions.java +++ b/presto-main/src/test/java/com/facebook/presto/block/BlockAssertions.java @@ -18,9 +18,9 @@ import com.facebook.presto.spi.block.BlockBuilderStatus; import com.facebook.presto.spi.block.DictionaryBlock; import com.facebook.presto.spi.block.RunLengthEncodedBlock; +import com.facebook.presto.spi.type.ArrayType; import com.facebook.presto.spi.type.DecimalType; import com.facebook.presto.spi.type.Type; -import com.facebook.presto.type.ArrayType; import io.airlift.slice.Slice; import java.math.BigDecimal; @@ -158,7 +158,7 @@ public static Block createStringDictionaryBlock(int start, int length) for (int i = 0; i < length; i++) { ids[i] = i % dictionarySize; } - return new DictionaryBlock(length, builder.build(), ids); + return new DictionaryBlock(builder.build(), ids); } public static Block createStringArraysBlock(Iterable> values) @@ -346,7 +346,7 @@ public static Block createLongDictionaryBlock(int start, int length) for (int i = 0; i < length; i++) { ids[i] = i % dictionarySize; } - return new DictionaryBlock(length, builder.build(), ids); + return new DictionaryBlock(builder.build(), ids); } public static Block createLongRepeatBlock(int value, int length) diff --git a/presto-main/src/test/java/com/facebook/presto/block/TestBlockBuilder.java b/presto-main/src/test/java/com/facebook/presto/block/TestBlockBuilder.java index c6564a659c99..9c582c1d8335 100644 --- a/presto-main/src/test/java/com/facebook/presto/block/TestBlockBuilder.java +++ b/presto-main/src/test/java/com/facebook/presto/block/TestBlockBuilder.java @@ -17,8 +17,8 @@ import com.facebook.presto.spi.block.Block; import com.facebook.presto.spi.block.BlockBuilder; import com.facebook.presto.spi.block.BlockBuilderStatus; +import com.facebook.presto.spi.type.ArrayType; import com.facebook.presto.spi.type.Type; -import com.facebook.presto.type.ArrayType; import com.google.common.collect.ImmutableList; import io.airlift.slice.Slices; import org.testng.annotations.Test; diff --git a/presto-main/src/test/java/com/facebook/presto/block/TestDictionaryBlock.java b/presto-main/src/test/java/com/facebook/presto/block/TestDictionaryBlock.java index 25f4dc6d0eb4..fcc986685204 100644 --- a/presto-main/src/test/java/com/facebook/presto/block/TestDictionaryBlock.java +++ b/presto-main/src/test/java/com/facebook/presto/block/TestDictionaryBlock.java @@ -62,8 +62,8 @@ public void testCopyPositionsWithCompaction() assertEquals(copiedBlock.getDictionary().getPositionCount(), 1); assertEquals(copiedBlock.getPositionCount(), positionsToCopy.size()); - assertBlock(copiedBlock.getDictionary(), new Slice[]{firstExpectedValue}); - assertBlock(copiedBlock, new Slice[]{firstExpectedValue, firstExpectedValue, firstExpectedValue, firstExpectedValue, firstExpectedValue}); + assertBlock(copiedBlock.getDictionary(), new Slice[] {firstExpectedValue}); + assertBlock(copiedBlock, new Slice[] {firstExpectedValue, firstExpectedValue, firstExpectedValue, firstExpectedValue, firstExpectedValue}); } @Test @@ -79,7 +79,7 @@ public void testCopyPositionsWithCompactionsAndReorder() assertEquals(copiedBlock.getDictionary().getPositionCount(), 2); assertEquals(copiedBlock.getPositionCount(), positionsToCopy.size()); - assertBlock(copiedBlock.getDictionary(), new Slice[] { expectedValues[0], expectedValues[5] }); + assertBlock(copiedBlock.getDictionary(), new Slice[] {expectedValues[0], expectedValues[5]}); assertDictionaryIds(copiedBlock, 0, 1, 0, 1, 0); } @@ -96,7 +96,7 @@ public void testCopyPositionsSamePosition() assertEquals(copiedBlock.getDictionary().getPositionCount(), 1); assertEquals(copiedBlock.getPositionCount(), positionsToCopy.size()); - assertBlock(copiedBlock.getDictionary(), new Slice[] { expectedValues[2] }); + assertBlock(copiedBlock.getDictionary(), new Slice[] {expectedValues[2]}); assertDictionaryIds(copiedBlock, 0, 0, 0); } @@ -126,7 +126,7 @@ public void testCompact() assertNotEquals(dictionaryBlock.getDictionarySourceId(), compactBlock.getDictionarySourceId()); assertEquals(compactBlock.getDictionary().getPositionCount(), (expectedValues.length / 2) + 1); - assertBlock(compactBlock.getDictionary(), new Slice[] { expectedValues[0], expectedValues[1], expectedValues[3] }); + assertBlock(compactBlock.getDictionary(), new Slice[] {expectedValues[0], expectedValues[1], expectedValues[3]}); assertDictionaryIds(compactBlock, 0, 1, 1, 2, 2, 0, 1, 1, 2, 2); assertEquals(compactBlock.isCompact(), true); @@ -164,7 +164,7 @@ private static DictionaryBlock createDictionaryBlockWithUnreferencedKeys(Slice[] } ids[i] = index; } - return new DictionaryBlock(positionCount, new SliceArrayBlock(dictionarySize, expectedValues), ids); + return new DictionaryBlock(new SliceArrayBlock(dictionarySize, expectedValues), ids); } private static DictionaryBlock createDictionaryBlock(Slice[] expectedValues, int positionCount) @@ -175,7 +175,7 @@ private static DictionaryBlock createDictionaryBlock(Slice[] expectedValues, int for (int i = 0; i < positionCount; i++) { ids[i] = i % dictionarySize; } - return new DictionaryBlock(positionCount, new SliceArrayBlock(dictionarySize, expectedValues), ids); + return new DictionaryBlock(new SliceArrayBlock(dictionarySize, expectedValues), ids); } private static void assertDictionaryIds(DictionaryBlock dictionaryBlock, int... expected) diff --git a/presto-main/src/test/java/com/facebook/presto/block/TestInterleavedBlock.java b/presto-main/src/test/java/com/facebook/presto/block/TestInterleavedBlock.java index f02332919f75..53c83b6b06fa 100644 --- a/presto-main/src/test/java/com/facebook/presto/block/TestInterleavedBlock.java +++ b/presto-main/src/test/java/com/facebook/presto/block/TestInterleavedBlock.java @@ -80,11 +80,11 @@ private void testGetSizeInBytes() InterleavedBlock block = blockBuilder.build(); List splitQuarter = splitBlock(block, 4); - int sizeInBytes = block.getSizeInBytes(); - int quarter1size = splitQuarter.get(0).getSizeInBytes(); - int quarter2size = splitQuarter.get(1).getSizeInBytes(); - int quarter3size = splitQuarter.get(2).getSizeInBytes(); - int quarter4size = splitQuarter.get(3).getSizeInBytes(); + long sizeInBytes = block.getSizeInBytes(); + long quarter1size = splitQuarter.get(0).getSizeInBytes(); + long quarter2size = splitQuarter.get(1).getSizeInBytes(); + long quarter3size = splitQuarter.get(2).getSizeInBytes(); + long quarter4size = splitQuarter.get(3).getSizeInBytes(); double expectedQuarterSizeMin = sizeInBytes * 0.2; double expectedQuarterSizeMax = sizeInBytes * 0.3; assertTrue(quarter1size > expectedQuarterSizeMin && quarter1size < expectedQuarterSizeMax, format("quarter1size is %s, should be between %s and %s", quarter1size, expectedQuarterSizeMin, expectedQuarterSizeMax)); diff --git a/presto-main/src/test/java/com/facebook/presto/block/TestMapBlock.java b/presto-main/src/test/java/com/facebook/presto/block/TestMapBlock.java index a4759824943e..3b49f6588148 100644 --- a/presto-main/src/test/java/com/facebook/presto/block/TestMapBlock.java +++ b/presto-main/src/test/java/com/facebook/presto/block/TestMapBlock.java @@ -18,13 +18,17 @@ import com.facebook.presto.spi.block.BlockBuilder; import com.facebook.presto.spi.block.BlockBuilderStatus; import com.facebook.presto.spi.block.SingleMapBlock; -import com.facebook.presto.type.MapType; -import com.google.common.collect.ImmutableMap; +import com.facebook.presto.spi.type.MapType; import com.google.common.primitives.Ints; import org.testng.annotations.Test; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; import java.util.Map; +import static com.facebook.presto.block.BlockAssertions.createLongsBlock; +import static com.facebook.presto.block.BlockAssertions.createStringsBlock; import static com.facebook.presto.spi.type.BigintType.BIGINT; import static com.facebook.presto.spi.type.VarcharType.VARCHAR; import static com.facebook.presto.util.StructuralTestUtil.mapType; @@ -33,6 +37,7 @@ import static org.testng.Assert.assertEquals; import static org.testng.Assert.assertFalse; import static org.testng.Assert.assertNotEquals; +import static org.testng.Assert.assertTrue; public class TestMapBlock extends AbstractTestBlock @@ -48,11 +53,11 @@ private Map[] createTestMap(int... entryCounts) Map[] result = new Map[entryCounts.length]; for (int rowNumber = 0; rowNumber < entryCounts.length; rowNumber++) { int entryCount = entryCounts[rowNumber]; - ImmutableMap.Builder builder = ImmutableMap.builder(); + Map map = new HashMap<>(); for (int entryNumber = 0; entryNumber < entryCount; entryNumber++) { - builder.put("key" + entryNumber, rowNumber * 100L + entryNumber); + map.put("key" + entryNumber, entryNumber == 5 ? null : rowNumber * 100L + entryNumber); } - result[rowNumber] = builder.build(); + result[rowNumber] = map; } return result; } @@ -68,6 +73,12 @@ private void testWith(Map[] expectedValues) assertBlockFilteredPositions(expectedValues, blockBuilder, Ints.asList(2, 3, 5, 6)); assertBlockFilteredPositions(expectedValues, blockBuilder.build(), Ints.asList(2, 3, 5, 6)); + Block block = createBlockWithValuesFromKeyValueBlock(expectedValues); + + assertBlock(block, expectedValues); + assertBlockFilteredPositions(expectedValues, block, Ints.asList(0, 1, 3, 4, 7)); + assertBlockFilteredPositions(expectedValues, block, Ints.asList(2, 3, 5, 6)); + Map[] expectedValuesWithNull = (Map[]) alternatingNullValues(expectedValues); BlockBuilder blockBuilderWithNull = createBlockBuilderWithValues(expectedValuesWithNull); @@ -77,6 +88,12 @@ private void testWith(Map[] expectedValues) assertBlockFilteredPositions(expectedValuesWithNull, blockBuilderWithNull.build(), Ints.asList(0, 1, 5, 6, 7, 10, 11, 12, 15)); assertBlockFilteredPositions(expectedValuesWithNull, blockBuilderWithNull, Ints.asList(2, 3, 4, 9, 13, 14)); assertBlockFilteredPositions(expectedValuesWithNull, blockBuilderWithNull.build(), Ints.asList(2, 3, 4, 9, 13, 14)); + + Block blockWithNull = createBlockWithValuesFromKeyValueBlock(expectedValuesWithNull); + + assertBlock(blockWithNull, expectedValuesWithNull); + assertBlockFilteredPositions(expectedValuesWithNull, blockWithNull, Ints.asList(0, 1, 5, 6, 7, 10, 11, 12, 15)); + assertBlockFilteredPositions(expectedValuesWithNull, blockWithNull, Ints.asList(2, 3, 4, 9, 13, 14)); } private BlockBuilder createBlockBuilderWithValues(Map[] maps) @@ -89,6 +106,29 @@ private BlockBuilder createBlockBuilderWithValues(Map[] maps) return mapBlockBuilder; } + private Block createBlockWithValuesFromKeyValueBlock(Map[] maps) + { + List keys = new ArrayList<>(); + List values = new ArrayList<>(); + int[] offsets = new int[maps.length + 1]; + boolean[] mapIsNull = new boolean[maps.length]; + for (int i = 0; i < maps.length; i++) { + Map map = maps[i]; + mapIsNull[i] = map == null; + if (map == null) { + offsets[i + 1] = offsets[i]; + } + else { + for (Map.Entry entry : map.entrySet()) { + keys.add(entry.getKey()); + values.add(entry.getValue()); + } + offsets[i + 1] = offsets[i] + map.size(); + } + } + return mapType(VARCHAR, BIGINT).createBlockFromKeyValue(mapIsNull, offsets, createStringsBlock(keys), createLongsBlock(values)); + } + private void createBlockBuilderWithValues(Map map, BlockBuilder mapBlockBuilder) { if (map == null) { @@ -98,7 +138,12 @@ private void createBlockBuilderWithValues(Map map, BlockBuilder ma BlockBuilder elementBlockBuilder = mapBlockBuilder.beginBlockEntry(); for (Map.Entry entry : map.entrySet()) { VARCHAR.writeSlice(elementBlockBuilder, utf8Slice(entry.getKey())); - BIGINT.writeLong(elementBlockBuilder, entry.getValue()); + if (entry.getValue() == null) { + elementBlockBuilder.appendNull(); + } + else { + BIGINT.writeLong(elementBlockBuilder, entry.getValue()); + } } mapBlockBuilder.closeEntry(); } @@ -123,15 +168,37 @@ private void assertValue(Block mapBlock, int position, Map map) assertFalse(mapBlock.isNull(position)); SingleMapBlock elementBlock = (SingleMapBlock) mapType.getObject(mapBlock, position); - // assert inserted keys + assertEquals(elementBlock.getPositionCount(), map.size() * 2); + + // Test new/hash-index access: assert inserted keys for (Map.Entry entry : map.entrySet()) { int pos = elementBlock.seekKey(utf8Slice(entry.getKey())); assertNotEquals(pos, -1); - assertEquals(BIGINT.getLong(elementBlock, pos), (long) entry.getValue()); + if (entry.getValue() == null) { + assertTrue(elementBlock.isNull(pos)); + } + else { + assertFalse(elementBlock.isNull(pos)); + assertEquals(BIGINT.getLong(elementBlock, pos), (long) entry.getValue()); + } } - // assert non-existent keys + // Test new/hash-index access: assert non-existent keys for (int i = 0; i < 10; i++) { assertEquals(elementBlock.seekKey(utf8Slice("not-inserted-" + i)), -1); } + + // Test legacy/iterative access + for (int i = 0; i < elementBlock.getPositionCount(); i += 2) { + String actualKey = VARCHAR.getSlice(elementBlock, i).toStringUtf8(); + Long actualValue; + if (elementBlock.isNull(i + 1)) { + actualValue = null; + } + else { + actualValue = BIGINT.getLong(elementBlock, i + 1); + } + assertTrue(map.containsKey(actualKey)); + assertEquals(actualValue, map.get(actualKey)); + } } } diff --git a/presto-main/src/test/java/com/facebook/presto/block/TestVariableWidthBlock.java b/presto-main/src/test/java/com/facebook/presto/block/TestVariableWidthBlock.java index 1507cd65cf16..d7e4e82dd257 100644 --- a/presto-main/src/test/java/com/facebook/presto/block/TestVariableWidthBlock.java +++ b/presto-main/src/test/java/com/facebook/presto/block/TestVariableWidthBlock.java @@ -95,11 +95,11 @@ private void testGetSizeInBytes() Block block = blockBuilder.build(); List splitQuarter = splitBlock(block, 4); - int sizeInBytes = block.getSizeInBytes(); - int quarter1size = splitQuarter.get(0).getSizeInBytes(); - int quarter2size = splitQuarter.get(1).getSizeInBytes(); - int quarter3size = splitQuarter.get(2).getSizeInBytes(); - int quarter4size = splitQuarter.get(3).getSizeInBytes(); + long sizeInBytes = block.getSizeInBytes(); + long quarter1size = splitQuarter.get(0).getSizeInBytes(); + long quarter2size = splitQuarter.get(1).getSizeInBytes(); + long quarter3size = splitQuarter.get(2).getSizeInBytes(); + long quarter4size = splitQuarter.get(3).getSizeInBytes(); double expectedQuarterSizeMin = sizeInBytes * 0.2; double expectedQuarterSizeMax = sizeInBytes * 0.3; assertTrue(quarter1size > expectedQuarterSizeMin && quarter1size < expectedQuarterSizeMax, format("quarter1size is %s, should be between %s and %s", quarter1size, expectedQuarterSizeMin, expectedQuarterSizeMax)); diff --git a/presto-main/src/test/java/com/facebook/presto/execution/MockQueryExecution.java b/presto-main/src/test/java/com/facebook/presto/execution/MockQueryExecution.java index 3c3eb3137b8a..189043086863 100644 --- a/presto-main/src/test/java/com/facebook/presto/execution/MockQueryExecution.java +++ b/presto-main/src/test/java/com/facebook/presto/execution/MockQueryExecution.java @@ -18,12 +18,15 @@ import com.facebook.presto.memory.VersionedMemoryPoolId; import com.facebook.presto.spi.QueryId; import com.facebook.presto.spi.memory.MemoryPoolId; +import com.facebook.presto.spi.resourceGroups.QueryType; import com.facebook.presto.spi.resourceGroups.ResourceGroupId; import com.facebook.presto.sql.planner.Plan; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; +import io.airlift.units.DataSize; import io.airlift.units.Duration; +import org.joda.time.DateTime; import java.net.URI; import java.util.ArrayList; @@ -36,8 +39,10 @@ import static com.facebook.presto.execution.QueryState.QUEUED; import static com.facebook.presto.execution.QueryState.RUNNING; import static com.facebook.presto.testing.TestingSession.testSessionBuilder; +import static io.airlift.units.DataSize.Unit.BYTE; import static java.util.Objects.requireNonNull; import static java.util.concurrent.TimeUnit.MILLISECONDS; +import static java.util.concurrent.TimeUnit.NANOSECONDS; public class MockQueryExecution implements QueryExecution @@ -96,7 +101,50 @@ public QueryInfo getQueryInfo() URI.create("http://test"), ImmutableList.of(), "SELECT 1", - new QueryStats(), + new QueryStats( + new DateTime(1), + new DateTime(2), + new DateTime(3), + new DateTime(4), + new Duration(6, NANOSECONDS), + new Duration(5, NANOSECONDS), + new Duration(7, NANOSECONDS), + new Duration(8, NANOSECONDS), + + new Duration(100, NANOSECONDS), + new Duration(200, NANOSECONDS), + + 9, + 10, + 11, + + 12, + 13, + 15, + 30, + 16, + + 17.0, + new DataSize(18, BYTE), + new DataSize(19, BYTE), + + true, + new Duration(20, NANOSECONDS), + new Duration(21, NANOSECONDS), + new Duration(22, NANOSECONDS), + new Duration(23, NANOSECONDS), + false, + ImmutableSet.of(), + + new DataSize(24, BYTE), + 25, + + new DataSize(26, BYTE), + 27, + + new DataSize(28, BYTE), + 29, + ImmutableList.of()), ImmutableMap.of(), ImmutableSet.of(), ImmutableMap.of(), @@ -229,6 +277,12 @@ public void addFinalQueryInfoListener(StateChangeListener stateChange throw new UnsupportedOperationException(); } + @Override + public Optional getQueryType() + { + return Optional.empty(); + } + private void fireStateChange() { for (StateChangeListener listener : listeners) { diff --git a/presto-main/src/test/java/com/facebook/presto/execution/TaskExecutorSimulator.java b/presto-main/src/test/java/com/facebook/presto/execution/TaskExecutorSimulator.java deleted file mode 100644 index 8ed9f7c156d9..000000000000 --- a/presto-main/src/test/java/com/facebook/presto/execution/TaskExecutorSimulator.java +++ /dev/null @@ -1,414 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package com.facebook.presto.execution; - -import com.facebook.presto.execution.executor.TaskExecutor; -import com.facebook.presto.execution.executor.TaskHandle; -import com.google.common.base.Throwables; -import com.google.common.base.Ticker; -import com.google.common.collect.ArrayListMultimap; -import com.google.common.collect.ImmutableList; -import com.google.common.collect.Multimap; -import com.google.common.collect.Multimaps; -import com.google.common.util.concurrent.Futures; -import com.google.common.util.concurrent.ListenableFuture; -import com.google.common.util.concurrent.ListeningExecutorService; -import com.google.common.util.concurrent.SettableFuture; -import io.airlift.stats.Distribution; -import io.airlift.units.Duration; - -import java.io.Closeable; -import java.util.ArrayList; -import java.util.Collection; -import java.util.List; -import java.util.Map; -import java.util.Map.Entry; -import java.util.Set; -import java.util.TreeMap; -import java.util.concurrent.Callable; -import java.util.concurrent.ExecutorService; -import java.util.concurrent.TimeUnit; -import java.util.concurrent.atomic.AtomicBoolean; -import java.util.concurrent.atomic.AtomicInteger; -import java.util.concurrent.atomic.AtomicLong; - -import static com.google.common.collect.Sets.newConcurrentHashSet; -import static com.google.common.util.concurrent.MoreExecutors.listeningDecorator; -import static io.airlift.concurrent.Threads.threadsNamed; -import static java.util.concurrent.Executors.newCachedThreadPool; -import static java.util.concurrent.TimeUnit.MILLISECONDS; -import static java.util.concurrent.TimeUnit.NANOSECONDS; - -public class TaskExecutorSimulator - implements Closeable -{ - private static final boolean PRINT_TASK_COMPLETION = false; - private static final boolean PRINT_SPLIT_COMPLETION = false; - - public static void main(String[] args) - throws Exception - { - try (TaskExecutorSimulator simulator = new TaskExecutorSimulator()) { - simulator.run(); - } - } - - private final ListeningExecutorService executor; - private final TaskExecutor taskExecutor; - - public TaskExecutorSimulator() - { - executor = listeningDecorator(newCachedThreadPool(threadsNamed(getClass().getSimpleName() + "-%s"))); - - taskExecutor = new TaskExecutor(24, 48, new Ticker() - { - private final long start = System.nanoTime(); - - @Override - public long read() - { - // run 10 times faster than reality - long now = System.nanoTime(); - return (now - start) * 100; - } - }); - taskExecutor.start(); - } - - @Override - public void close() - { - taskExecutor.stop(); - executor.shutdownNow(); - } - - public void run() - throws Exception - { - Multimap tasks = Multimaps.synchronizedListMultimap(ArrayListMultimap.create()); - Set> finishFutures = newConcurrentHashSet(); - AtomicBoolean done = new AtomicBoolean(); - - long start = System.nanoTime(); - - // large tasks - for (int userId = 0; userId < 2; userId++) { - ListenableFuture future = createUser("large_" + userId, 100, taskExecutor, done, tasks); - finishFutures.add(future); - } - - // small tasks - for (int userId = 0; userId < 4; userId++) { - ListenableFuture future = createUser("small_" + userId, 5, taskExecutor, done, tasks); - finishFutures.add(future); - } - - // tiny tasks - for (int userId = 0; userId < 1; userId++) { - ListenableFuture future = createUser("tiny_" + userId, 1, taskExecutor, done, tasks); - finishFutures.add(future); - } - - // warm up - for (int i = 0; i < 30; i++) { - MILLISECONDS.sleep(1000); - System.out.println(taskExecutor); - } - tasks.clear(); - - // run - for (int i = 0; i < 60; i++) { - MILLISECONDS.sleep(1000); - System.out.println(taskExecutor); - } - - // capture finished tasks - Map> middleTasks; - synchronized (tasks) { - middleTasks = new TreeMap<>(tasks.asMap()); - } - - // wait for finish - done.set(true); - Futures.allAsList(finishFutures).get(1, TimeUnit.MINUTES); - - Duration runtime = Duration.nanosSince(start).convertToMostSuccinctTimeUnit(); - synchronized (this) { - System.out.println(); - System.out.println("Simulation finished in " + runtime); - System.out.println(); - - for (Entry> entry : middleTasks.entrySet()) { - Distribution durationDistribution = new Distribution(); - Distribution taskParallelismDistribution = new Distribution(); - - for (SimulationTask task : entry.getValue()) { - long taskStart = Long.MAX_VALUE; - long taskEnd = 0; - long totalCpuTime = 0; - - for (SimulationSplit split : task.getSplits()) { - taskStart = Math.min(taskStart, split.getStartNanos()); - taskEnd = Math.max(taskEnd, split.getDoneNanos()); - totalCpuTime += MILLISECONDS.toNanos(split.getRequiredProcessMillis()); - } - - Duration taskDuration = new Duration(taskEnd - taskStart, NANOSECONDS).convertTo(MILLISECONDS); - durationDistribution.add(taskDuration.toMillis()); - - double taskParallelism = 1.0 * totalCpuTime / (taskEnd - taskStart); - taskParallelismDistribution.add((long) (taskParallelism * 100)); - } - - System.out.println("Splits " + entry.getKey() + ": Completed " + entry.getValue().size()); - - Map durationPercentiles = durationDistribution.getPercentiles(); - System.out.printf(" wall time ms :: p01 %4s :: p05 %4s :: p10 %4s :: p97 %4s :: p50 %4s :: p75 %4s :: p90 %4s :: p95 %4s :: p99 %4s\n", - durationPercentiles.get(0.01), - durationPercentiles.get(0.05), - durationPercentiles.get(0.10), - durationPercentiles.get(0.25), - durationPercentiles.get(0.50), - durationPercentiles.get(0.75), - durationPercentiles.get(0.90), - durationPercentiles.get(0.95), - durationPercentiles.get(0.99)); - - Map parallelismPercentiles = taskParallelismDistribution.getPercentiles(); - System.out.printf(" parallelism :: p99 %4.2f :: p95 %4.2f :: p90 %4.2f :: p75 %4.2f :: p50 %4.2f :: p25 %4.2f :: p10 %4.2f :: p05 %4.2f :: p01 %4.2f\n", - parallelismPercentiles.get(0.99) / 100.0, - parallelismPercentiles.get(0.95) / 100.0, - parallelismPercentiles.get(0.90) / 100.0, - parallelismPercentiles.get(0.75) / 100.0, - parallelismPercentiles.get(0.50) / 100.0, - parallelismPercentiles.get(0.25) / 100.0, - parallelismPercentiles.get(0.10) / 100.0, - parallelismPercentiles.get(0.05) / 100.0, - parallelismPercentiles.get(0.01) / 100.0); - } - } - Thread.sleep(10); - } - - private ListenableFuture createUser(String userId, - int splitsPerTask, - TaskExecutor taskExecutor, - AtomicBoolean done, - Multimap tasks) - { - return executor.submit((Callable) () -> { - int taskId = 0; - while (!done.get()) { - SimulationTask task = new SimulationTask(taskExecutor, new TaskId(userId, 0, taskId++)); - task.schedule(splitsPerTask, executor, new Duration(0, MILLISECONDS)).get(); - task.destroy(); - - printTaskCompletion(task); - - tasks.put(splitsPerTask, task); - } - return null; - }); - } - - private synchronized void printTaskCompletion(SimulationTask task) - { - if (!PRINT_TASK_COMPLETION) { - return; - } - - long taskStart = Long.MAX_VALUE; - long taskEnd = 0; - long taskQueuedTime = 0; - long totalCpuTime = 0; - - for (SimulationSplit split : task.getSplits()) { - taskStart = Math.min(taskStart, split.getStartNanos()); - taskEnd = Math.max(taskEnd, split.getDoneNanos()); - taskQueuedTime += split.getQueuedNanos(); - totalCpuTime += MILLISECONDS.toNanos(split.getRequiredProcessMillis()); - } - - System.out.printf("%-12s %8s %8s %.2f\n", - task.getTaskId() + ":", - new Duration(taskQueuedTime, NANOSECONDS).convertTo(MILLISECONDS), - new Duration(taskEnd - taskStart, NANOSECONDS).convertTo(MILLISECONDS), - 1.0 * totalCpuTime / (taskEnd - taskStart) - ); - - // print split info - if (PRINT_SPLIT_COMPLETION) { - for (SimulationSplit split : task.getSplits()) { - Duration totalQueueTime = new Duration(split.getQueuedNanos(), NANOSECONDS).convertTo(MILLISECONDS); - Duration executionWallTime = new Duration(split.getDoneNanos() - split.getStartNanos(), NANOSECONDS).convertTo(MILLISECONDS); - Duration totalWallTime = new Duration(split.getDoneNanos() - split.getCreatedNanos(), NANOSECONDS).convertTo(MILLISECONDS); - System.out.printf(" %8s %8s %8s\n", totalQueueTime, executionWallTime, totalWallTime); - } - - System.out.println(); - } - } - - private static class SimulationTask - { - private final long createdNanos = System.nanoTime(); - - private final TaskExecutor taskExecutor; - private final Object taskId; - - private final List splits = new ArrayList<>(); - private final List> splitFutures = new ArrayList<>(); - private final TaskHandle taskHandle; - - private SimulationTask(TaskExecutor taskExecutor, TaskId taskId) - { - this.taskExecutor = taskExecutor; - this.taskId = taskId; - taskHandle = taskExecutor.addTask(taskId, () -> 0, 10, new Duration(1, MILLISECONDS)); - } - - public void destroy() - { - taskExecutor.removeTask(taskHandle); - } - - public ListenableFuture schedule(int splits, ExecutorService executor, Duration entryDelay) - { - SettableFuture future = SettableFuture.create(); - - executor.submit((Runnable) () -> { - try { - for (int splitId = 0; splitId < splits; splitId++) { - SimulationSplit split = new SimulationSplit(new Duration(80, MILLISECONDS), new Duration(1, MILLISECONDS)); - SimulationTask.this.splits.add(split); - splitFutures.addAll(taskExecutor.enqueueSplits(taskHandle, false, ImmutableList.of(split))); - Thread.sleep(entryDelay.toMillis()); - } - - Futures.allAsList(splitFutures).get(); - future.set(null); - } - catch (Throwable e) { - future.setException(e); - throw Throwables.propagate(e); - } - }); - - return future; - } - - private Object getTaskId() - { - return taskId; - } - - private long getCreatedNanos() - { - return createdNanos; - } - - private List getSplits() - { - return splits; - } - } - - private static class SimulationSplit - implements SplitRunner - { - private final long requiredProcessMillis; - private final long processMillisPerCall; - private final AtomicLong completedProcessMillis = new AtomicLong(); - - private final AtomicInteger calls = new AtomicInteger(0); - private final long createdNanos = System.nanoTime(); - private final AtomicLong startNanos = new AtomicLong(-1); - private final AtomicLong doneNanos = new AtomicLong(-1); - - private final AtomicLong queuedNanos = new AtomicLong(); - - private long lastCallNanos = createdNanos; - - private SimulationSplit(Duration requiredProcessTime, Duration processTimePerCall) - { - this.requiredProcessMillis = requiredProcessTime.toMillis(); - this.processMillisPerCall = processTimePerCall.toMillis(); - } - - private long getRequiredProcessMillis() - { - return requiredProcessMillis; - } - - private long getCreatedNanos() - { - return createdNanos; - } - - private long getStartNanos() - { - return startNanos.get(); - } - - private long getDoneNanos() - { - return doneNanos.get(); - } - - private long getQueuedNanos() - { - return queuedNanos.get(); - } - - @Override - public boolean isFinished() - { - return doneNanos.get() >= 0; - } - - @Override - public void close() - { - } - - @Override - public ListenableFuture processFor(Duration duration) - throws Exception - { - long callStart = System.nanoTime(); - startNanos.compareAndSet(-1, callStart); - calls.incrementAndGet(); - queuedNanos.addAndGet(callStart - lastCallNanos); - - long processMillis = Math.min(requiredProcessMillis - completedProcessMillis.get(), processMillisPerCall); - MILLISECONDS.sleep(processMillis); - long completedMillis = completedProcessMillis.addAndGet(processMillis); - - boolean isFinished = completedMillis >= requiredProcessMillis; - long callEnd = System.nanoTime(); - lastCallNanos = callEnd; - if (isFinished) { - doneNanos.compareAndSet(-1, callEnd); - } - - return Futures.immediateCheckedFuture(null); - } - - @Override - public String getInfo() - { - return "simulation-split"; - } - } -} diff --git a/presto-main/src/test/java/com/facebook/presto/execution/TestCreateTableTask.java b/presto-main/src/test/java/com/facebook/presto/execution/TestCreateTableTask.java new file mode 100644 index 000000000000..394d4cbe7dc4 --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/execution/TestCreateTableTask.java @@ -0,0 +1,189 @@ + +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.execution; + +import com.facebook.presto.Session; +import com.facebook.presto.connector.ConnectorId; +import com.facebook.presto.metadata.AbstractMockMetadata; +import com.facebook.presto.metadata.Catalog; +import com.facebook.presto.metadata.CatalogManager; +import com.facebook.presto.metadata.QualifiedObjectName; +import com.facebook.presto.metadata.TableHandle; +import com.facebook.presto.metadata.TablePropertyManager; +import com.facebook.presto.security.AllowAllAccessControl; +import com.facebook.presto.spi.ColumnHandle; +import com.facebook.presto.spi.ConnectorTableMetadata; +import com.facebook.presto.spi.PrestoException; +import com.facebook.presto.spi.type.Type; +import com.facebook.presto.spi.type.TypeManager; +import com.facebook.presto.spi.type.TypeSignature; +import com.facebook.presto.sql.tree.ColumnDefinition; +import com.facebook.presto.sql.tree.CreateTable; +import com.facebook.presto.sql.tree.QualifiedName; +import com.facebook.presto.transaction.TransactionManager; +import com.facebook.presto.type.TypeRegistry; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import org.testng.annotations.BeforeMethod; +import org.testng.annotations.Test; + +import java.util.Optional; +import java.util.concurrent.atomic.AtomicInteger; + +import static com.facebook.presto.spi.StandardErrorCode.ALREADY_EXISTS; +import static com.facebook.presto.spi.session.PropertyMetadata.stringSessionProperty; +import static com.facebook.presto.testing.TestingSession.createBogusTestingCatalog; +import static com.facebook.presto.testing.TestingSession.testSessionBuilder; +import static com.facebook.presto.transaction.TransactionManager.createTestTransactionManager; +import static io.airlift.concurrent.MoreFutures.getFutureValue; +import static java.util.Collections.emptyList; +import static java.util.Objects.requireNonNull; +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertTrue; +import static org.testng.Assert.fail; + +@Test(singleThreaded = true) +public class TestCreateTableTask +{ + private static final String CATALOG_NAME = "catalog"; + private CatalogManager catalogManager; + private TypeManager typeManager; + private TransactionManager transactionManager; + private TablePropertyManager tablePropertyManager; + private Catalog testCatalog; + private Session testSession; + private MockMetadata metadata; + + @BeforeMethod + public void setUp() + throws Exception + { + catalogManager = new CatalogManager(); + typeManager = new TypeRegistry(); + transactionManager = createTestTransactionManager(catalogManager); + tablePropertyManager = new TablePropertyManager(); + testCatalog = createBogusTestingCatalog(CATALOG_NAME); + catalogManager.registerCatalog(testCatalog); + tablePropertyManager.addProperties(testCatalog.getConnectorId(), + ImmutableList.of(stringSessionProperty("baz", "test property", null, false))); + testSession = testSessionBuilder() + .setTransactionId(transactionManager.beginTransaction(false)) + .build(); + metadata = new MockMetadata(typeManager, + tablePropertyManager, + testCatalog.getConnectorId()); + } + + @Test + public void testCreateTableNotExistsTrue() + throws Exception + { + CreateTable statement = new CreateTable(QualifiedName.of("test_table"), + ImmutableList.of(new ColumnDefinition("a", "BIGINT", Optional.empty())), + true, + ImmutableMap.of(), + Optional.empty()); + + getFutureValue(new CreateTableTask().internalExecute(statement, metadata, new AllowAllAccessControl(), testSession, emptyList())); + assertEquals(metadata.getCreateTableCallCount(), 1); + } + + @Test + public void testCreateTableNotExistsFalse() + throws Exception + { + CreateTable statement = new CreateTable(QualifiedName.of("test_table"), + ImmutableList.of(new ColumnDefinition("a", "BIGINT", Optional.empty())), + false, + ImmutableMap.of(), + Optional.empty()); + + try { + getFutureValue(new CreateTableTask().internalExecute(statement, metadata, new AllowAllAccessControl(), testSession, emptyList())); + fail("expected exception"); + } + catch (RuntimeException e) { + // Expected + assertTrue(e instanceof PrestoException); + PrestoException prestoException = (PrestoException) e; + assertTrue(prestoException.getErrorCode().equals(ALREADY_EXISTS.toErrorCode())); + } + assertEquals(metadata.getCreateTableCallCount(), 1); + } + + private static class MockMetadata + extends AbstractMockMetadata + { + private final TypeManager typeManager; + private final TablePropertyManager tablePropertyManager; + private final ConnectorId catalogHandle; + private AtomicInteger createTableCallCount = new AtomicInteger(); + + public MockMetadata( + TypeManager typeManager, + TablePropertyManager tablePropertyManager, + ConnectorId catalogHandle) + { + this.typeManager = requireNonNull(typeManager, "typeManager is null"); + this.tablePropertyManager = requireNonNull(tablePropertyManager, "tablePropertyManager is null"); + this.catalogHandle = requireNonNull(catalogHandle, "catalogHandle is null"); + } + + @Override + public void createTable(Session session, String catalogName, ConnectorTableMetadata tableMetadata) + { + createTableCallCount.incrementAndGet(); + throw new PrestoException(ALREADY_EXISTS, "Table already exists"); + } + + @Override + public TablePropertyManager getTablePropertyManager() + { + return tablePropertyManager; + } + + @Override + public Type getType(TypeSignature signature) + { + return typeManager.getType(signature); + } + + @Override + public Optional getCatalogHandle(Session session, String catalogName) + { + if (catalogHandle.getCatalogName().equals(catalogName)) { + return Optional.of(catalogHandle); + } + return Optional.empty(); + } + + @Override + public Optional getTableHandle(Session session, QualifiedObjectName tableName) + { + return Optional.empty(); + } + + public int getCreateTableCallCount() + { + return createTableCallCount.get(); + } + + @Override + public void dropColumn(Session session, TableHandle tableHandle, ColumnHandle column) + { + throw new UnsupportedOperationException(); + } + } +} diff --git a/presto-main/src/test/java/com/facebook/presto/execution/TestTaskExecutor.java b/presto-main/src/test/java/com/facebook/presto/execution/TestTaskExecutor.java deleted file mode 100644 index 7798318a3b1a..000000000000 --- a/presto-main/src/test/java/com/facebook/presto/execution/TestTaskExecutor.java +++ /dev/null @@ -1,233 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package com.facebook.presto.execution; - -import com.facebook.presto.execution.executor.TaskExecutor; -import com.facebook.presto.execution.executor.TaskHandle; -import com.google.common.collect.ImmutableList; -import com.google.common.util.concurrent.Futures; -import com.google.common.util.concurrent.ListenableFuture; -import io.airlift.testing.TestingTicker; -import io.airlift.units.Duration; -import org.testng.annotations.Test; - -import java.util.concurrent.Phaser; -import java.util.concurrent.TimeUnit; -import java.util.concurrent.atomic.AtomicInteger; - -import static com.google.common.collect.Iterables.getOnlyElement; -import static java.util.concurrent.TimeUnit.MILLISECONDS; -import static org.testng.Assert.assertEquals; - -public class TestTaskExecutor -{ - @Test(invocationCount = 100) - public void test() - throws Exception - { - TestingTicker ticker = new TestingTicker(); - TaskExecutor taskExecutor = new TaskExecutor(4, 8, ticker); - taskExecutor.start(); - ticker.increment(20, MILLISECONDS); - - try { - TaskId taskId = new TaskId("test", 0, 0); - TaskHandle taskHandle = taskExecutor.addTask(taskId, () -> 0, 10, new Duration(1, MILLISECONDS)); - - Phaser beginPhase = new Phaser(); - beginPhase.register(); - Phaser verificationComplete = new Phaser(); - verificationComplete.register(); - - // add two jobs - TestingJob driver1 = new TestingJob(beginPhase, verificationComplete, 10); - ListenableFuture future1 = getOnlyElement(taskExecutor.enqueueSplits(taskHandle, true, ImmutableList.of(driver1))); - TestingJob driver2 = new TestingJob(beginPhase, verificationComplete, 10); - ListenableFuture future2 = getOnlyElement(taskExecutor.enqueueSplits(taskHandle, true, ImmutableList.of(driver2))); - assertEquals(driver1.getCompletedPhases(), 0); - assertEquals(driver2.getCompletedPhases(), 0); - - // verify worker have arrived but haven't processed yet - beginPhase.arriveAndAwaitAdvance(); - assertEquals(driver1.getCompletedPhases(), 0); - assertEquals(driver2.getCompletedPhases(), 0); - ticker.increment(10, MILLISECONDS); - assertEquals(taskExecutor.getMaxActiveSplitTime(), 10); - verificationComplete.arriveAndAwaitAdvance(); - - // advance one phase and verify - beginPhase.arriveAndAwaitAdvance(); - assertEquals(driver1.getCompletedPhases(), 1); - assertEquals(driver2.getCompletedPhases(), 1); - - verificationComplete.arriveAndAwaitAdvance(); - - // add one more job - TestingJob driver3 = new TestingJob(beginPhase, verificationComplete, 10); - ListenableFuture future3 = getOnlyElement(taskExecutor.enqueueSplits(taskHandle, false, ImmutableList.of(driver3))); - - // advance one phase and verify - beginPhase.arriveAndAwaitAdvance(); - assertEquals(driver1.getCompletedPhases(), 2); - assertEquals(driver2.getCompletedPhases(), 2); - assertEquals(driver3.getCompletedPhases(), 0); - verificationComplete.arriveAndAwaitAdvance(); - - // advance to the end of the first two task and verify - beginPhase.arriveAndAwaitAdvance(); - for (int i = 0; i < 7; i++) { - verificationComplete.arriveAndAwaitAdvance(); - beginPhase.arriveAndAwaitAdvance(); - assertEquals(beginPhase.getPhase(), verificationComplete.getPhase() + 1); - } - assertEquals(driver1.getCompletedPhases(), 10); - assertEquals(driver2.getCompletedPhases(), 10); - assertEquals(driver3.getCompletedPhases(), 8); - future1.get(1, TimeUnit.SECONDS); - future2.get(1, TimeUnit.SECONDS); - verificationComplete.arriveAndAwaitAdvance(); - - // advance two more times and verify - beginPhase.arriveAndAwaitAdvance(); - verificationComplete.arriveAndAwaitAdvance(); - beginPhase.arriveAndAwaitAdvance(); - assertEquals(driver1.getCompletedPhases(), 10); - assertEquals(driver2.getCompletedPhases(), 10); - assertEquals(driver3.getCompletedPhases(), 10); - future3.get(1, TimeUnit.SECONDS); - verificationComplete.arriveAndAwaitAdvance(); - - assertEquals(driver1.getFirstPhase(), 0); - assertEquals(driver2.getFirstPhase(), 0); - assertEquals(driver3.getFirstPhase(), 2); - - assertEquals(driver1.getLastPhase(), 10); - assertEquals(driver2.getLastPhase(), 10); - assertEquals(driver3.getLastPhase(), 12); - - // no splits remaining - ticker.increment(30, MILLISECONDS); - assertEquals(taskExecutor.getMaxActiveSplitTime(), 0); - } - finally { - taskExecutor.stop(); - } - } - - @Test - public void testTaskHandle() - throws Exception - { - TaskExecutor taskExecutor = new TaskExecutor(4, 8); - taskExecutor.start(); - - try { - TaskId taskId = new TaskId("test", 0, 0); - TaskHandle taskHandle = taskExecutor.addTask(taskId, () -> 0, 10, new Duration(1, MILLISECONDS)); - - Phaser beginPhase = new Phaser(); - beginPhase.register(); - Phaser verificationComplete = new Phaser(); - verificationComplete.register(); - TestingJob driver1 = new TestingJob(beginPhase, verificationComplete, 10); - TestingJob driver2 = new TestingJob(beginPhase, verificationComplete, 10); - - // force enqueue a split - taskExecutor.enqueueSplits(taskHandle, true, ImmutableList.of(driver1)); - assertEquals(taskHandle.getRunningLeafSplits(), 0); - - // normal enqueue a split - taskExecutor.enqueueSplits(taskHandle, false, ImmutableList.of(driver2)); - assertEquals(taskHandle.getRunningLeafSplits(), 1); - - // let the split continue to run - beginPhase.arriveAndDeregister(); - verificationComplete.arriveAndDeregister(); - } - finally { - taskExecutor.stop(); - } - } - - private static class TestingJob - implements SplitRunner - { - private final Phaser awaitWorkers; - private final Phaser awaitVerifiers; - private final int requiredPhases; - private final AtomicInteger completedPhases = new AtomicInteger(); - - private final AtomicInteger firstPhase = new AtomicInteger(-1); - private final AtomicInteger lastPhase = new AtomicInteger(-1); - - public TestingJob(Phaser awaitWorkers, Phaser awaitVerifiers, int requiredPhases) - { - this.awaitWorkers = awaitWorkers; - this.awaitVerifiers = awaitVerifiers; - this.requiredPhases = requiredPhases; - awaitWorkers.register(); - awaitVerifiers.register(); - } - - private int getFirstPhase() - { - return firstPhase.get(); - } - - private int getLastPhase() - { - return lastPhase.get(); - } - - private int getCompletedPhases() - { - return completedPhases.get(); - } - - @Override - public ListenableFuture processFor(Duration duration) - throws Exception - { - int phase = awaitWorkers.arriveAndAwaitAdvance(); - firstPhase.compareAndSet(-1, phase - 1); - lastPhase.set(phase); - awaitVerifiers.arriveAndAwaitAdvance(); - - completedPhases.getAndIncrement(); - return Futures.immediateFuture(null); - } - - @Override - public String getInfo() - { - return "testing-split"; - } - - @Override - public boolean isFinished() - { - boolean isFinished = completedPhases.get() >= requiredPhases; - if (isFinished) { - awaitVerifiers.arriveAndDeregister(); - awaitWorkers.arriveAndDeregister(); - } - return isFinished; - } - - @Override - public void close() - { - } - } -} diff --git a/presto-main/src/test/java/com/facebook/presto/execution/TestTaskManagerConfig.java b/presto-main/src/test/java/com/facebook/presto/execution/TestTaskManagerConfig.java index 5f398bda3c18..948567b52b39 100644 --- a/presto-main/src/test/java/com/facebook/presto/execution/TestTaskManagerConfig.java +++ b/presto-main/src/test/java/com/facebook/presto/execution/TestTaskManagerConfig.java @@ -18,6 +18,7 @@ import io.airlift.units.Duration; import org.testng.annotations.Test; +import java.math.BigDecimal; import java.util.Map; import java.util.concurrent.TimeUnit; @@ -51,7 +52,10 @@ public void testDefaults() .setTaskConcurrency(16) .setHttpResponseThreads(100) .setHttpTimeoutThreads(3) - .setTaskNotificationThreads(5)); + .setTaskNotificationThreads(5) + .setLevelAbsolutePriority(true) + .setLevelTimeMultiplier(new BigDecimal("2")) + .setLegacySchedulingBehavior(true)); } @Test @@ -78,6 +82,9 @@ public void testExplicitPropertyMappings() .put("task.http-response-threads", "4") .put("task.http-timeout-threads", "10") .put("task.task-notification-threads", "13") + .put("task.level-absolute-priority", "false") + .put("task.level-time-multiplier", "2.1") + .put("task.legacy-scheduling-behavior", "false") .build(); TaskManagerConfig expected = new TaskManagerConfig() @@ -100,7 +107,10 @@ public void testExplicitPropertyMappings() .setTaskConcurrency(8) .setHttpResponseThreads(4) .setHttpTimeoutThreads(10) - .setTaskNotificationThreads(13); + .setTaskNotificationThreads(13) + .setLevelAbsolutePriority(false) + .setLevelTimeMultiplier(new BigDecimal("2.1")) + .setLegacySchedulingBehavior(false); assertFullMapping(properties, expected); } diff --git a/presto-main/src/test/java/com/facebook/presto/execution/executor/Histogram.java b/presto-main/src/test/java/com/facebook/presto/execution/executor/Histogram.java new file mode 100644 index 000000000000..7f6ffbe5948c --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/execution/executor/Histogram.java @@ -0,0 +1,186 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.execution.executor; + +import com.google.common.collect.ImmutableList; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.List; +import java.util.SortedMap; +import java.util.TreeMap; +import java.util.function.Function; + +import static com.google.common.base.Preconditions.checkArgument; + +class Histogram> +{ + private final List buckets; + private final boolean discrete; + + private Histogram(Collection buckets, boolean discrete) + { + this.buckets = new ArrayList<>(buckets); + this.discrete = discrete; + Collections.sort(this.buckets); + } + + public static > Histogram fromDiscrete(Collection buckets) + { + return new Histogram<>(buckets, true); + } + + public static > Histogram fromContinuous(Collection buckets) + { + return new Histogram<>(buckets, false); + } + + public static Histogram fromContinuous(Collection initialData, Function keyFunction) + { + if (initialData.isEmpty()) { + return new Histogram<>(ImmutableList.of(), false); + } + + int numBuckets = Math.min(10, (int) Math.sqrt(initialData.size())); + long min = initialData.stream() + .mapToLong(keyFunction::apply) + .min() + .getAsLong(); + long max = initialData.stream() + .mapToLong(keyFunction::apply) + .max() + .getAsLong(); + + checkArgument(max > min); + + long bucketSize = (max - min) / numBuckets; + long bucketRemainder = (max - min) % numBuckets; + + List minimums = new ArrayList<>(); + + long currentMin = min; + for (int i = 0; i < numBuckets; i++) { + minimums.add(currentMin); + long currentMax = currentMin + bucketSize; + if (bucketRemainder > 0) { + currentMax++; + bucketRemainder--; + } + currentMin = currentMax + 1; + } + + minimums.add(numBuckets, currentMin); + + return new Histogram<>(minimums, false); + } + + public void printDistribution( + Collection data, + Function keyFunction, + Function keyFormatter) + { + if (buckets.isEmpty()) { + System.out.println("No buckets"); + return; + } + + if (data.isEmpty()) { + System.out.println("No data"); + return; + } + + long[] bucketData = new long[buckets.size()]; + + for (D datum : data) { + K key = keyFunction.apply(datum); + + for (int i = 0; i < buckets.size(); i++) { + if (key.compareTo(buckets.get(i)) >= 0 && (i == (buckets.size() - 1) || key.compareTo(buckets.get(i + 1)) < 0)) { + bucketData[i]++; + break; + } + } + } + + if (!discrete) { + for (int i = 0; i < bucketData.length - 1; i++) { + System.out.printf("%8s - %8s : (%5s values)\n", + keyFormatter.apply(buckets.get(i)), + keyFormatter.apply(buckets.get(i + 1)), + bucketData[i]); + } + } + else { + for (int i = 0; i < bucketData.length; i++) { + System.out.printf("%8s : (%5s values)\n", + keyFormatter.apply(buckets.get(i)), + bucketData[i]); + } + } + } + + public void printDistribution( + Collection data, + Function keyFunction, + Function valueFunction, + Function keyFormatter, + Function, G> valueFormatter) + { + if (buckets.isEmpty()) { + System.out.println("No buckets"); + return; + } + + if (data.isEmpty()) { + System.out.println("No data"); + return; + } + + SortedMap> bucketData = new TreeMap<>(); + for (int i = 0; i < buckets.size(); i++) { + bucketData.put(i, new ArrayList<>()); + } + + for (D datum : data) { + K key = keyFunction.apply(datum); + V value = valueFunction.apply(datum); + + for (int i = 0; i < buckets.size(); i++) { + if (key.compareTo(buckets.get(i)) >= 0 && (i == (buckets.size() - 1) || key.compareTo(buckets.get(i + 1)) < 0)) { + bucketData.get(i).add(value); + break; + } + } + } + + if (!discrete) { + for (int i = 0; i < bucketData.size() - 1; i++) { + System.out.printf("%8s - %8s : (%5s values) %s\n", + keyFormatter.apply(buckets.get(i)), + keyFormatter.apply(buckets.get(i + 1)), + bucketData.get(i).size(), + valueFormatter.apply(bucketData.get(i))); + } + } + else { + for (int i = 0; i < bucketData.size(); i++) { + System.out.printf("%19s : (%5s values) %s\n", + keyFormatter.apply(buckets.get(i)), + bucketData.get(i).size(), + valueFormatter.apply(bucketData.get(i))); + } + } + } +} diff --git a/presto-main/src/test/java/com/facebook/presto/execution/executor/SimulationController.java b/presto-main/src/test/java/com/facebook/presto/execution/executor/SimulationController.java new file mode 100644 index 000000000000..bc5dddc68078 --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/execution/executor/SimulationController.java @@ -0,0 +1,272 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.execution.executor; + +import com.facebook.presto.execution.TaskId; +import com.facebook.presto.execution.executor.SimulationTask.IntermediateTask; +import com.facebook.presto.execution.executor.SimulationTask.LeafTask; +import com.facebook.presto.execution.executor.SplitGenerators.SplitGenerator; +import com.google.common.collect.ArrayListMultimap; +import com.google.common.collect.ListMultimap; +import com.google.common.collect.Multimaps; + +import java.util.Map; +import java.util.OptionalInt; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.function.BiConsumer; + +import static com.facebook.presto.execution.executor.SimulationController.TaskSpecification.Type.LEAF; +import static com.facebook.presto.execution.executor.TaskExecutor.GUARANTEED_SPLITS_PER_TASK; +import static java.util.concurrent.Executors.newSingleThreadExecutor; +import static java.util.concurrent.TimeUnit.MILLISECONDS; + +class SimulationController +{ + private final TaskExecutor taskExecutor; + private final BiConsumer callback; + + private final ExecutorService controllerExecutor = newSingleThreadExecutor(); + + private final Map specificationEnabled = new ConcurrentHashMap<>(); + private final ListMultimap runningTasks = Multimaps.synchronizedListMultimap(ArrayListMultimap.create()); + + private final ListMultimap completedTasks = Multimaps.synchronizedListMultimap(ArrayListMultimap.create()); + private final AtomicBoolean clearPendingQueue = new AtomicBoolean(); + + private final AtomicBoolean stopped = new AtomicBoolean(); + + public SimulationController(TaskExecutor taskExecutor, BiConsumer callback) + { + this.taskExecutor = taskExecutor; + this.callback = callback; + } + + public synchronized void addTaskSpecification(TaskSpecification spec) + { + specificationEnabled.put(spec, false); + } + + public synchronized void clearPendingQueue() + { + System.out.println("Clearing pending queue.."); + clearPendingQueue.set(true); + } + + public synchronized void stop() + { + stopped.set(true); + controllerExecutor.shutdownNow(); + taskExecutor.stop(); + } + + public synchronized void enableSpecification(TaskSpecification specification) + { + specificationEnabled.replace(specification, false, true); + startSpec(specification); + } + + public synchronized void disableSpecification(TaskSpecification specification) + { + if (specificationEnabled.replace(specification, true, false) && callback != null) { + runCallback(); + } + } + + public synchronized void runCallback() + { + callback.accept(this, taskExecutor); + } + + public void run() + { + controllerExecutor.submit(() -> { + while (!stopped.get()) { + replaceCompletedTasks(); + scheduleSplitsForRunningTasks(); + + try { + MILLISECONDS.sleep(500); + } + catch (InterruptedException e) { + return; + } + } + }); + } + + private synchronized void scheduleSplitsForRunningTasks() + { + if (clearPendingQueue.get()) { + if (taskExecutor.getWaitingSplits() > (taskExecutor.getIntermediateSplits() - taskExecutor.getBlockedSplits())) { + return; + } + + System.out.println("Cleared pending queue."); + clearPendingQueue.set(false); + } + + for (TaskSpecification specification : specificationEnabled.keySet()) { + if (!specificationEnabled.get(specification)) { + continue; + } + + for (SimulationTask task : runningTasks.get(specification)) { + if (specification.getType() == LEAF) { + int remainingSplits = specification.getNumSplitsPerTask() - (task.getRunningSplits().size() + task.getCompletedSplits().size()); + int candidateSplits = GUARANTEED_SPLITS_PER_TASK - task.getRunningSplits().size(); + for (int i = 0; i < Math.min(remainingSplits, candidateSplits); i++) { + task.schedule(taskExecutor, 1); + } + } + else { + int remainingSplits = specification.getNumSplitsPerTask() - (task.getRunningSplits().size() + task.getCompletedSplits().size()); + task.schedule(taskExecutor, remainingSplits); + } + } + } + } + + private synchronized void replaceCompletedTasks() + { + boolean moved; + do { + moved = false; + + for (TaskSpecification specification : specificationEnabled.keySet()) { + if (specification.getTotalTasks().isPresent() && + specificationEnabled.get(specification) && + specification.getTotalTasks().getAsInt() <= completedTasks.get(specification).size() + runningTasks.get(specification).size()) { + System.out.println(); + System.out.println(specification.getName() + " disabled for reaching target count " + specification.getTotalTasks()); + System.out.println(); + disableSpecification(specification); + continue; + } + for (SimulationTask task : runningTasks.get(specification)) { + if (task.getCompletedSplits().size() >= specification.getNumSplitsPerTask()) { + completedTasks.put(specification, task); + runningTasks.remove(specification, task); + taskExecutor.removeTask(task.getTaskHandle()); + + if (!specificationEnabled.get(specification)) { + continue; + } + + createTask(specification); + moved = true; + break; + } + } + } + } while (moved); + } + + private void createTask(TaskSpecification specification) + { + if (specification.getType() == LEAF) { + runningTasks.put(specification, new LeafTask( + taskExecutor, + specification, + new TaskId(specification.getName(), 0, runningTasks.get(specification).size() + completedTasks.get(specification).size()))); + } + else { + runningTasks.put(specification, new IntermediateTask( + taskExecutor, + specification, + new TaskId(specification.getName(), 0, runningTasks.get(specification).size() + completedTasks.get(specification).size()))); + } + } + + public Map getSpecificationEnabled() + { + return specificationEnabled; + } + + public ListMultimap getRunningTasks() + { + return runningTasks; + } + + public ListMultimap getCompletedTasks() + { + return completedTasks; + } + + private void startSpec(TaskSpecification specification) + { + if (!specificationEnabled.get(specification)) { + return; + } + for (int i = 0; i < specification.getNumConcurrentTasks(); i++) { + createTask(specification); + } + } + + public static class TaskSpecification + { + enum Type { + LEAF, + INTERMEDIATE + } + + private final Type type; + private final String name; + private final OptionalInt totalTasks; + private final int numConcurrentTasks; + private final int numSplitsPerTask; + private final SplitGenerator splitGenerator; + + TaskSpecification(Type type, String name, OptionalInt totalTasks, int numConcurrentTasks, int numSplitsPerTask, SplitGenerator splitGenerator) + { + this.type = type; + this.name = name; + this.totalTasks = totalTasks; + this.numConcurrentTasks = numConcurrentTasks; + this.numSplitsPerTask = numSplitsPerTask; + this.splitGenerator = splitGenerator; + } + + Type getType() + { + return type; + } + + String getName() + { + return name; + } + + int getNumConcurrentTasks() + { + return numConcurrentTasks; + } + + int getNumSplitsPerTask() + { + return numSplitsPerTask; + } + + OptionalInt getTotalTasks() + { + return totalTasks; + } + + SplitSpecification nextSpecification() + { + return splitGenerator.next(); + } + } +} diff --git a/presto-main/src/test/java/com/facebook/presto/execution/executor/SimulationSplit.java b/presto-main/src/test/java/com/facebook/presto/execution/executor/SimulationSplit.java new file mode 100644 index 000000000000..1ea94a672bea --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/execution/executor/SimulationSplit.java @@ -0,0 +1,292 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.execution.executor; + +import com.facebook.presto.execution.SplitRunner; +import com.google.common.util.concurrent.Futures; +import com.google.common.util.concurrent.ListenableFuture; +import com.google.common.util.concurrent.SettableFuture; +import io.airlift.units.Duration; + +import java.util.concurrent.RejectedExecutionException; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicLong; + +import static com.facebook.presto.operator.Operator.NOT_BLOCKED; +import static io.airlift.units.Duration.succinctNanos; +import static java.util.Objects.requireNonNull; +import static java.util.concurrent.TimeUnit.NANOSECONDS; + +abstract class SimulationSplit + implements SplitRunner +{ + private final SimulationTask task; + + private final AtomicInteger calls = new AtomicInteger(0); + + private final long createdNanos = System.nanoTime(); + private final AtomicLong completedProcessNanos = new AtomicLong(); + private final AtomicLong startNanos = new AtomicLong(-1); + private final AtomicLong doneNanos = new AtomicLong(-1); + private final AtomicLong waitNanos = new AtomicLong(); + private final AtomicLong lastReadyTime = new AtomicLong(-1); + private final AtomicBoolean killed = new AtomicBoolean(false); + + private final long scheduledTimeNanos; + + SimulationSplit(SimulationTask task, long scheduledTimeNanos) + { + this.task = requireNonNull(task, "task is null"); + this.scheduledTimeNanos = scheduledTimeNanos; + } + + long getCreatedNanos() + { + return createdNanos; + } + + long getCompletedProcessNanos() + { + return completedProcessNanos.get(); + } + + long getStartNanos() + { + return startNanos.get(); + } + + long getDoneNanos() + { + return doneNanos.get(); + } + + long getWaitNanos() + { + return waitNanos.get(); + } + + int getCalls() + { + return calls.get(); + } + + long getScheduledTimeNanos() + { + return scheduledTimeNanos; + } + + String getTaskId() + { + return task.getTaskId().toString(); + } + + SimulationTask getTask() + { + return task; + } + + boolean isKilled() + { + return killed.get(); + } + + void setKilled() + { + waitNanos.addAndGet(System.nanoTime() - lastReadyTime.get()); + killed.set(true); + task.setKilled(); + } + + @Override + public boolean isFinished() + { + return doneNanos.get() >= 0; + } + + @Override + public void close() + { + } + + abstract boolean process(); + + abstract ListenableFuture getProcessResult(); + + void setSplitReady() + { + lastReadyTime.set(System.nanoTime()); + } + + @Override + public ListenableFuture processFor(Duration duration) + { + calls.incrementAndGet(); + + long callStart = System.nanoTime(); + startNanos.compareAndSet(-1, callStart); + lastReadyTime.compareAndSet(-1, callStart); + waitNanos.addAndGet(callStart - lastReadyTime.get()); + + boolean done = process(); + + long callEnd = System.nanoTime(); + + completedProcessNanos.addAndGet(callEnd - callStart); + + if (done) { + doneNanos.compareAndSet(-1, callEnd); + + if (!isKilled()) { + task.splitComplete(this); + } + + return Futures.immediateCheckedFuture(null); + } + + ListenableFuture processResult = getProcessResult(); + if (processResult.isDone()) { + setSplitReady(); + } + + return processResult; + } + + static class LeafSplit + extends SimulationSplit + { + private final long perQuantaNanos; + + public LeafSplit(SimulationTask task, long scheduledTimeNanos, long perQuantaNanos) + { + super(task, scheduledTimeNanos); + this.perQuantaNanos = perQuantaNanos; + } + + public boolean process() + { + if (getCompletedProcessNanos() >= super.scheduledTimeNanos) { + return true; + } + + long processNanos = Math.min(super.scheduledTimeNanos - getCompletedProcessNanos(), perQuantaNanos); + if (processNanos > 0) { + try { + NANOSECONDS.sleep(processNanos); + } + catch (InterruptedException e) { + setKilled(); + return true; + } + } + + return false; + } + + public ListenableFuture getProcessResult() + { + return NOT_BLOCKED; + } + + @Override + public String getInfo() + { + double pct = (100.0 * getCompletedProcessNanos() / super.scheduledTimeNanos); + return String.format("leaf %3s%% done (total: %8s, per quanta: %8s)", + (int) (pct > 100.00 ? 100.0 : pct), + succinctNanos(super.scheduledTimeNanos), + succinctNanos(perQuantaNanos)); + } + } + + static class IntermediateSplit + extends SimulationSplit + { + private final long wallTimeNanos; + private final long numQuantas; + private final long perQuantaNanos; + private final long betweenQuantaNanos; + + private final ScheduledExecutorService executorService; + + private SettableFuture future = SettableFuture.create(); + private SettableFuture doneFuture = SettableFuture.create(); + + public IntermediateSplit(SimulationTask task, long scheduledTimeNanos, long wallTimeNanos, long numQuantas, long perQuantaNanos, long betweenQuantaNanos, ScheduledExecutorService executorService) + { + super(task, scheduledTimeNanos); + this.wallTimeNanos = wallTimeNanos; + this.numQuantas = numQuantas; + this.perQuantaNanos = perQuantaNanos; + this.betweenQuantaNanos = betweenQuantaNanos; + this.executorService = executorService; + + doneFuture.set(null); + } + + public boolean process() + { + try { + if (getCalls() < numQuantas) { + NANOSECONDS.sleep(perQuantaNanos); + return false; + } + } + catch (InterruptedException ignored) { + setKilled(); + return true; + } + + return true; + } + + public ListenableFuture getProcessResult() + { + future = SettableFuture.create(); + try { + executorService.schedule(() -> { + try { + if (!executorService.isShutdown()) { + future.set(null); + } + else { + setKilled(); + } + setSplitReady(); + } + catch (RuntimeException ignored) { + setKilled(); + } + }, betweenQuantaNanos, NANOSECONDS); + } + catch (RejectedExecutionException ignored) { + setKilled(); + return doneFuture; + } + return future; + } + + @Override + public String getInfo() + { + double pct = (100.0 * getCalls() / numQuantas); + return String.format("intr %3s%% done (wall: %9s, per quanta: %8s, between quanta: %8s)", + (int) (pct > 100.00 ? 100.0 : pct), + succinctNanos(wallTimeNanos), + succinctNanos(perQuantaNanos), + succinctNanos(betweenQuantaNanos)); + } + } +} diff --git a/presto-main/src/test/java/com/facebook/presto/execution/executor/SimulationTask.java b/presto-main/src/test/java/com/facebook/presto/execution/executor/SimulationTask.java new file mode 100644 index 000000000000..b4e292869fbc --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/execution/executor/SimulationTask.java @@ -0,0 +1,170 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.execution.executor; + +import com.facebook.presto.execution.TaskId; +import com.facebook.presto.execution.executor.SimulationController.TaskSpecification; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.Sets; +import io.airlift.units.Duration; + +import java.util.Set; +import java.util.concurrent.atomic.AtomicBoolean; + +import static java.util.concurrent.TimeUnit.SECONDS; + +abstract class SimulationTask +{ + private final TaskSpecification specification; + private final TaskId taskId; + + private final Set runningSplits = Sets.newConcurrentHashSet(); + private final Set completedSplits = Sets.newConcurrentHashSet(); + + private final TaskHandle taskHandle; + private final AtomicBoolean killed = new AtomicBoolean(); + + public SimulationTask(TaskExecutor taskExecutor, TaskSpecification specification, TaskId taskId) + { + this.specification = specification; + this.taskId = taskId; + taskHandle = taskExecutor.addTask(taskId, () -> 0, 10, new Duration(1, SECONDS)); + } + + public void setKilled() + { + killed.set(true); + } + + public boolean isKilled() + { + return killed.get(); + } + + public Set getCompletedSplits() + { + return completedSplits; + } + + TaskId getTaskId() + { + return taskId; + } + + public TaskHandle getTaskHandle() + { + return taskHandle; + } + + public Set getRunningSplits() + { + return runningSplits; + } + + public synchronized void splitComplete(SimulationSplit split) + { + runningSplits.remove(split); + completedSplits.add(split); + } + + public TaskSpecification getSpecification() + { + return specification; + } + + public long getTotalWaitTimeNanos() + { + long runningWaitTime = runningSplits.stream() + .mapToLong(SimulationSplit::getWaitNanos) + .sum(); + + long completedWaitTime = completedSplits.stream() + .mapToLong(SimulationSplit::getWaitNanos) + .sum(); + + return runningWaitTime + completedWaitTime; + } + + public long getProcessedTimeNanos() + { + long runningProcessedTime = runningSplits.stream() + .mapToLong(SimulationSplit::getCompletedProcessNanos) + .sum(); + + long completedProcessedTime = completedSplits.stream() + .mapToLong(SimulationSplit::getCompletedProcessNanos) + .sum(); + + return runningProcessedTime + completedProcessedTime; + } + + public long getScheduledTimeNanos() + { + long runningWallTime = runningSplits.stream() + .mapToLong(SimulationSplit::getScheduledTimeNanos) + .sum(); + + long completedWallTime = completedSplits.stream() + .mapToLong(SimulationSplit::getScheduledTimeNanos) + .sum(); + + return runningWallTime + completedWallTime; + } + + public abstract void schedule(TaskExecutor taskExecutor, int numSplits); + + public static class LeafTask + extends SimulationTask + { + private final TaskSpecification taskSpecification; + + public LeafTask(TaskExecutor taskExecutor, TaskSpecification specification, TaskId taskId) + { + super(taskExecutor, specification, taskId); + this.taskSpecification = specification; + } + + public void schedule(TaskExecutor taskExecutor, int numSplits) + { + ImmutableList.Builder splits = ImmutableList.builder(); + for (int i = 0; i < numSplits; i++) { + splits.add(taskSpecification.nextSpecification().instantiate(this)); + } + super.runningSplits.addAll(splits.build()); + taskExecutor.enqueueSplits(getTaskHandle(), false, splits.build()); + } + } + + public static class IntermediateTask + extends SimulationTask + { + private final SplitSpecification splitSpecification; + + public IntermediateTask(TaskExecutor taskExecutor, TaskSpecification specification, TaskId taskId) + { + super(taskExecutor, specification, taskId); + this.splitSpecification = specification.nextSpecification(); + } + + public void schedule(TaskExecutor taskExecutor, int numSplits) + { + ImmutableList.Builder splits = ImmutableList.builder(); + for (int i = 0; i < numSplits; i++) { + splits.add(splitSpecification.instantiate(this)); + } + super.runningSplits.addAll(splits.build()); + taskExecutor.enqueueSplits(getTaskHandle(), true, splits.build()); + } + } +} diff --git a/presto-main/src/test/java/com/facebook/presto/execution/executor/SplitGenerators.java b/presto-main/src/test/java/com/facebook/presto/execution/executor/SplitGenerators.java new file mode 100644 index 000000000000..2423e16a239d --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/execution/executor/SplitGenerators.java @@ -0,0 +1,347 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.execution.executor; + +import com.facebook.presto.execution.executor.SplitSpecification.IntermediateSplitSpecification; +import com.facebook.presto.execution.executor.SplitSpecification.LeafSplitSpecification; +import com.google.common.collect.ImmutableList; +import io.airlift.units.Duration; + +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.ThreadLocalRandom; + +import static com.facebook.presto.execution.executor.Histogram.fromContinuous; +import static java.util.concurrent.TimeUnit.DAYS; +import static java.util.concurrent.TimeUnit.MICROSECONDS; +import static java.util.concurrent.TimeUnit.MILLISECONDS; +import static java.util.concurrent.TimeUnit.MINUTES; + +class SplitGenerators +{ + private SplitGenerators() {} + + public static void main(String[] args) + { + Histogram bins = fromContinuous(ImmutableList.of( + MILLISECONDS.toNanos(0), + MILLISECONDS.toNanos(1), + MILLISECONDS.toNanos(10), + MILLISECONDS.toNanos(100), + MILLISECONDS.toNanos(1_000), + MILLISECONDS.toNanos(10_000), + MILLISECONDS.toNanos(60_000), + MILLISECONDS.toNanos(300_000), + MINUTES.toNanos(20), + DAYS.toNanos(1))); + + IntermediateSplitGenerator intermediateSplitGenerator = new IntermediateSplitGenerator(null); + List intermediateSpecs = new ArrayList<>(); + for (int i = 0; i < 10_000; i++) { + IntermediateSplitSpecification next = intermediateSplitGenerator.next(); + intermediateSpecs.add(next); + } + + System.out.println("Scheduled time distributions"); + System.out.println("============================"); + System.out.println(); + System.out.println("Tasks with 8x " + IntermediateSplitGenerator.class.getSimpleName()); + bins.printDistribution(intermediateSpecs, t -> t.getScheduledTimeNanos() * 8, a -> 1, Duration::succinctNanos, a -> ""); + + List leafSplitGenerators = ImmutableList.of( + new FastLeafSplitGenerator(), + new SlowLeafSplitGenerator(), + new L4LeafSplitGenerator(), + new QuantaExceedingSplitGenerator(), + new AggregatedLeafSplitGenerator()); + + for (SplitGenerator generator : leafSplitGenerators) { + List leafSpecs = new ArrayList<>(); + for (int i = 0; i < 17000; i++) { + leafSpecs.add(generator.next()); + } + + System.out.println(); + System.out.println("Tasks with 4x " + generator.getClass().getSimpleName()); + bins.printDistribution(leafSpecs, t -> t.getScheduledTimeNanos() * 4, Duration::succinctNanos); + + System.out.println("Per quanta:"); + bins.printDistribution(leafSpecs, SplitSpecification::getPerQuantaNanos, Duration::succinctNanos); + } + } + + interface SplitGenerator + { + SplitSpecification next(); + } + + public static class IntermediateSplitGenerator + implements SplitGenerator + { + private final ScheduledExecutorService wakeupExecutor; + + IntermediateSplitGenerator(ScheduledExecutorService wakeupExecutor) + { + this.wakeupExecutor = wakeupExecutor; + } + + public IntermediateSplitSpecification next() + { + long numQuanta = generateIntermediateSplitNumQuanta(0, 1); + + long wallNanos = MILLISECONDS.toNanos(generateIntermediateSplitWallTimeMs(0, 1)); + long scheduledNanos = MILLISECONDS.toNanos(generateIntermediateSplitScheduledTimeMs(0, 1)); + + long blockedNanos = (long) (ThreadLocalRandom.current().nextDouble(0.97, 0.99) * wallNanos); + + long perQuantaNanos = scheduledNanos / numQuanta; + long betweenQuantaNanos = blockedNanos / numQuanta; + + return new IntermediateSplitSpecification(scheduledNanos, wallNanos, numQuanta, perQuantaNanos, betweenQuantaNanos, wakeupExecutor); + } + } + + public static class AggregatedLeafSplitGenerator + implements SplitGenerator + { + public LeafSplitSpecification next() + { + long totalNanos = MILLISECONDS.toNanos(generateLeafSplitScheduledTimeMs(0, 1)); + long quantaNanos = Math.min(totalNanos, MICROSECONDS.toNanos(generateLeafSplitPerCallMicros(0, 1))); + + return new LeafSplitSpecification(totalNanos, quantaNanos); + } + } + + public static class FastLeafSplitGenerator + implements SplitGenerator + { + public LeafSplitSpecification next() + { + long totalNanos = MILLISECONDS.toNanos(generateLeafSplitScheduledTimeMs(0, 0.75)); + long quantaNanos = Math.min(totalNanos, MICROSECONDS.toNanos(generateLeafSplitPerCallMicros(0, 1))); + + return new LeafSplitSpecification(totalNanos, quantaNanos); + } + } + + public static class SlowLeafSplitGenerator + implements SplitGenerator + { + public LeafSplitSpecification next() + { + long totalNanos = MILLISECONDS.toNanos(generateLeafSplitScheduledTimeMs(0.75, 1)); + long quantaNanos = Math.min(totalNanos, MICROSECONDS.toNanos(generateLeafSplitPerCallMicros(0, 1))); + + return new LeafSplitSpecification(totalNanos, quantaNanos); + } + } + + public static class L4LeafSplitGenerator + implements SplitGenerator + { + public LeafSplitSpecification next() + { + long totalNanos = MILLISECONDS.toNanos(generateLeafSplitScheduledTimeMs(0.99, 1)); + long quantaNanos = Math.min(totalNanos, MICROSECONDS.toNanos(generateLeafSplitPerCallMicros(0, 0.9))); + + return new LeafSplitSpecification(totalNanos, quantaNanos); + } + } + + public static class QuantaExceedingSplitGenerator + implements SplitGenerator + { + public LeafSplitSpecification next() + { + long totalNanos = MILLISECONDS.toNanos(generateLeafSplitScheduledTimeMs(0.99, 1)); + long quantaNanos = Math.min(totalNanos, MICROSECONDS.toNanos(generateLeafSplitPerCallMicros(0.75, 1))); + + return new LeafSplitSpecification(totalNanos, quantaNanos); + } + } + + public static class SimpleLeafSplitGenerator + implements SplitGenerator + { + private final long totalNanos; + private final long quantaNanos; + + public SimpleLeafSplitGenerator(long totalNanos, long quantaNanos) + { + this.totalNanos = totalNanos; + this.quantaNanos = quantaNanos; + } + + public LeafSplitSpecification next() + { + return new LeafSplitSpecification(totalNanos, quantaNanos); + } + } + + // these numbers come from real world stats + private static long generateLeafSplitScheduledTimeMs(double origin, double bound) + { + ThreadLocalRandom generator = ThreadLocalRandom.current(); + double value = generator.nextDouble(origin, bound); + // in reality, max is several hours, but this would make the simulation too slow + if (value > 0.998) { + return generator.nextLong(5 * 60 * 1000, 10 * 60 * 1000); + } + + if (value > 0.99) { + return generator.nextLong(60 * 1000, 5 * 60 * 1000); + } + + if (value > 0.95) { + return generator.nextLong(10_000, 60 * 1000); + } + + if (value > 0.50) { + return generator.nextLong(1000, 10_000); + } + + if (value > 0.25) { + return generator.nextLong(100, 1000); + } + + if (value > 0.10) { + return generator.nextLong(10, 100); + } + + return generator.nextLong(1, 10); + } + + private static long generateLeafSplitPerCallMicros(double origin, double bound) + { + ThreadLocalRandom generator = ThreadLocalRandom.current(); + double value = generator.nextDouble(origin, bound); + if (value > 0.9999) { + return 200_000_000; + } + + if (value > 0.99) { + return generator.nextLong(3_000_000, 15_000_000); + } + + if (value > 0.95) { + return generator.nextLong(2_000_000, 5_000_000); + } + + if (value > 0.90) { + return generator.nextLong(1_500_000, 5_000_000); + } + + if (value > 0.75) { + return generator.nextLong(1_000_000, 2_000_000); + } + + if (value > 0.50) { + return generator.nextLong(500_000, 1_000_000); + } + + if (value > 0.1) { + return generator.nextLong(100_000, 500_000); + } + + return generator.nextLong(250, 500); + } + + private static long generateIntermediateSplitScheduledTimeMs(double origin, double bound) + { + ThreadLocalRandom generator = ThreadLocalRandom.current(); + double value = generator.nextDouble(origin, bound); + // in reality, max is several hours, but this would make the simulation too slow + + if (value > 0.999) { + return generator.nextLong(5 * 60 * 1000, 10 * 60 * 1000); + } + + if (value > 0.99) { + return generator.nextLong(60 * 1000, 5 * 60 * 1000); + } + + if (value > 0.95) { + return generator.nextLong(10_000, 60 * 1000); + } + + if (value > 0.75) { + return generator.nextLong(1000, 10_000); + } + + if (value > 0.45) { + return generator.nextLong(100, 1000); + } + + if (value > 0.20) { + return generator.nextLong(10, 100); + } + + return generator.nextLong(1, 10); + } + + private static long generateIntermediateSplitWallTimeMs(double origin, double bound) + { + ThreadLocalRandom generator = ThreadLocalRandom.current(); + double value = generator.nextDouble(origin, bound); + // in reality, max is several hours, but this would make the simulation too slow + + if (value > 0.90) { + return generator.nextLong(400_000, 800_000); + } + + if (value > 0.75) { + return generator.nextLong(100_000, 200_000); + } + + if (value > 0.50) { + return generator.nextLong(50_000, 100_000); + } + + if (value > 0.40) { + return generator.nextLong(30_000, 50_000); + } + + if (value > 0.30) { + return generator.nextLong(20_000, 30_000); + } + + if (value > 0.20) { + return generator.nextLong(10_000, 15_000); + } + + if (value > 0.10) { + return generator.nextLong(5_000, 10_000); + } + + return generator.nextLong(1_000, 5_000); + } + + private static long generateIntermediateSplitNumQuanta(double origin, double bound) + { + ThreadLocalRandom generator = ThreadLocalRandom.current(); + double value = generator.nextDouble(origin, bound); + + if (value > 0.95) { + return generator.nextLong(2000, 20_000); + } + + if (value > 0.90) { + return generator.nextLong(1_000, 2_000); + } + + return generator.nextLong(10, 1000); + } +} diff --git a/presto-main/src/test/java/com/facebook/presto/execution/executor/SplitSpecification.java b/presto-main/src/test/java/com/facebook/presto/execution/executor/SplitSpecification.java new file mode 100644 index 000000000000..1dccbb25aa73 --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/execution/executor/SplitSpecification.java @@ -0,0 +1,86 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.execution.executor; + +import com.facebook.presto.execution.executor.SimulationSplit.IntermediateSplit; +import com.facebook.presto.execution.executor.SimulationSplit.LeafSplit; + +import java.util.concurrent.ScheduledExecutorService; + +abstract class SplitSpecification +{ + private final long scheduledTimeNanos; + private final long perQuantaNanos; + + private SplitSpecification(long scheduledTimeNanos, long perQuantaNanos) + { + this.scheduledTimeNanos = scheduledTimeNanos; + this.perQuantaNanos = perQuantaNanos; + } + + public long getScheduledTimeNanos() + { + return scheduledTimeNanos; + } + + public long getPerQuantaNanos() + { + return perQuantaNanos; + } + + public abstract SimulationSplit instantiate(SimulationTask task); + + public static class LeafSplitSpecification + extends SplitSpecification + { + public LeafSplitSpecification(long scheduledTimeNanos, long perQuantaNanos) + { + super(scheduledTimeNanos, perQuantaNanos); + } + + public LeafSplit instantiate(SimulationTask task) + { + return new LeafSplit(task, super.getScheduledTimeNanos(), super.getPerQuantaNanos()); + } + } + + public static class IntermediateSplitSpecification + extends SplitSpecification + { + private final long wallTimeNanos; + private final long numQuantas; + private final long betweenQuantaNanos; + private final ScheduledExecutorService wakeupExecutor; + + public IntermediateSplitSpecification( + long scheduledTimeNanos, + long perQuantaNanos, + long wallTimeNanos, + long numQuantas, + long betweenQuantaNanos, + ScheduledExecutorService wakeupExecutor) + { + super(scheduledTimeNanos, perQuantaNanos); + this.wallTimeNanos = wallTimeNanos; + this.numQuantas = numQuantas; + this.betweenQuantaNanos = betweenQuantaNanos; + this.wakeupExecutor = wakeupExecutor; + } + + public IntermediateSplit instantiate(SimulationTask task) + { + return new IntermediateSplit(task, wallTimeNanos, numQuantas, super.getPerQuantaNanos(), betweenQuantaNanos, super.getScheduledTimeNanos(), wakeupExecutor); + } + } +} diff --git a/presto-main/src/test/java/com/facebook/presto/execution/executor/TaskExecutorSimulator.java b/presto-main/src/test/java/com/facebook/presto/execution/executor/TaskExecutorSimulator.java new file mode 100644 index 000000000000..3d357cf1d26f --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/execution/executor/TaskExecutorSimulator.java @@ -0,0 +1,446 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.execution.executor; + +import com.facebook.presto.execution.executor.SimulationController.TaskSpecification; +import com.facebook.presto.execution.executor.SplitGenerators.AggregatedLeafSplitGenerator; +import com.facebook.presto.execution.executor.SplitGenerators.FastLeafSplitGenerator; +import com.facebook.presto.execution.executor.SplitGenerators.IntermediateSplitGenerator; +import com.facebook.presto.execution.executor.SplitGenerators.L4LeafSplitGenerator; +import com.facebook.presto.execution.executor.SplitGenerators.QuantaExceedingSplitGenerator; +import com.facebook.presto.execution.executor.SplitGenerators.SimpleLeafSplitGenerator; +import com.facebook.presto.execution.executor.SplitGenerators.SlowLeafSplitGenerator; +import com.google.common.base.Ticker; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableSet; +import com.google.common.collect.ListMultimap; +import com.google.common.util.concurrent.ListeningExecutorService; +import io.airlift.units.Duration; +import org.joda.time.DateTime; + +import java.io.Closeable; +import java.util.List; +import java.util.LongSummaryStatistics; +import java.util.Map; +import java.util.OptionalInt; +import java.util.Set; +import java.util.concurrent.ScheduledExecutorService; +import java.util.stream.Collectors; + +import static com.facebook.presto.execution.executor.Histogram.fromContinuous; +import static com.facebook.presto.execution.executor.Histogram.fromDiscrete; +import static com.facebook.presto.execution.executor.SimulationController.TaskSpecification.Type.INTERMEDIATE; +import static com.facebook.presto.execution.executor.SimulationController.TaskSpecification.Type.LEAF; +import static com.google.common.util.concurrent.MoreExecutors.listeningDecorator; +import static io.airlift.concurrent.Threads.threadsNamed; +import static io.airlift.units.Duration.nanosSince; +import static io.airlift.units.Duration.succinctNanos; +import static java.util.concurrent.Executors.newCachedThreadPool; +import static java.util.concurrent.Executors.newScheduledThreadPool; +import static java.util.concurrent.Executors.newSingleThreadScheduledExecutor; +import static java.util.concurrent.TimeUnit.DAYS; +import static java.util.concurrent.TimeUnit.HOURS; +import static java.util.concurrent.TimeUnit.MILLISECONDS; +import static java.util.concurrent.TimeUnit.MINUTES; +import static java.util.concurrent.TimeUnit.SECONDS; +import static java.util.function.Function.identity; + +public class TaskExecutorSimulator + implements Closeable +{ + public static void main(String[] args) + throws Exception + { + try (TaskExecutorSimulator simulator = new TaskExecutorSimulator()) { + simulator.run(); + } + } + + private final ListeningExecutorService submissionExecutor = listeningDecorator(newCachedThreadPool(threadsNamed(getClass().getSimpleName() + "-%s"))); + private final ScheduledExecutorService overallStatusPrintExecutor = newSingleThreadScheduledExecutor(); + private final ScheduledExecutorService runningSplitsPrintExecutor = newSingleThreadScheduledExecutor(); + private final ScheduledExecutorService wakeupExecutor = newScheduledThreadPool(32); + + private final TaskExecutor taskExecutor; + + private TaskExecutorSimulator() + { + taskExecutor = new TaskExecutor(36, 72, Ticker.systemTicker()); + taskExecutor.start(); + } + + @Override + public void close() + { + submissionExecutor.shutdownNow(); + overallStatusPrintExecutor.shutdownNow(); + runningSplitsPrintExecutor.shutdownNow(); + wakeupExecutor.shutdownNow(); + taskExecutor.stop(); + } + + public void run() + throws Exception + { + long start = System.nanoTime(); + scheduleStatusPrinter(start); + + SimulationController controller = new SimulationController(taskExecutor, TaskExecutorSimulator::printSummaryStats); + + // Uncomment one of these: + // runExperimentOverloadedCluster(controller); + // runExperimentMisbehavingQuanta(controller); + // runExperimentStarveSlowSplits(controller); + runExperimentWithinLevelFairness(controller); + + System.out.println("Stopped scheduling new tasks. Ending simulation.."); + controller.stop(); + close(); + + SECONDS.sleep(5); + + System.out.println(); + System.out.println("Simulation finished at " + DateTime.now() + ". Runtime: " + nanosSince(start)); + System.out.println(); + + printSummaryStats(controller, taskExecutor); + } + + private void runExperimentOverloadedCluster(SimulationController controller) + throws InterruptedException + { + /* + Designed to simulate a somewhat overloaded Hive cluster. + The following data is a point-in-time snapshot representative production cluster: + - 60 running queries => 45 queries/node + - 80 tasks/node + - 600 splits scheduled/node (80% intermediate => ~480, 20% leaf => 120) + - Only 60% intermediate splits will ever get data (~300) + + Desired result: + This experiment should demonstrate the trade-offs that will be made during periods when a + node is under heavy load. Ideally, the different classes of tasks should each accumulate + scheduled time, and not spend disproportionately long waiting. + */ + + System.out.println("Overload experiment started."); + TaskSpecification leafSpec = new TaskSpecification(LEAF, "leaf", OptionalInt.empty(), 16, 30, new AggregatedLeafSplitGenerator()); + controller.addTaskSpecification(leafSpec); + + TaskSpecification slowLeafSpec = new TaskSpecification(LEAF, "slow_leaf", OptionalInt.empty(), 16, 10, new SlowLeafSplitGenerator()); + controller.addTaskSpecification(slowLeafSpec); + + TaskSpecification intermediateSpec = new TaskSpecification(INTERMEDIATE, "intermediate", OptionalInt.empty(), 8, 40, new IntermediateSplitGenerator(wakeupExecutor)); + controller.addTaskSpecification(intermediateSpec); + + controller.enableSpecification(leafSpec); + controller.enableSpecification(slowLeafSpec); + controller.enableSpecification(intermediateSpec); + controller.run(); + + SECONDS.sleep(30); + + // this gets the executor into a more realistic point-in-time state, where long running tasks start to make progress + for (int i = 0; i < 20; i++) { + controller.clearPendingQueue(); + MINUTES.sleep(1); + } + + System.out.println("Overload experiment completed."); + } + + private void runExperimentStarveSlowSplits(SimulationController controller) + throws InterruptedException + { + /* + Designed to simulate how higher level admission control affects short-term scheduling decisions. + A fixed, large number of tasks (120) are submitted at approximately the same time. + + Desired result: + Presto is designed to prioritize fast, short tasks at the expense of longer slower tasks. + This experiment allows us to quantify exactly how this preference manifests itself. It is + expected that shorter tasks will complete faster, however, longer tasks should not starve + for more than a couple of minutes at a time. + */ + + System.out.println("Starvation experiment started."); + TaskSpecification slowLeafSpec = new TaskSpecification(LEAF, "slow_leaf", OptionalInt.of(600), 40, 4, new SlowLeafSplitGenerator()); + controller.addTaskSpecification(slowLeafSpec); + + TaskSpecification intermediateSpec = new TaskSpecification(INTERMEDIATE, "intermediate", OptionalInt.of(400), 40, 8, new IntermediateSplitGenerator(wakeupExecutor)); + controller.addTaskSpecification(intermediateSpec); + + TaskSpecification fastLeafSpec = new TaskSpecification(LEAF, "fast_leaf", OptionalInt.of(600), 40, 4, new FastLeafSplitGenerator()); + controller.addTaskSpecification(fastLeafSpec); + + controller.enableSpecification(slowLeafSpec); + controller.enableSpecification(fastLeafSpec); + controller.enableSpecification(intermediateSpec); + + controller.run(); + + for (int i = 0; i < 60; i++) { + SECONDS.sleep(20); + controller.clearPendingQueue(); + } + + System.out.println("Starvation experiment completed."); + } + + private void runExperimentMisbehavingQuanta(SimulationController controller) + throws InterruptedException + { + /* + Designed to simulate how Presto allocates resources in scenarios where there is variance in + quanta run-time between tasks. + + Desired result: + Variance in quanta run time should not affect total accrued scheduled time. It is + acceptable, however, to penalize tasks that use extremely short quanta, as each quanta + incurs scheduling overhead. + */ + + System.out.println("Misbehaving quanta experiment started."); + + TaskSpecification slowLeafSpec = new TaskSpecification(LEAF, "good_leaf", OptionalInt.empty(), 16, 4, new L4LeafSplitGenerator()); + controller.addTaskSpecification(slowLeafSpec); + + TaskSpecification misbehavingLeafSpec = new TaskSpecification(LEAF, "bad_leaf", OptionalInt.empty(), 16, 4, new QuantaExceedingSplitGenerator()); + controller.addTaskSpecification(misbehavingLeafSpec); + + controller.enableSpecification(slowLeafSpec); + controller.enableSpecification(misbehavingLeafSpec); + + controller.run(); + + for (int i = 0; i < 120; i++) { + controller.clearPendingQueue(); + SECONDS.sleep(20); + } + + System.out.println("Misbehaving quanta experiment completed."); + } + + private void runExperimentWithinLevelFairness(SimulationController controller) + throws InterruptedException + { + /* + Designed to simulate how Presto allocates resources to tasks at the same level of the + feedback queue when there is large variance in accrued scheduled time. + + Desired result: + Scheduling within levels should be fair - total accrued time should not affect what + fraction of resources tasks are allocated as long as they are in the same level. + */ + + System.out.println("Level fairness experiment started."); + + TaskSpecification longLeafSpec = new TaskSpecification(INTERMEDIATE, "l4_long", OptionalInt.empty(), 2, 16, new SimpleLeafSplitGenerator(MINUTES.toNanos(4), SECONDS.toNanos(1))); + controller.addTaskSpecification(longLeafSpec); + + TaskSpecification shortLeafSpec = new TaskSpecification(INTERMEDIATE, "l4_short", OptionalInt.empty(), 2, 16, new SimpleLeafSplitGenerator(MINUTES.toNanos(2), SECONDS.toNanos(1))); + controller.addTaskSpecification(shortLeafSpec); + + controller.enableSpecification(longLeafSpec); + controller.run(); + + // wait until long tasks are all well into L4 + MINUTES.sleep(1); + controller.runCallback(); + + // start short leaf tasks + controller.enableSpecification(shortLeafSpec); + + // wait until short tasks hit L4 + SECONDS.sleep(25); + controller.runCallback(); + + // now watch for L4 fairness at this point + MINUTES.sleep(2); + + System.out.println("Level fairness experiment completed."); + } + + private void scheduleStatusPrinter(long start) + { + overallStatusPrintExecutor.scheduleAtFixedRate(() -> { + try { + System.out.printf( + "%6s -- %4s splits (R: %2s L: %3s I: %3s B: %3s W: %3s C: %5s) | %3s tasks (%3s %3s %3s %3s %3s) | Selections: %4s %4s %4s %4s %3s\n", + nanosSince(start), + taskExecutor.getTotalSplits(), + taskExecutor.getRunningSplits(), + taskExecutor.getTotalSplits() - taskExecutor.getIntermediateSplits(), + taskExecutor.getIntermediateSplits(), + taskExecutor.getBlockedSplits(), + taskExecutor.getWaitingSplits(), + taskExecutor.getCompletedSplitsLevel0() + taskExecutor.getCompletedSplitsLevel1() + taskExecutor.getCompletedSplitsLevel2() + taskExecutor.getCompletedSplitsLevel3() + taskExecutor.getCompletedSplitsLevel4(), + taskExecutor.getTasks(), + taskExecutor.getRunningTasksLevel0(), + taskExecutor.getRunningTasksLevel1(), + taskExecutor.getRunningTasksLevel2(), + taskExecutor.getRunningTasksLevel3(), + taskExecutor.getRunningTasksLevel4(), + (int) taskExecutor.getSelectedCountLevel0().getOneMinute().getRate(), + (int) taskExecutor.getSelectedCountLevel1().getOneMinute().getRate(), + (int) taskExecutor.getSelectedCountLevel2().getOneMinute().getRate(), + (int) taskExecutor.getSelectedCountLevel3().getOneMinute().getRate(), + (int) taskExecutor.getSelectedCountLevel4().getOneMinute().getRate()); + } + catch (Exception ignored) { + } + }, 1, 1, SECONDS); + } + + private static void printSummaryStats(SimulationController controller, TaskExecutor taskExecutor) + { + Map specEnabled = controller.getSpecificationEnabled(); + + ListMultimap completedTasks = controller.getCompletedTasks(); + ListMultimap runningTasks = controller.getRunningTasks(); + Set allTasks = ImmutableSet.builder().addAll(completedTasks.values()).addAll(runningTasks.values()).build(); + + long completedSplits = completedTasks.values().stream().mapToInt(t -> t.getCompletedSplits().size()).sum(); + long runningSplits = runningTasks.values().stream().mapToInt(t -> t.getCompletedSplits().size()).sum(); + + System.out.println("Completed tasks : " + completedTasks.size()); + System.out.println("Remaining tasks : " + runningTasks.size()); + System.out.println("Completed splits: " + completedSplits); + System.out.println("Remaining splits: " + runningSplits); + System.out.println(); + System.out.println("Completed tasks L0: " + taskExecutor.getCompletedTasksLevel0()); + System.out.println("Completed tasks L1: " + taskExecutor.getCompletedTasksLevel1()); + System.out.println("Completed tasks L2: " + taskExecutor.getCompletedTasksLevel2()); + System.out.println("Completed tasks L3: " + taskExecutor.getCompletedTasksLevel3()); + System.out.println("Completed tasks L4: " + taskExecutor.getCompletedTasksLevel4()); + System.out.println(); + System.out.println("Completed splits L0: " + taskExecutor.getCompletedSplitsLevel0()); + System.out.println("Completed splits L1: " + taskExecutor.getCompletedSplitsLevel1()); + System.out.println("Completed splits L2: " + taskExecutor.getCompletedSplitsLevel2()); + System.out.println("Completed splits L3: " + taskExecutor.getCompletedSplitsLevel3()); + System.out.println("Completed splits L4: " + taskExecutor.getCompletedSplitsLevel4()); + + Histogram levelsHistogram = fromContinuous(ImmutableList.of( + MILLISECONDS.toNanos(0L), + MILLISECONDS.toNanos(1_000), + MILLISECONDS.toNanos(10_000L), + MILLISECONDS.toNanos(60_000L), + MILLISECONDS.toNanos(300_000L), + HOURS.toNanos(1), + DAYS.toNanos(1))); + + System.out.println(); + System.out.println("Levels - Completed Task Processed Time"); + levelsHistogram.printDistribution( + completedTasks.values().stream().filter(t -> t.getSpecification().getType() == LEAF).collect(Collectors.toList()), + SimulationTask::getScheduledTimeNanos, + SimulationTask::getProcessedTimeNanos, + Duration::succinctNanos, + TaskExecutorSimulator::formatNanos); + + System.out.println(); + System.out.println("Levels - Running Task Processed Time"); + levelsHistogram.printDistribution( + runningTasks.values().stream().filter(t -> t.getSpecification().getType() == LEAF).collect(Collectors.toList()), + SimulationTask::getScheduledTimeNanos, + SimulationTask::getProcessedTimeNanos, + Duration::succinctNanos, + TaskExecutorSimulator::formatNanos); + + System.out.println(); + System.out.println("Levels - All Task Wait Time"); + levelsHistogram.printDistribution( + runningTasks.values().stream().filter(t -> t.getSpecification().getType() == LEAF).collect(Collectors.toList()), + SimulationTask::getScheduledTimeNanos, + SimulationTask::getTotalWaitTimeNanos, + Duration::succinctNanos, + TaskExecutorSimulator::formatNanos); + + System.out.println(); + System.out.println("Specification - Processed time"); + Set specifications = runningTasks.values().stream().map(t -> t.getSpecification().getName()).collect(Collectors.toSet()); + fromDiscrete(specifications).printDistribution( + allTasks, + t -> t.getSpecification().getName(), + SimulationTask::getProcessedTimeNanos, + identity(), + TaskExecutorSimulator::formatNanos); + + System.out.println(); + System.out.println("Specification - Wait time"); + fromDiscrete(specifications).printDistribution( + allTasks, + t -> t.getSpecification().getName(), + SimulationTask::getTotalWaitTimeNanos, + identity(), + TaskExecutorSimulator::formatNanos); + + System.out.println(); + System.out.println("Breakdown by specification"); + System.out.println("##########################"); + for (TaskSpecification specification : specEnabled.keySet()) { + List allSpecificationTasks = ImmutableList.builder() + .addAll(completedTasks.get(specification)) + .addAll(runningTasks.get(specification)) + .build(); + + System.out.println(specification.getName()); + System.out.println("============================="); + System.out.println("Completed tasks : " + completedTasks.get(specification).size()); + System.out.println("In-progress tasks : " + runningTasks.get(specification).size()); + System.out.println("Total tasks : " + specification.getTotalTasks()); + System.out.println("Splits/task : " + specification.getNumSplitsPerTask()); + System.out.println("Current required time : " + succinctNanos(allSpecificationTasks.stream().mapToLong(SimulationTask::getScheduledTimeNanos).sum())); + System.out.println("Completed scheduled time : " + succinctNanos(allSpecificationTasks.stream().mapToLong(SimulationTask::getProcessedTimeNanos).sum())); + System.out.println("Total wait time : " + succinctNanos(allSpecificationTasks.stream().mapToLong(SimulationTask::getTotalWaitTimeNanos).sum())); + + System.out.println(); + System.out.println("All Tasks by Scheduled time - Processed Time"); + levelsHistogram.printDistribution( + allSpecificationTasks, + SimulationTask::getScheduledTimeNanos, + SimulationTask::getProcessedTimeNanos, + Duration::succinctNanos, + TaskExecutorSimulator::formatNanos); + + System.out.println(); + System.out.println("All Tasks by Scheduled time - Wait Time"); + levelsHistogram.printDistribution( + allSpecificationTasks, + SimulationTask::getScheduledTimeNanos, + SimulationTask::getTotalWaitTimeNanos, + Duration::succinctNanos, + TaskExecutorSimulator::formatNanos); + + System.out.println(); + System.out.println("Complete Tasks by Scheduled time - Wait Time"); + levelsHistogram.printDistribution( + completedTasks.get(specification), + SimulationTask::getScheduledTimeNanos, + SimulationTask::getTotalWaitTimeNanos, + Duration::succinctNanos, + TaskExecutorSimulator::formatNanos); + } + } + + private static String formatNanos(List list) + { + LongSummaryStatistics stats = list.stream().mapToLong(Long::new).summaryStatistics(); + return String.format("Min: %8s Max: %8s Avg: %8s Sum: %8s", + succinctNanos(stats.getMin() == Long.MAX_VALUE ? 0 : stats.getMin()), + succinctNanos(stats.getMax() == Long.MIN_VALUE ? 0 : stats.getMax()), + succinctNanos((long) stats.getAverage()), + succinctNanos(stats.getSum())); + } +} diff --git a/presto-main/src/test/java/com/facebook/presto/execution/executor/TestTaskExecutor.java b/presto-main/src/test/java/com/facebook/presto/execution/executor/TestTaskExecutor.java new file mode 100644 index 000000000000..3c3d6bfdb4a1 --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/execution/executor/TestTaskExecutor.java @@ -0,0 +1,520 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.execution.executor; + +import com.facebook.presto.execution.SplitRunner; +import com.facebook.presto.execution.TaskId; +import com.google.common.collect.ImmutableList; +import com.google.common.util.concurrent.Futures; +import com.google.common.util.concurrent.ListenableFuture; +import com.google.common.util.concurrent.SettableFuture; +import io.airlift.testing.TestingTicker; +import io.airlift.units.Duration; +import org.testng.annotations.Test; + +import java.util.Arrays; +import java.util.concurrent.Future; +import java.util.concurrent.Phaser; +import java.util.concurrent.atomic.AtomicInteger; + +import static com.facebook.presto.execution.executor.MultilevelSplitQueue.LEVEL_CONTRIBUTION_CAP; +import static com.facebook.presto.execution.executor.MultilevelSplitQueue.LEVEL_THRESHOLD_SECONDS; +import static com.google.common.collect.Iterables.getOnlyElement; +import static io.airlift.testing.Assertions.assertGreaterThan; +import static io.airlift.testing.Assertions.assertGreaterThanOrEqual; +import static io.airlift.testing.Assertions.assertLessThan; +import static io.airlift.testing.Assertions.assertLessThanOrEqual; +import static java.util.concurrent.TimeUnit.MILLISECONDS; +import static java.util.concurrent.TimeUnit.MINUTES; +import static java.util.concurrent.TimeUnit.SECONDS; +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertTrue; + +public class TestTaskExecutor +{ + @Test(invocationCount = 100) + public void testTasksComplete() + throws Exception + { + TestingTicker ticker = new TestingTicker(); + TaskExecutor taskExecutor = new TaskExecutor(4, 8, ticker); + taskExecutor.start(); + ticker.increment(20, MILLISECONDS); + + try { + TaskId taskId = new TaskId("test", 0, 0); + TaskHandle taskHandle = taskExecutor.addTask(taskId, () -> 0, 10, new Duration(1, MILLISECONDS)); + + Phaser beginPhase = new Phaser(); + beginPhase.register(); + Phaser verificationComplete = new Phaser(); + verificationComplete.register(); + + // add two jobs + TestingJob driver1 = new TestingJob(ticker, new Phaser(1), beginPhase, verificationComplete, 10, 0); + ListenableFuture future1 = getOnlyElement(taskExecutor.enqueueSplits(taskHandle, true, ImmutableList.of(driver1))); + TestingJob driver2 = new TestingJob(ticker, new Phaser(1), beginPhase, verificationComplete, 10, 0); + ListenableFuture future2 = getOnlyElement(taskExecutor.enqueueSplits(taskHandle, true, ImmutableList.of(driver2))); + assertEquals(driver1.getCompletedPhases(), 0); + assertEquals(driver2.getCompletedPhases(), 0); + + // verify worker have arrived but haven't processed yet + beginPhase.arriveAndAwaitAdvance(); + assertEquals(driver1.getCompletedPhases(), 0); + assertEquals(driver2.getCompletedPhases(), 0); + ticker.increment(10, MILLISECONDS); + assertEquals(taskExecutor.getMaxActiveSplitTime(), 10); + verificationComplete.arriveAndAwaitAdvance(); + + // advance one phase and verify + beginPhase.arriveAndAwaitAdvance(); + assertEquals(driver1.getCompletedPhases(), 1); + assertEquals(driver2.getCompletedPhases(), 1); + + verificationComplete.arriveAndAwaitAdvance(); + + // add one more job + TestingJob driver3 = new TestingJob(ticker, new Phaser(1), beginPhase, verificationComplete, 10, 0); + ListenableFuture future3 = getOnlyElement(taskExecutor.enqueueSplits(taskHandle, false, ImmutableList.of(driver3))); + + // advance one phase and verify + beginPhase.arriveAndAwaitAdvance(); + assertEquals(driver1.getCompletedPhases(), 2); + assertEquals(driver2.getCompletedPhases(), 2); + assertEquals(driver3.getCompletedPhases(), 0); + verificationComplete.arriveAndAwaitAdvance(); + + // advance to the end of the first two task and verify + beginPhase.arriveAndAwaitAdvance(); + for (int i = 0; i < 7; i++) { + verificationComplete.arriveAndAwaitAdvance(); + beginPhase.arriveAndAwaitAdvance(); + assertEquals(beginPhase.getPhase(), verificationComplete.getPhase() + 1); + } + assertEquals(driver1.getCompletedPhases(), 10); + assertEquals(driver2.getCompletedPhases(), 10); + assertEquals(driver3.getCompletedPhases(), 8); + future1.get(1, SECONDS); + future2.get(1, SECONDS); + verificationComplete.arriveAndAwaitAdvance(); + + // advance two more times and verify + beginPhase.arriveAndAwaitAdvance(); + verificationComplete.arriveAndAwaitAdvance(); + beginPhase.arriveAndAwaitAdvance(); + assertEquals(driver1.getCompletedPhases(), 10); + assertEquals(driver2.getCompletedPhases(), 10); + assertEquals(driver3.getCompletedPhases(), 10); + future3.get(1, SECONDS); + verificationComplete.arriveAndAwaitAdvance(); + + assertEquals(driver1.getFirstPhase(), 0); + assertEquals(driver2.getFirstPhase(), 0); + assertEquals(driver3.getFirstPhase(), 2); + + assertEquals(driver1.getLastPhase(), 10); + assertEquals(driver2.getLastPhase(), 10); + assertEquals(driver3.getLastPhase(), 12); + + // no splits remaining + ticker.increment(30, MILLISECONDS); + assertEquals(taskExecutor.getMaxActiveSplitTime(), 0); + } + finally { + taskExecutor.stop(); + } + } + + @Test(invocationCount = 100) + public void testQuantaFairness() + throws Exception + { + TestingTicker ticker = new TestingTicker(); + TaskExecutor taskExecutor = new TaskExecutor(1, 2, ticker); + taskExecutor.start(); + ticker.increment(20, MILLISECONDS); + + try { + TaskHandle shortQuantaTaskHandle = taskExecutor.addTask(new TaskId("shortQuanta", 0, 0), () -> 0, 10, new Duration(1, MILLISECONDS)); + TaskHandle longQuantaTaskHandle = taskExecutor.addTask(new TaskId("longQuanta", 0, 0), () -> 0, 10, new Duration(1, MILLISECONDS)); + + Phaser globalPhaser = new Phaser(); + + TestingJob shortQuantaDriver = new TestingJob(ticker, new Phaser(), new Phaser(), globalPhaser, 10, 10); + TestingJob longQuantaDriver = new TestingJob(ticker, new Phaser(), new Phaser(), globalPhaser, 10, 20); + + taskExecutor.enqueueSplits(shortQuantaTaskHandle, true, ImmutableList.of(shortQuantaDriver)); + taskExecutor.enqueueSplits(longQuantaTaskHandle, true, ImmutableList.of(longQuantaDriver)); + + for (int i = 0; i < 11; i++) { + globalPhaser.arriveAndAwaitAdvance(); + } + + assertTrue(shortQuantaDriver.getCompletedPhases() >= 7 && shortQuantaDriver.getCompletedPhases() <= 8); + assertTrue(longQuantaDriver.getCompletedPhases() >= 3 && longQuantaDriver.getCompletedPhases() <= 4); + + globalPhaser.arriveAndDeregister(); + } + finally { + taskExecutor.stop(); + } + } + + @Test(invocationCount = 100) + public void testLevelMovement() + throws Exception + { + TestingTicker ticker = new TestingTicker(); + TaskExecutor taskExecutor = new TaskExecutor(2, 2, ticker); + taskExecutor.start(); + ticker.increment(20, MILLISECONDS); + + try { + TaskHandle testTaskHandle = taskExecutor.addTask(new TaskId("test", 0, 0), () -> 0, 10, new Duration(1, MILLISECONDS)); + + Phaser globalPhaser = new Phaser(); + globalPhaser.bulkRegister(3); + + int quantaTimeMills = 500; + int phasesPerSecond = 1000 / quantaTimeMills; + int totalPhases = LEVEL_THRESHOLD_SECONDS[LEVEL_THRESHOLD_SECONDS.length - 1] * phasesPerSecond; + TestingJob driver1 = new TestingJob(ticker, globalPhaser, new Phaser(), new Phaser(), totalPhases, quantaTimeMills); + TestingJob driver2 = new TestingJob(ticker, globalPhaser, new Phaser(), new Phaser(), totalPhases, quantaTimeMills); + + taskExecutor.enqueueSplits(testTaskHandle, true, ImmutableList.of(driver1, driver2)); + + int completedPhases = 0; + for (int i = 0; i < (LEVEL_THRESHOLD_SECONDS.length - 1); i++) { + for (; (completedPhases / phasesPerSecond) < LEVEL_THRESHOLD_SECONDS[i + 1]; completedPhases++) { + globalPhaser.arriveAndAwaitAdvance(); + } + + assertEquals(testTaskHandle.getPriority().getLevel(), i + 1); + } + + globalPhaser.arriveAndDeregister(); + } + finally { + taskExecutor.stop(); + } + } + + @Test(invocationCount = 100) + public void testNoInstantaneousFairness() + throws Exception + { + TestingTicker ticker = new TestingTicker(); + TaskExecutor taskExecutor = new TaskExecutor(1, 2, 2, true, true, ticker); + taskExecutor.start(); + ticker.increment(20, MILLISECONDS); + + try { + for (int i = 2; i < (LEVEL_THRESHOLD_SECONDS.length - 1); i++) { + int levelStartMillis = (int) SECONDS.toMillis(LEVEL_THRESHOLD_SECONDS[i]); + int nextLevelStartMillis = (int) SECONDS.toMillis(LEVEL_THRESHOLD_SECONDS[i + 1]); + + TaskHandle longJob = taskExecutor.addTask(new TaskId("longTask", 0, 0), () -> 0, 10, new Duration(1, MILLISECONDS)); + TestingJob longSplit = new TestingJob(ticker, new Phaser(), new Phaser(), new Phaser(), (nextLevelStartMillis / 100) - 20, 100); + + taskExecutor.enqueueSplits(longJob, true, ImmutableList.of(longSplit)); + longSplit.getCompletedFuture().get(); + + TaskHandle shortJob = taskExecutor.addTask(new TaskId("shortTask", 0, 0), () -> 0, 10, new Duration(1, MILLISECONDS)); + TestingJob shortSplit = new TestingJob(ticker, new Phaser(), new Phaser(), new Phaser(), (levelStartMillis / 100) + 1, 100); + + taskExecutor.enqueueSplits(shortJob, true, ImmutableList.of(shortSplit)); + shortSplit.getCompletedFuture().get(); + + Phaser globalPhaser = new Phaser(2); + TestingJob shortSplit1 = new TestingJob(ticker, globalPhaser, new Phaser(), new Phaser(), 20, 10); + TestingJob shortSplit2 = new TestingJob(ticker, globalPhaser, new Phaser(), new Phaser(), 20, 10); + TestingJob longSplit1 = new TestingJob(ticker, globalPhaser, new Phaser(), new Phaser(), 20, 10); + TestingJob longSplit2 = new TestingJob(ticker, globalPhaser, new Phaser(), new Phaser(), 20, 10); + + taskExecutor.enqueueSplits(longJob, true, ImmutableList.of(longSplit1, longSplit2)); + taskExecutor.enqueueSplits(shortJob, true, ImmutableList.of(shortSplit1, shortSplit2)); + + for (int j = 0; j < 10; j++) { + globalPhaser.arriveAndAwaitAdvance(); + } + + assertLessThanOrEqual(longSplit1.getCompletedPhases() + longSplit2.getCompletedPhases(), 2); + assertGreaterThanOrEqual(shortSplit1.getCompletedPhases() + shortSplit2.getCompletedPhases(), 8); + + globalPhaser.arriveAndDeregister(); + longSplit1.getCompletedFuture().get(); + longSplit2.getCompletedFuture().get(); + shortSplit1.getCompletedFuture().get(); + shortSplit2.getCompletedFuture().get(); + longJob.destroy(); + shortJob.destroy(); + } + } + finally { + taskExecutor.stop(); + } + } + + @Test(invocationCount = 100) + public void testLevelMultipliers() + throws Exception + { + TestingTicker ticker = new TestingTicker(); + TaskExecutor taskExecutor = new TaskExecutor(1, 3, 2, false, false, ticker); + taskExecutor.start(); + ticker.increment(20, MILLISECONDS); + + try { + for (int i = 0; i < (LEVEL_THRESHOLD_SECONDS.length - 1); i++) { + TaskHandle[] taskHandles = { + taskExecutor.addTask(new TaskId("test1", 0, 0), () -> 0, 10, new Duration(1, MILLISECONDS)), + taskExecutor.addTask(new TaskId("test2", 0, 0), () -> 0, 10, new Duration(1, MILLISECONDS)), + taskExecutor.addTask(new TaskId("test3", 0, 0), () -> 0, 10, new Duration(1, MILLISECONDS)) + }; + + // move task 0 to next level + TestingJob task0Job = new TestingJob(ticker, new Phaser(1), new Phaser(), new Phaser(), 1, LEVEL_THRESHOLD_SECONDS[i + 1] * 1000); + taskExecutor.enqueueSplits( + taskHandles[0], + true, + ImmutableList.of(task0Job)); + // move tasks 1 and 2 to this level + TestingJob task1Job = new TestingJob(ticker, new Phaser(1), new Phaser(), new Phaser(), 1, LEVEL_THRESHOLD_SECONDS[i] * 1000); + taskExecutor.enqueueSplits( + taskHandles[1], + true, + ImmutableList.of(task1Job)); + TestingJob task2Job = new TestingJob(ticker, new Phaser(1), new Phaser(), new Phaser(), 1, LEVEL_THRESHOLD_SECONDS[i] * 1000); + taskExecutor.enqueueSplits( + taskHandles[2], + true, + ImmutableList.of(task2Job)); + + task0Job.getCompletedFuture().get(); + task1Job.getCompletedFuture().get(); + task2Job.getCompletedFuture().get(); + + // then, start new drivers for all tasks + Phaser globalPhaser = new Phaser(2); + int phasesForNextLevel = LEVEL_THRESHOLD_SECONDS[i + 1] - LEVEL_THRESHOLD_SECONDS[i]; + TestingJob[] drivers = new TestingJob[6]; + for (int j = 0; j < 6; j++) { + drivers[j] = new TestingJob(ticker, globalPhaser, new Phaser(), new Phaser(), phasesForNextLevel, 1000); + } + + taskExecutor.enqueueSplits(taskHandles[0], true, ImmutableList.of(drivers[0], drivers[1])); + taskExecutor.enqueueSplits(taskHandles[1], true, ImmutableList.of(drivers[2], drivers[3])); + taskExecutor.enqueueSplits(taskHandles[2], true, ImmutableList.of(drivers[4], drivers[5])); + + // run all three drivers + int lowerLevelStart = drivers[2].getCompletedPhases() + drivers[3].getCompletedPhases() + drivers[4].getCompletedPhases() + drivers[5].getCompletedPhases(); + int higherLevelStart = drivers[0].getCompletedPhases() + drivers[1].getCompletedPhases(); + while (Arrays.stream(drivers).noneMatch(TestingJob::isFinished)) { + globalPhaser.arriveAndAwaitAdvance(); + + int lowerLevelEnd = drivers[2].getCompletedPhases() + drivers[3].getCompletedPhases() + drivers[4].getCompletedPhases() + drivers[5].getCompletedPhases(); + int lowerLevelTime = lowerLevelEnd - lowerLevelStart; + int higherLevelEnd = drivers[0].getCompletedPhases() + drivers[1].getCompletedPhases(); + int higherLevelTime = higherLevelEnd - higherLevelStart; + + if (higherLevelTime > 20) { + assertGreaterThan(lowerLevelTime, (higherLevelTime * 2) - 10); + assertLessThan(higherLevelTime, (lowerLevelTime * 2) + 10); + } + } + + try { + globalPhaser.arriveAndDeregister(); + } + catch (IllegalStateException e) { + // under high concurrency sometimes the deregister call can occur after completion + // this is not a real problem + } + taskExecutor.removeTask(taskHandles[0]); + taskExecutor.removeTask(taskHandles[1]); + taskExecutor.removeTask(taskHandles[2]); + } + } + finally { + taskExecutor.stop(); + } + } + + @Test + public void testTaskHandle() + throws Exception + { + TestingTicker ticker = new TestingTicker(); + TaskExecutor taskExecutor = new TaskExecutor(4, 8, ticker); + taskExecutor.start(); + + try { + TaskId taskId = new TaskId("test", 0, 0); + TaskHandle taskHandle = taskExecutor.addTask(taskId, () -> 0, 10, new Duration(1, MILLISECONDS)); + + Phaser beginPhase = new Phaser(); + beginPhase.register(); + Phaser verificationComplete = new Phaser(); + verificationComplete.register(); + + TestingJob driver1 = new TestingJob(ticker, new Phaser(), beginPhase, verificationComplete, 10, 0); + TestingJob driver2 = new TestingJob(ticker, new Phaser(), beginPhase, verificationComplete, 10, 0); + + // force enqueue a split + taskExecutor.enqueueSplits(taskHandle, true, ImmutableList.of(driver1)); + assertEquals(taskHandle.getRunningLeafSplits(), 0); + + // normal enqueue a split + taskExecutor.enqueueSplits(taskHandle, false, ImmutableList.of(driver2)); + assertEquals(taskHandle.getRunningLeafSplits(), 1); + + // let the split continue to run + beginPhase.arriveAndDeregister(); + verificationComplete.arriveAndDeregister(); + } + finally { + taskExecutor.stop(); + } + } + + @Test + public void testLevelContributionCap() + throws Exception + { + MultilevelSplitQueue splitQueue = new MultilevelSplitQueue(false, 2); + TaskHandle handle0 = new TaskHandle(new TaskId("test0", 0, 0), splitQueue, () -> 1, 1, new Duration(1, SECONDS)); + TaskHandle handle1 = new TaskHandle(new TaskId("test1", 0, 0), splitQueue, () -> 1, 1, new Duration(1, SECONDS)); + + for (int i = 0; i < (LEVEL_THRESHOLD_SECONDS.length - 1); i++) { + long levelAdvanceTime = SECONDS.toNanos(LEVEL_THRESHOLD_SECONDS[i + 1] - LEVEL_THRESHOLD_SECONDS[i]); + handle0.addScheduledNanos(levelAdvanceTime); + assertEquals(handle0.getPriority().getLevel(), i + 1); + + handle1.addScheduledNanos(levelAdvanceTime); + assertEquals(handle1.getPriority().getLevel(), i + 1); + + assertEquals(splitQueue.getLevelScheduledTime()[i], 2 * Math.min(levelAdvanceTime, LEVEL_CONTRIBUTION_CAP)); + assertEquals(splitQueue.getLevelScheduledTime()[i + 1], 0); + } + } + + @Test + public void testUpdateLevelWithCap() + throws Exception + { + MultilevelSplitQueue splitQueue = new MultilevelSplitQueue(false, 2); + TaskHandle handle0 = new TaskHandle(new TaskId("test0", 0, 0), splitQueue, () -> 1, 1, new Duration(1, SECONDS)); + + long quantaNanos = MINUTES.toNanos(10); + handle0.addScheduledNanos(quantaNanos); + long cappedNanos = Math.min(quantaNanos, LEVEL_CONTRIBUTION_CAP); + + for (int i = 0; i < (LEVEL_THRESHOLD_SECONDS.length - 1); i++) { + long thisLevelTime = Math.min(SECONDS.toNanos(LEVEL_THRESHOLD_SECONDS[i + 1] - LEVEL_THRESHOLD_SECONDS[i]), cappedNanos); + assertEquals(splitQueue.getLevelScheduledTime()[i], thisLevelTime); + cappedNanos -= thisLevelTime; + } + } + + private static class TestingJob + implements SplitRunner + { + private final TestingTicker ticker; + private final Phaser globalPhaser; + private final Phaser beginQuantaPhaser; + private final Phaser endQuantaPhaser; + private final int requiredPhases; + private final int quantaTimeMillis; + private final AtomicInteger completedPhases = new AtomicInteger(); + + private final AtomicInteger firstPhase = new AtomicInteger(-1); + private final AtomicInteger lastPhase = new AtomicInteger(-1); + + private final SettableFuture completed = SettableFuture.create(); + + public TestingJob(TestingTicker ticker, Phaser globalPhaser, Phaser beginQuantaPhaser, Phaser endQuantaPhaser, int requiredPhases, int quantaTimeMillis) + { + this.ticker = ticker; + this.globalPhaser = globalPhaser; + this.beginQuantaPhaser = beginQuantaPhaser; + this.endQuantaPhaser = endQuantaPhaser; + this.requiredPhases = requiredPhases; + this.quantaTimeMillis = quantaTimeMillis; + + beginQuantaPhaser.register(); + endQuantaPhaser.register(); + + if (globalPhaser.getRegisteredParties() == 0) { + globalPhaser.register(); + } + } + + private int getFirstPhase() + { + return firstPhase.get(); + } + + private int getLastPhase() + { + return lastPhase.get(); + } + + private int getCompletedPhases() + { + return completedPhases.get(); + } + + @Override + public ListenableFuture processFor(Duration duration) + throws Exception + { + ticker.increment(quantaTimeMillis, MILLISECONDS); + globalPhaser.arriveAndAwaitAdvance(); + int phase = beginQuantaPhaser.arriveAndAwaitAdvance(); + firstPhase.compareAndSet(-1, phase - 1); + lastPhase.set(phase); + endQuantaPhaser.arriveAndAwaitAdvance(); + if (completedPhases.incrementAndGet() >= requiredPhases) { + endQuantaPhaser.arriveAndDeregister(); + beginQuantaPhaser.arriveAndDeregister(); + globalPhaser.arriveAndDeregister(); + completed.set(null); + } + + return Futures.immediateFuture(null); + } + + @Override + public String getInfo() + { + return "testing-split"; + } + + @Override + public boolean isFinished() + { + return completed.isDone(); + } + + @Override + public void close() + { + } + + public Future getCompletedFuture() + { + return completed; + } + } +} diff --git a/presto-main/src/test/java/com/facebook/presto/execution/resourceGroups/TestResourceGroups.java b/presto-main/src/test/java/com/facebook/presto/execution/resourceGroups/TestResourceGroups.java index a2e3ae612db3..f88bc9d8e1b1 100644 --- a/presto-main/src/test/java/com/facebook/presto/execution/resourceGroups/TestResourceGroups.java +++ b/presto-main/src/test/java/com/facebook/presto/execution/resourceGroups/TestResourceGroups.java @@ -15,6 +15,8 @@ import com.facebook.presto.execution.MockQueryExecution; import com.facebook.presto.execution.resourceGroups.InternalResourceGroup.RootInternalResourceGroup; +import com.facebook.presto.server.QueryStateInfo; +import com.facebook.presto.server.ResourceGroupStateInfo; import com.facebook.presto.spi.resourceGroups.ResourceGroupInfo; import com.google.common.collect.ImmutableSet; import io.airlift.units.DataSize; @@ -26,6 +28,7 @@ import java.util.HashSet; import java.util.Iterator; import java.util.List; +import java.util.Optional; import java.util.Random; import java.util.Set; import java.util.SortedMap; @@ -34,6 +37,7 @@ import static com.facebook.presto.execution.QueryState.FAILED; import static com.facebook.presto.execution.QueryState.QUEUED; import static com.facebook.presto.execution.QueryState.RUNNING; +import static com.facebook.presto.spi.resourceGroups.ResourceGroupState.CAN_RUN; import static com.facebook.presto.spi.resourceGroups.SchedulingPolicy.QUERY_PRIORITY; import static com.facebook.presto.spi.resourceGroups.SchedulingPolicy.WEIGHTED; import static com.google.common.util.concurrent.MoreExecutors.directExecutor; @@ -42,6 +46,7 @@ import static io.airlift.units.DataSize.Unit.BYTE; import static io.airlift.units.DataSize.Unit.MEGABYTE; import static java.util.Collections.reverse; +import static java.util.concurrent.TimeUnit.MILLISECONDS; import static java.util.concurrent.TimeUnit.SECONDS; import static org.testng.Assert.assertEquals; @@ -470,6 +475,59 @@ public void testGetInfo() assertEquals(info.getNumAggregatedQueuedQueries(), 26); } + @Test + public void testGetResourceGroupStateInfo() + { + RootInternalResourceGroup root = new RootInternalResourceGroup("root", (group, export) -> { }, directExecutor()); + root.setSoftMemoryLimit(new DataSize(1, MEGABYTE)); + root.setMaxQueuedQueries(40); + root.setMaxRunningQueries(10); + root.setSchedulingPolicy(WEIGHTED); + + InternalResourceGroup rootA = root.getOrCreateSubGroup("a"); + rootA.setSoftMemoryLimit(new DataSize(1, MEGABYTE)); + rootA.setMaxQueuedQueries(20); + rootA.setMaxRunningQueries(0); + + InternalResourceGroup rootB = root.getOrCreateSubGroup("b"); + rootB.setSoftMemoryLimit(new DataSize(1, MEGABYTE)); + rootB.setMaxQueuedQueries(20); + rootB.setMaxRunningQueries(1); + rootB.setSchedulingWeight(2); + rootB.setSchedulingPolicy(QUERY_PRIORITY); + + InternalResourceGroup rootAX = rootA.getOrCreateSubGroup("x"); + rootAX.setSoftMemoryLimit(new DataSize(1, MEGABYTE)); + rootAX.setMaxQueuedQueries(10); + rootAX.setMaxRunningQueries(10); + + InternalResourceGroup rootAY = rootA.getOrCreateSubGroup("y"); + rootAY.setSoftMemoryLimit(new DataSize(1, MEGABYTE)); + rootAY.setMaxQueuedQueries(10); + rootAY.setMaxRunningQueries(10); + + Set queries = fillGroupTo(rootAX, ImmutableSet.of(), 5, false); + queries.addAll(fillGroupTo(rootAY, ImmutableSet.of(), 5, false)); + queries.addAll(fillGroupTo(rootB, ImmutableSet.of(), 10, true)); + + ResourceGroupStateInfo stateInfo = root.getStateInfo(); + assertEquals(stateInfo.getId(), root.getId()); + assertEquals(stateInfo.getState(), CAN_RUN); + assertEquals(stateInfo.getSoftMemoryLimit(), root.getSoftMemoryLimit()); + assertEquals(stateInfo.getMemoryUsage(), new DataSize(0, BYTE)); + assertEquals(stateInfo.getSubGroups().size(), 2); + assertEquals(stateInfo.getSubGroups().get(0).getId(), rootA.getId()); + assertEquals(stateInfo.getSubGroups().get(1).getId(), rootB.getId()); + assertEquals(stateInfo.getMaxRunningQueries(), root.getMaxRunningQueries()); + assertEquals(stateInfo.getRunningTimeLimit(), new Duration(Long.MAX_VALUE, MILLISECONDS)); + assertEquals(stateInfo.getMaxQueuedQueries(), root.getMaxQueuedQueries()); + assertEquals(stateInfo.getQueuedTimeLimit(), new Duration(Long.MAX_VALUE, MILLISECONDS)); + assertEquals(stateInfo.getNumQueuedQueries(), 19); + assertEquals(stateInfo.getRunningQueries().size(), 1); + QueryStateInfo queryInfo = stateInfo.getRunningQueries().get(0); + assertEquals(queryInfo.getResourceGroupId(), Optional.of(rootB.getId())); + } + @Test public void testGetBlockedQueuedQueries() { diff --git a/presto-main/src/test/java/com/facebook/presto/failureDetector/TestHeartbeatFailureDetector.java b/presto-main/src/test/java/com/facebook/presto/failureDetector/TestHeartbeatFailureDetector.java index 41147d9f9966..fe45b9a824d9 100644 --- a/presto-main/src/test/java/com/facebook/presto/failureDetector/TestHeartbeatFailureDetector.java +++ b/presto-main/src/test/java/com/facebook/presto/failureDetector/TestHeartbeatFailureDetector.java @@ -14,6 +14,10 @@ package com.facebook.presto.failureDetector; import com.facebook.presto.execution.QueryManagerConfig; +import com.facebook.presto.failureDetector.HeartbeatFailureDetector.Stats; +import com.facebook.presto.server.InternalCommunicationConfig; +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; import com.google.inject.Binder; import com.google.inject.Injector; import com.google.inject.Key; @@ -25,6 +29,7 @@ import io.airlift.jaxrs.JaxrsModule; import io.airlift.jmx.testing.TestingJmxModule; import io.airlift.json.JsonModule; +import io.airlift.json.ObjectMapperProvider; import io.airlift.node.testing.TestingNodeModule; import io.airlift.tracetoken.TraceTokenModule; import org.testng.annotations.Test; @@ -32,11 +37,15 @@ import javax.ws.rs.GET; import javax.ws.rs.Path; +import java.net.SocketTimeoutException; +import java.net.URI; + import static io.airlift.configuration.ConfigBinder.configBinder; import static io.airlift.discovery.client.DiscoveryBinder.discoveryBinder; import static io.airlift.discovery.client.ServiceTypes.serviceType; import static io.airlift.jaxrs.JaxrsBinder.jaxrsBinder; import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertFalse; import static org.testng.Assert.assertTrue; public class TestHeartbeatFailureDetector @@ -59,6 +68,7 @@ public void testExcludesCurrentNode() @Override public void configure(Binder binder) { + configBinder(binder).bindConfig(InternalCommunicationConfig.class); configBinder(binder).bindConfig(QueryManagerConfig.class); discoveryBinder(binder).bindSelector("presto"); discoveryBinder(binder).bindHttpAnnouncement("presto"); @@ -86,6 +96,23 @@ public void configure(Binder binder) assertTrue(detector.getFailed().isEmpty()); } + @Test + public void testHeartbeatStatsSerialization() + throws Exception + { + ObjectMapper objectMapper = new ObjectMapperProvider().get(); + Stats stats = new Stats(new URI("http://example.com")); + String serialized = objectMapper.writeValueAsString(stats); + JsonNode deserialized = objectMapper.readTree(serialized); + assertFalse(deserialized.has("lastFailureInfo")); + + stats.recordFailure(new SocketTimeoutException("timeout")); + serialized = objectMapper.writeValueAsString(stats); + deserialized = objectMapper.readTree(serialized); + assertFalse(deserialized.get("lastFailureInfo").isNull()); + assertEquals(deserialized.get("lastFailureInfo").get("type").asText(), SocketTimeoutException.class.getName()); + } + @Path("/foo") public static class FooResource { diff --git a/presto-main/src/test/java/com/facebook/presto/metadata/AbstractMockMetadata.java b/presto-main/src/test/java/com/facebook/presto/metadata/AbstractMockMetadata.java new file mode 100644 index 000000000000..3c0af3616059 --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/metadata/AbstractMockMetadata.java @@ -0,0 +1,406 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.metadata; + +import com.facebook.presto.Session; +import com.facebook.presto.connector.ConnectorId; +import com.facebook.presto.spi.CatalogSchemaName; +import com.facebook.presto.spi.ColumnHandle; +import com.facebook.presto.spi.ColumnIdentity; +import com.facebook.presto.spi.ColumnMetadata; +import com.facebook.presto.spi.ConnectorTableMetadata; +import com.facebook.presto.spi.Constraint; +import com.facebook.presto.spi.TableIdentity; +import com.facebook.presto.spi.block.BlockEncodingSerde; +import com.facebook.presto.spi.connector.ConnectorOutputMetadata; +import com.facebook.presto.spi.predicate.TupleDomain; +import com.facebook.presto.spi.security.GrantInfo; +import com.facebook.presto.spi.security.Privilege; +import com.facebook.presto.spi.statistics.TableStatistics; +import com.facebook.presto.spi.type.Type; +import com.facebook.presto.spi.type.TypeManager; +import com.facebook.presto.spi.type.TypeSignature; +import com.facebook.presto.sql.tree.QualifiedName; +import io.airlift.slice.Slice; + +import java.util.Collection; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.OptionalLong; +import java.util.Set; + +public abstract class AbstractMockMetadata + implements Metadata +{ + @Override + public void verifyComparableOrderableContract() + { + throw new UnsupportedOperationException(); + } + + @Override + public Type getType(TypeSignature signature) + { + throw new UnsupportedOperationException(); + } + + @Override + public boolean isAggregationFunction(QualifiedName name) + { + throw new UnsupportedOperationException(); + } + + @Override + public List listFunctions() + { + throw new UnsupportedOperationException(); + } + + @Override + public void addFunctions(List functions) + { + throw new UnsupportedOperationException(); + } + + @Override + public boolean schemaExists(Session session, CatalogSchemaName schema) + { + throw new UnsupportedOperationException(); + } + + @Override + public List listSchemaNames(Session session, String catalogName) + { + throw new UnsupportedOperationException(); + } + + @Override + public Optional getTableHandle(Session session, QualifiedObjectName tableName) + { + throw new UnsupportedOperationException(); + } + + @Override + public List getLayouts(Session session, TableHandle tableHandle, Constraint constraint, Optional> desiredColumns) + { + throw new UnsupportedOperationException(); + } + + @Override + public TableLayout getLayout(Session session, TableLayoutHandle handle) + { + throw new UnsupportedOperationException(); + } + + @Override + public Optional getInfo(Session session, TableLayoutHandle handle) + { + throw new UnsupportedOperationException(); + } + + @Override + public TableMetadata getTableMetadata(Session session, TableHandle tableHandle) + { + throw new UnsupportedOperationException(); + } + + @Override + public TableStatistics getTableStatistics(Session session, TableHandle tableHandle, Constraint constraint) + { + throw new UnsupportedOperationException(); + } + + @Override + public List listTables(Session session, QualifiedTablePrefix prefix) + { + throw new UnsupportedOperationException(); + } + + @Override + public Map getColumnHandles(Session session, TableHandle tableHandle) + { + throw new UnsupportedOperationException(); + } + + @Override + public ColumnMetadata getColumnMetadata(Session session, TableHandle tableHandle, ColumnHandle columnHandle) + { + throw new UnsupportedOperationException(); + } + + @Override + public Map> listTableColumns(Session session, QualifiedTablePrefix prefix) + { + throw new UnsupportedOperationException(); + } + + @Override + public void createSchema(Session session, CatalogSchemaName schema, Map properties) + { + throw new UnsupportedOperationException(); + } + + @Override + public void dropSchema(Session session, CatalogSchemaName schema) + { + throw new UnsupportedOperationException(); + } + + @Override + public void renameSchema(Session session, CatalogSchemaName source, String target) + { + throw new UnsupportedOperationException(); + } + + @Override + public void createTable(Session session, String catalogName, ConnectorTableMetadata tableMetadata) + { + throw new UnsupportedOperationException(); + } + + @Override + public void renameTable(Session session, TableHandle tableHandle, QualifiedObjectName newTableName) + { + throw new UnsupportedOperationException(); + } + + @Override + public void renameColumn(Session session, TableHandle tableHandle, ColumnHandle source, String target) + { + throw new UnsupportedOperationException(); + } + + @Override + public void addColumn(Session session, TableHandle tableHandle, ColumnMetadata column) + { + throw new UnsupportedOperationException(); + } + + @Override + public void dropTable(Session session, TableHandle tableHandle) + { + throw new UnsupportedOperationException(); + } + + @Override + public TableIdentity getTableIdentity(Session session, TableHandle tableHandle) + { + throw new UnsupportedOperationException(); + } + + @Override + public TableIdentity deserializeTableIdentity(Session session, String catalogName, byte[] bytes) + { + throw new UnsupportedOperationException(); + } + + @Override + public ColumnIdentity getColumnIdentity(Session session, TableHandle tableHandle, ColumnHandle columnHandle) + { + throw new UnsupportedOperationException(); + } + + @Override + public ColumnIdentity deserializeColumnIdentity(Session session, String catalogName, byte[] bytes) + { + throw new UnsupportedOperationException(); + } + + @Override + public Optional getNewTableLayout(Session session, String catalogName, ConnectorTableMetadata tableMetadata) + { + throw new UnsupportedOperationException(); + } + + @Override + public OutputTableHandle beginCreateTable(Session session, String catalogName, ConnectorTableMetadata tableMetadata, Optional layout) + { + throw new UnsupportedOperationException(); + } + + @Override + public Optional finishCreateTable(Session session, OutputTableHandle tableHandle, Collection fragments) + { + throw new UnsupportedOperationException(); + } + + @Override + public Optional getInsertLayout(Session session, TableHandle target) + { + throw new UnsupportedOperationException(); + } + + @Override + public void beginQuery(Session session, Set connectors) + { + throw new UnsupportedOperationException(); + } + + @Override + public void cleanupQuery(Session session) + { + throw new UnsupportedOperationException(); + } + + @Override + public InsertTableHandle beginInsert(Session session, TableHandle tableHandle) + { + throw new UnsupportedOperationException(); + } + + @Override + public Optional finishInsert(Session session, InsertTableHandle tableHandle, Collection fragments) + { + throw new UnsupportedOperationException(); + } + + @Override + public ColumnHandle getUpdateRowIdColumnHandle(Session session, TableHandle tableHandle) + { + throw new UnsupportedOperationException(); + } + + @Override + public boolean supportsMetadataDelete(Session session, TableHandle tableHandle, TableLayoutHandle tableLayoutHandle) + { + throw new UnsupportedOperationException(); + } + + @Override + public OptionalLong metadataDelete(Session session, TableHandle tableHandle, TableLayoutHandle tableLayoutHandle) + { + throw new UnsupportedOperationException(); + } + + @Override + public TableHandle beginDelete(Session session, TableHandle tableHandle) + { + throw new UnsupportedOperationException(); + } + + @Override + public void finishDelete(Session session, TableHandle tableHandle, Collection fragments) + { + throw new UnsupportedOperationException(); + } + + @Override + public Optional getCatalogHandle(Session session, String catalogName) + { + throw new UnsupportedOperationException(); + } + + @Override + public Map getCatalogNames(Session session) + { + throw new UnsupportedOperationException(); + } + + @Override + public List listViews(Session session, QualifiedTablePrefix prefix) + { + throw new UnsupportedOperationException(); + } + + @Override + public Map getViews(Session session, QualifiedTablePrefix prefix) + { + throw new UnsupportedOperationException(); + } + + @Override + public Optional getView(Session session, QualifiedObjectName viewName) + { + throw new UnsupportedOperationException(); + } + + @Override + public void createView(Session session, QualifiedObjectName viewName, String viewData, boolean replace) + { + throw new UnsupportedOperationException(); + } + + @Override + public void dropView(Session session, QualifiedObjectName viewName) + { + throw new UnsupportedOperationException(); + } + + @Override + public Optional resolveIndex(Session session, TableHandle tableHandle, Set indexableColumns, Set outputColumns, TupleDomain tupleDomain) + { + throw new UnsupportedOperationException(); + } + + @Override + public void grantTablePrivileges(Session session, QualifiedObjectName tableName, Set privileges, String grantee, boolean grantOption) + { + throw new UnsupportedOperationException(); + } + + @Override + public void revokeTablePrivileges(Session session, QualifiedObjectName tableName, Set privileges, String grantee, boolean grantOption) + { + throw new UnsupportedOperationException(); + } + + @Override + public List listTablePrivileges(Session session, QualifiedTablePrefix prefix) + { + throw new UnsupportedOperationException(); + } + + @Override + public FunctionRegistry getFunctionRegistry() + { + throw new UnsupportedOperationException(); + } + + @Override + public ProcedureRegistry getProcedureRegistry() + { + throw new UnsupportedOperationException(); + } + + @Override + public TypeManager getTypeManager() + { + throw new UnsupportedOperationException(); + } + + @Override + public BlockEncodingSerde getBlockEncodingSerde() + { + throw new UnsupportedOperationException(); + } + + @Override + public SessionPropertyManager getSessionPropertyManager() + { + throw new UnsupportedOperationException(); + } + + @Override + public SchemaPropertyManager getSchemaPropertyManager() + { + throw new UnsupportedOperationException(); + } + + @Override + public TablePropertyManager getTablePropertyManager() + { + throw new UnsupportedOperationException(); + } +} diff --git a/presto-main/src/test/java/com/facebook/presto/metadata/DummyMetadata.java b/presto-main/src/test/java/com/facebook/presto/metadata/DummyMetadata.java index 2a3d7596fe4d..312236889238 100644 --- a/presto-main/src/test/java/com/facebook/presto/metadata/DummyMetadata.java +++ b/presto-main/src/test/java/com/facebook/presto/metadata/DummyMetadata.java @@ -41,7 +41,8 @@ import java.util.OptionalLong; import java.util.Set; -public class DummyMetadata implements Metadata +public class DummyMetadata + implements Metadata { @Override public void verifyComparableOrderableContract() @@ -93,10 +94,10 @@ public Optional getTableHandle(Session session, QualifiedObjectName @Override public List getLayouts( - Session session, - TableHandle tableHandle, - Constraint constraint, - Optional> desiredColumns) + Session session, + TableHandle tableHandle, + Constraint constraint, + Optional> desiredColumns) { throw new UnsupportedOperationException(); } @@ -191,6 +192,12 @@ public void addColumn(Session session, TableHandle tableHandle, ColumnMetadata c throw new UnsupportedOperationException(); } + @Override + public void dropColumn(Session session, TableHandle tableHandle, ColumnHandle column) + { + throw new UnsupportedOperationException(); + } + @Override public void dropTable(Session session, TableHandle tableHandle) { @@ -343,11 +350,11 @@ public void dropView(Session session, QualifiedObjectName viewName) @Override public Optional resolveIndex( - Session session, - TableHandle tableHandle, - Set indexableColumns, - Set outputColumns, - TupleDomain tupleDomain) + Session session, + TableHandle tableHandle, + Set indexableColumns, + Set outputColumns, + TupleDomain tupleDomain) { throw new UnsupportedOperationException(); } diff --git a/presto-main/src/test/java/com/facebook/presto/metadata/TestDiscoveryNodeManager.java b/presto-main/src/test/java/com/facebook/presto/metadata/TestDiscoveryNodeManager.java index 0951026ac06c..9faa2bcbea12 100644 --- a/presto-main/src/test/java/com/facebook/presto/metadata/TestDiscoveryNodeManager.java +++ b/presto-main/src/test/java/com/facebook/presto/metadata/TestDiscoveryNodeManager.java @@ -14,6 +14,7 @@ package com.facebook.presto.metadata; import com.facebook.presto.client.NodeVersion; +import com.facebook.presto.server.InternalCommunicationConfig; import com.facebook.presto.server.NoOpFailureDetector; import com.facebook.presto.spi.Node; import com.google.common.collect.ArrayListMultimap; @@ -49,6 +50,7 @@ public class TestDiscoveryNodeManager { private final NodeInfo nodeInfo = new NodeInfo("test"); + private final InternalCommunicationConfig internalCommunicationConfig = new InternalCommunicationConfig(); private NodeVersion expectedVersion; private List activeNodes; private List inactiveNodes; @@ -90,7 +92,7 @@ public void setup() public void testGetAllNodes() throws Exception { - DiscoveryNodeManager manager = new DiscoveryNodeManager(selector, nodeInfo, new NoOpFailureDetector(), expectedVersion, testHttpClient); + DiscoveryNodeManager manager = new DiscoveryNodeManager(selector, nodeInfo, new NoOpFailureDetector(), expectedVersion, testHttpClient, internalCommunicationConfig); AllNodes allNodes = manager.getAllNodes(); Set activeNodes = allNodes.getActiveNodes(); @@ -125,7 +127,7 @@ public void testGetCurrentNode() .setEnvironment("test") .setNodeId(expected.getNodeIdentifier())); - DiscoveryNodeManager manager = new DiscoveryNodeManager(selector, nodeInfo, new NoOpFailureDetector(), expectedVersion, testHttpClient); + DiscoveryNodeManager manager = new DiscoveryNodeManager(selector, nodeInfo, new NoOpFailureDetector(), expectedVersion, testHttpClient, internalCommunicationConfig); assertEquals(manager.getCurrentNode(), expected); } @@ -134,7 +136,7 @@ public void testGetCurrentNode() public void testGetCoordinators() throws Exception { - InternalNodeManager manager = new DiscoveryNodeManager(selector, nodeInfo, new NoOpFailureDetector(), expectedVersion, testHttpClient); + InternalNodeManager manager = new DiscoveryNodeManager(selector, nodeInfo, new NoOpFailureDetector(), expectedVersion, testHttpClient, internalCommunicationConfig); assertEquals(manager.getCoordinators(), ImmutableSet.of(coordinator)); } @@ -142,6 +144,6 @@ public void testGetCoordinators() @Test(expectedExceptions = IllegalStateException.class, expectedExceptionsMessageRegExp = ".* current node not returned .*") public void testGetCurrentNodeRequired() { - new DiscoveryNodeManager(selector, new NodeInfo("test"), new NoOpFailureDetector(), expectedVersion, testHttpClient); + new DiscoveryNodeManager(selector, new NodeInfo("test"), new NoOpFailureDetector(), expectedVersion, testHttpClient, internalCommunicationConfig); } } diff --git a/presto-main/src/test/java/com/facebook/presto/operator/PageAssertions.java b/presto-main/src/test/java/com/facebook/presto/operator/PageAssertions.java index 1b183e93255d..7dbf77349159 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/PageAssertions.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/PageAssertions.java @@ -31,6 +31,7 @@ public static void assertPageEquals(List types, Page actualPage, { assertEquals(types.size(), actualPage.getChannelCount()); assertEquals(actualPage.getChannelCount(), expectedPage.getChannelCount()); + assertEquals(actualPage.getPositionCount(), expectedPage.getPositionCount()); for (int i = 0; i < actualPage.getChannelCount(); i++) { assertBlockEquals(types.get(i), actualPage.getBlock(i), expectedPage.getBlock(i)); } diff --git a/presto-main/src/test/java/com/facebook/presto/operator/TestUnnestOperator.java b/presto-main/src/test/java/com/facebook/presto/operator/TestUnnestOperator.java index 80fe230f7027..77820223ea81 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/TestUnnestOperator.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/TestUnnestOperator.java @@ -15,10 +15,10 @@ import com.facebook.presto.metadata.MetadataManager; import com.facebook.presto.spi.Page; +import com.facebook.presto.spi.type.ArrayType; import com.facebook.presto.spi.type.Type; import com.facebook.presto.sql.planner.plan.PlanNodeId; import com.facebook.presto.testing.MaterializedResult; -import com.facebook.presto.type.ArrayType; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import org.testng.annotations.AfterMethod; diff --git a/presto-main/src/test/java/com/facebook/presto/operator/aggregation/BenchmarkArrayAggregation.java b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/BenchmarkArrayAggregation.java index 00d882e7fa91..115860fa28a1 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/aggregation/BenchmarkArrayAggregation.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/BenchmarkArrayAggregation.java @@ -19,8 +19,8 @@ import com.facebook.presto.spi.block.Block; import com.facebook.presto.spi.block.BlockBuilder; import com.facebook.presto.spi.block.BlockBuilderStatus; +import com.facebook.presto.spi.type.ArrayType; import com.facebook.presto.spi.type.Type; -import com.facebook.presto.type.ArrayType; import com.google.common.collect.ImmutableList; import io.airlift.slice.Slices; import org.openjdk.jmh.annotations.Benchmark; diff --git a/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestApproximatePercentileAggregation.java b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestApproximatePercentileAggregation.java index 1c3b0cd3fcaf..93ba6b0ea5c5 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestApproximatePercentileAggregation.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestApproximatePercentileAggregation.java @@ -18,7 +18,7 @@ import com.facebook.presto.spi.block.BlockBuilder; import com.facebook.presto.spi.block.BlockBuilderStatus; import com.facebook.presto.spi.block.RunLengthEncodedBlock; -import com.facebook.presto.type.ArrayType; +import com.facebook.presto.spi.type.ArrayType; import com.google.common.collect.ImmutableList; import org.testng.annotations.Test; diff --git a/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestArrayMaxNAggregation.java b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestArrayMaxNAggregation.java index 8d76cd96ab4a..e3fc1315b7f7 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestArrayMaxNAggregation.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestArrayMaxNAggregation.java @@ -17,8 +17,8 @@ import com.facebook.presto.spi.block.Block; import com.facebook.presto.spi.block.BlockBuilder; import com.facebook.presto.spi.block.BlockBuilderStatus; +import com.facebook.presto.spi.type.ArrayType; import com.facebook.presto.spi.type.StandardTypes; -import com.facebook.presto.type.ArrayType; import com.google.common.collect.ImmutableList; import com.google.common.collect.Lists; import org.testng.annotations.Test; diff --git a/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestArrayMinAggregation.java b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestArrayMinAggregation.java index 7defcc9502ef..cbc72565508c 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestArrayMinAggregation.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestArrayMinAggregation.java @@ -16,7 +16,7 @@ import com.facebook.presto.spi.block.Block; import com.facebook.presto.spi.block.BlockBuilder; import com.facebook.presto.spi.block.BlockBuilderStatus; -import com.facebook.presto.type.ArrayType; +import com.facebook.presto.spi.type.ArrayType; import com.google.common.collect.ImmutableList; import java.util.List; diff --git a/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestChecksumAggregation.java b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestChecksumAggregation.java index 9f10c3c55230..0b24dbf988aa 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestChecksumAggregation.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestChecksumAggregation.java @@ -16,6 +16,7 @@ import com.facebook.presto.metadata.MetadataManager; import com.facebook.presto.metadata.Signature; import com.facebook.presto.spi.block.Block; +import com.facebook.presto.spi.type.ArrayType; import com.facebook.presto.spi.type.BigintType; import com.facebook.presto.spi.type.BooleanType; import com.facebook.presto.spi.type.DecimalType; @@ -24,7 +25,6 @@ import com.facebook.presto.spi.type.Type; import com.facebook.presto.spi.type.VarbinaryType; import com.facebook.presto.spi.type.VarcharType; -import com.facebook.presto.type.ArrayType; import org.testng.annotations.Test; import static com.facebook.presto.block.BlockAssertions.createArrayBigintBlock; diff --git a/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestDoubleHistogramAggregation.java b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestDoubleHistogramAggregation.java index 931d7be5e81a..7d5c1ed6be99 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestDoubleHistogramAggregation.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestDoubleHistogramAggregation.java @@ -20,9 +20,9 @@ import com.facebook.presto.spi.PageBuilder; import com.facebook.presto.spi.PrestoException; import com.facebook.presto.spi.block.Block; +import com.facebook.presto.spi.type.MapType; import com.facebook.presto.spi.type.StandardTypes; import com.facebook.presto.sql.analyzer.FeaturesConfig; -import com.facebook.presto.type.MapType; import com.facebook.presto.type.TypeRegistry; import com.google.common.collect.ImmutableList; import com.google.common.collect.Maps; diff --git a/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestHistogram.java b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestHistogram.java index 95458473a594..9a905964b917 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestHistogram.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestHistogram.java @@ -17,12 +17,12 @@ import com.facebook.presto.metadata.Signature; import com.facebook.presto.spi.block.BlockBuilder; import com.facebook.presto.spi.block.BlockBuilderStatus; +import com.facebook.presto.spi.type.ArrayType; +import com.facebook.presto.spi.type.MapType; +import com.facebook.presto.spi.type.RowType; import com.facebook.presto.spi.type.SqlTimestampWithTimeZone; import com.facebook.presto.spi.type.StandardTypes; import com.facebook.presto.spi.type.TimeZoneKey; -import com.facebook.presto.type.ArrayType; -import com.facebook.presto.type.MapType; -import com.facebook.presto.type.RowType; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import org.joda.time.DateTime; diff --git a/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestMapAggAggregation.java b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestMapAggAggregation.java index 5ce58e61a428..e8fe4268fbb9 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestMapAggAggregation.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestMapAggAggregation.java @@ -17,10 +17,10 @@ import com.facebook.presto.metadata.Signature; import com.facebook.presto.spi.block.BlockBuilder; import com.facebook.presto.spi.block.BlockBuilderStatus; +import com.facebook.presto.spi.type.ArrayType; +import com.facebook.presto.spi.type.MapType; +import com.facebook.presto.spi.type.RowType; import com.facebook.presto.spi.type.StandardTypes; -import com.facebook.presto.type.ArrayType; -import com.facebook.presto.type.MapType; -import com.facebook.presto.type.RowType; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import org.testng.annotations.Test; diff --git a/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestMapUnionAggregation.java b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestMapUnionAggregation.java index 6ab370ea8580..5423fa84725c 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestMapUnionAggregation.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestMapUnionAggregation.java @@ -15,8 +15,8 @@ import com.facebook.presto.metadata.MetadataManager; import com.facebook.presto.metadata.Signature; -import com.facebook.presto.type.ArrayType; -import com.facebook.presto.type.MapType; +import com.facebook.presto.spi.type.ArrayType; +import com.facebook.presto.spi.type.MapType; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import org.testng.annotations.Test; diff --git a/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestMinMaxByAggregation.java b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestMinMaxByAggregation.java index 7603ecda45de..0a2595aa9020 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestMinMaxByAggregation.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestMinMaxByAggregation.java @@ -17,10 +17,10 @@ import com.facebook.presto.metadata.Signature; import com.facebook.presto.operator.aggregation.state.StateCompiler; import com.facebook.presto.spi.type.DecimalType; +import com.facebook.presto.spi.type.RowType; import com.facebook.presto.spi.type.SqlDecimal; import com.facebook.presto.spi.type.StandardTypes; import com.facebook.presto.spi.type.Type; -import com.facebook.presto.type.RowType; import com.google.common.collect.ImmutableList; import org.testng.annotations.Test; diff --git a/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestMultimapAggAggregation.java b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestMultimapAggAggregation.java index 8640ec1a7ab5..1a2cc3284f30 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestMultimapAggAggregation.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestMultimapAggAggregation.java @@ -16,10 +16,10 @@ import com.facebook.presto.RowPageBuilder; import com.facebook.presto.metadata.MetadataManager; import com.facebook.presto.metadata.Signature; +import com.facebook.presto.spi.type.ArrayType; +import com.facebook.presto.spi.type.MapType; +import com.facebook.presto.spi.type.RowType; import com.facebook.presto.spi.type.Type; -import com.facebook.presto.type.ArrayType; -import com.facebook.presto.type.MapType; -import com.facebook.presto.type.RowType; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import org.testng.annotations.Test; diff --git a/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestRealHistogramAggregation.java b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestRealHistogramAggregation.java index ea407048391c..9429b256958f 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestRealHistogramAggregation.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestRealHistogramAggregation.java @@ -20,9 +20,9 @@ import com.facebook.presto.spi.PageBuilder; import com.facebook.presto.spi.PrestoException; import com.facebook.presto.spi.block.Block; +import com.facebook.presto.spi.type.MapType; import com.facebook.presto.spi.type.StandardTypes; import com.facebook.presto.sql.analyzer.FeaturesConfig; -import com.facebook.presto.type.MapType; import com.facebook.presto.type.TypeRegistry; import com.google.common.collect.ImmutableList; import com.google.common.collect.Maps; diff --git a/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestStateCompiler.java b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestStateCompiler.java index 9c3fd0747502..0ace09355d23 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestStateCompiler.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestStateCompiler.java @@ -18,6 +18,7 @@ import com.facebook.presto.array.ByteBigArray; import com.facebook.presto.array.DoubleBigArray; import com.facebook.presto.array.LongBigArray; +import com.facebook.presto.array.ReferenceCountMap; import com.facebook.presto.array.SliceBigArray; import com.facebook.presto.bytecode.DynamicClassLoader; import com.facebook.presto.operator.aggregation.state.LongState; @@ -32,9 +33,9 @@ import com.facebook.presto.spi.function.AccumulatorStateFactory; import com.facebook.presto.spi.function.AccumulatorStateSerializer; import com.facebook.presto.spi.function.GroupedAccumulatorState; +import com.facebook.presto.spi.type.ArrayType; +import com.facebook.presto.spi.type.RowType; import com.facebook.presto.spi.type.Type; -import com.facebook.presto.type.ArrayType; -import com.facebook.presto.type.RowType; import com.facebook.presto.util.Reflection; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; @@ -253,6 +254,7 @@ private long getSize(Slice slice) private long getComplexStateRetainedSize(TestComplexState state) { long retainedSize = ClassLayout.parseClass(state.getClass()).instanceSize(); + // reflection is necessary because TestComplexState implementation is generated Field[] fields = state.getClass().getDeclaredFields(); try { for (Field field : fields) { @@ -271,6 +273,34 @@ private long getComplexStateRetainedSize(TestComplexState state) return retainedSize; } + private static long getBlockBigArrayReferenceCountMapOverhead(TestComplexState state) + { + long overhead = 0; + // reflection is necessary because TestComplexState implementation is generated + Field[] stateFields = state.getClass().getDeclaredFields(); + try { + for (Field stateField : stateFields) { + if (stateField.getType() != BlockBigArray.class) { + continue; + } + stateField.setAccessible(true); + Field[] blockBigArrayFields = stateField.getType().getDeclaredFields(); + for (Field blockBigArrayField : blockBigArrayFields) { + if (blockBigArrayField.getType() != ReferenceCountMap.class) { + continue; + } + blockBigArrayField.setAccessible(true); + MethodHandle sizeOf = Reflection.methodHandle(blockBigArrayField.getType(), "sizeOf", null); + overhead += (long) sizeOf.invokeWithArguments(blockBigArrayField.get(stateField.get(state))); + } + } + } + catch (Throwable t) { + throw new RuntimeException(t); + } + return overhead; + } + @Test public void testComplexStateEstimatedSize() { @@ -280,6 +310,9 @@ public void testComplexStateEstimatedSize() TestComplexState groupedState = factory.createGroupedState(); long initialRetainedSize = getComplexStateRetainedSize(groupedState); assertEquals(groupedState.getEstimatedSize(), initialRetainedSize); + // BlockBigArray has an internal map that can grow in size when getting more blocks + // need to handle the map overhead separately + initialRetainedSize -= getBlockBigArrayReferenceCountMapOverhead(groupedState); for (int i = 0; i < 1000; i++) { long retainedSize = 0; ((GroupedAccumulatorState) groupedState).setGroupId(i); @@ -303,7 +336,7 @@ public void testComplexStateEstimatedSize() Block map = mapBlockBuilder.build(); retainedSize += map.getRetainedSizeInBytes(); groupedState.setAnotherBlock(map); - assertEquals(groupedState.getEstimatedSize(), initialRetainedSize + retainedSize * (i + 1)); + assertEquals(groupedState.getEstimatedSize(), initialRetainedSize + retainedSize * (i + 1) + getBlockBigArrayReferenceCountMapOverhead(groupedState)); } for (int i = 0; i < 1000; i++) { @@ -329,7 +362,7 @@ public void testComplexStateEstimatedSize() Block map = mapBlockBuilder.build(); retainedSize += map.getRetainedSizeInBytes(); groupedState.setAnotherBlock(map); - assertEquals(groupedState.getEstimatedSize(), initialRetainedSize + retainedSize * 1000); + assertEquals(groupedState.getEstimatedSize(), initialRetainedSize + retainedSize * 1000 + getBlockBigArrayReferenceCountMapOverhead(groupedState)); } } diff --git a/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestTypedHistogram.java b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestTypedHistogram.java index 473943585e83..ad98cfdee0a1 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestTypedHistogram.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestTypedHistogram.java @@ -16,12 +16,14 @@ import com.facebook.presto.spi.block.Block; import com.facebook.presto.spi.block.BlockBuilder; import com.facebook.presto.spi.block.BlockBuilderStatus; +import com.facebook.presto.spi.type.MapType; import org.testng.annotations.Test; import java.util.function.IntUnaryOperator; import java.util.stream.IntStream; import static com.facebook.presto.spi.type.BigintType.BIGINT; +import static com.facebook.presto.util.StructuralTestUtil.mapType; import static org.testng.Assert.assertEquals; public class TestTypedHistogram @@ -42,7 +44,10 @@ public void testMassive() typedHistogram.add(i, inputBlock, 1); } - Block outputBlock = typedHistogram.serialize(); + MapType mapType = mapType(BIGINT, BIGINT); + BlockBuilder out = mapType.createBlockBuilder(new BlockBuilderStatus(), 1); + typedHistogram.serialize(out); + Block outputBlock = mapType.getObject(out, 0); for (int i = 0; i < outputBlock.getPositionCount(); i += 2) { assertEquals(BIGINT.getLong(outputBlock, i + 1), BIGINT.getLong(outputBlock, i)); } diff --git a/presto-main/src/test/java/com/facebook/presto/operator/index/TestFieldSetFilteringRecordSet.java b/presto-main/src/test/java/com/facebook/presto/operator/index/TestFieldSetFilteringRecordSet.java index 8c00c58fc55d..1e30485be359 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/index/TestFieldSetFilteringRecordSet.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/index/TestFieldSetFilteringRecordSet.java @@ -17,9 +17,9 @@ import com.facebook.presto.metadata.FunctionRegistry; import com.facebook.presto.spi.InMemoryRecordSet; import com.facebook.presto.spi.RecordCursor; +import com.facebook.presto.spi.type.ArrayType; import com.facebook.presto.spi.type.TypeManager; import com.facebook.presto.sql.analyzer.FeaturesConfig; -import com.facebook.presto.type.ArrayType; import com.facebook.presto.type.TypeRegistry; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; diff --git a/presto-main/src/test/java/com/facebook/presto/operator/project/TestDictionaryAwarePageFilter.java b/presto-main/src/test/java/com/facebook/presto/operator/project/TestDictionaryAwarePageFilter.java index ae4a37e4ec3c..7c2ae8ddb32b 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/project/TestDictionaryAwarePageFilter.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/project/TestDictionaryAwarePageFilter.java @@ -32,6 +32,7 @@ import static com.facebook.presto.block.BlockAssertions.createLongSequenceBlock; import static com.facebook.presto.block.BlockAssertions.createLongsBlock; import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertThrows; import static org.testng.Assert.assertTrue; public class TestDictionaryAwarePageFilter @@ -70,6 +71,15 @@ private static void testRleBlock(boolean filterRange) testFilter(filter, noMatch, filterRange); } + @Test + public void testRleBlockWithFailure() + throws Exception + { + DictionaryAwarePageFilter filter = createDictionaryAwarePageFilter(true, LongArrayBlock.class); + RunLengthEncodedBlock fail = new RunLengthEncodedBlock(createLongSequenceBlock(-10, -9), 100); + assertThrows(NegativeValueException.class, () -> testFilter(filter, fail, true)); + } + @Test public void testDictionaryBlock() throws Exception @@ -81,7 +91,14 @@ public void testDictionaryBlock() testFilter(createDictionaryBlock(20, 0), LongArrayBlock.class); // match all - testFilter(new DictionaryBlock(100, createLongSequenceBlock(4, 5), new int[100]), LongArrayBlock.class); + testFilter(new DictionaryBlock(createLongSequenceBlock(4, 5), new int[100]), LongArrayBlock.class); + } + + @Test + public void testDictionaryBlockWithFailure() + throws Exception + { + assertThrows(NegativeValueException.class, () -> testFilter(createDictionaryBlockWithFailure(20, 100), LongArrayBlock.class)); } @Test @@ -95,7 +112,7 @@ public void testDictionaryBlockProcessingWithUnusedFailure() testFilter(createDictionaryBlockWithUnusedEntries(20, 0), DictionaryBlock.class); // match all - testFilter(new DictionaryBlock(100, createLongsBlock(4, 5, -1), new int[100]), DictionaryBlock.class); + testFilter(new DictionaryBlock(createLongsBlock(4, 5, -1), new int[100]), DictionaryBlock.class); } @Test @@ -130,7 +147,15 @@ private static DictionaryBlock createDictionaryBlock(int dictionarySize, int blo Block dictionary = createLongSequenceBlock(0, dictionarySize); int[] ids = new int[blockSize]; Arrays.setAll(ids, index -> index % dictionarySize); - return new DictionaryBlock(blockSize, dictionary, ids); + return new DictionaryBlock(dictionary, ids); + } + + private static DictionaryBlock createDictionaryBlockWithFailure(int dictionarySize, int blockSize) + { + Block dictionary = createLongSequenceBlock(-10, dictionarySize - 10); + int[] ids = new int[blockSize]; + Arrays.setAll(ids, index -> index % dictionarySize); + return new DictionaryBlock(dictionary, ids); } private static DictionaryBlock createDictionaryBlockWithUnusedEntries(int dictionarySize, int blockSize) @@ -138,7 +163,7 @@ private static DictionaryBlock createDictionaryBlockWithUnusedEntries(int dictio Block dictionary = createLongSequenceBlock(-10, dictionarySize); int[] ids = new int[blockSize]; Arrays.setAll(ids, index -> (index % dictionarySize) + 10); - return new DictionaryBlock(blockSize, dictionary, ids); + return new DictionaryBlock(dictionary, ids); } private static void testFilter(Block block, Class expectedType) @@ -257,6 +282,7 @@ public SelectedPositions filter(ConnectorSession session, Page page) IntArrayList selectedPositions = new IntArrayList(); for (int position = 0; position < block.getPositionCount(); position++) { long value = block.getLong(position, 0); + verifyPositive(value); boolean selected = isSelected(filterRange, value); if (selected) { @@ -286,5 +312,22 @@ public SelectedPositions filter(ConnectorSession session, Page page) return SelectedPositions.positionsList(selectedPositions.elements(), 3, selectedPositions.size() - 6); } + + private static long verifyPositive(long value) + { + if (value < 0) { + throw new NegativeValueException(value); + } + return value; + } + } + + private static class NegativeValueException + extends RuntimeException + { + public NegativeValueException(long value) + { + super("value is negative: " + value); + } } } diff --git a/presto-main/src/test/java/com/facebook/presto/operator/project/TestDictionaryAwarePageProjection.java b/presto-main/src/test/java/com/facebook/presto/operator/project/TestDictionaryAwarePageProjection.java index 66342b525c7c..7d20c3f82ca0 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/project/TestDictionaryAwarePageProjection.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/project/TestDictionaryAwarePageProjection.java @@ -35,6 +35,7 @@ import static com.facebook.presto.spi.type.BigintType.BIGINT; import static io.airlift.testing.Assertions.assertInstanceOf; import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertThrows; public class TestDictionaryAwarePageProjection { @@ -66,6 +67,16 @@ public void testRleBlock() testProject(block, RunLengthEncodedBlock.class); } + @Test + public void testRleBlockWithFailure() + throws Exception + { + Block value = createLongSequenceBlock(-43, -42); + RunLengthEncodedBlock block = new RunLengthEncodedBlock(value, 100); + + testProjectFails(block, RunLengthEncodedBlock.class); + } + @Test public void testDictionaryBlock() throws Exception @@ -75,6 +86,15 @@ public void testDictionaryBlock() testProject(block, DictionaryBlock.class); } + @Test + public void testDictionaryBlockWithFailure() + throws Exception + { + DictionaryBlock block = createDictionaryBlockWithFailure(10, 100); + + testProjectFails(block, DictionaryBlock.class); + } + @Test public void testDictionaryBlockProcessingWithUnusedFailure() throws Exception @@ -115,7 +135,15 @@ private static DictionaryBlock createDictionaryBlock(int dictionarySize, int blo Block dictionary = createLongSequenceBlock(0, dictionarySize); int[] ids = new int[blockSize]; Arrays.setAll(ids, index -> index % dictionarySize); - return new DictionaryBlock(blockSize, dictionary, ids); + return new DictionaryBlock(dictionary, ids); + } + + private static DictionaryBlock createDictionaryBlockWithFailure(int dictionarySize, int blockSize) + { + Block dictionary = createLongSequenceBlock(-10, dictionarySize - 10); + int[] ids = new int[blockSize]; + Arrays.setAll(ids, index -> index % dictionarySize); + return new DictionaryBlock(dictionary, ids); } private static DictionaryBlock createDictionaryBlockWithUnusedEntries(int dictionarySize, int blockSize) @@ -123,7 +151,7 @@ private static DictionaryBlock createDictionaryBlockWithUnusedEntries(int dictio Block dictionary = createLongSequenceBlock(-10, dictionarySize); int[] ids = new int[blockSize]; Arrays.setAll(ids, index -> (index % dictionarySize) + 10); - return new DictionaryBlock(blockSize, dictionary, ids); + return new DictionaryBlock(dictionary, ids); } private static void testProject(Block block, Class expectedResultType) @@ -134,6 +162,14 @@ private static void testProject(Block block, Class expectedResu testProjectList(lazyWrapper(block), expectedResultType, createProjection()); } + private static void testProjectFails(Block block, Class expectedResultType) + { + assertThrows(NegativeValueException.class, () -> testProjectRange(block, expectedResultType, createProjection())); + assertThrows(NegativeValueException.class, () -> testProjectList(block, expectedResultType, createProjection())); + assertThrows(NegativeValueException.class, () -> testProjectRange(lazyWrapper(block), expectedResultType, createProjection())); + assertThrows(NegativeValueException.class, () -> testProjectList(lazyWrapper(block), expectedResultType, createProjection())); + } + private static void testProjectRange(Block block, Class expectedResultType, DictionaryAwarePageProjection projection) { Block result = projection.project(null, new Page(block), SelectedPositions.positionsRange(5, 10)); @@ -212,9 +248,18 @@ public Block project(ConnectorSession session, Page page, SelectedPositions sele private static long verifyPositive(long value) { if (value < 0) { - throw new IllegalArgumentException("value is negative: " + value); + throw new NegativeValueException(value); } return value; } } + + private static class NegativeValueException + extends RuntimeException + { + public NegativeValueException(long value) + { + super("value is negative: " + value); + } + } } diff --git a/presto-main/src/test/java/com/facebook/presto/operator/scalar/AbstractTestFunctions.java b/presto-main/src/test/java/com/facebook/presto/operator/scalar/AbstractTestFunctions.java index ab900e0de487..8376cf45da03 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/scalar/AbstractTestFunctions.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/scalar/AbstractTestFunctions.java @@ -40,6 +40,7 @@ import static com.facebook.presto.metadata.FunctionRegistry.mangleOperatorName; import static com.facebook.presto.spi.StandardErrorCode.INVALID_CAST_ARGUMENT; import static com.facebook.presto.spi.StandardErrorCode.INVALID_FUNCTION_ARGUMENT; +import static com.facebook.presto.spi.StandardErrorCode.NOT_SUPPORTED; import static com.facebook.presto.spi.StandardErrorCode.NUMERIC_VALUE_OUT_OF_RANGE; import static com.facebook.presto.spi.type.DecimalType.createDecimalType; import static com.facebook.presto.type.UnknownType.UNKNOWN; @@ -198,6 +199,28 @@ protected void assertInvalidCast(String projection, String message) } } + public void assertCachedInstanceHasBoundedRetainedSize(String projection) + { + functionAssertions.assertCachedInstanceHasBoundedRetainedSize(projection); + } + + protected void assertNotSupported(String projection, String message) + { + try { + functionAssertions.executeProjectionWithFullEngine(projection); + fail("expected exception"); + } + catch (PrestoException e) { + assertEquals(e.getErrorCode(), NOT_SUPPORTED.toErrorCode()); + assertEquals(e.getMessage(), message); + } + } + + protected void tryEvaluateWithAll(String projection, Type expectedType) + { + functionAssertions.tryEvaluateWithAll(projection, expectedType); + } + protected void registerScalarFunction(SqlScalarFunction sqlScalarFunction) { Metadata metadata = functionAssertions.getMetadata(); diff --git a/presto-main/src/test/java/com/facebook/presto/operator/scalar/BenchmarkArrayDistinct.java b/presto-main/src/test/java/com/facebook/presto/operator/scalar/BenchmarkArrayDistinct.java index 01a78e2dfe57..223986fee4c8 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/scalar/BenchmarkArrayDistinct.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/scalar/BenchmarkArrayDistinct.java @@ -24,11 +24,11 @@ import com.facebook.presto.spi.block.BlockBuilderStatus; import com.facebook.presto.spi.function.ScalarFunction; import com.facebook.presto.spi.function.SqlType; +import com.facebook.presto.spi.type.ArrayType; import com.facebook.presto.spi.type.Type; import com.facebook.presto.sql.gen.ExpressionCompiler; import com.facebook.presto.sql.relational.CallExpression; import com.facebook.presto.sql.relational.RowExpression; -import com.facebook.presto.type.ArrayType; import com.google.common.base.Verify; import com.google.common.collect.ImmutableList; import io.airlift.slice.Slices; diff --git a/presto-main/src/test/java/com/facebook/presto/operator/scalar/BenchmarkArrayFilter.java b/presto-main/src/test/java/com/facebook/presto/operator/scalar/BenchmarkArrayFilter.java index ea46c75a949a..e2d50932da7e 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/scalar/BenchmarkArrayFilter.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/scalar/BenchmarkArrayFilter.java @@ -25,6 +25,7 @@ import com.facebook.presto.spi.block.Block; import com.facebook.presto.spi.block.BlockBuilder; import com.facebook.presto.spi.block.BlockBuilderStatus; +import com.facebook.presto.spi.type.ArrayType; import com.facebook.presto.spi.type.Type; import com.facebook.presto.spi.type.TypeManager; import com.facebook.presto.sql.gen.ExpressionCompiler; @@ -32,7 +33,6 @@ import com.facebook.presto.sql.relational.LambdaDefinitionExpression; import com.facebook.presto.sql.relational.RowExpression; import com.facebook.presto.sql.relational.VariableReferenceExpression; -import com.facebook.presto.type.ArrayType; import com.google.common.base.Throwables; import com.google.common.base.Verify; import com.google.common.collect.ImmutableList; diff --git a/presto-main/src/test/java/com/facebook/presto/operator/scalar/BenchmarkArrayHashCodeOperator.java b/presto-main/src/test/java/com/facebook/presto/operator/scalar/BenchmarkArrayHashCodeOperator.java index 5a0bb6fb1adf..0e8948c96480 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/scalar/BenchmarkArrayHashCodeOperator.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/scalar/BenchmarkArrayHashCodeOperator.java @@ -28,12 +28,12 @@ import com.facebook.presto.spi.function.ScalarFunction; import com.facebook.presto.spi.function.SqlType; import com.facebook.presto.spi.function.TypeParameter; +import com.facebook.presto.spi.type.ArrayType; import com.facebook.presto.spi.type.StandardTypes; import com.facebook.presto.spi.type.Type; import com.facebook.presto.sql.gen.ExpressionCompiler; import com.facebook.presto.sql.relational.CallExpression; import com.facebook.presto.sql.relational.RowExpression; -import com.facebook.presto.type.ArrayType; import com.google.common.collect.ImmutableList; import io.airlift.slice.Slices; import org.openjdk.jmh.annotations.Benchmark; @@ -62,13 +62,13 @@ import static com.facebook.presto.operator.scalar.CombineHashFunction.getHash; import static com.facebook.presto.spi.function.OperatorType.HASH_CODE; +import static com.facebook.presto.spi.type.ArrayType.ARRAY_NULL_ELEMENT_MSG; import static com.facebook.presto.spi.type.BigintType.BIGINT; import static com.facebook.presto.spi.type.BooleanType.BOOLEAN; import static com.facebook.presto.spi.type.DoubleType.DOUBLE; import static com.facebook.presto.spi.type.VarcharType.VARCHAR; import static com.facebook.presto.sql.relational.Expressions.field; import static com.facebook.presto.testing.TestingConnectorSession.SESSION; -import static com.facebook.presto.type.ArrayType.ARRAY_NULL_ELEMENT_MSG; import static com.facebook.presto.type.TypeUtils.checkElementNotNull; import static com.facebook.presto.type.TypeUtils.hashPosition; diff --git a/presto-main/src/test/java/com/facebook/presto/operator/scalar/BenchmarkArrayJoin.java b/presto-main/src/test/java/com/facebook/presto/operator/scalar/BenchmarkArrayJoin.java index be08057dfc49..227884d58385 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/scalar/BenchmarkArrayJoin.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/scalar/BenchmarkArrayJoin.java @@ -21,10 +21,10 @@ import com.facebook.presto.spi.block.Block; import com.facebook.presto.spi.block.BlockBuilder; import com.facebook.presto.spi.block.BlockBuilderStatus; +import com.facebook.presto.spi.type.ArrayType; import com.facebook.presto.sql.gen.ExpressionCompiler; import com.facebook.presto.sql.relational.CallExpression; import com.facebook.presto.sql.relational.RowExpression; -import com.facebook.presto.type.ArrayType; import com.google.common.collect.ImmutableList; import io.airlift.slice.Slices; import org.openjdk.jmh.annotations.Benchmark; diff --git a/presto-main/src/test/java/com/facebook/presto/operator/scalar/BenchmarkArraySort.java b/presto-main/src/test/java/com/facebook/presto/operator/scalar/BenchmarkArraySort.java index 52d0512b3f82..97540fbab445 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/scalar/BenchmarkArraySort.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/scalar/BenchmarkArraySort.java @@ -23,11 +23,11 @@ import com.facebook.presto.spi.block.BlockBuilderStatus; import com.facebook.presto.spi.function.ScalarFunction; import com.facebook.presto.spi.function.SqlType; +import com.facebook.presto.spi.type.ArrayType; import com.facebook.presto.spi.type.Type; import com.facebook.presto.sql.gen.ExpressionCompiler; import com.facebook.presto.sql.relational.CallExpression; import com.facebook.presto.sql.relational.RowExpression; -import com.facebook.presto.type.ArrayType; import com.google.common.base.Verify; import com.google.common.collect.ImmutableList; import com.google.common.primitives.Ints; diff --git a/presto-main/src/test/java/com/facebook/presto/operator/scalar/BenchmarkArraySubscript.java b/presto-main/src/test/java/com/facebook/presto/operator/scalar/BenchmarkArraySubscript.java index 0a7b2e96e8c7..4bcaf9736944 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/scalar/BenchmarkArraySubscript.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/scalar/BenchmarkArraySubscript.java @@ -24,11 +24,11 @@ import com.facebook.presto.spi.block.BlockBuilderStatus; import com.facebook.presto.spi.block.DictionaryBlock; import com.facebook.presto.spi.block.SliceArrayBlock; +import com.facebook.presto.spi.type.ArrayType; import com.facebook.presto.spi.type.Type; import com.facebook.presto.sql.gen.ExpressionCompiler; import com.facebook.presto.sql.relational.CallExpression; import com.facebook.presto.sql.relational.RowExpression; -import com.facebook.presto.type.ArrayType; import com.google.common.collect.ImmutableList; import io.airlift.slice.Slice; import org.openjdk.jmh.annotations.Benchmark; @@ -203,7 +203,7 @@ private static Block createDictionaryValueBlock(int positionCount, int mapSize) for (int i = 0; i < keyIds.length; i++) { keyIds[i] = ThreadLocalRandom.current().nextInt(0, dictionarySize); } - return new DictionaryBlock(positionCount * mapSize, dictionaryBlock, keyIds); + return new DictionaryBlock(dictionaryBlock, keyIds); } private static String randomString(int length) diff --git a/presto-main/src/test/java/com/facebook/presto/operator/scalar/BenchmarkArrayTransform.java b/presto-main/src/test/java/com/facebook/presto/operator/scalar/BenchmarkArrayTransform.java index ea6f4bc81a9c..d45baad21c90 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/scalar/BenchmarkArrayTransform.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/scalar/BenchmarkArrayTransform.java @@ -22,6 +22,7 @@ import com.facebook.presto.spi.block.Block; import com.facebook.presto.spi.block.BlockBuilder; import com.facebook.presto.spi.block.BlockBuilderStatus; +import com.facebook.presto.spi.type.ArrayType; import com.facebook.presto.spi.type.Type; import com.facebook.presto.sql.gen.ExpressionCompiler; import com.facebook.presto.sql.relational.CallExpression; @@ -30,7 +31,6 @@ import com.facebook.presto.sql.relational.LambdaDefinitionExpression; import com.facebook.presto.sql.relational.RowExpression; import com.facebook.presto.sql.relational.VariableReferenceExpression; -import com.facebook.presto.type.ArrayType; import com.google.common.base.Verify; import com.google.common.collect.ImmutableList; import org.openjdk.jmh.annotations.Benchmark; diff --git a/presto-main/src/test/java/com/facebook/presto/operator/scalar/BenchmarkMapConcat.java b/presto-main/src/test/java/com/facebook/presto/operator/scalar/BenchmarkMapConcat.java index 3117496dce9f..15de523800bd 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/scalar/BenchmarkMapConcat.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/scalar/BenchmarkMapConcat.java @@ -25,10 +25,10 @@ import com.facebook.presto.spi.block.DictionaryBlock; import com.facebook.presto.spi.block.InterleavedBlock; import com.facebook.presto.spi.block.SliceArrayBlock; +import com.facebook.presto.spi.type.MapType; import com.facebook.presto.sql.gen.ExpressionCompiler; import com.facebook.presto.sql.relational.CallExpression; import com.facebook.presto.sql.relational.RowExpression; -import com.facebook.presto.type.MapType; import com.google.common.collect.ImmutableList; import io.airlift.slice.Slice; import org.openjdk.jmh.annotations.Benchmark; @@ -161,7 +161,7 @@ public Page getPage() private static Block createMapBlock(int positionCount, Block keyBlock, Block valueBlock) { - InterleavedBlock interleavedBlock = new InterleavedBlock(new Block[]{keyBlock, valueBlock}); + InterleavedBlock interleavedBlock = new InterleavedBlock(new Block[] {keyBlock, valueBlock}); int[] offsets = new int[positionCount + 1]; int mapSize = keyBlock.getPositionCount() / positionCount; for (int i = 0; i < offsets.length; i++) { @@ -177,7 +177,7 @@ private static Block createKeyBlock(int positionCount, List keys) for (int i = 0; i < keyIds.length; i++) { keyIds[i] = i % keys.size(); } - return new DictionaryBlock(positionCount * keys.size(), keyDictionaryBlock, keyIds); + return new DictionaryBlock(keyDictionaryBlock, keyIds); } private static Block createValueBlock(int positionCount, int mapSize) diff --git a/presto-main/src/test/java/com/facebook/presto/operator/scalar/BenchmarkMapSubscript.java b/presto-main/src/test/java/com/facebook/presto/operator/scalar/BenchmarkMapSubscript.java index 75328b204eee..63a128621fb8 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/scalar/BenchmarkMapSubscript.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/scalar/BenchmarkMapSubscript.java @@ -25,11 +25,11 @@ import com.facebook.presto.spi.block.DictionaryBlock; import com.facebook.presto.spi.block.InterleavedBlock; import com.facebook.presto.spi.block.SliceArrayBlock; +import com.facebook.presto.spi.type.MapType; import com.facebook.presto.spi.type.Type; import com.facebook.presto.sql.gen.ExpressionCompiler; import com.facebook.presto.sql.relational.CallExpression; import com.facebook.presto.sql.relational.RowExpression; -import com.facebook.presto.type.MapType; import com.google.common.collect.ImmutableList; import io.airlift.slice.Slice; import org.openjdk.jmh.annotations.Benchmark; @@ -188,7 +188,7 @@ private static Block createKeyBlock(int positionCount, List keys) for (int i = 0; i < keyIds.length; i++) { keyIds[i] = i % keys.size(); } - return new DictionaryBlock(positionCount * keys.size(), keyDictionaryBlock, keyIds); + return new DictionaryBlock(keyDictionaryBlock, keyIds); } private static Block createFixWidthValueBlock(int positionCount, int mapSize) @@ -227,7 +227,7 @@ private static Block createDictionaryValueBlock(int positionCount, int mapSize) for (int i = 0; i < keyIds.length; i++) { keyIds[i] = ThreadLocalRandom.current().nextInt(0, dictionarySize); } - return new DictionaryBlock(positionCount * mapSize, dictionaryBlock, keyIds); + return new DictionaryBlock(dictionaryBlock, keyIds); } private static String randomString(int length) diff --git a/presto-main/src/test/java/com/facebook/presto/operator/scalar/BenchmarkTransformKey.java b/presto-main/src/test/java/com/facebook/presto/operator/scalar/BenchmarkTransformKey.java index d328e7ac3b5b..7c126d87c2ea 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/scalar/BenchmarkTransformKey.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/scalar/BenchmarkTransformKey.java @@ -23,12 +23,12 @@ import com.facebook.presto.spi.block.BlockBuilder; import com.facebook.presto.spi.block.BlockBuilderStatus; import com.facebook.presto.spi.block.InterleavedBlockBuilder; +import com.facebook.presto.spi.type.MapType; import com.facebook.presto.spi.type.Type; import com.facebook.presto.sql.gen.ExpressionCompiler; import com.facebook.presto.sql.relational.LambdaDefinitionExpression; import com.facebook.presto.sql.relational.RowExpression; import com.facebook.presto.sql.relational.VariableReferenceExpression; -import com.facebook.presto.type.MapType; import com.google.common.collect.ImmutableList; import org.openjdk.jmh.annotations.Benchmark; import org.openjdk.jmh.annotations.BenchmarkMode; diff --git a/presto-main/src/test/java/com/facebook/presto/operator/scalar/BenchmarkTransformValue.java b/presto-main/src/test/java/com/facebook/presto/operator/scalar/BenchmarkTransformValue.java index ea7183d4ee4c..cc4364570e80 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/scalar/BenchmarkTransformValue.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/scalar/BenchmarkTransformValue.java @@ -23,12 +23,12 @@ import com.facebook.presto.spi.block.BlockBuilder; import com.facebook.presto.spi.block.BlockBuilderStatus; import com.facebook.presto.spi.block.InterleavedBlockBuilder; +import com.facebook.presto.spi.type.MapType; import com.facebook.presto.spi.type.Type; import com.facebook.presto.sql.gen.ExpressionCompiler; import com.facebook.presto.sql.relational.LambdaDefinitionExpression; import com.facebook.presto.sql.relational.RowExpression; import com.facebook.presto.sql.relational.VariableReferenceExpression; -import com.facebook.presto.type.MapType; import com.google.common.collect.ImmutableList; import io.airlift.slice.Slices; import org.openjdk.jmh.annotations.Benchmark; diff --git a/presto-main/src/test/java/com/facebook/presto/operator/scalar/FunctionAssertions.java b/presto-main/src/test/java/com/facebook/presto/operator/scalar/FunctionAssertions.java index 2da00738c010..3c0e6415ec36 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/scalar/FunctionAssertions.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/scalar/FunctionAssertions.java @@ -31,6 +31,7 @@ import com.facebook.presto.operator.project.InterpretedPageProjection; import com.facebook.presto.operator.project.PageFilter; import com.facebook.presto.operator.project.PageProcessor; +import com.facebook.presto.operator.project.PageProcessorOutput; import com.facebook.presto.operator.project.PageProjection; import com.facebook.presto.spi.ColumnHandle; import com.facebook.presto.spi.ConnectorPageSource; @@ -39,6 +40,7 @@ import com.facebook.presto.spi.HostAddress; import com.facebook.presto.spi.InMemoryRecordSet; import com.facebook.presto.spi.Page; +import com.facebook.presto.spi.PageBuilder; import com.facebook.presto.spi.RecordPageSource; import com.facebook.presto.spi.RecordSet; import com.facebook.presto.spi.block.Block; @@ -67,14 +69,19 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.Iterables; +import com.google.common.collect.Iterators; import com.google.common.util.concurrent.UncheckedExecutionException; import io.airlift.slice.Slice; import io.airlift.slice.Slices; import org.joda.time.DateTime; import org.joda.time.DateTimeZone; +import org.openjdk.jol.info.ClassLayout; import java.io.Closeable; +import java.lang.reflect.Field; +import java.lang.reflect.Modifier; import java.util.ArrayList; +import java.util.Arrays; import java.util.HashSet; import java.util.List; import java.util.Map; @@ -109,13 +116,16 @@ import static com.facebook.presto.sql.tree.BooleanLiteral.TRUE_LITERAL; import static com.facebook.presto.testing.TestingTaskContext.createTaskContext; import static io.airlift.concurrent.Threads.daemonThreadsNamed; +import static io.airlift.slice.SizeOf.sizeOf; import static io.airlift.testing.Assertions.assertInstanceOf; +import static java.lang.String.format; import static java.util.Objects.requireNonNull; import static java.util.concurrent.Executors.newCachedThreadPool; import static org.testng.Assert.assertEquals; import static org.testng.Assert.assertNotNull; import static org.testng.Assert.assertNull; import static org.testng.Assert.assertTrue; +import static org.testng.Assert.fail; public final class FunctionAssertions implements Closeable @@ -247,11 +257,21 @@ private void tryEvaluate(String expression, Type expectedType, Session session) selectUniqueValue(expression, expectedType, session, compiler); } + public void tryEvaluateWithAll(String expression, Type expectedType) + { + tryEvaluateWithAll(expression, expectedType, session); + } + public void tryEvaluateWithAll(String expression, Type expectedType, Session session) { executeProjectionWithAll(expression, expectedType, session, compiler); } + public void executeProjectionWithFullEngine(String projection) + { + MaterializedResult result = runner.execute("SELECT " + projection); + } + private Object selectSingleValue(String projection, Type expectedType, ExpressionCompiler compiler) { return selectUniqueValue(projection, expectedType, session, compiler); @@ -268,6 +288,117 @@ private Object selectUniqueValue(String projection, Type expectedType, Session s return Iterables.getOnlyElement(resultSet); } + public void assertCachedInstanceHasBoundedRetainedSize(String projection) + { + requireNonNull(projection, "projection is null"); + + Expression projectionExpression = createExpression(projection, metadata, SYMBOL_TYPES); + RowExpression projectionRowExpression = toRowExpression(projectionExpression); + PageProcessor processor = compiler.compilePageProcessor(Optional.empty(), ImmutableList.of(projectionRowExpression)).get(); + + // This is a heuristic to detect whether the retained size of cachedInstance is bounded. + // * The test runs at least 1000 iterations. + // * The test passes if max retained size doesn't refresh after + // 4x the number of iterations when max was last updated. + // * The test fails if retained size reaches 1MB. + // Note that 1MB is arbitrarily chosen and may be increased if a function implementation + // legitimately needs more. + + long maxRetainedSize = 0; + int maxIterationCount = 0; + for (int iterationCount = 0; iterationCount < Math.max(1000, maxIterationCount * 4); iterationCount++) { + PageProcessorOutput output = processor.process(session.toConnectorSession(), SOURCE_PAGE); + // consume the iterator + Iterators.getOnlyElement(output); + + long retainedSize = processor.getProjections().stream() + .mapToLong(this::getRetainedSizeOfCachedInstance) + .sum(); + if (retainedSize > maxRetainedSize) { + maxRetainedSize = retainedSize; + maxIterationCount = iterationCount; + } + + if (maxRetainedSize >= 1048576) { + fail(format("The retained size of cached instance of function invocation is likely unbounded: %s", projection)); + } + } + } + + private long getRetainedSizeOfCachedInstance(PageProjection projection) + { + Field[] fields = projection.getClass().getDeclaredFields(); + long retainedSize = 0; + for (Field field : fields) { + field.setAccessible(true); + String fieldName = field.getName(); + if (!fieldName.startsWith("__cachedInstance")) { + continue; + } + try { + retainedSize += getRetainedSizeOf(field.get(projection)); + } + catch (IllegalAccessException e) { + throw new RuntimeException(e); + } + } + return retainedSize; + } + + private long getRetainedSizeOf(Object object) + { + if (object instanceof PageBuilder) { + return ((PageBuilder) object).getRetainedSizeInBytes(); + } + if (object instanceof Block) { + return ((Block) object).getRetainedSizeInBytes(); + } + + Class type = object.getClass(); + if (type.isArray()) { + if (type == int[].class) { + return sizeOf((int[]) object); + } + else if (type == boolean[].class) { + return sizeOf((boolean[]) object); + } + else if (type == byte[].class) { + return sizeOf((byte[]) object); + } + else if (type == long[].class) { + return sizeOf((long[]) object); + } + else if (type == short[].class) { + return sizeOf((short[]) object); + } + else if (type == Block[].class) { + Object[] objects = (Object[]) object; + return Arrays.stream(objects) + .mapToLong(this::getRetainedSizeOf) + .sum(); + } + else { + throw new IllegalArgumentException(format("Unknown type encountered: %s", type)); + } + } + + long retainedSize = ClassLayout.parseClass(type).instanceSize(); + Field[] fields = type.getDeclaredFields(); + for (Field field : fields) { + try { + if (field.getType().isPrimitive() || Modifier.isStatic(field.getModifiers())) { + continue; + } + field.setAccessible(true); + retainedSize += getRetainedSizeOf(field.get(object)); + } + catch (IllegalAccessException t) { + throw new RuntimeException(t); + } + } + return retainedSize; + } + private List executeProjectionWithAll(String projection, Type expectedType, Session session, ExpressionCompiler compiler) { requireNonNull(projection, "projection is null"); diff --git a/presto-main/src/test/java/com/facebook/presto/operator/scalar/TestArrayExceptFunction.java b/presto-main/src/test/java/com/facebook/presto/operator/scalar/TestArrayExceptFunction.java index eb752d712fc5..8fca03b64875 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/scalar/TestArrayExceptFunction.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/scalar/TestArrayExceptFunction.java @@ -13,7 +13,7 @@ */ package com.facebook.presto.operator.scalar; -import com.facebook.presto.type.ArrayType; +import com.facebook.presto.spi.type.ArrayType; import com.google.common.collect.ImmutableList; import org.testng.annotations.Test; diff --git a/presto-main/src/test/java/com/facebook/presto/operator/scalar/TestArrayFilterFunction.java b/presto-main/src/test/java/com/facebook/presto/operator/scalar/TestArrayFilterFunction.java index b853532e9bda..299c35cad730 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/scalar/TestArrayFilterFunction.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/scalar/TestArrayFilterFunction.java @@ -13,7 +13,7 @@ */ package com.facebook.presto.operator.scalar; -import com.facebook.presto.type.ArrayType; +import com.facebook.presto.spi.type.ArrayType; import com.google.common.collect.ImmutableList; import org.testng.annotations.Test; diff --git a/presto-main/src/test/java/com/facebook/presto/operator/scalar/TestArrayFunctions.java b/presto-main/src/test/java/com/facebook/presto/operator/scalar/TestArrayFunctions.java new file mode 100644 index 000000000000..0465499f18f6 --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/operator/scalar/TestArrayFunctions.java @@ -0,0 +1,44 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.operator.scalar; + +import com.facebook.presto.spi.type.ArrayType; +import com.google.common.base.Joiner; +import org.testng.annotations.Test; + +import static com.facebook.presto.spi.type.DoubleType.DOUBLE; +import static com.facebook.presto.spi.type.IntegerType.INTEGER; +import static java.util.Collections.nCopies; + +public class TestArrayFunctions + extends AbstractTestFunctions +{ + @Test + public void testArrayConstructor() + { + tryEvaluateWithAll("array[" + Joiner.on(", ").join(nCopies(254, "rand()")) + "]", new ArrayType(DOUBLE)); + assertNotSupported( + "array[" + Joiner.on(", ").join(nCopies(255, "rand()")) + "]", + "Too many arguments for array constructor"); + } + + @Test + public void testArrayConcat() + { + assertFunction("CONCAT(" + Joiner.on(", ").join(nCopies(253, "array[1]")) + ")", new ArrayType(INTEGER), nCopies(253, 1)); + assertNotSupported( + "CONCAT(" + Joiner.on(", ").join(nCopies(254, "array[1]")) + ")", + "Too many arguments for vararg function"); + } +} diff --git a/presto-main/src/test/java/com/facebook/presto/operator/scalar/TestArrayReduceFunction.java b/presto-main/src/test/java/com/facebook/presto/operator/scalar/TestArrayReduceFunction.java index e11d3225f3b4..86dd9cbc8949 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/scalar/TestArrayReduceFunction.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/scalar/TestArrayReduceFunction.java @@ -13,7 +13,7 @@ */ package com.facebook.presto.operator.scalar; -import com.facebook.presto.type.ArrayType; +import com.facebook.presto.spi.type.ArrayType; import org.testng.annotations.Test; import static com.facebook.presto.spi.type.BigintType.BIGINT; diff --git a/presto-main/src/test/java/com/facebook/presto/operator/scalar/TestArrayTransformFunction.java b/presto-main/src/test/java/com/facebook/presto/operator/scalar/TestArrayTransformFunction.java index cb3317a60be8..19801f9bc97d 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/scalar/TestArrayTransformFunction.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/scalar/TestArrayTransformFunction.java @@ -13,8 +13,8 @@ */ package com.facebook.presto.operator.scalar; -import com.facebook.presto.type.ArrayType; -import com.facebook.presto.type.RowType; +import com.facebook.presto.spi.type.ArrayType; +import com.facebook.presto.spi.type.RowType; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import org.testng.annotations.Test; diff --git a/presto-main/src/test/java/com/facebook/presto/operator/scalar/TestLambdaExpression.java b/presto-main/src/test/java/com/facebook/presto/operator/scalar/TestLambdaExpression.java index 845344e83d63..0223ef04ee04 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/scalar/TestLambdaExpression.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/scalar/TestLambdaExpression.java @@ -14,8 +14,8 @@ package com.facebook.presto.operator.scalar; import com.facebook.presto.Session; -import com.facebook.presto.type.ArrayType; -import com.facebook.presto.type.RowType; +import com.facebook.presto.spi.type.ArrayType; +import com.facebook.presto.spi.type.RowType; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import org.testng.annotations.BeforeClass; diff --git a/presto-main/src/test/java/com/facebook/presto/operator/scalar/TestMapFilterFunction.java b/presto-main/src/test/java/com/facebook/presto/operator/scalar/TestMapFilterFunction.java index 28473976a0bb..54c7f85e5e26 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/scalar/TestMapFilterFunction.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/scalar/TestMapFilterFunction.java @@ -13,7 +13,7 @@ */ package com.facebook.presto.operator.scalar; -import com.facebook.presto.type.ArrayType; +import com.facebook.presto.spi.type.ArrayType; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import org.testng.annotations.Test; diff --git a/presto-main/src/test/java/com/facebook/presto/operator/scalar/TestMapTransformKeyFunction.java b/presto-main/src/test/java/com/facebook/presto/operator/scalar/TestMapTransformKeyFunction.java index e76645271b66..c4af0a98d8e9 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/scalar/TestMapTransformKeyFunction.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/scalar/TestMapTransformKeyFunction.java @@ -13,7 +13,7 @@ */ package com.facebook.presto.operator.scalar; -import com.facebook.presto.type.ArrayType; +import com.facebook.presto.spi.type.ArrayType; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import org.testng.annotations.Test; @@ -33,6 +33,12 @@ public class TestMapTransformKeyFunction extends AbstractTestFunctions { + @Test + public void testRetainedSizeBounded() + { + assertCachedInstanceHasBoundedRetainedSize("transform_keys(map(ARRAY [1, 2, 3, 4], ARRAY [10, 20, 30, 40]), (k, v) -> k + v)"); + } + @Test public void testEmpty() throws Exception diff --git a/presto-main/src/test/java/com/facebook/presto/operator/scalar/TestMapTransformValueFunction.java b/presto-main/src/test/java/com/facebook/presto/operator/scalar/TestMapTransformValueFunction.java index 7bac063a6b9a..db2216b1c716 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/scalar/TestMapTransformValueFunction.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/scalar/TestMapTransformValueFunction.java @@ -13,7 +13,7 @@ */ package com.facebook.presto.operator.scalar; -import com.facebook.presto.type.ArrayType; +import com.facebook.presto.spi.type.ArrayType; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import org.testng.annotations.Test; @@ -33,6 +33,12 @@ public class TestMapTransformValueFunction extends AbstractTestFunctions { + @Test + public void testRetainedSizeBounded() + { + assertCachedInstanceHasBoundedRetainedSize("transform_values(map(ARRAY [25, 26, 27], ARRAY [25, 26, 27]), (k, v) -> k + v)"); + } + @Test public void testEmpty() throws Exception diff --git a/presto-main/src/test/java/com/facebook/presto/operator/scalar/TestMathFunctions.java b/presto-main/src/test/java/com/facebook/presto/operator/scalar/TestMathFunctions.java index 01d33b135812..5d67a1d9565f 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/scalar/TestMathFunctions.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/scalar/TestMathFunctions.java @@ -17,6 +17,7 @@ import com.facebook.presto.spi.type.DecimalType; import com.facebook.presto.spi.type.SqlDecimal; import com.facebook.presto.spi.type.VarcharType; +import com.google.common.base.Joiner; import org.testng.annotations.Test; import static com.facebook.presto.SessionTestUtils.TEST_SESSION; @@ -30,6 +31,7 @@ import static com.facebook.presto.spi.type.RealType.REAL; import static com.facebook.presto.spi.type.SmallintType.SMALLINT; import static com.facebook.presto.spi.type.TinyintType.TINYINT; +import static java.util.Collections.nCopies; public class TestMathFunctions extends AbstractTestFunctions @@ -1095,6 +1097,12 @@ public void testGreatest() // invalid assertInvalidFunction("greatest(1.5, 0.0 / 0.0)", "Invalid argument to greatest(): NaN"); + + // argument count limit + tryEvaluateWithAll("greatest(" + Joiner.on(", ").join(nCopies(127, "rand()")) + ")", DOUBLE); + assertNotSupported( + "greatest(" + Joiner.on(", ").join(nCopies(128, "rand()")) + ")", + "Too many arguments for function call greatest()"); } @Test diff --git a/presto-main/src/test/java/com/facebook/presto/operator/scalar/TestPageProcessorCompiler.java b/presto-main/src/test/java/com/facebook/presto/operator/scalar/TestPageProcessorCompiler.java index 99ed596745a0..70f686aacf06 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/scalar/TestPageProcessorCompiler.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/scalar/TestPageProcessorCompiler.java @@ -20,13 +20,13 @@ import com.facebook.presto.spi.block.DictionaryBlock; import com.facebook.presto.spi.block.RunLengthEncodedBlock; import com.facebook.presto.spi.block.SliceArrayBlock; +import com.facebook.presto.spi.type.ArrayType; import com.facebook.presto.spi.type.StandardTypes; import com.facebook.presto.sql.gen.ExpressionCompiler; import com.facebook.presto.sql.relational.CallExpression; import com.facebook.presto.sql.relational.DeterminismEvaluator; import com.facebook.presto.sql.relational.InputReferenceExpression; import com.facebook.presto.sql.relational.RowExpression; -import com.facebook.presto.type.ArrayType; import com.google.common.collect.ImmutableList; import io.airlift.slice.DynamicSliceOutput; import io.airlift.slice.Slice; @@ -192,7 +192,7 @@ private static DictionaryBlock createDictionaryBlock(Slice[] expectedValues, int for (int i = 0; i < positionCount; i++) { ids[i] = i % dictionarySize; } - return new DictionaryBlock(positionCount, new SliceArrayBlock(dictionarySize, expectedValues), ids); + return new DictionaryBlock(new SliceArrayBlock(dictionarySize, expectedValues), ids); } private static Slice[] createExpectedValues(int positionCount) diff --git a/presto-main/src/test/java/com/facebook/presto/operator/scalar/TestRegexpFunctions.java b/presto-main/src/test/java/com/facebook/presto/operator/scalar/TestRegexpFunctions.java index aebd9697329c..9d1b9bcafd42 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/scalar/TestRegexpFunctions.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/scalar/TestRegexpFunctions.java @@ -15,9 +15,9 @@ import com.facebook.presto.spi.function.ScalarFunction; import com.facebook.presto.spi.function.SqlType; +import com.facebook.presto.spi.type.ArrayType; import com.facebook.presto.spi.type.StandardTypes; import com.facebook.presto.sql.analyzer.FeaturesConfig; -import com.facebook.presto.type.ArrayType; import com.google.common.collect.ImmutableList; import io.airlift.slice.Slice; import io.airlift.slice.Slices; diff --git a/presto-main/src/test/java/com/facebook/presto/operator/scalar/TestStringFunctions.java b/presto-main/src/test/java/com/facebook/presto/operator/scalar/TestStringFunctions.java index 0479a7caad64..81351ca2d721 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/scalar/TestStringFunctions.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/scalar/TestStringFunctions.java @@ -17,11 +17,12 @@ import com.facebook.presto.spi.function.LiteralParameters; import com.facebook.presto.spi.function.ScalarFunction; import com.facebook.presto.spi.function.SqlType; +import com.facebook.presto.spi.type.ArrayType; +import com.facebook.presto.spi.type.MapType; import com.facebook.presto.spi.type.SqlVarbinary; import com.facebook.presto.spi.type.StandardTypes; -import com.facebook.presto.type.ArrayType; import com.facebook.presto.type.LiteralParameter; -import com.facebook.presto.type.MapType; +import com.google.common.base.Joiner; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import io.airlift.slice.Slice; @@ -40,6 +41,7 @@ import static com.facebook.presto.util.StructuralTestUtil.mapType; import static com.google.common.base.Strings.repeat; import static java.lang.String.format; +import static java.util.Collections.nCopies; public class TestStringFunctions extends AbstractTestFunctions @@ -101,7 +103,6 @@ public void testCodepoint() @Test public void testConcat() { - assertInvalidFunction("CONCAT()", "There must be two or more concatenation arguments"); assertInvalidFunction("CONCAT('')", "There must be two or more concatenation arguments"); assertFunction("CONCAT('hello', ' world')", VARCHAR, "hello world"); assertFunction("CONCAT('', '')", VARCHAR, ""); @@ -109,12 +110,16 @@ public void testConcat() assertFunction("CONCAT('', 'what')", VARCHAR, "what"); assertFunction("CONCAT(CONCAT('this', ' is'), ' cool')", VARCHAR, "this is cool"); assertFunction("CONCAT('this', CONCAT(' is', ' cool'))", VARCHAR, "this is cool"); - // + // Test concat for non-ASCII assertFunction("CONCAT('hello na\u00EFve', ' world')", VARCHAR, "hello na\u00EFve world"); assertFunction("CONCAT('\uD801\uDC2D', 'end')", VARCHAR, "\uD801\uDC2Dend"); assertFunction("CONCAT('\uD801\uDC2D', 'end', '\uD801\uDC2D')", VARCHAR, "\uD801\uDC2Dend\uD801\uDC2D"); assertFunction("CONCAT(CONCAT('\u4FE1\u5FF5', ',\u7231'), ',\u5E0C\u671B')", VARCHAR, "\u4FE1\u5FF5,\u7231,\u5E0C\u671B"); + + // Test argument count limit + assertFunction("CONCAT(" + Joiner.on(", ").join(nCopies(254, "'1'")) + ")", VARCHAR, Joiner.on("").join(nCopies(254, "1"))); + assertNotSupported("CONCAT(" + Joiner.on(", ").join(nCopies(255, "'1'")) + ")", "Too many arguments for string concatenation"); } @Test diff --git a/presto-main/src/test/java/com/facebook/presto/operator/scalar/TestVarbinaryFunctions.java b/presto-main/src/test/java/com/facebook/presto/operator/scalar/TestVarbinaryFunctions.java index ef6b4af9f051..31867ca27031 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/scalar/TestVarbinaryFunctions.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/scalar/TestVarbinaryFunctions.java @@ -48,7 +48,7 @@ public class TestVarbinaryFunctions public void testBinaryLiteral() throws Exception { - assertFunction("X'58F7'", VARBINARY, new SqlVarbinary(new byte[]{(byte) 0x58, (byte) 0xF7})); + assertFunction("X'58F7'", VARBINARY, sqlVarbinaryHex("58F7")); } @Test @@ -60,6 +60,42 @@ public void testLength() assertFunction("length(CAST('abc' AS VARBINARY))", BIGINT, 3L); } + @Test + public void testConcat() + throws Exception + { + assertInvalidFunction("CONCAT(X'')", "There must be two or more concatenation arguments"); + + assertFunction("CAST('foo' AS VARBINARY) || CAST ('bar' AS VARBINARY)", VARBINARY, sqlVarbinary("foo" + "bar")); + assertFunction("CAST('foo' AS VARBINARY) || CAST ('bar' AS VARBINARY) || CAST ('baz' AS VARBINARY)", VARBINARY, sqlVarbinary("foo" + "bar" + "baz")); + assertFunction("CAST(' foo ' AS VARBINARY) || CAST (' bar ' AS VARBINARY) || CAST (' baz ' AS VARBINARY)", VARBINARY, sqlVarbinary(" foo " + " bar " + " baz ")); + assertFunction("CAST('foo' AS VARBINARY) || CAST ('bar' AS VARBINARY) || CAST ('bazbaz' AS VARBINARY)", VARBINARY, sqlVarbinary("foo" + "bar" + "bazbaz")); + + assertFunction("X'000102' || X'AAABAC' || X'FDFEFF'", VARBINARY, sqlVarbinaryHex("000102" + "AAABAC" + "FDFEFF")); + assertFunction("X'CAFFEE' || X'F7' || X'DE58'", VARBINARY, sqlVarbinaryHex("CAFFEE" + "F7" + "DE58")); + + assertFunction("X'58' || X'F7'", VARBINARY, sqlVarbinaryHex("58F7")); + assertFunction("X'' || X'58' || X'F7'", VARBINARY, sqlVarbinaryHex("58F7")); + assertFunction("X'58' || X'' || X'F7'", VARBINARY, sqlVarbinaryHex("58F7")); + assertFunction("X'58' || X'F7' || X''", VARBINARY, sqlVarbinaryHex("58F7")); + assertFunction("X'' || X'58' || X'' || X'F7' || X''", VARBINARY, sqlVarbinaryHex("58F7")); + assertFunction("X'' || X'' || X'' || X'' || X'' || X''", VARBINARY, sqlVarbinaryHex("")); + + assertFunction("CONCAT(CAST('foo' AS VARBINARY), CAST ('bar' AS VARBINARY))", VARBINARY, sqlVarbinary("foo" + "bar")); + assertFunction("CONCAT(CAST('foo' AS VARBINARY), CAST ('bar' AS VARBINARY), CAST ('baz' AS VARBINARY))", VARBINARY, sqlVarbinary("foo" + "bar" + "baz")); + assertFunction("CONCAT(CAST('foo' AS VARBINARY), CAST ('bar' AS VARBINARY), CAST ('bazbaz' AS VARBINARY))", VARBINARY, sqlVarbinary("foo" + "bar" + "bazbaz")); + + assertFunction("CONCAT(X'000102', X'AAABAC', X'FDFEFF')", VARBINARY, sqlVarbinaryHex("000102" + "AAABAC" + "FDFEFF")); + assertFunction("CONCAT(X'CAFFEE', X'F7', X'DE58')", VARBINARY, sqlVarbinaryHex("CAFFEE" + "F7" + "DE58")); + + assertFunction("CONCAT(X'58', X'F7')", VARBINARY, sqlVarbinaryHex("58F7")); + assertFunction("CONCAT(X'', X'58', X'F7')", VARBINARY, sqlVarbinaryHex("58F7")); + assertFunction("CONCAT(X'58', X'', X'F7')", VARBINARY, sqlVarbinaryHex("58F7")); + assertFunction("CONCAT(X'58', X'F7', X'')", VARBINARY, sqlVarbinaryHex("58F7")); + assertFunction("CONCAT(X'', X'58', X'', X'F7', X'')", VARBINARY, sqlVarbinaryHex("58F7")); + assertFunction("CONCAT(X'', X'', X'', X'', X'', X'')", VARBINARY, sqlVarbinaryHex("")); + } + @Test public void testToBase64() throws Exception diff --git a/presto-main/src/test/java/com/facebook/presto/operator/scalar/TestZipFunction.java b/presto-main/src/test/java/com/facebook/presto/operator/scalar/TestZipFunction.java index 687e15367909..a61c042558fb 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/scalar/TestZipFunction.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/scalar/TestZipFunction.java @@ -13,9 +13,9 @@ */ package com.facebook.presto.operator.scalar; +import com.facebook.presto.spi.type.ArrayType; +import com.facebook.presto.spi.type.RowType; import com.facebook.presto.spi.type.Type; -import com.facebook.presto.type.ArrayType; -import com.facebook.presto.type.RowType; import org.testng.annotations.Test; import java.util.List; diff --git a/presto-main/src/test/java/com/facebook/presto/operator/scalar/TestZipWithFunction.java b/presto-main/src/test/java/com/facebook/presto/operator/scalar/TestZipWithFunction.java index 7a9e3c006f38..dd9a177bf6c2 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/scalar/TestZipWithFunction.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/scalar/TestZipWithFunction.java @@ -13,8 +13,8 @@ */ package com.facebook.presto.operator.scalar; -import com.facebook.presto.type.ArrayType; -import com.facebook.presto.type.RowType; +import com.facebook.presto.spi.type.ArrayType; +import com.facebook.presto.spi.type.RowType; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import org.testng.annotations.Test; diff --git a/presto-main/src/test/java/com/facebook/presto/security/TestAccessControlManager.java b/presto-main/src/test/java/com/facebook/presto/security/TestAccessControlManager.java index 877c6b678cb5..87bd8603eeb8 100644 --- a/presto-main/src/test/java/com/facebook/presto/security/TestAccessControlManager.java +++ b/presto-main/src/test/java/com/facebook/presto/security/TestAccessControlManager.java @@ -342,6 +342,12 @@ public void checkCanAddColumn(ConnectorTransactionHandle transactionHandle, Iden throw new UnsupportedOperationException(); } + @Override + public void checkCanDropColumn(ConnectorTransactionHandle transactionHandle, Identity identity, SchemaTableName tableName) + { + throw new UnsupportedOperationException(); + } + @Override public void checkCanRenameColumn(ConnectorTransactionHandle transactionHandle, Identity identity, SchemaTableName tableName) { diff --git a/presto-main/src/test/java/com/facebook/presto/security/TestFileBasedSystemAccessControl.java b/presto-main/src/test/java/com/facebook/presto/security/TestFileBasedSystemAccessControl.java index 5f038491d7f3..7729bce92ce3 100644 --- a/presto-main/src/test/java/com/facebook/presto/security/TestFileBasedSystemAccessControl.java +++ b/presto-main/src/test/java/com/facebook/presto/security/TestFileBasedSystemAccessControl.java @@ -42,12 +42,12 @@ public class TestFileBasedSystemAccessControl private static final QualifiedObjectName aliceTable = new QualifiedObjectName("alice-catalog", "schema", "table"); private static final QualifiedObjectName aliceView = new QualifiedObjectName("alice-catalog", "schema", "view"); private static final CatalogSchemaName aliceSchema = new CatalogSchemaName("alice-catalog", "schema"); - private TransactionManager transactionManager; @Test public void testCatalogOperations() { - AccessControlManager accessControlManager = newAccessControlManager(); + TransactionManager transactionManager = createTestTransactionManager(); + AccessControlManager accessControlManager = newAccessControlManager(transactionManager); transaction(transactionManager, accessControlManager) .execute(transactionId -> { @@ -64,7 +64,8 @@ public void testCatalogOperations() @Test public void testSchemaOperations() { - AccessControlManager accessControlManager = newAccessControlManager(); + TransactionManager transactionManager = createTestTransactionManager(); + AccessControlManager accessControlManager = newAccessControlManager(transactionManager); transaction(transactionManager, accessControlManager) .execute(transactionId -> { @@ -85,7 +86,8 @@ public void testSchemaOperations() @Test public void testTableOperations() { - AccessControlManager accessControlManager = newAccessControlManager(); + TransactionManager transactionManager = createTestTransactionManager(); + AccessControlManager accessControlManager = newAccessControlManager(transactionManager); transaction(transactionManager, accessControlManager) .execute(transactionId -> { @@ -110,7 +112,8 @@ public void testTableOperations() public void testViewOperations() throws Exception { - AccessControlManager accessControlManager = newAccessControlManager(); + TransactionManager transactionManager = createTestTransactionManager(); + AccessControlManager accessControlManager = newAccessControlManager(transactionManager); transaction(transactionManager, accessControlManager) .execute(transactionId -> { @@ -128,10 +131,9 @@ public void testViewOperations() })); } - private AccessControlManager newAccessControlManager() + private AccessControlManager newAccessControlManager(TransactionManager transactionManager) { - transactionManager = createTestTransactionManager(); - AccessControlManager accessControlManager = new AccessControlManager(transactionManager); + AccessControlManager accessControlManager = new AccessControlManager(transactionManager); String path = this.getClass().getClassLoader().getResource("catalog.json").getPath(); accessControlManager.setSystemAccessControl(FileBasedSystemAccessControl.NAME, ImmutableMap.of("security.config-file", path)); diff --git a/presto-main/src/test/java/com/facebook/presto/server/TestQueryStateInfo.java b/presto-main/src/test/java/com/facebook/presto/server/TestQueryStateInfo.java index d7c8938d97c8..c85ea347547e 100644 --- a/presto-main/src/test/java/com/facebook/presto/server/TestQueryStateInfo.java +++ b/presto-main/src/test/java/com/facebook/presto/server/TestQueryStateInfo.java @@ -41,6 +41,8 @@ import static com.facebook.presto.spi.resourceGroups.ResourceGroupState.CAN_QUEUE; import static com.facebook.presto.spi.resourceGroups.ResourceGroupState.CAN_RUN; import static io.airlift.units.DataSize.Unit.BYTE; +import static java.util.concurrent.TimeUnit.DAYS; +import static java.util.concurrent.TimeUnit.HOURS; import static org.testng.Assert.assertEquals; import static org.testng.Assert.assertTrue; @@ -59,7 +61,9 @@ public void testQueryStateInfo() groupRootAX, new DataSize(6000, BYTE), 1, + null, 10, + null, CAN_QUEUE, 0, new DataSize(4000, BYTE), @@ -71,7 +75,9 @@ public void testQueryStateInfo() groupRootAY, new DataSize(8000, BYTE), 1, + new Duration(10, HOURS), 10, + new Duration(1, DAYS), CAN_RUN, 0, new DataSize(0, BYTE), @@ -83,7 +89,9 @@ public void testQueryStateInfo() groupRootA, new DataSize(8000, BYTE), 1, + null, 10, + null, CAN_QUEUE, 1, new DataSize(4000, BYTE), @@ -95,7 +103,9 @@ public void testQueryStateInfo() groupRootB, new DataSize(8000, BYTE), 1, + new Duration(10, HOURS), 10, + new Duration(1, DAYS), CAN_QUEUE, 0, new DataSize(4000, BYTE), @@ -107,7 +117,9 @@ public void testQueryStateInfo() new ResourceGroupId("root"), new DataSize(10000, BYTE), 2, + null, 20, + null, CAN_QUEUE, 0, new DataSize(6000, BYTE), diff --git a/presto-main/src/test/java/com/facebook/presto/server/TestResourceGroupStateInfo.java b/presto-main/src/test/java/com/facebook/presto/server/TestResourceGroupStateInfo.java new file mode 100644 index 000000000000..d37e30754a69 --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/server/TestResourceGroupStateInfo.java @@ -0,0 +1,136 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.server; + +import com.facebook.presto.spi.QueryId; +import com.facebook.presto.spi.resourceGroups.ResourceGroupId; +import com.facebook.presto.spi.resourceGroups.ResourceGroupInfo; +import com.google.common.collect.ImmutableList; +import io.airlift.json.JsonCodec; +import io.airlift.units.DataSize; +import io.airlift.units.Duration; +import org.joda.time.DateTime; +import org.testng.annotations.Test; + +import java.util.Optional; +import java.util.OptionalDouble; + +import static com.facebook.presto.execution.QueryState.RUNNING; +import static com.facebook.presto.spi.resourceGroups.ResourceGroupState.CAN_RUN; +import static io.airlift.units.DataSize.Unit.BYTE; +import static io.airlift.units.DataSize.Unit.GIGABYTE; +import static java.util.concurrent.TimeUnit.HOURS; +import static org.testng.Assert.assertEquals; + +public class TestResourceGroupStateInfo +{ + @Test + public void testJsonRoundTrip() + { + ResourceGroupId resourceGroupId = new ResourceGroupId(ImmutableList.of("test", "user")); + ResourceGroupId subGroupId = new ResourceGroupId(resourceGroupId, "sub"); + ResourceGroupStateInfo expected = new ResourceGroupStateInfo( + resourceGroupId, + CAN_RUN, + new DataSize(10, GIGABYTE), + new DataSize(100, BYTE), + 10, + 100, + new Duration(1, HOURS), + new Duration(10, HOURS), + ImmutableList.of(new QueryStateInfo( + new QueryId("test_query"), + RUNNING, + Optional.of(resourceGroupId), + "SELECT * FROM t", + DateTime.parse("2017-06-12T21:39:48.658Z"), + "test_user", + Optional.of("catalog"), + Optional.of("schema"), + Optional.empty(), + Optional.of(new QueryProgressStats( + DateTime.parse("2017-06-12T21:39:50.966Z"), + 150060, + 243, + 1541, + 566038, + 1680000, + 24, + 124539, + 8283750, + false, + OptionalDouble.empty())))), + 10, + ImmutableList.of(new ResourceGroupInfo( + subGroupId, + new DataSize(1, GIGABYTE), + 10, + new Duration(1, HOURS), + 100, + new Duration(10, HOURS), + CAN_RUN, + 1, + new DataSize(100, BYTE), + 1, + 10))); + JsonCodec codec = JsonCodec.jsonCodec(ResourceGroupStateInfo.class); + ResourceGroupStateInfo actual = codec.fromJson(codec.toJson(expected)); + + assertEquals(actual.getId(), resourceGroupId); + assertEquals(actual.getState(), CAN_RUN); + assertEquals(actual.getSoftMemoryLimit(), new DataSize(10, GIGABYTE)); + assertEquals(actual.getMemoryUsage(), new DataSize(100, BYTE)); + assertEquals(actual.getMaxRunningQueries(), 10); + assertEquals(actual.getRunningTimeLimit(), new Duration(1, HOURS)); + assertEquals(actual.getMaxQueuedQueries(), 100); + assertEquals(actual.getQueuedTimeLimit(), new Duration(10, HOURS)); + assertEquals(actual.getNumQueuedQueries(), 10); + assertEquals(actual.getRunningQueries().size(), 1); + QueryStateInfo queryStateInfo = actual.getRunningQueries().get(0); + assertEquals(queryStateInfo.getQueryId(), new QueryId("test_query")); + assertEquals(queryStateInfo.getQueryState(), RUNNING); + assertEquals(queryStateInfo.getResourceGroupId(), Optional.of(resourceGroupId)); + assertEquals(queryStateInfo.getQuery(), "SELECT * FROM t"); + assertEquals(queryStateInfo.getCreateTime(), DateTime.parse("2017-06-12T21:39:48.658Z")); + assertEquals(queryStateInfo.getUser(), "test_user"); + assertEquals(queryStateInfo.getCatalog(), Optional.of("catalog")); + assertEquals(queryStateInfo.getSchema(), Optional.of("schema")); + assertEquals(queryStateInfo.getResourceGroupChain(), Optional.empty()); + QueryProgressStats progressStats = queryStateInfo.getProgress().get(); + assertEquals(progressStats.getExecutionStartTime(), DateTime.parse("2017-06-12T21:39:50.966Z")); + assertEquals(progressStats.getElapsedTimeMillis(), 150060); + assertEquals(progressStats.getQueuedTimeMillis(), 243); + assertEquals(progressStats.getCpuTimeMillis(), 1541); + assertEquals(progressStats.getScheduledTimeMillis(), 566038); + assertEquals(progressStats.getBlockedTimeMillis(), 1680000); + assertEquals(progressStats.getPeakMemoryBytes(), 24); + assertEquals(progressStats.getInputRows(), 124539); + assertEquals(progressStats.getInputBytes(), 8283750); + assertEquals(progressStats.isBlocked(), false); + assertEquals(progressStats.getProgressPercentage(), OptionalDouble.empty()); + assertEquals(actual.getSubGroups().size(), 1); + ResourceGroupInfo subGroup = actual.getSubGroups().get(0); + assertEquals(subGroup.getId(), subGroupId); + assertEquals(subGroup.getSoftMemoryLimit(), new DataSize(1, GIGABYTE)); + assertEquals(subGroup.getMaxRunningQueries(), 10); + assertEquals(subGroup.getRunningTimeLimit(), new Duration(1, HOURS)); + assertEquals(subGroup.getMaxQueuedQueries(), 100); + assertEquals(subGroup.getQueuedTimeLimit(), new Duration(10, HOURS)); + assertEquals(subGroup.getState(), CAN_RUN); + assertEquals(subGroup.getNumEligibleSubGroups(), 1); + assertEquals(subGroup.getMemoryUsage(), new DataSize(100, BYTE)); + assertEquals(subGroup.getNumAggregatedRunningQueries(), 1); + assertEquals(subGroup.getNumAggregatedQueuedQueries(), 10); + } +} diff --git a/presto-main/src/test/java/com/facebook/presto/server/remotetask/TestHttpRemoteTask.java b/presto-main/src/test/java/com/facebook/presto/server/remotetask/TestHttpRemoteTask.java index d7cad9c29975..188ee2d1611a 100644 --- a/presto-main/src/test/java/com/facebook/presto/server/remotetask/TestHttpRemoteTask.java +++ b/presto-main/src/test/java/com/facebook/presto/server/remotetask/TestHttpRemoteTask.java @@ -33,6 +33,7 @@ import com.facebook.presto.spi.ErrorCode; import com.facebook.presto.spi.type.Type; import com.facebook.presto.spi.type.TypeManager; +import com.facebook.presto.sql.analyzer.FeaturesConfig; import com.facebook.presto.testing.TestingHandleResolver; import com.facebook.presto.type.TypeDeserializer; import com.facebook.presto.type.TypeRegistry; @@ -80,6 +81,7 @@ import static com.facebook.presto.testing.assertions.Assert.assertEquals; import static com.google.common.collect.Iterables.getOnlyElement; import static com.google.inject.multibindings.Multibinder.newSetBinder; +import static io.airlift.configuration.ConfigBinder.configBinder; import static io.airlift.json.JsonBinder.jsonBinder; import static io.airlift.json.JsonCodecBinder.jsonCodecBinder; import static java.lang.String.format; @@ -181,6 +183,7 @@ private static HttpRemoteTaskFactory createHttpRemoteTaskFactory(TestingTaskReso public void configure(Binder binder) { binder.bind(JsonMapper.class); + configBinder(binder).bindConfig(FeaturesConfig.class); binder.bind(TypeRegistry.class).in(Scopes.SINGLETON); binder.bind(TypeManager.class).to(TypeRegistry.class).in(Scopes.SINGLETON); jsonBinder(binder).addDeserializerBinding(Type.class).to(TypeDeserializer.class); diff --git a/presto-main/src/test/java/com/facebook/presto/sql/analyzer/TestAnalyzer.java b/presto-main/src/test/java/com/facebook/presto/sql/analyzer/TestAnalyzer.java index e8e19ecfad47..b73e056c9c5f 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/analyzer/TestAnalyzer.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/analyzer/TestAnalyzer.java @@ -40,13 +40,13 @@ import com.facebook.presto.spi.connector.ConnectorSplitManager; import com.facebook.presto.spi.connector.ConnectorTransactionHandle; import com.facebook.presto.spi.transaction.IsolationLevel; +import com.facebook.presto.spi.type.ArrayType; import com.facebook.presto.spi.type.TypeManager; import com.facebook.presto.sql.parser.SqlParser; import com.facebook.presto.sql.tree.NodeLocation; import com.facebook.presto.sql.tree.Statement; import com.facebook.presto.testing.TestingMetadata; import com.facebook.presto.transaction.TransactionManager; -import com.facebook.presto.type.ArrayType; import com.facebook.presto.type.TypeRegistry; import com.google.common.collect.ImmutableList; import io.airlift.json.JsonCodec; @@ -1431,6 +1431,26 @@ public void testQuantifiedComparisonExpression() assertFails(TYPE_MISMATCH, "SELECT cast(NULL AS HyperLogLog) = ANY (VALUES cast(NULL AS HyperLogLog))"); } + @Test + public void testJoinUnnest() + throws Exception + { + analyze("SELECT * FROM (VALUES array[2, 2]) a(x) CROSS JOIN UNNEST(x)"); + analyze("SELECT * FROM (VALUES array[2, 2]) a(x) LEFT OUTER JOIN UNNEST(x) ON true"); + analyze("SELECT * FROM (VALUES array[2, 2]) a(x) RIGHT OUTER JOIN UNNEST(x) ON true"); + analyze("SELECT * FROM (VALUES array[2, 2]) a(x) FULL OUTER JOIN UNNEST(x) ON true"); + } + + @Test + public void testJoinLateral() + throws Exception + { + analyze("SELECT * FROM (VALUES array[2, 2]) a(x) CROSS JOIN LATERAL(VALUES x)"); + analyze("SELECT * FROM (VALUES array[2, 2]) a(x) LEFT OUTER JOIN LATERAL(VALUES x) ON true"); + analyze("SELECT * FROM (VALUES array[2, 2]) a(x) RIGHT OUTER JOIN LATERAL(VALUES x) ON true"); + analyze("SELECT * FROM (VALUES array[2, 2]) a(x) FULL OUTER JOIN LATERAL(VALUES x) ON true"); + } + @BeforeMethod(alwaysRun = true) public void setup() throws Exception diff --git a/presto-main/src/test/java/com/facebook/presto/sql/analyzer/TestFeaturesConfig.java b/presto-main/src/test/java/com/facebook/presto/sql/analyzer/TestFeaturesConfig.java index 2c0f7146c5f7..4e66a6ed77a5 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/analyzer/TestFeaturesConfig.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/analyzer/TestFeaturesConfig.java @@ -49,6 +49,7 @@ public void testDefaults() .setDictionaryAggregation(false) .setLegacyArrayAgg(false) .setLegacyMapSubscript(false) + .setNewMapBlock(true) .setRegexLibrary(JONI) .setRe2JDfaStatesLimit(Integer.MAX_VALUE) .setRe2JDfaRetries(5) @@ -76,6 +77,7 @@ public void testExplicitPropertyMappings() .put("deprecated.legacy-array-agg", "true") .put("deprecated.legacy-order-by", "true") .put("deprecated.legacy-map-subscript", "true") + .put("deprecated.new-map-block", "false") .put("distributed-index-joins-enabled", "true") .put("distributed-joins-enabled", "false") .put("fast-inequality-joins", "false") @@ -107,6 +109,7 @@ public void testExplicitPropertyMappings() .put("deprecated.legacy-array-agg", "true") .put("deprecated.legacy-order-by", "true") .put("deprecated.legacy-map-subscript", "true") + .put("deprecated.new-map-block", "false") .put("distributed-index-joins-enabled", "true") .put("distributed-joins-enabled", "false") .put("fast-inequality-joins", "false") @@ -151,6 +154,7 @@ public void testExplicitPropertyMappings() .setPushAggregationThroughJoin(false) .setLegacyArrayAgg(true) .setLegacyMapSubscript(true) + .setNewMapBlock(false) .setRegexLibrary(RE2J) .setRe2JDfaStatesLimit(42) .setRe2JDfaRetries(42) diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/TestDomainTranslator.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/TestDomainTranslator.java index df90f78b2853..7ad092cd0400 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/TestDomainTranslator.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/TestDomainTranslator.java @@ -798,6 +798,42 @@ void testNonImplictCastOnSymbolSide() assertEquals(result.getTupleDomain(), TupleDomain.all()); } + @Test + void testNoSaturatedFloorCastFromUnsupportedApproximateDomain() + { + Expression originalExpression = equal( + new Cast(C_DECIMAL_12_2.toSymbolReference(), DOUBLE.toString()), + LiteralInterpreter.toExpression(12345.56, DOUBLE)); + + ExtractionResult result = fromPredicate(originalExpression); + assertEquals(result.getRemainingExpression(), originalExpression); + assertEquals(result.getTupleDomain(), TupleDomain.all()); + + originalExpression = equal( + new Cast(C_BIGINT.toSymbolReference(), DOUBLE.toString()), + LiteralInterpreter.toExpression(12345.56, DOUBLE)); + + result = fromPredicate(originalExpression); + assertEquals(result.getRemainingExpression(), originalExpression); + assertEquals(result.getTupleDomain(), TupleDomain.all()); + + originalExpression = equal( + new Cast(C_BIGINT.toSymbolReference(), REAL.toString()), + LiteralInterpreter.toExpression(realValue(12345.56f), REAL)); + + result = fromPredicate(originalExpression); + assertEquals(result.getRemainingExpression(), originalExpression); + assertEquals(result.getTupleDomain(), TupleDomain.all()); + + originalExpression = equal( + new Cast(C_INTEGER.toSymbolReference(), REAL.toString()), + LiteralInterpreter.toExpression(realValue(12345.56f), REAL)); + + result = fromPredicate(originalExpression); + assertEquals(result.getRemainingExpression(), originalExpression); + assertEquals(result.getTupleDomain(), TupleDomain.all()); + } + @Test public void testFromComparisonsWithCoercions() throws Exception @@ -814,73 +850,73 @@ public void testFromComparisonsWithCoercions() assertEquals(result.getRemainingExpression(), TRUE_LITERAL); assertEquals(result.getTupleDomain(), withColumnDomains(ImmutableMap.of(C_VARCHAR, Domain.create(ValueSet.ofRanges(Range.greaterThan(VARCHAR, utf8Slice("test"))), false)))); - // A is a long column. Check that it can be compared against doubles - originalExpression = greaterThan(cast(C_BIGINT, DOUBLE), doubleLiteral(2.0)); + // A is a integer column. Check that it can be compared against doubles + originalExpression = greaterThan(cast(C_INTEGER, DOUBLE), doubleLiteral(2.0)); result = fromPredicate(originalExpression); assertEquals(result.getRemainingExpression(), TRUE_LITERAL); - assertEquals(result.getTupleDomain(), withColumnDomains(ImmutableMap.of(C_BIGINT, Domain.create(ValueSet.ofRanges(Range.greaterThan(BIGINT, 2L)), false)))); + assertEquals(result.getTupleDomain(), withColumnDomains(ImmutableMap.of(C_INTEGER, Domain.create(ValueSet.ofRanges(Range.greaterThan(INTEGER, 2L)), false)))); - originalExpression = greaterThan(cast(C_BIGINT, DOUBLE), doubleLiteral(2.1)); + originalExpression = greaterThan(cast(C_INTEGER, DOUBLE), doubleLiteral(2.1)); result = fromPredicate(originalExpression); assertEquals(result.getRemainingExpression(), TRUE_LITERAL); - assertEquals(result.getTupleDomain(), withColumnDomains(ImmutableMap.of(C_BIGINT, Domain.create(ValueSet.ofRanges(Range.greaterThan(BIGINT, 2L)), false)))); + assertEquals(result.getTupleDomain(), withColumnDomains(ImmutableMap.of(C_INTEGER, Domain.create(ValueSet.ofRanges(Range.greaterThan(INTEGER, 2L)), false)))); - originalExpression = greaterThanOrEqual(cast(C_BIGINT, DOUBLE), doubleLiteral(2.0)); + originalExpression = greaterThanOrEqual(cast(C_INTEGER, DOUBLE), doubleLiteral(2.0)); result = fromPredicate(originalExpression); assertEquals(result.getRemainingExpression(), TRUE_LITERAL); - assertEquals(result.getTupleDomain(), withColumnDomains(ImmutableMap.of(C_BIGINT, Domain.create(ValueSet.ofRanges(Range.greaterThanOrEqual(BIGINT, 2L)), false)))); + assertEquals(result.getTupleDomain(), withColumnDomains(ImmutableMap.of(C_INTEGER, Domain.create(ValueSet.ofRanges(Range.greaterThanOrEqual(INTEGER, 2L)), false)))); - originalExpression = greaterThanOrEqual(cast(C_BIGINT, DOUBLE), doubleLiteral(2.1)); + originalExpression = greaterThanOrEqual(cast(C_INTEGER, DOUBLE), doubleLiteral(2.1)); result = fromPredicate(originalExpression); assertEquals(result.getRemainingExpression(), TRUE_LITERAL); - assertEquals(result.getTupleDomain(), withColumnDomains(ImmutableMap.of(C_BIGINT, Domain.create(ValueSet.ofRanges(Range.greaterThan(BIGINT, 2L)), false)))); + assertEquals(result.getTupleDomain(), withColumnDomains(ImmutableMap.of(C_INTEGER, Domain.create(ValueSet.ofRanges(Range.greaterThan(INTEGER, 2L)), false)))); - originalExpression = lessThan(cast(C_BIGINT, DOUBLE), doubleLiteral(2.0)); + originalExpression = lessThan(cast(C_INTEGER, DOUBLE), doubleLiteral(2.0)); result = fromPredicate(originalExpression); assertEquals(result.getRemainingExpression(), TRUE_LITERAL); - assertEquals(result.getTupleDomain(), withColumnDomains(ImmutableMap.of(C_BIGINT, Domain.create(ValueSet.ofRanges(Range.lessThan(BIGINT, 2L)), false)))); + assertEquals(result.getTupleDomain(), withColumnDomains(ImmutableMap.of(C_INTEGER, Domain.create(ValueSet.ofRanges(Range.lessThan(INTEGER, 2L)), false)))); - originalExpression = lessThan(cast(C_BIGINT, DOUBLE), doubleLiteral(2.1)); + originalExpression = lessThan(cast(C_INTEGER, DOUBLE), doubleLiteral(2.1)); result = fromPredicate(originalExpression); assertEquals(result.getRemainingExpression(), TRUE_LITERAL); - assertEquals(result.getTupleDomain(), withColumnDomains(ImmutableMap.of(C_BIGINT, Domain.create(ValueSet.ofRanges(Range.lessThanOrEqual(BIGINT, 2L)), false)))); + assertEquals(result.getTupleDomain(), withColumnDomains(ImmutableMap.of(C_INTEGER, Domain.create(ValueSet.ofRanges(Range.lessThanOrEqual(INTEGER, 2L)), false)))); - originalExpression = lessThanOrEqual(cast(C_BIGINT, DOUBLE), doubleLiteral(2.0)); + originalExpression = lessThanOrEqual(cast(C_INTEGER, DOUBLE), doubleLiteral(2.0)); result = fromPredicate(originalExpression); assertEquals(result.getRemainingExpression(), TRUE_LITERAL); - assertEquals(result.getTupleDomain(), withColumnDomains(ImmutableMap.of(C_BIGINT, Domain.create(ValueSet.ofRanges(Range.lessThanOrEqual(BIGINT, 2L)), false)))); + assertEquals(result.getTupleDomain(), withColumnDomains(ImmutableMap.of(C_INTEGER, Domain.create(ValueSet.ofRanges(Range.lessThanOrEqual(INTEGER, 2L)), false)))); - originalExpression = lessThanOrEqual(cast(C_BIGINT, DOUBLE), doubleLiteral(2.1)); + originalExpression = lessThanOrEqual(cast(C_INTEGER, DOUBLE), doubleLiteral(2.1)); result = fromPredicate(originalExpression); assertEquals(result.getRemainingExpression(), TRUE_LITERAL); - assertEquals(result.getTupleDomain(), withColumnDomains(ImmutableMap.of(C_BIGINT, Domain.create(ValueSet.ofRanges(Range.lessThanOrEqual(BIGINT, 2L)), false)))); + assertEquals(result.getTupleDomain(), withColumnDomains(ImmutableMap.of(C_INTEGER, Domain.create(ValueSet.ofRanges(Range.lessThanOrEqual(INTEGER, 2L)), false)))); - originalExpression = equal(cast(C_BIGINT, DOUBLE), doubleLiteral(2.0)); + originalExpression = equal(cast(C_INTEGER, DOUBLE), doubleLiteral(2.0)); result = fromPredicate(originalExpression); assertEquals(result.getRemainingExpression(), TRUE_LITERAL); - assertEquals(result.getTupleDomain(), withColumnDomains(ImmutableMap.of(C_BIGINT, Domain.create(ValueSet.ofRanges(Range.equal(BIGINT, 2L)), false)))); + assertEquals(result.getTupleDomain(), withColumnDomains(ImmutableMap.of(C_INTEGER, Domain.create(ValueSet.ofRanges(Range.equal(INTEGER, 2L)), false)))); - originalExpression = equal(cast(C_BIGINT, DOUBLE), doubleLiteral(2.1)); + originalExpression = equal(cast(C_INTEGER, DOUBLE), doubleLiteral(2.1)); result = fromPredicate(originalExpression); assertEquals(result.getRemainingExpression(), TRUE_LITERAL); - assertEquals(result.getTupleDomain(), withColumnDomains(ImmutableMap.of(C_BIGINT, Domain.none(BIGINT)))); + assertEquals(result.getTupleDomain(), withColumnDomains(ImmutableMap.of(C_INTEGER, Domain.none(INTEGER)))); - originalExpression = notEqual(cast(C_BIGINT, DOUBLE), doubleLiteral(2.0)); + originalExpression = notEqual(cast(C_INTEGER, DOUBLE), doubleLiteral(2.0)); result = fromPredicate(originalExpression); assertEquals(result.getRemainingExpression(), TRUE_LITERAL); - assertEquals(result.getTupleDomain(), withColumnDomains(ImmutableMap.of(C_BIGINT, Domain.create(ValueSet.ofRanges(Range.lessThan(BIGINT, 2L), Range.greaterThan(BIGINT, 2L)), false)))); + assertEquals(result.getTupleDomain(), withColumnDomains(ImmutableMap.of(C_INTEGER, Domain.create(ValueSet.ofRanges(Range.lessThan(INTEGER, 2L), Range.greaterThan(INTEGER, 2L)), false)))); - originalExpression = notEqual(cast(C_BIGINT, DOUBLE), doubleLiteral(2.1)); + originalExpression = notEqual(cast(C_INTEGER, DOUBLE), doubleLiteral(2.1)); result = fromPredicate(originalExpression); assertEquals(result.getRemainingExpression(), TRUE_LITERAL); - assertEquals(result.getTupleDomain(), withColumnDomains(ImmutableMap.of(C_BIGINT, Domain.notNull(BIGINT)))); + assertEquals(result.getTupleDomain(), withColumnDomains(ImmutableMap.of(C_INTEGER, Domain.notNull(INTEGER)))); - originalExpression = isDistinctFrom(cast(C_BIGINT, DOUBLE), doubleLiteral(2.0)); + originalExpression = isDistinctFrom(cast(C_INTEGER, DOUBLE), doubleLiteral(2.0)); result = fromPredicate(originalExpression); assertEquals(result.getRemainingExpression(), TRUE_LITERAL); - assertEquals(result.getTupleDomain(), withColumnDomains(ImmutableMap.of(C_BIGINT, Domain.create(ValueSet.ofRanges(Range.lessThan(BIGINT, 2L), Range.greaterThan(BIGINT, 2L)), true)))); + assertEquals(result.getTupleDomain(), withColumnDomains(ImmutableMap.of(C_INTEGER, Domain.create(ValueSet.ofRanges(Range.lessThan(INTEGER, 2L), Range.greaterThan(INTEGER, 2L)), true)))); - originalExpression = isDistinctFrom(cast(C_BIGINT, DOUBLE), doubleLiteral(2.1)); + originalExpression = isDistinctFrom(cast(C_INTEGER, DOUBLE), doubleLiteral(2.1)); result = fromPredicate(originalExpression); assertEquals(result.getRemainingExpression(), TRUE_LITERAL); assertTrue(result.getTupleDomain().isAll()); @@ -899,73 +935,73 @@ public void testFromComparisonsWithCoercions() assertEquals(result.getRemainingExpression(), TRUE_LITERAL); assertEquals(result.getTupleDomain(), withColumnDomains(ImmutableMap.of(C_VARCHAR, Domain.create(ValueSet.ofRanges(Range.lessThanOrEqual(VARCHAR, utf8Slice("test"))), false)))); - // A is a long column. Check that it can be compared against doubles - originalExpression = not(greaterThan(cast(C_BIGINT, DOUBLE), doubleLiteral(2.0))); + // A is a integer column. Check that it can be compared against doubles + originalExpression = not(greaterThan(cast(C_INTEGER, DOUBLE), doubleLiteral(2.0))); result = fromPredicate(originalExpression); assertEquals(result.getRemainingExpression(), TRUE_LITERAL); - assertEquals(result.getTupleDomain(), withColumnDomains(ImmutableMap.of(C_BIGINT, Domain.create(ValueSet.ofRanges(Range.lessThanOrEqual(BIGINT, 2L)), false)))); + assertEquals(result.getTupleDomain(), withColumnDomains(ImmutableMap.of(C_INTEGER, Domain.create(ValueSet.ofRanges(Range.lessThanOrEqual(INTEGER, 2L)), false)))); - originalExpression = not(greaterThan(cast(C_BIGINT, DOUBLE), doubleLiteral(2.1))); + originalExpression = not(greaterThan(cast(C_INTEGER, DOUBLE), doubleLiteral(2.1))); result = fromPredicate(originalExpression); assertEquals(result.getRemainingExpression(), TRUE_LITERAL); - assertEquals(result.getTupleDomain(), withColumnDomains(ImmutableMap.of(C_BIGINT, Domain.create(ValueSet.ofRanges(Range.lessThanOrEqual(BIGINT, 2L)), false)))); + assertEquals(result.getTupleDomain(), withColumnDomains(ImmutableMap.of(C_INTEGER, Domain.create(ValueSet.ofRanges(Range.lessThanOrEqual(INTEGER, 2L)), false)))); - originalExpression = not(greaterThanOrEqual(cast(C_BIGINT, DOUBLE), doubleLiteral(2.0))); + originalExpression = not(greaterThanOrEqual(cast(C_INTEGER, DOUBLE), doubleLiteral(2.0))); result = fromPredicate(originalExpression); assertEquals(result.getRemainingExpression(), TRUE_LITERAL); - assertEquals(result.getTupleDomain(), withColumnDomains(ImmutableMap.of(C_BIGINT, Domain.create(ValueSet.ofRanges(Range.lessThan(BIGINT, 2L)), false)))); + assertEquals(result.getTupleDomain(), withColumnDomains(ImmutableMap.of(C_INTEGER, Domain.create(ValueSet.ofRanges(Range.lessThan(INTEGER, 2L)), false)))); - originalExpression = not(greaterThanOrEqual(cast(C_BIGINT, DOUBLE), doubleLiteral(2.1))); + originalExpression = not(greaterThanOrEqual(cast(C_INTEGER, DOUBLE), doubleLiteral(2.1))); result = fromPredicate(originalExpression); assertEquals(result.getRemainingExpression(), TRUE_LITERAL); - assertEquals(result.getTupleDomain(), withColumnDomains(ImmutableMap.of(C_BIGINT, Domain.create(ValueSet.ofRanges(Range.lessThanOrEqual(BIGINT, 2L)), false)))); + assertEquals(result.getTupleDomain(), withColumnDomains(ImmutableMap.of(C_INTEGER, Domain.create(ValueSet.ofRanges(Range.lessThanOrEqual(INTEGER, 2L)), false)))); - originalExpression = not(lessThan(cast(C_BIGINT, DOUBLE), doubleLiteral(2.0))); + originalExpression = not(lessThan(cast(C_INTEGER, DOUBLE), doubleLiteral(2.0))); result = fromPredicate(originalExpression); assertEquals(result.getRemainingExpression(), TRUE_LITERAL); - assertEquals(result.getTupleDomain(), withColumnDomains(ImmutableMap.of(C_BIGINT, Domain.create(ValueSet.ofRanges(Range.greaterThanOrEqual(BIGINT, 2L)), false)))); + assertEquals(result.getTupleDomain(), withColumnDomains(ImmutableMap.of(C_INTEGER, Domain.create(ValueSet.ofRanges(Range.greaterThanOrEqual(INTEGER, 2L)), false)))); - originalExpression = not(lessThan(cast(C_BIGINT, DOUBLE), doubleLiteral(2.1))); + originalExpression = not(lessThan(cast(C_INTEGER, DOUBLE), doubleLiteral(2.1))); result = fromPredicate(originalExpression); assertEquals(result.getRemainingExpression(), TRUE_LITERAL); - assertEquals(result.getTupleDomain(), withColumnDomains(ImmutableMap.of(C_BIGINT, Domain.create(ValueSet.ofRanges(Range.greaterThan(BIGINT, 2L)), false)))); + assertEquals(result.getTupleDomain(), withColumnDomains(ImmutableMap.of(C_INTEGER, Domain.create(ValueSet.ofRanges(Range.greaterThan(INTEGER, 2L)), false)))); - originalExpression = not(lessThanOrEqual(cast(C_BIGINT, DOUBLE), doubleLiteral(2.0))); + originalExpression = not(lessThanOrEqual(cast(C_INTEGER, DOUBLE), doubleLiteral(2.0))); result = fromPredicate(originalExpression); assertEquals(result.getRemainingExpression(), TRUE_LITERAL); - assertEquals(result.getTupleDomain(), withColumnDomains(ImmutableMap.of(C_BIGINT, Domain.create(ValueSet.ofRanges(Range.greaterThan(BIGINT, 2L)), false)))); + assertEquals(result.getTupleDomain(), withColumnDomains(ImmutableMap.of(C_INTEGER, Domain.create(ValueSet.ofRanges(Range.greaterThan(INTEGER, 2L)), false)))); - originalExpression = not(lessThanOrEqual(cast(C_BIGINT, DOUBLE), doubleLiteral(2.1))); + originalExpression = not(lessThanOrEqual(cast(C_INTEGER, DOUBLE), doubleLiteral(2.1))); result = fromPredicate(originalExpression); assertEquals(result.getRemainingExpression(), TRUE_LITERAL); - assertEquals(result.getTupleDomain(), withColumnDomains(ImmutableMap.of(C_BIGINT, Domain.create(ValueSet.ofRanges(Range.greaterThan(BIGINT, 2L)), false)))); + assertEquals(result.getTupleDomain(), withColumnDomains(ImmutableMap.of(C_INTEGER, Domain.create(ValueSet.ofRanges(Range.greaterThan(INTEGER, 2L)), false)))); - originalExpression = not(equal(cast(C_BIGINT, DOUBLE), doubleLiteral(2.0))); + originalExpression = not(equal(cast(C_INTEGER, DOUBLE), doubleLiteral(2.0))); result = fromPredicate(originalExpression); assertEquals(result.getRemainingExpression(), TRUE_LITERAL); - assertEquals(result.getTupleDomain(), withColumnDomains(ImmutableMap.of(C_BIGINT, Domain.create(ValueSet.ofRanges(Range.lessThan(BIGINT, 2L), Range.greaterThan(BIGINT, 2L)), false)))); + assertEquals(result.getTupleDomain(), withColumnDomains(ImmutableMap.of(C_INTEGER, Domain.create(ValueSet.ofRanges(Range.lessThan(INTEGER, 2L), Range.greaterThan(INTEGER, 2L)), false)))); - originalExpression = not(equal(cast(C_BIGINT, DOUBLE), doubleLiteral(2.1))); + originalExpression = not(equal(cast(C_INTEGER, DOUBLE), doubleLiteral(2.1))); result = fromPredicate(originalExpression); assertEquals(result.getRemainingExpression(), TRUE_LITERAL); - assertEquals(result.getTupleDomain(), withColumnDomains(ImmutableMap.of(C_BIGINT, Domain.notNull(BIGINT)))); + assertEquals(result.getTupleDomain(), withColumnDomains(ImmutableMap.of(C_INTEGER, Domain.notNull(INTEGER)))); - originalExpression = not(notEqual(cast(C_BIGINT, DOUBLE), doubleLiteral(2.0))); + originalExpression = not(notEqual(cast(C_INTEGER, DOUBLE), doubleLiteral(2.0))); result = fromPredicate(originalExpression); assertEquals(result.getRemainingExpression(), TRUE_LITERAL); - assertEquals(result.getTupleDomain(), withColumnDomains(ImmutableMap.of(C_BIGINT, Domain.create(ValueSet.ofRanges(Range.equal(BIGINT, 2L)), false)))); + assertEquals(result.getTupleDomain(), withColumnDomains(ImmutableMap.of(C_INTEGER, Domain.create(ValueSet.ofRanges(Range.equal(INTEGER, 2L)), false)))); - originalExpression = not(notEqual(cast(C_BIGINT, DOUBLE), doubleLiteral(2.1))); + originalExpression = not(notEqual(cast(C_INTEGER, DOUBLE), doubleLiteral(2.1))); result = fromPredicate(originalExpression); assertEquals(result.getRemainingExpression(), TRUE_LITERAL); - assertEquals(result.getTupleDomain(), withColumnDomains(ImmutableMap.of(C_BIGINT, Domain.none(BIGINT)))); + assertEquals(result.getTupleDomain(), withColumnDomains(ImmutableMap.of(C_INTEGER, Domain.none(INTEGER)))); - originalExpression = not(isDistinctFrom(cast(C_BIGINT, DOUBLE), doubleLiteral(2.0))); + originalExpression = not(isDistinctFrom(cast(C_INTEGER, DOUBLE), doubleLiteral(2.0))); result = fromPredicate(originalExpression); assertEquals(result.getRemainingExpression(), TRUE_LITERAL); - assertEquals(result.getTupleDomain(), withColumnDomains(ImmutableMap.of(C_BIGINT, Domain.create(ValueSet.ofRanges(Range.equal(BIGINT, 2L)), false)))); + assertEquals(result.getTupleDomain(), withColumnDomains(ImmutableMap.of(C_INTEGER, Domain.create(ValueSet.ofRanges(Range.equal(INTEGER, 2L)), false)))); - originalExpression = not(isDistinctFrom(cast(C_BIGINT, DOUBLE), doubleLiteral(2.1))); + originalExpression = not(isDistinctFrom(cast(C_INTEGER, DOUBLE), doubleLiteral(2.1))); result = fromPredicate(originalExpression); assertEquals(result.getRemainingExpression(), TRUE_LITERAL); assertTrue(result.getTupleDomain().isNone()); @@ -1062,10 +1098,10 @@ public void testFromBetweenPredicate() assertEquals(result.getRemainingExpression(), TRUE_LITERAL); assertEquals(result.getTupleDomain(), withColumnDomains(ImmutableMap.of(C_BIGINT, Domain.create(ValueSet.ofRanges(Range.range(BIGINT, 1L, true, 2L, true)), false)))); - originalExpression = between(cast(C_BIGINT, DOUBLE), cast(bigintLiteral(1L), DOUBLE), doubleLiteral(2.1)); + originalExpression = between(cast(C_INTEGER, DOUBLE), cast(bigintLiteral(1L), DOUBLE), doubleLiteral(2.1)); result = fromPredicate(originalExpression); assertEquals(result.getRemainingExpression(), TRUE_LITERAL); - assertEquals(result.getTupleDomain(), withColumnDomains(ImmutableMap.of(C_BIGINT, Domain.create(ValueSet.ofRanges(Range.range(BIGINT, 1L, true, 2L, true)), false)))); + assertEquals(result.getTupleDomain(), withColumnDomains(ImmutableMap.of(C_INTEGER, Domain.create(ValueSet.ofRanges(Range.range(INTEGER, 1L, true, 2L, true)), false)))); originalExpression = between(C_BIGINT, bigintLiteral(1L), nullLiteral(BIGINT)); result = fromPredicate(originalExpression); @@ -1078,10 +1114,10 @@ public void testFromBetweenPredicate() assertEquals(result.getRemainingExpression(), TRUE_LITERAL); assertEquals(result.getTupleDomain(), withColumnDomains(ImmutableMap.of(C_BIGINT, Domain.create(ValueSet.ofRanges(Range.lessThan(BIGINT, 1L), Range.greaterThan(BIGINT, 2L)), false)))); - originalExpression = not(between(cast(C_BIGINT, DOUBLE), cast(bigintLiteral(1L), DOUBLE), doubleLiteral(2.1))); + originalExpression = not(between(cast(C_INTEGER, DOUBLE), cast(bigintLiteral(1L), DOUBLE), doubleLiteral(2.1))); result = fromPredicate(originalExpression); assertEquals(result.getRemainingExpression(), TRUE_LITERAL); - assertEquals(result.getTupleDomain(), withColumnDomains(ImmutableMap.of(C_BIGINT, Domain.create(ValueSet.ofRanges(Range.lessThan(BIGINT, 1L), Range.greaterThan(BIGINT, 2L)), false)))); + assertEquals(result.getTupleDomain(), withColumnDomains(ImmutableMap.of(C_INTEGER, Domain.create(ValueSet.ofRanges(Range.lessThan(INTEGER, 1L), Range.greaterThan(INTEGER, 2L)), false)))); originalExpression = not(between(C_BIGINT, bigintLiteral(1L), nullLiteral(BIGINT))); result = fromPredicate(originalExpression); @@ -1197,19 +1233,17 @@ public void testExpressionConstantFolding() void testMultipleCoercionsOnSymbolSide() throws Exception { - ComparisonExpression originalExpression = comparison(GREATER_THAN, cast(cast(C_BIGINT, REAL), DOUBLE), doubleLiteral(3.7)); + ComparisonExpression originalExpression = comparison(GREATER_THAN, cast(cast(C_SMALLINT, REAL), DOUBLE), doubleLiteral(3.7)); ExtractionResult result = fromPredicate(originalExpression); assertEquals(result.getRemainingExpression(), TRUE_LITERAL); - assertEquals(result.getTupleDomain(), withColumnDomains(ImmutableMap.of(C_BIGINT, Domain.create(ValueSet.ofRanges(Range.greaterThan(BIGINT, 3L)), false)))); + assertEquals(result.getTupleDomain(), withColumnDomains(ImmutableMap.of(C_SMALLINT, Domain.create(ValueSet.ofRanges(Range.greaterThan(SMALLINT, 3L)), false)))); } @Test public void testNumericTypeTranslation() throws Exception { - List translationChain = ImmutableList.of( - new NumericValues<>(C_DOUBLE, -1.0 * Double.MAX_VALUE, -22.0, -44.5556836, 23.0, 44.5556789, Double.MAX_VALUE), - new NumericValues<>(C_REAL, realValue(-1.0f * Float.MAX_VALUE), realValue(-22.0f), realValue(-44.555687f), realValue(23.0f), realValue(44.555676f), realValue(Float.MAX_VALUE)), + testNumericTypeTranslationChain( new NumericValues<>(C_DECIMAL_26_5, longDecimal("-999999999999999999999.99999"), longDecimal("-22.00000"), longDecimal("-44.55569"), longDecimal("23.00000"), longDecimal("44.55567"), longDecimal("999999999999999999999.99999")), new NumericValues<>(C_DECIMAL_23_4, longDecimal("-9999999999999999999.9999"), longDecimal("-22.0000"), longDecimal("-44.5557"), longDecimal("23.0000"), longDecimal("44.5556"), longDecimal("9999999999999999999.9999")), new NumericValues<>(C_BIGINT, Long.MIN_VALUE, -22L, -45L, 23L, 44L, Long.MAX_VALUE), @@ -1220,13 +1254,20 @@ public void testNumericTypeTranslation() new NumericValues<>(C_SMALLINT, (long) Short.MIN_VALUE, -22L, -45L, 23L, 44L, (long) Short.MAX_VALUE), new NumericValues<>(C_DECIMAL_3_0, shortDecimal("-999"), shortDecimal("-22"), shortDecimal("-45"), shortDecimal("23"), shortDecimal("44"), shortDecimal("999")), new NumericValues<>(C_TINYINT, (long) Byte.MIN_VALUE, -22L, -45L, 23L, 44L, (long) Byte.MAX_VALUE), - new NumericValues<>(C_DECIMAL_2_0, shortDecimal("-99"), shortDecimal("-22"), shortDecimal("-45"), shortDecimal("23"), shortDecimal("44"), shortDecimal("99")) + new NumericValues<>(C_DECIMAL_2_0, shortDecimal("-99"), shortDecimal("-22"), shortDecimal("-45"), shortDecimal("23"), shortDecimal("44"), shortDecimal("99"))); + + testNumericTypeTranslationChain( + new NumericValues<>(C_DOUBLE, -1.0 * Double.MAX_VALUE, -22.0, -44.5556836, 23.0, 44.5556789, Double.MAX_VALUE), + new NumericValues<>(C_REAL, realValue(-1.0f * Float.MAX_VALUE), realValue(-22.0f), realValue(-44.555687f), realValue(23.0f), realValue(44.555676f), realValue(Float.MAX_VALUE)) ); + } - for (int literalIndex = 0; literalIndex < translationChain.size(); literalIndex++) { - for (int columnIndex = literalIndex + 1; columnIndex < translationChain.size(); columnIndex++) { - NumericValues literal = translationChain.get(literalIndex); - NumericValues column = translationChain.get(columnIndex); + private void testNumericTypeTranslationChain(NumericValues... translationChain) + { + for (int literalIndex = 0; literalIndex < translationChain.length; literalIndex++) { + for (int columnIndex = literalIndex + 1; columnIndex < translationChain.length; columnIndex++) { + NumericValues literal = translationChain[literalIndex]; + NumericValues column = translationChain[columnIndex]; testNumericTypeTranslation(column, literal); } } diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/TestEqualityInference.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/TestEqualityInference.java index 01db56e686c5..d33a34d7a256 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/TestEqualityInference.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/TestEqualityInference.java @@ -365,13 +365,13 @@ public void testExpressionsThatMayReturnNullOnNonNullInput() private static Predicate matchesSymbolScope(final Predicate symbolScope) { - return expression -> Iterables.all(DependencyExtractor.extractUnique(expression), symbolScope); + return expression -> Iterables.all(SymbolsExtractor.extractUnique(expression), symbolScope); } private static Predicate matchesStraddlingScope(final Predicate symbolScope) { return expression -> { - Set symbols = DependencyExtractor.extractUnique(expression); + Set symbols = SymbolsExtractor.extractUnique(expression); return Iterables.any(symbols, symbolScope) && Iterables.any(symbols, not(symbolScope)); }; } diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/TestLogicalPlanner.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/TestLogicalPlanner.java index 48f3d2f5cd39..bca103261b97 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/TestLogicalPlanner.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/TestLogicalPlanner.java @@ -26,6 +26,7 @@ import com.facebook.presto.sql.planner.plan.PlanNode; import com.facebook.presto.sql.planner.plan.SemiJoinNode; import com.facebook.presto.sql.planner.plan.ValuesNode; +import com.facebook.presto.sql.tree.LongLiteral; import com.facebook.presto.tests.QueryTemplate; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; @@ -50,15 +51,17 @@ import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.join; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.lateral; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.node; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.output; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.project; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.semiJoin; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.strictTableScan; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.tableScan; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.values; import static com.facebook.presto.sql.planner.optimizations.PlanNodeSearcher.searchFrom; -import static com.facebook.presto.sql.planner.optimizations.Predicates.isInstanceOfAny; import static com.facebook.presto.sql.planner.plan.JoinNode.Type.INNER; import static com.facebook.presto.sql.planner.plan.JoinNode.Type.LEFT; import static com.facebook.presto.tests.QueryTemplate.queryTemplate; +import static com.facebook.presto.util.MorePredicates.isInstanceOfAny; import static io.airlift.slice.Slices.utf8Slice; import static org.testng.Assert.assertEquals; import static org.testng.Assert.assertFalse; @@ -370,4 +373,36 @@ public void testCorrelatedScalarAggregationRewriteToLeftOuterJoin() project(ImmutableMap.of("NON_NULL", expression("true")), node(ValuesNode.class))))))))))); } + + @Test + public void testRemovesTrivialFilters() + { + assertPlan( + "SELECT * FROM nation WHERE 1 = 1", + output( + tableScan("nation")) + ); + assertPlan( + "SELECT * FROM nation WHERE 1 = 0", + output( + values("nationkey", "name", "regionkey", "comment")) + ); + } + + @Test + public void testPruneCountAggregationOverScalar() + { + assertPlan( + "SELECT count(*) FROM (SELECT sum(orderkey) FROM orders)", + output( + values(ImmutableList.of("_col0"), ImmutableList.of(ImmutableList.of(new LongLiteral("1")))))); + assertPlan( + "SELECT count(s) FROM (SELECT sum(orderkey) AS s FROM orders)", + anyTree( + tableScan("orders"))); + assertPlan( + "SELECT count(*) FROM (SELECT sum(orderkey) FROM orders GROUP BY custkey)", + anyTree( + tableScan("orders"))); + } } diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/TestingConnectorIndexHandle.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/TestingConnectorIndexHandle.java new file mode 100644 index 000000000000..2ffce3792266 --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/TestingConnectorIndexHandle.java @@ -0,0 +1,22 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner; + +import com.facebook.presto.spi.ConnectorIndexHandle; + +public enum TestingConnectorIndexHandle + implements ConnectorIndexHandle +{ + INSTANCE +} diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/TestingConnectorTransactionHandle.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/TestingConnectorTransactionHandle.java new file mode 100644 index 000000000000..c26dfe94d933 --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/TestingConnectorTransactionHandle.java @@ -0,0 +1,22 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner; + +import com.facebook.presto.spi.connector.ConnectorTransactionHandle; + +public enum TestingConnectorTransactionHandle + implements ConnectorTransactionHandle +{ + INSTANCE +} diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/TestingWriterTarget.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/TestingWriterTarget.java new file mode 100644 index 000000000000..ca45d957f943 --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/TestingWriterTarget.java @@ -0,0 +1,27 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.facebook.presto.sql.planner; + +import com.facebook.presto.sql.planner.plan.TableWriterNode; + +public class TestingWriterTarget + extends TableWriterNode.WriterTarget +{ + @Override + public String toString() + { + return "testing handle"; + } +} diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/AliasMatcher.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/AliasMatcher.java index 49cb8097ad65..4ec853583a6d 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/AliasMatcher.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/AliasMatcher.java @@ -64,6 +64,9 @@ public MatchResult detailMatches(PlanNode node, PlanNodeCost cost, Session sessi @Override public String toString() { - return format("bind %s -> %s", alias, matcher); + if (alias.isPresent()) { + return format("bind %s -> %s", alias.get(), matcher); + } + return format("bind %s", matcher); } } diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/AssignUniqueIdMatcher.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/AssignUniqueIdMatcher.java index 440713f68763..eafcbe89034f 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/AssignUniqueIdMatcher.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/AssignUniqueIdMatcher.java @@ -21,6 +21,8 @@ import java.util.Optional; +import static com.google.common.base.MoreObjects.toStringHelper; + public class AssignUniqueIdMatcher implements RvalueMatcher { @@ -35,4 +37,11 @@ public Optional getAssignedSymbol(PlanNode node, Session session, Metada return Optional.of(assignUniqueIdNode.getIdColumn()); } + + @Override + public String toString() + { + return toStringHelper(this) + .toString(); + } } diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/ColumnReference.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/ColumnReference.java index afaf6b4b82c9..490daeee1fbd 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/ColumnReference.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/ColumnReference.java @@ -19,6 +19,7 @@ import com.facebook.presto.metadata.TableMetadata; import com.facebook.presto.spi.ColumnHandle; import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.sql.planner.plan.IndexSourceNode; import com.facebook.presto.sql.planner.plan.PlanNode; import com.facebook.presto.sql.planner.plan.TableScanNode; @@ -44,12 +45,24 @@ public ColumnReference(String tableName, String columnName) @Override public Optional getAssignedSymbol(PlanNode node, Session session, Metadata metadata, SymbolAliases symbolAliases) { - if (!(node instanceof TableScanNode)) { + TableHandle tableHandle; + Map assignments; + + if (node instanceof TableScanNode) { + TableScanNode tableScanNode = (TableScanNode) node; + tableHandle = tableScanNode.getTable(); + assignments = tableScanNode.getAssignments(); + } + else if (node instanceof IndexSourceNode) { + IndexSourceNode indexSourceNode = (IndexSourceNode) node; + tableHandle = indexSourceNode.getTableHandle(); + assignments = indexSourceNode.getAssignments(); + } + else { return Optional.empty(); } - TableScanNode tableScanNode = (TableScanNode) node; - TableMetadata tableMetadata = metadata.getTableMetadata(session, tableScanNode.getTable()); + TableMetadata tableMetadata = metadata.getTableMetadata(session, tableHandle); String actualTableName = tableMetadata.getTable().getTableName(); // Wrong table -> doesn't match. @@ -57,17 +70,17 @@ public Optional getAssignedSymbol(PlanNode node, Session session, Metada return Optional.empty(); } - Optional columnHandle = getColumnHandle(tableScanNode.getTable(), session, metadata); + Optional columnHandle = getColumnHandle(tableHandle, session, metadata); checkState(columnHandle.isPresent(), format("Table %s doesn't have column %s. Typo in test?", tableName, columnName)); - return getAssignedSymbol(tableScanNode, columnHandle.get()); + return getAssignedSymbol(assignments, columnHandle.get()); } - private Optional getAssignedSymbol(TableScanNode tableScanNode, ColumnHandle columnHandle) + private Optional getAssignedSymbol(Map assignments, ColumnHandle columnHandle) { Optional result = Optional.empty(); - for (Map.Entry entry : tableScanNode.getAssignments().entrySet()) { + for (Map.Entry entry : assignments.entrySet()) { if (entry.getValue().equals(columnHandle)) { checkState(!result.isPresent(), "Multiple ColumnHandles found for %s:%s in table scan assignments", tableName, columnName); result = Optional.of(entry.getKey()); diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/EquiJoinClauseProvider.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/EquiJoinClauseProvider.java index 8e7e320e13a2..98943e27fff9 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/EquiJoinClauseProvider.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/EquiJoinClauseProvider.java @@ -37,9 +37,6 @@ public JoinNode.EquiJoinClause getExpectedValue(SymbolAliases aliases) @Override public String toString() { - return "EquiJoinClauseProvider{" + - "left=" + left + - ", right=" + right + - '}'; + return left + " = " + right; } } diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/ExpressionVerifier.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/ExpressionVerifier.java index 00e48f578a4f..c3bfa27ee35d 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/ExpressionVerifier.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/ExpressionVerifier.java @@ -22,6 +22,7 @@ import com.facebook.presto.sql.tree.ComparisonExpression; import com.facebook.presto.sql.tree.DoubleLiteral; import com.facebook.presto.sql.tree.Expression; +import com.facebook.presto.sql.tree.FunctionCall; import com.facebook.presto.sql.tree.GenericLiteral; import com.facebook.presto.sql.tree.InListExpression; import com.facebook.presto.sql.tree.InPredicate; @@ -31,11 +32,15 @@ import com.facebook.presto.sql.tree.LongLiteral; import com.facebook.presto.sql.tree.Node; import com.facebook.presto.sql.tree.NotExpression; +import com.facebook.presto.sql.tree.NullLiteral; +import com.facebook.presto.sql.tree.SimpleCaseExpression; import com.facebook.presto.sql.tree.StringLiteral; import com.facebook.presto.sql.tree.SymbolReference; import com.facebook.presto.sql.tree.TryExpression; +import com.facebook.presto.sql.tree.WhenClause; import java.util.List; +import java.util.Optional; import static com.google.common.base.Preconditions.checkState; import static java.lang.String.format; @@ -66,7 +71,7 @@ * */ final class ExpressionVerifier - extends AstVisitor + extends AstVisitor { private final SymbolAliases symbolAliases; @@ -76,13 +81,13 @@ final class ExpressionVerifier } @Override - protected Boolean visitNode(Node node, Expression context) + protected Boolean visitNode(Node node, Node context) { throw new IllegalStateException(format("Node %s is not supported", node)); } @Override - protected Boolean visitTryExpression(TryExpression actual, Expression expected) + protected Boolean visitTryExpression(TryExpression actual, Node expected) { if (!(expected instanceof TryExpression)) { return false; @@ -92,7 +97,7 @@ protected Boolean visitTryExpression(TryExpression actual, Expression expected) } @Override - protected Boolean visitCast(Cast actual, Expression expectedExpression) + protected Boolean visitCast(Cast actual, Node expectedExpression) { if (!(expectedExpression instanceof Cast)) { return false; @@ -108,7 +113,7 @@ protected Boolean visitCast(Cast actual, Expression expectedExpression) } @Override - protected Boolean visitIsNullPredicate(IsNullPredicate actual, Expression expectedExpression) + protected Boolean visitIsNullPredicate(IsNullPredicate actual, Node expectedExpression) { if (!(expectedExpression instanceof IsNullPredicate)) { return false; @@ -120,7 +125,7 @@ protected Boolean visitIsNullPredicate(IsNullPredicate actual, Expression expect } @Override - protected Boolean visitIsNotNullPredicate(IsNotNullPredicate actual, Expression expectedExpression) + protected Boolean visitIsNotNullPredicate(IsNotNullPredicate actual, Node expectedExpression) { if (!(expectedExpression instanceof IsNotNullPredicate)) { return false; @@ -132,7 +137,7 @@ protected Boolean visitIsNotNullPredicate(IsNotNullPredicate actual, Expression } @Override - protected Boolean visitInPredicate(InPredicate actual, Expression expectedExpression) + protected Boolean visitInPredicate(InPredicate actual, Node expectedExpression) { if (expectedExpression instanceof InPredicate) { InPredicate expected = (InPredicate) expectedExpression; @@ -166,7 +171,7 @@ protected Boolean visitInPredicate(InPredicate actual, Expression expectedExpres } @Override - protected Boolean visitComparisonExpression(ComparisonExpression actual, Expression expectedExpression) + protected Boolean visitComparisonExpression(ComparisonExpression actual, Node expectedExpression) { if (expectedExpression instanceof ComparisonExpression) { ComparisonExpression expected = (ComparisonExpression) expectedExpression; @@ -178,7 +183,7 @@ protected Boolean visitComparisonExpression(ComparisonExpression actual, Express } @Override - protected Boolean visitArithmeticBinary(ArithmeticBinaryExpression actual, Expression expectedExpression) + protected Boolean visitArithmeticBinary(ArithmeticBinaryExpression actual, Node expectedExpression) { if (expectedExpression instanceof ArithmeticBinaryExpression) { ArithmeticBinaryExpression expected = (ArithmeticBinaryExpression) expectedExpression; @@ -189,7 +194,7 @@ protected Boolean visitArithmeticBinary(ArithmeticBinaryExpression actual, Expre return false; } - protected Boolean visitGenericLiteral(GenericLiteral actual, Expression expected) + protected Boolean visitGenericLiteral(GenericLiteral actual, Node expected) { if (expected instanceof GenericLiteral) { return getValueFromLiteral(actual).equals(getValueFromLiteral(expected)); @@ -199,7 +204,7 @@ protected Boolean visitGenericLiteral(GenericLiteral actual, Expression expected } @Override - protected Boolean visitLongLiteral(LongLiteral actual, Expression expected) + protected Boolean visitLongLiteral(LongLiteral actual, Node expected) { if (expected instanceof LongLiteral) { return getValueFromLiteral(actual).equals(getValueFromLiteral(expected)); @@ -209,7 +214,7 @@ protected Boolean visitLongLiteral(LongLiteral actual, Expression expected) } @Override - protected Boolean visitDoubleLiteral(DoubleLiteral actual, Expression expected) + protected Boolean visitDoubleLiteral(DoubleLiteral actual, Node expected) { if (expected instanceof DoubleLiteral) { return getValueFromLiteral(actual).equals(getValueFromLiteral(expected)); @@ -219,7 +224,7 @@ protected Boolean visitDoubleLiteral(DoubleLiteral actual, Expression expected) } @Override - protected Boolean visitBooleanLiteral(BooleanLiteral actual, Expression expected) + protected Boolean visitBooleanLiteral(BooleanLiteral actual, Node expected) { if (expected instanceof BooleanLiteral) { return getValueFromLiteral(actual).equals(getValueFromLiteral(expected)); @@ -227,7 +232,7 @@ protected Boolean visitBooleanLiteral(BooleanLiteral actual, Expression expected return false; } - private static String getValueFromLiteral(Expression expression) + private static String getValueFromLiteral(Node expression) { if (expression instanceof LongLiteral) { return String.valueOf(((LongLiteral) expression).getValue()); @@ -247,7 +252,7 @@ else if (expression instanceof GenericLiteral) { } @Override - protected Boolean visitStringLiteral(StringLiteral actual, Expression expectedExpression) + protected Boolean visitStringLiteral(StringLiteral actual, Node expectedExpression) { if (expectedExpression instanceof StringLiteral) { StringLiteral expected = (StringLiteral) expectedExpression; @@ -257,7 +262,7 @@ protected Boolean visitStringLiteral(StringLiteral actual, Expression expectedEx } @Override - protected Boolean visitLogicalBinaryExpression(LogicalBinaryExpression actual, Expression expectedExpression) + protected Boolean visitLogicalBinaryExpression(LogicalBinaryExpression actual, Node expectedExpression) { if (expectedExpression instanceof LogicalBinaryExpression) { LogicalBinaryExpression expected = (LogicalBinaryExpression) expectedExpression; @@ -269,7 +274,7 @@ protected Boolean visitLogicalBinaryExpression(LogicalBinaryExpression actual, E } @Override - protected Boolean visitBetweenPredicate(BetweenPredicate actual, Expression expectedExpression) + protected Boolean visitBetweenPredicate(BetweenPredicate actual, Node expectedExpression) { if (expectedExpression instanceof BetweenPredicate) { BetweenPredicate expected = (BetweenPredicate) expectedExpression; @@ -280,7 +285,7 @@ protected Boolean visitBetweenPredicate(BetweenPredicate actual, Expression expe } @Override - protected Boolean visitNotExpression(NotExpression actual, Expression expected) + protected Boolean visitNotExpression(NotExpression actual, Node expected) { if (expected instanceof NotExpression) { return process(actual.getValue(), ((NotExpression) expected).getValue()); @@ -289,7 +294,7 @@ protected Boolean visitNotExpression(NotExpression actual, Expression expected) } @Override - protected Boolean visitSymbolReference(SymbolReference actual, Expression expected) + protected Boolean visitSymbolReference(SymbolReference actual, Node expected) { if (!(expected instanceof SymbolReference)) { return false; @@ -298,7 +303,7 @@ protected Boolean visitSymbolReference(SymbolReference actual, Expression expect } @Override - protected Boolean visitCoalesceExpression(CoalesceExpression actual, Expression expected) + protected Boolean visitCoalesceExpression(CoalesceExpression actual, Node expected) { if (!(expected instanceof CoalesceExpression)) { return false; @@ -314,4 +319,94 @@ protected Boolean visitCoalesceExpression(CoalesceExpression actual, Expression } return false; } + + @Override + protected Boolean visitSimpleCaseExpression(SimpleCaseExpression actual, Node expected) + { + if (!(expected instanceof SimpleCaseExpression)) { + return false; + } + SimpleCaseExpression expectedCase = (SimpleCaseExpression) expected; + if (!process(actual.getOperand(), expectedCase.getOperand())) { + return false; + } + + if (!process(actual.getWhenClauses(), expectedCase.getWhenClauses())) { + return false; + } + + return process(actual.getDefaultValue(), expectedCase.getDefaultValue()); + } + + @Override + protected Boolean visitWhenClause(WhenClause actual, Node expected) + { + if (!(expected instanceof WhenClause)) { + return false; + } + WhenClause expectedWhenClause = (WhenClause) expected; + + return process(actual.getOperand(), expectedWhenClause.getOperand()) && process(actual.getResult(), expectedWhenClause.getResult()); + } + + @Override + protected Boolean visitFunctionCall(FunctionCall actual, Node expected) + { + if (!(expected instanceof FunctionCall)) { + return false; + } + FunctionCall expectedFunction = (FunctionCall) expected; + + if (actual.isDistinct() != expectedFunction.isDistinct()) { + return false; + } + + if (!actual.getName().equals(expectedFunction.getName())) { + return false; + } + + if (!process(actual.getArguments(), expectedFunction.getArguments())) { + return false; + } + + if (!process(actual.getFilter(), expectedFunction.getFilter())) { + return false; + } + + if (!process(actual.getWindow(), expectedFunction.getWindow())) { + return false; + } + + return true; + } + + @Override + protected Boolean visitNullLiteral(NullLiteral node, Node expected) + { + return expected instanceof NullLiteral; + } + + private boolean process(List actuals, List expecteds) + { + if (actuals.size() != expecteds.size()) { + return false; + } + for (int i = 0; i < actuals.size(); i++) { + if (!process(actuals.get(i), expecteds.get(i))) { + return false; + } + } + return true; + } + + private boolean process(Optional actual, Optional expected) + { + if (actual.isPresent() != expected.isPresent()) { + return false; + } + if (actual.isPresent()) { + return process(actual.get(), expected.get()); + } + return true; + } } diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/IndexSourceMatcher.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/IndexSourceMatcher.java new file mode 100644 index 000000000000..30e8e200f665 --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/IndexSourceMatcher.java @@ -0,0 +1,94 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.assertions; + +import com.facebook.presto.Session; +import com.facebook.presto.cost.PlanNodeCost; +import com.facebook.presto.metadata.Metadata; +import com.facebook.presto.metadata.TableMetadata; +import com.facebook.presto.spi.predicate.Domain; +import com.facebook.presto.sql.planner.plan.IndexSourceNode; +import com.facebook.presto.sql.planner.plan.PlanNode; +import com.google.common.collect.ImmutableMap; + +import java.util.Map; +import java.util.Optional; + +import static com.facebook.presto.sql.planner.assertions.MatchResult.NO_MATCH; +import static com.facebook.presto.sql.planner.assertions.MatchResult.match; +import static com.facebook.presto.sql.planner.assertions.Util.domainsMatch; +import static com.google.common.base.MoreObjects.toStringHelper; +import static com.google.common.base.Preconditions.checkState; +import static java.util.Objects.requireNonNull; + +final class IndexSourceMatcher + implements Matcher +{ + private final String expectedTableName; + private final Optional> expectedConstraint; + + public IndexSourceMatcher(String expectedTableName) + { + this.expectedTableName = requireNonNull(expectedTableName, "expectedTableName is null"); + expectedConstraint = Optional.empty(); + } + + public IndexSourceMatcher(String expectedTableName, Map expectedConstraint) + { + this.expectedTableName = requireNonNull(expectedTableName, "expectedTableName is null"); + this.expectedConstraint = Optional.of(ImmutableMap.copyOf(expectedConstraint)); + } + + @Override + public boolean shapeMatches(PlanNode node) + { + return node instanceof IndexSourceNode; + } + + @Override + public MatchResult detailMatches(PlanNode node, PlanNodeCost cost, Session session, Metadata metadata, SymbolAliases symbolAliases) + { + checkState(shapeMatches(node), "Plan testing framework error: shapeMatches returned false in detailMatches in %s", this.getClass().getName()); + + IndexSourceNode indexSourceNode = (IndexSourceNode) node; + TableMetadata tableMetadata = metadata.getTableMetadata(session, indexSourceNode.getTableHandle()); + String actualTableName = tableMetadata.getTable().getTableName(); + + if (!expectedTableName.equalsIgnoreCase(actualTableName)) { + return NO_MATCH; + } + + if (expectedConstraint.isPresent() && + !domainsMatch( + expectedConstraint, + indexSourceNode.getEffectiveTupleDomain(), + indexSourceNode.getTableHandle(), + session, + metadata)) { + return NO_MATCH; + } + + return match(); + } + + @Override + public String toString() + { + return toStringHelper(this) + .omitNullValues() + .add("expectedTableName", expectedTableName) + .add("expectedConstraint", expectedConstraint.orElse(null)) + .toString(); + } +} diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/JoinMatcher.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/JoinMatcher.java index d73a005c8b22..c455a42edded 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/JoinMatcher.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/JoinMatcher.java @@ -98,8 +98,9 @@ public MatchResult detailMatches(PlanNode node, PlanNodeCost cost, Session sessi public String toString() { return toStringHelper(this) + .omitNullValues() .add("equiCriteria", equiCriteria) - .add("filter", filter) + .add("filter", filter.orElse(null)) .toString(); } } diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/MarkDistinctMatcher.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/MarkDistinctMatcher.java new file mode 100644 index 000000000000..d79f38a45f98 --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/MarkDistinctMatcher.java @@ -0,0 +1,81 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.assertions; + +import com.facebook.presto.Session; +import com.facebook.presto.cost.PlanNodeCost; +import com.facebook.presto.metadata.Metadata; +import com.facebook.presto.sql.planner.plan.MarkDistinctNode; +import com.facebook.presto.sql.planner.plan.PlanNode; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableSet; + +import java.util.List; +import java.util.Optional; + +import static com.facebook.presto.sql.planner.assertions.MatchResult.NO_MATCH; +import static com.facebook.presto.sql.planner.assertions.MatchResult.match; +import static com.google.common.base.MoreObjects.toStringHelper; +import static com.google.common.base.Preconditions.checkState; +import static com.google.common.collect.ImmutableSet.toImmutableSet; +import static java.util.Objects.requireNonNull; + +public class MarkDistinctMatcher + implements Matcher +{ + private final PlanTestSymbol markerSymbol; + private final List distinctSymbols; + private final Optional hashSymbol; + + public MarkDistinctMatcher(PlanTestSymbol markerSymbol, List distinctSymbols, Optional hashSymbol) + { + this.markerSymbol = requireNonNull(markerSymbol, "markerSymbol is null"); + this.distinctSymbols = ImmutableList.copyOf(distinctSymbols); + this.hashSymbol = requireNonNull(hashSymbol, "hashSymbol is null"); + } + + @Override + public boolean shapeMatches(PlanNode node) + { + return node instanceof MarkDistinctNode; + } + + @Override + public MatchResult detailMatches(PlanNode node, PlanNodeCost planNodeCost, Session session, Metadata metadata, SymbolAliases symbolAliases) + { + checkState(shapeMatches(node), "Plan testing framework error: shapeMatches returned false in detailMatches in %s", this.getClass().getName()); + MarkDistinctNode markDistinctNode = (MarkDistinctNode) node; + + if (!markDistinctNode.getHashSymbol().equals(hashSymbol.map(alias -> alias.toSymbol(symbolAliases)))) { + return NO_MATCH; + } + + if (!ImmutableSet.copyOf(markDistinctNode.getDistinctSymbols()) + .equals(distinctSymbols.stream().map(alias -> alias.toSymbol(symbolAliases)).collect(toImmutableSet()))) { + return NO_MATCH; + } + + return match(markerSymbol.toString(), markDistinctNode.getMarkerSymbol().toSymbolReference()); + } + + @Override + public String toString() + { + return toStringHelper(this) + .add("markerSymbol", markerSymbol) + .add("distinctSymbols", distinctSymbols) + .add("hashSymbol", hashSymbol) + .toString(); + } +} diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/PlanMatchPattern.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/PlanMatchPattern.java index 8d34a77637ee..f059b2679467 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/PlanMatchPattern.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/PlanMatchPattern.java @@ -28,20 +28,24 @@ import com.facebook.presto.sql.planner.plan.ExchangeNode; import com.facebook.presto.sql.planner.plan.FilterNode; import com.facebook.presto.sql.planner.plan.GroupIdNode; +import com.facebook.presto.sql.planner.plan.IndexSourceNode; import com.facebook.presto.sql.planner.plan.IntersectNode; import com.facebook.presto.sql.planner.plan.JoinNode; import com.facebook.presto.sql.planner.plan.LateralJoinNode; import com.facebook.presto.sql.planner.plan.LimitNode; +import com.facebook.presto.sql.planner.plan.MarkDistinctNode; import com.facebook.presto.sql.planner.plan.OutputNode; import com.facebook.presto.sql.planner.plan.PlanNode; import com.facebook.presto.sql.planner.plan.ProjectNode; import com.facebook.presto.sql.planner.plan.SemiJoinNode; import com.facebook.presto.sql.planner.plan.SortNode; import com.facebook.presto.sql.planner.plan.TableScanNode; +import com.facebook.presto.sql.planner.plan.TableWriterNode; import com.facebook.presto.sql.planner.plan.UnionNode; import com.facebook.presto.sql.planner.plan.ValuesNode; import com.facebook.presto.sql.planner.plan.WindowNode; import com.facebook.presto.sql.tree.Expression; +import com.facebook.presto.sql.tree.FrameBound; import com.facebook.presto.sql.tree.FunctionCall; import com.facebook.presto.sql.tree.QualifiedName; import com.facebook.presto.sql.tree.WindowFrame; @@ -55,6 +59,7 @@ import java.util.List; import java.util.Map; import java.util.Optional; +import java.util.function.Consumer; import java.util.stream.IntStream; import static com.facebook.presto.sql.ExpressionUtils.rewriteIdentifiersToSymbolReferences; @@ -107,6 +112,12 @@ public static PlanMatchPattern tableScan(String expectedTableName) return node(TableScanNode.class).with(new TableScanMatcher(expectedTableName)); } + public static PlanMatchPattern tableScan(String expectedTableName, String originalConstraint) + { + Expression expectedOriginalConstraint = rewriteIdentifiersToSymbolReferences(new SqlParser().createExpression(originalConstraint)); + return node(TableScanNode.class).with(new TableScanMatcher(expectedTableName, expectedOriginalConstraint)); + } + public static PlanMatchPattern tableScan(String expectedTableName, Map columnReferences) { PlanMatchPattern result = tableScan(expectedTableName); @@ -132,6 +143,13 @@ public static PlanMatchPattern constrainedTableScan(String expectedTableName, Ma return result.addColumnReferences(expectedTableName, columnReferences); } + public static PlanMatchPattern constrainedIndexSource(String expectedTableName, Map constraint, Map columnReferences) + { + return node(IndexSourceNode.class) + .with(new IndexSourceMatcher(expectedTableName, constraint)) + .addColumnReferences(expectedTableName, columnReferences); + } + private PlanMatchPattern addColumnReferences(String expectedTableName, Map columnReferences) { columnReferences.entrySet().forEach( @@ -163,26 +181,49 @@ public static PlanMatchPattern aggregation( return result; } - public static PlanMatchPattern window( - ExpectedValueProvider specification, - List> windowFunctions, + public static PlanMatchPattern markDistinct( + String markerSymbol, + List distinctSymbols, PlanMatchPattern source) { - PlanMatchPattern result = node(WindowNode.class, source).with(new WindowMatcher(specification)); - windowFunctions.forEach( - function -> result.withAlias(Optional.empty(), new WindowFunctionMatcher(function))); - return result; + return node(MarkDistinctNode.class, source).with(new MarkDistinctMatcher( + new SymbolAlias(markerSymbol), + toSymbolAliases(distinctSymbols), + Optional.empty())); } - public static PlanMatchPattern window( - ExpectedValueProvider specification, - Map> assignments, + public static PlanMatchPattern markDistinct( + String markerSymbol, + List distinctSymbols, + String hashSymbol, PlanMatchPattern source) { - PlanMatchPattern result = node(WindowNode.class, source).with(new WindowMatcher(specification)); - assignments.entrySet().forEach( - assignment -> result.withAlias(assignment.getKey(), new WindowFunctionMatcher(assignment.getValue()))); - return result; + return node(MarkDistinctNode.class, source).with(new MarkDistinctMatcher( + new SymbolAlias(markerSymbol), + toSymbolAliases(distinctSymbols), + Optional.of(new SymbolAlias(hashSymbol)))); + } + + public static ExpectedValueProvider windowFrame( + WindowFrame.Type type, + FrameBound.Type startType, + Optional startValue, + FrameBound.Type endType, + Optional endValue) + { + return new WindowFrameProvider( + type, + startType, + startValue.map(SymbolAlias::new), + endType, + endValue.map(SymbolAlias::new)); + } + + public static PlanMatchPattern window(Consumer windowMatcherBuilderConsumer, PlanMatchPattern source) + { + WindowMatcher.Builder windowMatcherBuilder = new WindowMatcher.Builder(source); + windowMatcherBuilderConsumer.accept(windowMatcherBuilder); + return windowMatcherBuilder.build(); } public static PlanMatchPattern sort(PlanMatchPattern source) @@ -360,6 +401,11 @@ public static PlanMatchPattern limit(long limit, PlanMatchPattern source) return node(LimitNode.class, source).with(new LimitMatcher(limit)); } + public static PlanMatchPattern tableWriter(List columns, List columnNames, PlanMatchPattern source) + { + return node(TableWriterNode.class, source).with(new TableWriterMatcher(columns, columnNames)); + } + public PlanMatchPattern(List sourcePatterns) { requireNonNull(sourcePatterns, "sourcePatterns are null"); diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/TableScanMatcher.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/TableScanMatcher.java index 65b39ad9cc47..8a5b7e528d3d 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/TableScanMatcher.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/TableScanMatcher.java @@ -17,35 +17,50 @@ import com.facebook.presto.cost.PlanNodeCost; import com.facebook.presto.metadata.Metadata; import com.facebook.presto.metadata.TableMetadata; -import com.facebook.presto.spi.ColumnHandle; import com.facebook.presto.spi.predicate.Domain; -import com.facebook.presto.spi.predicate.TupleDomain; +import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.plan.PlanNode; import com.facebook.presto.sql.planner.plan.TableScanNode; +import com.facebook.presto.sql.tree.Expression; +import com.facebook.presto.sql.tree.SymbolReference; import java.util.Map; import java.util.Optional; +import static com.facebook.presto.sql.planner.assertions.Util.domainsMatch; import static com.google.common.base.MoreObjects.toStringHelper; import static com.google.common.base.Preconditions.checkState; import static java.util.Objects.requireNonNull; +import static java.util.Optional.empty; +import static java.util.stream.Collectors.toMap; final class TableScanMatcher implements Matcher { private final String expectedTableName; private final Optional> expectedConstraint; + private final Optional expectedOriginalConstraint; TableScanMatcher(String expectedTableName) { - this.expectedTableName = requireNonNull(expectedTableName, "expectedTableName is null"); - expectedConstraint = Optional.empty(); + this(expectedTableName, empty(), empty()); } public TableScanMatcher(String expectedTableName, Map expectedConstraint) + { + this(expectedTableName, Optional.of(expectedConstraint), empty()); + } + + public TableScanMatcher(String expectedTableName, Expression originalConstraint) + { + this(expectedTableName, empty(), Optional.of(originalConstraint)); + } + + private TableScanMatcher(String expectedTableName, Optional> expectedConstraint, Optional originalConstraint) { this.expectedTableName = requireNonNull(expectedTableName, "expectedTableName is null"); - this.expectedConstraint = Optional.of(requireNonNull(expectedConstraint, "expectedConstraint is null")); + this.expectedConstraint = requireNonNull(expectedConstraint, "expectedConstraint is null"); + this.expectedOriginalConstraint = requireNonNull(originalConstraint, "expectedOriginalConstraint is null"); } @Override @@ -64,35 +79,22 @@ public MatchResult detailMatches(PlanNode node, PlanNodeCost cost, Session sessi String actualTableName = tableMetadata.getTable().getTableName(); return new MatchResult( expectedTableName.equalsIgnoreCase(actualTableName) && - domainMatches(tableScanNode, session, metadata)); + originalConstraintMatches(tableScanNode) && + ((!expectedConstraint.isPresent()) || + domainsMatch(expectedConstraint, tableScanNode.getCurrentConstraint(), tableScanNode.getTable(), session, metadata))); } - private boolean domainMatches(TableScanNode tableScanNode, Session session, Metadata metadata) + private boolean originalConstraintMatches(TableScanNode node) { - if (!expectedConstraint.isPresent()) { - return true; - } - - TupleDomain actualConstraint = tableScanNode.getCurrentConstraint(); - if (expectedConstraint.isPresent() && !actualConstraint.getDomains().isPresent()) { - return false; - } - - Map columnHandles = metadata.getColumnHandles(session, tableScanNode.getTable()); - for (Map.Entry expectedColumnConstraint : expectedConstraint.get().entrySet()) { - if (!columnHandles.containsKey(expectedColumnConstraint.getKey())) { - return false; - } - ColumnHandle columnHandle = columnHandles.get(expectedColumnConstraint.getKey()); - if (!actualConstraint.getDomains().get().containsKey(columnHandle)) { - return false; - } - if (!expectedColumnConstraint.getValue().contains(actualConstraint.getDomains().get().get(columnHandle))) { - return false; - } - } - - return true; + return expectedOriginalConstraint + .map(expected -> { + Map assignments = node.getOutputSymbols().stream() + .collect(toMap(Symbol::getName, Symbol::toSymbolReference)); + SymbolAliases symbolAliases = SymbolAliases.builder().putAll(assignments).build(); + ExpressionVerifier verifier = new ExpressionVerifier(symbolAliases); + return verifier.process(node.getOriginalConstraint(), expected); + }) + .orElse(true); } @Override @@ -102,6 +104,7 @@ public String toString() .omitNullValues() .add("expectedTableName", expectedTableName) .add("expectedConstraint", expectedConstraint.orElse(null)) + .add("expectedOriginalConstraint", expectedOriginalConstraint.orElse(null)) .toString(); } } diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/TableWriterMatcher.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/TableWriterMatcher.java new file mode 100644 index 000000000000..c86ed5d90385 --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/TableWriterMatcher.java @@ -0,0 +1,78 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.facebook.presto.sql.planner.assertions; + +import com.facebook.presto.Session; +import com.facebook.presto.cost.PlanNodeCost; +import com.facebook.presto.metadata.Metadata; +import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.sql.planner.plan.PlanNode; +import com.facebook.presto.sql.planner.plan.TableWriterNode; + +import java.util.List; + +import static com.facebook.presto.sql.planner.assertions.MatchResult.NO_MATCH; +import static com.facebook.presto.sql.planner.assertions.MatchResult.match; +import static com.google.common.base.MoreObjects.toStringHelper; +import static com.google.common.base.Preconditions.checkState; +import static com.google.common.collect.ImmutableList.toImmutableList; + +public class TableWriterMatcher + implements Matcher +{ + private final List columns; + private final List columnNames; + + public TableWriterMatcher(List columns, List columnNames) + { + this.columns = columns; + this.columnNames = columnNames; + } + + @Override + public boolean shapeMatches(PlanNode node) + { + return node instanceof TableWriterNode; + } + + @Override + public MatchResult detailMatches(PlanNode node, PlanNodeCost planNodeCost, Session session, Metadata metadata, SymbolAliases symbolAliases) + { + checkState(shapeMatches(node), "Plan testing framework error: shapeMatches returned false in detailMatches in %s", this.getClass().getName()); + + TableWriterNode tableWriterNode = (TableWriterNode) node; + if (!tableWriterNode.getColumnNames().equals(columnNames)) { + return NO_MATCH; + } + + if (!columns.stream() + .map(s -> Symbol.from(symbolAliases.get(s))) + .collect(toImmutableList()) + .equals(tableWriterNode.getColumns())) { + return NO_MATCH; + } + + return match(); + } + + @Override + public String toString() + { + return toStringHelper(this) + .add("columns", columns) + .add("columnNames", columnNames) + .toString(); + } +} diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/Util.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/Util.java new file mode 100644 index 000000000000..8e6d6372b84a --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/Util.java @@ -0,0 +1,66 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.assertions; + +import com.facebook.presto.Session; +import com.facebook.presto.metadata.Metadata; +import com.facebook.presto.metadata.TableHandle; +import com.facebook.presto.spi.ColumnHandle; +import com.facebook.presto.spi.predicate.Domain; +import com.facebook.presto.spi.predicate.TupleDomain; + +import java.util.Map; +import java.util.Optional; + +final class Util +{ + private Util() {} + + /** + * @param expectedDomains if empty, the actualConstraint's domains must also be empty. + */ + static boolean domainsMatch( + Optional> expectedDomains, + TupleDomain actualConstraint, + TableHandle tableHandle, + Session session, + Metadata metadata) + { + Optional> actualDomains = actualConstraint.getDomains(); + + if (expectedDomains.isPresent() != actualDomains.isPresent()) { + return false; + } + + if (!expectedDomains.isPresent()) { + return true; + } + + Map columnHandles = metadata.getColumnHandles(session, tableHandle); + for (Map.Entry expectedColumnConstraint : expectedDomains.get().entrySet()) { + if (!columnHandles.containsKey(expectedColumnConstraint.getKey())) { + return false; + } + ColumnHandle columnHandle = columnHandles.get(expectedColumnConstraint.getKey()); + if (!actualDomains.get().containsKey(columnHandle)) { + return false; + } + if (!expectedColumnConstraint.getValue().contains(actualDomains.get().get(columnHandle))) { + return false; + } + } + + return true; + } +} diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/ValuesMatcher.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/ValuesMatcher.java index f194f1587fe0..58e4c31f6247 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/ValuesMatcher.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/ValuesMatcher.java @@ -75,6 +75,7 @@ public MatchResult detailMatches(PlanNode node, PlanNodeCost planNodeCost, Sessi public String toString() { return toStringHelper(this) + .omitNullValues() .add("outputSymbolAliases", outputSymbolAliases) .add("expectedOutputSymbolCount", expectedOutputSymbolCount) .add("expectedRows", expectedRows) diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/WindowFrameProvider.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/WindowFrameProvider.java new file mode 100644 index 000000000000..d6a3ddfe0551 --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/WindowFrameProvider.java @@ -0,0 +1,70 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.assertions; + +import com.facebook.presto.sql.planner.plan.WindowNode; +import com.facebook.presto.sql.tree.FrameBound; +import com.facebook.presto.sql.tree.WindowFrame; + +import java.util.Optional; + +import static com.google.common.base.MoreObjects.toStringHelper; +import static java.util.Objects.requireNonNull; + +public class WindowFrameProvider + implements ExpectedValueProvider +{ + private final WindowFrame.Type type; + private final FrameBound.Type startType; + private final Optional startValue; + private final FrameBound.Type endType; + private final Optional endValue; + + WindowFrameProvider( + WindowFrame.Type type, + FrameBound.Type startType, + Optional startValue, + FrameBound.Type endType, + Optional endValue) + { + this.type = requireNonNull(type, "type is null"); + this.startType = requireNonNull(startType, "startType is null"); + this.startValue = requireNonNull(startValue, "startValue is null"); + this.endType = requireNonNull(endType, "endType is null"); + this.endValue = requireNonNull(endValue, "endValue is null"); + } + + @Override + public WindowNode.Frame getExpectedValue(SymbolAliases aliases) + { + return new WindowNode.Frame( + type, + startType, + startValue.map(alias -> alias.toSymbol(aliases)), + endType, + endValue.map(alias -> alias.toSymbol(aliases))); + } + + @Override + public String toString() + { + return toStringHelper(this) + .add("type", this.type) + .add("startType", this.startType) + .add("startValue", this.startValue) + .add("endType", this.endType) + .add("endValue", this.endValue) + .toString(); + } +} diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/WindowFunctionMatcher.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/WindowFunctionMatcher.java index 6c39ceb5f274..6a2de3a054ae 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/WindowFunctionMatcher.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/WindowFunctionMatcher.java @@ -15,25 +15,41 @@ import com.facebook.presto.Session; import com.facebook.presto.metadata.Metadata; +import com.facebook.presto.metadata.Signature; import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.plan.PlanNode; import com.facebook.presto.sql.planner.plan.WindowNode; import com.facebook.presto.sql.tree.FunctionCall; +import java.util.List; import java.util.Map; import java.util.Optional; +import static com.google.common.base.MoreObjects.toStringHelper; import static com.google.common.base.Preconditions.checkState; +import static com.google.common.collect.ImmutableList.toImmutableList; import static java.util.Objects.requireNonNull; public class WindowFunctionMatcher implements RvalueMatcher { private final ExpectedValueProvider callMaker; + private final Optional signature; + private final Optional> frameMaker; - public WindowFunctionMatcher(ExpectedValueProvider callMaker) + /** + * @param callMaker Always validates the function call + * @param signature Optionally validates the signature + * @param frameMaker Optionally validates the frame + */ + public WindowFunctionMatcher( + ExpectedValueProvider callMaker, + Optional signature, + Optional> frameMaker) { this.callMaker = requireNonNull(callMaker, "functionCall is null"); + this.signature = requireNonNull(signature, "signature is null"); + this.frameMaker = requireNonNull(frameMaker, "frameMaker is null"); } @Override @@ -47,19 +63,33 @@ public Optional getAssignedSymbol(PlanNode node, Session session, Metada WindowNode windowNode = (WindowNode) node; FunctionCall expectedCall = callMaker.getExpectedValue(symbolAliases); - for (Map.Entry assignment : windowNode.getWindowFunctions().entrySet()) { - if (expectedCall.equals(assignment.getValue().getFunctionCall())) { - checkState(!result.isPresent(), "Ambiguous function calls in %s", windowNode); - result = Optional.of(assignment.getKey()); - } - } + Optional expectedFrame = frameMaker.map(maker -> maker.getExpectedValue(symbolAliases)); + + List matchedOutputs = windowNode.getWindowFunctions().entrySet().stream() + .filter(assignment -> + expectedCall.equals(assignment.getValue().getFunctionCall()) + && signature.map(assignment.getValue().getSignature()::equals).orElse(true) + && expectedFrame.map(assignment.getValue().getFrame()::equals).orElse(true)) + .map(Map.Entry::getKey) + .collect(toImmutableList()); - return result; + checkState(matchedOutputs.size() <= 1, "Ambiguous function calls in %s", windowNode); + + if (matchedOutputs.isEmpty()) { + return Optional.empty(); + } + return Optional.of(matchedOutputs.get(0)); } @Override public String toString() { - return callMaker.toString(); + // Only include fields in the description if they are actual constraints. + return toStringHelper(this) + .omitNullValues() + .add("callMaker", callMaker) + .add("signature", signature.orElse(null)) + .add("frameMaker", frameMaker.orElse(null)) + .toString(); } } diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/WindowMatcher.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/WindowMatcher.java index 6d291969351a..3e85a06e88bd 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/WindowMatcher.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/WindowMatcher.java @@ -16,21 +16,47 @@ import com.facebook.presto.Session; import com.facebook.presto.cost.PlanNodeCost; import com.facebook.presto.metadata.Metadata; +import com.facebook.presto.metadata.Signature; +import com.facebook.presto.spi.block.SortOrder; import com.facebook.presto.sql.planner.plan.PlanNode; import com.facebook.presto.sql.planner.plan.WindowNode; +import com.facebook.presto.sql.tree.FunctionCall; +import java.util.LinkedList; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; + +import static com.facebook.presto.sql.planner.assertions.MatchResult.NO_MATCH; +import static com.facebook.presto.sql.planner.assertions.MatchResult.match; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.node; import static com.google.common.base.MoreObjects.toStringHelper; import static com.google.common.base.Preconditions.checkState; +import static com.google.common.collect.ImmutableSet.toImmutableSet; +import static java.util.Objects.requireNonNull; -final class WindowMatcher +/** + * Optionally validates each of the non-function fields of the node. + */ +public final class WindowMatcher implements Matcher { - private final ExpectedValueProvider specification; + private final Optional> prePartitionedInputs; + private final Optional> specification; + private final Optional preSortedOrderPrefix; + private final Optional> hashSymbol; - WindowMatcher( - ExpectedValueProvider specification) + private WindowMatcher( + Optional> prePartitionedInputs, + Optional> specification, + Optional preSortedOrderPrefix, + Optional> hashSymbol) { - this.specification = specification; + this.prePartitionedInputs = requireNonNull(prePartitionedInputs, "prePartitionedInputs is null"); + this.specification = requireNonNull(specification, "specification is null"); + this.preSortedOrderPrefix = requireNonNull(preSortedOrderPrefix, "preSortedOrderPrefix is null"); + this.hashSymbol = requireNonNull(hashSymbol, "hashSymbol is null"); } @Override @@ -46,19 +72,166 @@ public MatchResult detailMatches(PlanNode node, PlanNodeCost cost, Session sessi WindowNode windowNode = (WindowNode) node; + if (!prePartitionedInputs + .map(expectedInputs -> expectedInputs.stream() + .map(alias -> alias.toSymbol(symbolAliases)) + .collect(toImmutableSet()) + .equals(windowNode.getPrePartitionedInputs())) + .orElse(true)) { + return NO_MATCH; + } + + if (!specification + .map(expectedSpecification -> + expectedSpecification.getExpectedValue(symbolAliases) + .equals(windowNode.getSpecification())) + .orElse(true)) { + return NO_MATCH; + } + + if (!preSortedOrderPrefix + .map(Integer.valueOf(windowNode.getPreSortedOrderPrefix())::equals) + .orElse(true)) { + return NO_MATCH; + } + + if (!hashSymbol + .map(expectedHashSymbol -> expectedHashSymbol + .map(alias -> alias.toSymbol(symbolAliases)) + .equals(windowNode.getHashSymbol())) + .orElse(true)) { + return NO_MATCH; + } + /* * Window functions produce a symbol (the result of the function call) that we might * want to bind to an alias so we can reference it further up the tree. As such, * they need to be matched with an Alias matcher so we can bind the symbol if desired. */ - return new MatchResult(windowNode.getSpecification().equals(specification.getExpectedValue(symbolAliases))); + return match(); } @Override public String toString() { + // Only include fields in the description if they are actual constraints. return toStringHelper(this) - .add("specification", specification) + .omitNullValues() + .add("prePartitionedInputs", prePartitionedInputs.orElse(null)) + .add("specification", specification.orElse(null)) + .add("preSortedOrderPrefix", preSortedOrderPrefix.orElse(null)) + .add("hashSymbol", hashSymbol.orElse(null)) .toString(); } + + /** + * By default, matches any WindowNode. Users add additional constraints by + * calling the various member functions of the Builder, typically named according + * to the field names of WindowNode. + */ + public static class Builder + { + private final PlanMatchPattern source; + private Optional> prePartitionedInputs = Optional.empty(); + private Optional> specification = Optional.empty(); + private Optional preSortedOrderPrefix = Optional.empty(); + private List windowFunctionMatchers = new LinkedList<>(); + private Optional> hashSymbol = Optional.empty(); + + Builder(PlanMatchPattern source) + { + this.source = requireNonNull(source, "source is null"); + } + + public Builder prePartitionedInputs(Set prePartitionedInputs) + { + requireNonNull(prePartitionedInputs, "prePartitionedInputs is null"); + this.prePartitionedInputs = Optional.of( + prePartitionedInputs.stream() + .map(SymbolAlias::new) + .collect(toImmutableSet())); + return this; + } + + public Builder specification( + List partitionBy, + List orderBy, + Map orderings) + { + return specification(PlanMatchPattern.specification(partitionBy, orderBy, orderings)); + } + + public Builder specification(ExpectedValueProvider specification) + { + requireNonNull(specification, "specification is null"); + this.specification = Optional.of(specification); + return this; + } + + public Builder preSortedOrderPrefix(int preSortedOrderPrefix) + { + this.preSortedOrderPrefix = Optional.of(preSortedOrderPrefix); + return this; + } + + public Builder addFunction(String outputAlias, ExpectedValueProvider functionCall) + { + return addFunction(Optional.of(outputAlias), functionCall); + } + + public Builder addFunction(ExpectedValueProvider functionCall) + { + return addFunction(Optional.empty(), functionCall); + } + + private Builder addFunction(Optional outputAlias, ExpectedValueProvider functionCall) + { + windowFunctionMatchers.add(new AliasMatcher(outputAlias, new WindowFunctionMatcher(functionCall, Optional.empty(), Optional.empty()))); + return this; + } + + public Builder addFunction( + String outputAlias, + ExpectedValueProvider functionCall, + Signature signature, + ExpectedValueProvider frame) + { + windowFunctionMatchers.add( + new AliasMatcher( + Optional.of(outputAlias), + new WindowFunctionMatcher(functionCall, Optional.of(signature), Optional.of(frame)))); + return this; + } + + /** + * Matches only if WindowNode.getHashSymbol() is an empty option. + */ + public Builder hashSymbol() + { + this.hashSymbol = Optional.of(Optional.empty()); + return this; + } + + /** + * Matches only if WindowNode.getHashSymbol() is a non-empty option containing hashSymbol. + */ + public Builder hashSymbol(String hashSymbol) + { + requireNonNull(hashSymbol, "hashSymbol is null"); + this.hashSymbol = Optional.of(Optional.of(new SymbolAlias(hashSymbol))); + return this; + } + + PlanMatchPattern build() + { + PlanMatchPattern result = node(WindowNode.class, source).with( + new WindowMatcher( + prePartitionedInputs, + specification, + preSortedOrderPrefix, + hashSymbol)); + windowFunctionMatchers.forEach(result::with); + return result; + } + } } diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/TestIterativeOptimizer.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/TestIterativeOptimizer.java index 76a75d5df73d..050b9ec99cd5 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/TestIterativeOptimizer.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/TestIterativeOptimizer.java @@ -15,9 +15,7 @@ import com.facebook.presto.Session; import com.facebook.presto.spi.PrestoException; -import com.facebook.presto.sql.planner.PlanNodeIdAllocator; import com.facebook.presto.sql.planner.StatsRecorder; -import com.facebook.presto.sql.planner.SymbolAllocator; import com.facebook.presto.sql.planner.optimizations.PlanOptimizer; import com.facebook.presto.sql.planner.plan.Assignments; import com.facebook.presto.sql.planner.plan.PlanNode; @@ -90,7 +88,7 @@ private static class NonConvergingRule // In that case, it will be removed. // Thanks to that approach, it never converges and always produces different node. @Override - public Optional apply(PlanNode node, Lookup lookup, PlanNodeIdAllocator idAllocator, SymbolAllocator symbolAllocator, Session session) + public Optional apply(PlanNode node, Context context) { if (node instanceof ProjectNode) { ProjectNode project = (ProjectNode) node; @@ -99,7 +97,7 @@ public Optional apply(PlanNode node, Lookup lookup, PlanNodeIdAllocato } } - PlanNode projectNode = new ProjectNode(idAllocator.getNextId(), node, Assignments.identity(node.getOutputSymbols())); + PlanNode projectNode = new ProjectNode(context.getIdAllocator().getNextId(), node, Assignments.identity(node.getOutputSymbols())); return Optional.of(projectNode); } diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/TestRuleStore.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/TestMatchingEngine.java similarity index 56% rename from presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/TestRuleStore.java rename to presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/TestMatchingEngine.java index 6097f46a177f..e68a51bb278f 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/TestRuleStore.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/TestMatchingEngine.java @@ -14,10 +14,10 @@ package com.facebook.presto.sql.planner.iterative; -import com.facebook.presto.Session; +import com.facebook.presto.matching.MatchingEngine; +import com.facebook.presto.matching.Pattern; import com.facebook.presto.metadata.DummyMetadata; import com.facebook.presto.sql.planner.PlanNodeIdAllocator; -import com.facebook.presto.sql.planner.SymbolAllocator; import com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder; import com.facebook.presto.sql.planner.plan.Assignments; import com.facebook.presto.sql.planner.plan.FilterNode; @@ -25,28 +25,28 @@ import com.facebook.presto.sql.planner.plan.ProjectNode; import com.facebook.presto.sql.planner.plan.ValuesNode; import com.facebook.presto.sql.tree.BooleanLiteral; -import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableSet; import org.testng.annotations.Test; import java.util.Optional; import static com.google.common.base.MoreObjects.toStringHelper; -import static java.util.stream.Collectors.toList; +import static java.util.stream.Collectors.toSet; import static org.testng.Assert.assertEquals; -public class TestRuleStore +public class TestMatchingEngine { private final PlanBuilder planBuilder = new PlanBuilder(new PlanNodeIdAllocator(), new DummyMetadata()); @Test - public void test() + public void testWithPlanNodeHierarchy() { - Rule projectRule1 = new NoOpRule(Pattern.node(ProjectNode.class)); - Rule projectRule2 = new NoOpRule(Pattern.node(ProjectNode.class)); - Rule filterRule = new NoOpRule(Pattern.node(FilterNode.class)); + Rule projectRule1 = new NoOpRule(Pattern.typeOf(ProjectNode.class)); + Rule projectRule2 = new NoOpRule(Pattern.typeOf(ProjectNode.class)); + Rule filterRule = new NoOpRule(Pattern.typeOf(FilterNode.class)); Rule anyRule = new NoOpRule(Pattern.any()); - RuleStore ruleStore = RuleStore.builder() + MatchingEngine matchingEngine = MatchingEngine.builder() .register(projectRule1) .register(projectRule2) .register(filterRule) @@ -58,14 +58,38 @@ public void test() ValuesNode valuesNode = planBuilder.values(); assertEquals( - ruleStore.getCandidates(projectNode).collect(toList()), - ImmutableList.of(projectRule1, projectRule2, anyRule)); + matchingEngine.getCandidates(projectNode).collect(toSet()), + ImmutableSet.of(projectRule1, projectRule2, anyRule)); assertEquals( - ruleStore.getCandidates(filterNode).collect(toList()), - ImmutableList.of(filterRule, anyRule)); + matchingEngine.getCandidates(filterNode).collect(toSet()), + ImmutableSet.of(filterRule, anyRule)); assertEquals( - ruleStore.getCandidates(valuesNode).collect(toList()), - ImmutableList.of(anyRule)); + matchingEngine.getCandidates(valuesNode).collect(toSet()), + ImmutableSet.of(anyRule)); + } + + @Test + public void testInterfacesHierarchy() + { + Rule a = new NoOpRule(Pattern.typeOf(A.class)); + Rule b = new NoOpRule(Pattern.typeOf(B.class)); + Rule ab = new NoOpRule(Pattern.typeOf(AB.class)); + + MatchingEngine matchingEngine = MatchingEngine.builder() + .register(a) + .register(b) + .register(ab) + .build(); + + assertEquals( + matchingEngine.getCandidates(new A() {}).collect(toSet()), + ImmutableSet.of(a)); + assertEquals( + matchingEngine.getCandidates(new B() {}).collect(toSet()), + ImmutableSet.of(b)); + assertEquals( + matchingEngine.getCandidates(new AB()).collect(toSet()), + ImmutableSet.of(ab, a, b)); } private static class NoOpRule @@ -85,7 +109,7 @@ public Pattern getPattern() } @Override - public Optional apply(PlanNode node, Lookup lookup, PlanNodeIdAllocator idAllocator, SymbolAllocator symbolAllocator, Session session) + public Optional apply(PlanNode node, Context context) { return Optional.empty(); } @@ -98,4 +122,8 @@ public String toString() .toString(); } } + + private interface A {} + private interface B {} + private static class AB implements A, B {} } diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestAddIntermediateAggregations.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestAddIntermediateAggregations.java index 1f858a9574e2..3905f8044fa6 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestAddIntermediateAggregations.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestAddIntermediateAggregations.java @@ -15,15 +15,13 @@ import com.facebook.presto.sql.planner.assertions.ExpectedValueProvider; import com.facebook.presto.sql.planner.assertions.PlanMatchPattern; -import com.facebook.presto.sql.planner.iterative.rule.test.RuleTester; +import com.facebook.presto.sql.planner.iterative.rule.test.BaseRuleTest; import com.facebook.presto.sql.planner.plan.AggregationNode; import com.facebook.presto.sql.planner.plan.Assignments; import com.facebook.presto.sql.planner.plan.ExchangeNode; import com.facebook.presto.sql.tree.FunctionCall; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; -import org.testng.annotations.AfterClass; -import org.testng.annotations.BeforeClass; import org.testng.annotations.Test; import java.util.Optional; @@ -44,45 +42,30 @@ import static com.facebook.presto.sql.planner.plan.ExchangeNode.Scope.REMOTE; import static com.facebook.presto.sql.planner.plan.ExchangeNode.Type.GATHER; import static com.facebook.presto.sql.planner.plan.ExchangeNode.Type.REPARTITION; -import static io.airlift.testing.Closeables.closeAllRuntimeException; public class TestAddIntermediateAggregations + extends BaseRuleTest { - private RuleTester tester; - - @BeforeClass - public void setUp() - { - tester = new RuleTester(); - } - - @AfterClass(alwaysRun = true) - public void tearDown() - { - closeAllRuntimeException(tester); - tester = null; - } - @Test public void testBasic() { ExpectedValueProvider aggregationPattern = PlanMatchPattern.functionCall("count", false, ImmutableList.of(anySymbol())); - tester.assertThat(new AddIntermediateAggregations()) + tester().assertThat(new AddIntermediateAggregations()) .setSystemProperty(ENABLE_INTERMEDIATE_AGGREGATIONS, "true") .setSystemProperty(TASK_CONCURRENCY, "4") .on(p -> p.aggregation(af -> { af.globalGrouping() .step(AggregationNode.Step.FINAL) - .addAggregation(p.symbol("c", BIGINT), expression("count(b)"), ImmutableList.of(BIGINT)) + .addAggregation(p.symbol("c"), expression("count(b)"), ImmutableList.of(BIGINT)) .source( p.gatheringExchange( ExchangeNode.Scope.REMOTE, p.aggregation(ap -> ap.globalGrouping() .step(AggregationNode.Step.PARTIAL) - .addAggregation(p.symbol("b", BIGINT), expression("count(a)"), ImmutableList.of(BIGINT)) + .addAggregation(p.symbol("b"), expression("count(a)"), ImmutableList.of(BIGINT)) .source( - p.values(p.symbol("a", BIGINT)))))); + p.values(p.symbol("a")))))); })) .matches( aggregation( @@ -123,21 +106,21 @@ public void testNoInputCount() ExpectedValueProvider rawInputCount = PlanMatchPattern.functionCall("count", false, ImmutableList.of()); ExpectedValueProvider partialInputCount = PlanMatchPattern.functionCall("count", false, ImmutableList.of(anySymbol())); - tester.assertThat(new AddIntermediateAggregations()) + tester().assertThat(new AddIntermediateAggregations()) .setSystemProperty(ENABLE_INTERMEDIATE_AGGREGATIONS, "true") .setSystemProperty(TASK_CONCURRENCY, "4") .on(p -> p.aggregation(af -> { af.globalGrouping() .step(AggregationNode.Step.FINAL) - .addAggregation(p.symbol("c", BIGINT), expression("count(b)"), ImmutableList.of(BIGINT)) + .addAggregation(p.symbol("c"), expression("count(b)"), ImmutableList.of(BIGINT)) .source( p.gatheringExchange( ExchangeNode.Scope.REMOTE, p.aggregation(ap -> ap.globalGrouping() .step(AggregationNode.Step.PARTIAL) - .addAggregation(p.symbol("b", BIGINT), expression("count(*)"), ImmutableList.of(BIGINT)) + .addAggregation(p.symbol("b"), expression("count(*)"), ImmutableList.of(BIGINT)) .source( - p.values(p.symbol("a", BIGINT)))))); + p.values(p.symbol("a")))))); })) .matches( aggregation( @@ -176,13 +159,13 @@ public void testMultipleExchanges() { ExpectedValueProvider aggregationPattern = PlanMatchPattern.functionCall("count", false, ImmutableList.of(anySymbol())); - tester.assertThat(new AddIntermediateAggregations()) + tester().assertThat(new AddIntermediateAggregations()) .setSystemProperty(ENABLE_INTERMEDIATE_AGGREGATIONS, "true") .setSystemProperty(TASK_CONCURRENCY, "4") .on(p -> p.aggregation(af -> { af.globalGrouping() .step(AggregationNode.Step.FINAL) - .addAggregation(p.symbol("c", BIGINT), expression("count(b)"), ImmutableList.of(BIGINT)) + .addAggregation(p.symbol("c"), expression("count(b)"), ImmutableList.of(BIGINT)) .source( p.gatheringExchange( ExchangeNode.Scope.REMOTE, @@ -190,9 +173,9 @@ public void testMultipleExchanges() ExchangeNode.Scope.REMOTE, p.aggregation(ap -> ap.globalGrouping() .step(AggregationNode.Step.PARTIAL) - .addAggregation(p.symbol("b", BIGINT), expression("count(a)"), ImmutableList.of(BIGINT)) + .addAggregation(p.symbol("b"), expression("count(a)"), ImmutableList.of(BIGINT)) .source( - p.values(p.symbol("a", BIGINT))))))); + p.values(p.symbol("a"))))))); })) .matches( aggregation( @@ -230,21 +213,21 @@ public void testMultipleExchanges() @Test public void testSessionDisable() { - tester.assertThat(new AddIntermediateAggregations()) + tester().assertThat(new AddIntermediateAggregations()) .setSystemProperty(ENABLE_INTERMEDIATE_AGGREGATIONS, "false") .setSystemProperty(TASK_CONCURRENCY, "4") .on(p -> p.aggregation(af -> { af.globalGrouping() .step(AggregationNode.Step.FINAL) - .addAggregation(p.symbol("c", BIGINT), expression("count(b)"), ImmutableList.of(BIGINT)) + .addAggregation(p.symbol("c"), expression("count(b)"), ImmutableList.of(BIGINT)) .source( p.gatheringExchange( ExchangeNode.Scope.REMOTE, p.aggregation(ap -> ap.globalGrouping() .step(AggregationNode.Step.PARTIAL) - .addAggregation(p.symbol("b", BIGINT), expression("count(a)"), ImmutableList.of(BIGINT)) + .addAggregation(p.symbol("b"), expression("count(a)"), ImmutableList.of(BIGINT)) .source( - p.values(p.symbol("a", BIGINT)))))); + p.values(p.symbol("a")))))); })) .doesNotFire(); } @@ -254,21 +237,21 @@ public void testNoLocalParallel() { ExpectedValueProvider aggregationPattern = PlanMatchPattern.functionCall("count", false, ImmutableList.of(anySymbol())); - tester.assertThat(new AddIntermediateAggregations()) + tester().assertThat(new AddIntermediateAggregations()) .setSystemProperty(ENABLE_INTERMEDIATE_AGGREGATIONS, "true") .setSystemProperty(TASK_CONCURRENCY, "1") .on(p -> p.aggregation(af -> { af.globalGrouping() .step(AggregationNode.Step.FINAL) - .addAggregation(p.symbol("c", BIGINT), expression("count(b)"), ImmutableList.of(BIGINT)) + .addAggregation(p.symbol("c"), expression("count(b)"), ImmutableList.of(BIGINT)) .source( p.gatheringExchange( ExchangeNode.Scope.REMOTE, p.aggregation(ap -> ap.globalGrouping() .step(AggregationNode.Step.PARTIAL) - .addAggregation(p.symbol("b", BIGINT), expression("count(a)"), ImmutableList.of(BIGINT)) + .addAggregation(p.symbol("b"), expression("count(a)"), ImmutableList.of(BIGINT)) .source( - p.values(p.symbol("a", BIGINT)))))); + p.values(p.symbol("a")))))); })) .matches( aggregation( @@ -297,21 +280,21 @@ public void testNoLocalParallel() @Test public void testWithGroups() { - tester.assertThat(new AddIntermediateAggregations()) + tester().assertThat(new AddIntermediateAggregations()) .setSystemProperty(ENABLE_INTERMEDIATE_AGGREGATIONS, "true") .setSystemProperty(TASK_CONCURRENCY, "4") .on(p -> p.aggregation(af -> { - af.addGroupingSet(p.symbol("c", BIGINT)) + af.addGroupingSet(p.symbol("c")) .step(AggregationNode.Step.FINAL) - .addAggregation(p.symbol("c", BIGINT), expression("count(b)"), ImmutableList.of(BIGINT)) + .addAggregation(p.symbol("c"), expression("count(b)"), ImmutableList.of(BIGINT)) .source( p.gatheringExchange( ExchangeNode.Scope.REMOTE, - p.aggregation(ap -> ap.addGroupingSet(p.symbol("b", BIGINT)) + p.aggregation(ap -> ap.addGroupingSet(p.symbol("b")) .step(AggregationNode.Step.PARTIAL) - .addAggregation(p.symbol("b", BIGINT), expression("count(a)"), ImmutableList.of(BIGINT)) + .addAggregation(p.symbol("b"), expression("count(a)"), ImmutableList.of(BIGINT)) .source( - p.values(p.symbol("a", BIGINT)))))); + p.values(p.symbol("a")))))); })) .doesNotFire(); } @@ -321,23 +304,23 @@ public void testInterimProject() { ExpectedValueProvider aggregationPattern = PlanMatchPattern.functionCall("count", false, ImmutableList.of(anySymbol())); - tester.assertThat(new AddIntermediateAggregations()) + tester().assertThat(new AddIntermediateAggregations()) .setSystemProperty(ENABLE_INTERMEDIATE_AGGREGATIONS, "true") .setSystemProperty(TASK_CONCURRENCY, "4") .on(p -> p.aggregation(af -> { af.globalGrouping() .step(AggregationNode.Step.FINAL) - .addAggregation(p.symbol("c", BIGINT), expression("count(b)"), ImmutableList.of(BIGINT)) + .addAggregation(p.symbol("c"), expression("count(b)"), ImmutableList.of(BIGINT)) .source( p.gatheringExchange( ExchangeNode.Scope.REMOTE, p.project( - Assignments.identity(p.symbol("b", BIGINT)), + Assignments.identity(p.symbol("b")), p.aggregation(ap -> ap.globalGrouping() .step(AggregationNode.Step.PARTIAL) - .addAggregation(p.symbol("b", BIGINT), expression("count(a)"), ImmutableList.of(BIGINT)) + .addAggregation(p.symbol("b"), expression("count(a)"), ImmutableList.of(BIGINT)) .source( - p.values(p.symbol("a", BIGINT))))))); + p.values(p.symbol("a"))))))); })) .matches( aggregation( diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestCanonicalizeFilterExpressions.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestCanonicalizeFilterExpressions.java new file mode 100644 index 000000000000..908578e014bb --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestCanonicalizeFilterExpressions.java @@ -0,0 +1,43 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.iterative.rule; + +import com.facebook.presto.sql.planner.iterative.rule.test.BaseRuleTest; +import org.testng.annotations.Test; + +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.filter; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.values; +import static com.facebook.presto.sql.tree.BooleanLiteral.FALSE_LITERAL; + +public class TestCanonicalizeFilterExpressions + extends BaseRuleTest +{ + @Test + public void testDoesNotFireForExpressionsInCanonicalForm() + { + tester().assertThat(new CanonicalizeFilterExpressions()) + .on(p -> p.filter(FALSE_LITERAL, p.values())) + .doesNotFire(); + } + + @Test + public void testCanonicalizesExpressions() + { + tester().assertThat(new CanonicalizeFilterExpressions()) + .on(p -> p.filter( + p.expression("x IS NOT NULL"), + p.values(p.symbol("x")))) + .matches(filter("NOT (x IS NULL)", values("x"))); + } +} diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestCanonicalizeJoinExpressions.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestCanonicalizeJoinExpressions.java new file mode 100644 index 000000000000..22dd5c8d7e57 --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestCanonicalizeJoinExpressions.java @@ -0,0 +1,62 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.iterative.rule; + +import com.facebook.presto.sql.planner.iterative.rule.test.BaseRuleTest; +import org.testng.annotations.Test; + +import java.util.Optional; + +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.join; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.values; +import static com.facebook.presto.sql.planner.plan.JoinNode.Type.INNER; +import static com.facebook.presto.sql.tree.BooleanLiteral.FALSE_LITERAL; +import static java.util.Collections.emptyList; + +public class TestCanonicalizeJoinExpressions + extends BaseRuleTest +{ + @Test + public void testDoesNotFireForUnfilteredJoin() + { + tester().assertThat(new CanonicalizeJoinExpressions()) + .on(p -> p.join(INNER, p.values(), p.values())) + .doesNotFire(); + } + + @Test + public void testDoesNotFireForCanonicalExpressions() + { + tester().assertThat(new CanonicalizeJoinExpressions()) + .on(p -> p.join(INNER, p.values(), p.values(), FALSE_LITERAL)) + .doesNotFire(); + } + + @Test + public void testCanonicalizesExpressions() + { + tester().assertThat(new CanonicalizeJoinExpressions()) + .on(p -> p.join( + INNER, + p.values(p.symbol("x")), + p.values(), + p.expression("x IS NOT NULL"))) + .matches(join( + INNER, + emptyList(), + Optional.of("NOT (x IS NULL)"), + values("x"), + values())); + } +} diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestCanonicalizeProjectExpressions.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestCanonicalizeProjectExpressions.java new file mode 100644 index 000000000000..07eb919ad96d --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestCanonicalizeProjectExpressions.java @@ -0,0 +1,46 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.iterative.rule; + +import com.facebook.presto.sql.planner.iterative.rule.test.BaseRuleTest; +import com.facebook.presto.sql.planner.plan.Assignments; +import com.google.common.collect.ImmutableMap; +import org.testng.annotations.Test; + +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.expression; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.project; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.values; +import static com.facebook.presto.sql.tree.BooleanLiteral.FALSE_LITERAL; + +public class TestCanonicalizeProjectExpressions + extends BaseRuleTest +{ + @Test + public void testDoesNotFireForExpressionsInCanonicalForm() + { + tester().assertThat(new CanonicalizeProjectExpressions()) + .on(p -> p.project(Assignments.of(p.symbol("x"), FALSE_LITERAL), p.values())) + .doesNotFire(); + } + + @Test + public void testCanonicalizesExpressions() + { + tester().assertThat(new CanonicalizeProjectExpressions()) + .on(p -> p.project( + Assignments.of(p.symbol("y"), p.expression("x IS NOT NULL")), + p.values(p.symbol("x")))) + .matches(project(ImmutableMap.of("y", expression("NOT (x IS NULL)")), values("x"))); + } +} diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestCanonicalizeTableScanExpressions.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestCanonicalizeTableScanExpressions.java new file mode 100644 index 000000000000..87d37eefef5d --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestCanonicalizeTableScanExpressions.java @@ -0,0 +1,64 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.iterative.rule; + +import com.facebook.presto.connector.ConnectorId; +import com.facebook.presto.metadata.TableHandle; +import com.facebook.presto.sql.planner.iterative.rule.test.BaseRuleTest; +import com.facebook.presto.tpch.TpchColumnHandle; +import com.facebook.presto.tpch.TpchTableHandle; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import org.testng.annotations.Test; + +import static com.facebook.presto.spi.type.BigintType.BIGINT; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.tableScan; +import static com.facebook.presto.sql.tree.BooleanLiteral.FALSE_LITERAL; +import static com.facebook.presto.tpch.TpchMetadata.TINY_SCALE_FACTOR; +import static java.util.Collections.emptyList; +import static java.util.Collections.emptyMap; + +public class TestCanonicalizeTableScanExpressions + extends BaseRuleTest +{ + @Test + public void testDoesNotFireForUnfilteredTableScan() + { + tester().assertThat(new CanonicalizeTableScanExpressions()) + .on(p -> p.tableScan(emptyList(), emptyMap())) + .doesNotFire(); + } + + @Test + public void testDoesNotFireForFilterInCanonicalForm() + { + tester().assertThat(new CanonicalizeTableScanExpressions()) + .on(p -> p.tableScan(emptyList(), emptyMap(), FALSE_LITERAL)) + .doesNotFire(); + } + + @Test + public void testCanonicalizesFilter() + { + tester().assertThat(new CanonicalizeTableScanExpressions()) + .on(p -> p.tableScan( + new TableHandle( + new ConnectorId("local"), + new TpchTableHandle("local", "nation", TINY_SCALE_FACTOR)), + ImmutableList.of(p.symbol("nationkey")), + ImmutableMap.of(p.symbol("nationkey"), new TpchColumnHandle("nationkey", BIGINT)), + p.expression("nationkey IS NOT NULL"))) + .matches(tableScan("nation", "NOT (nationkey IS NULL)")); + } +} diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestEliminateCrossJoins.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestEliminateCrossJoins.java index 30680e0321ed..e46826bfb567 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestEliminateCrossJoins.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestEliminateCrossJoins.java @@ -16,8 +16,8 @@ import com.facebook.presto.sql.planner.PlanNodeIdAllocator; import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.iterative.GroupReference; +import com.facebook.presto.sql.planner.iterative.rule.test.BaseRuleTest; import com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder; -import com.facebook.presto.sql.planner.iterative.rule.test.RuleTester; import com.facebook.presto.sql.planner.optimizations.joins.JoinGraph; import com.facebook.presto.sql.planner.plan.Assignments; import com.facebook.presto.sql.planner.plan.JoinNode; @@ -29,8 +29,6 @@ import com.facebook.presto.sql.tree.Expression; import com.facebook.presto.sql.tree.SymbolReference; import com.google.common.collect.ImmutableList; -import org.testng.annotations.AfterClass; -import org.testng.annotations.BeforeClass; import org.testng.annotations.Test; import java.util.Arrays; @@ -38,11 +36,9 @@ import java.util.function.Function; import static com.facebook.presto.SystemSessionProperties.REORDER_JOINS; -import static com.facebook.presto.spi.type.BigintType.BIGINT; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.any; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.join; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.node; -import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.project; import static com.facebook.presto.sql.planner.iterative.rule.EliminateCrossJoins.getJoinOrder; import static com.facebook.presto.sql.planner.iterative.rule.EliminateCrossJoins.isOriginalOrder; import static com.facebook.presto.sql.planner.plan.JoinNode.Type.INNER; @@ -50,47 +46,31 @@ import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.collect.Iterables.getOnlyElement; -import static io.airlift.testing.Closeables.closeAllRuntimeException; import static org.testng.Assert.assertEquals; import static org.testng.AssertJUnit.assertFalse; import static org.testng.AssertJUnit.assertTrue; @Test(singleThreaded = true) public class TestEliminateCrossJoins + extends BaseRuleTest { - private RuleTester tester; private final PlanNodeIdAllocator idAllocator = new PlanNodeIdAllocator(); - @BeforeClass - public void setUp() - { - tester = new RuleTester(); - } - - @AfterClass(alwaysRun = true) - public void tearDown() - { - closeAllRuntimeException(tester); - tester = null; - } - @Test public void testEliminateCrossJoin() { - tester.assertThat(new EliminateCrossJoins()) + tester().assertThat(new EliminateCrossJoins()) .setSystemProperty(REORDER_JOINS, "true") .on(crossJoinAndJoin(INNER)) .matches( - project( + join(INNER, + ImmutableList.of(aliases -> new EquiJoinClause(new Symbol("cySymbol"), new Symbol("bySymbol"))), join(INNER, - ImmutableList.of(aliases -> new EquiJoinClause(new Symbol("cySymbol"), new Symbol("bySymbol"))), - join(INNER, - ImmutableList.of(aliases -> new EquiJoinClause(new Symbol("axSymbol"), new Symbol("cxSymbol"))), - any(), - any() - ), + ImmutableList.of(aliases -> new EquiJoinClause(new Symbol("axSymbol"), new Symbol("cxSymbol"))), + any(), any() - ) + ), + any() ) ); } @@ -98,18 +78,16 @@ public void testEliminateCrossJoin() @Test public void testRetainOutgoingGroupReferences() { - tester.assertThat(new EliminateCrossJoins()) + tester().assertThat(new EliminateCrossJoins()) .setSystemProperty(REORDER_JOINS, "true") .on(crossJoinAndJoin(INNER)) .matches( - any( + node(JoinNode.class, node(JoinNode.class, - node(JoinNode.class, - node(GroupReference.class), - node(GroupReference.class) - ), + node(GroupReference.class), node(GroupReference.class) - ) + ), + node(GroupReference.class) ) ); } @@ -117,7 +95,7 @@ public void testRetainOutgoingGroupReferences() @Test public void testDoNotReorderOuterJoin() { - tester.assertThat(new EliminateCrossJoins()) + tester().assertThat(new EliminateCrossJoins()) .setSystemProperty(REORDER_JOINS, "true") .on(crossJoinAndJoin(JoinNode.Type.LEFT)) .doesNotFire(); @@ -257,10 +235,10 @@ public void testGiveUpOnNonIdentityProjections() private Function crossJoinAndJoin(JoinNode.Type secondJoinType) { return p -> { - Symbol axSymbol = p.symbol("axSymbol", BIGINT); - Symbol bySymbol = p.symbol("bySymbol", BIGINT); - Symbol cxSymbol = p.symbol("cxSymbol", BIGINT); - Symbol cySymbol = p.symbol("cySymbol", BIGINT); + Symbol axSymbol = p.symbol("axSymbol"); + Symbol bySymbol = p.symbol("bySymbol"); + Symbol cxSymbol = p.symbol("cxSymbol"); + Symbol cySymbol = p.symbol("cySymbol"); // (a inner join b) inner join c on c.x = a.x and c.y = b.y return p.join(INNER, diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestEvaluateZeroLimit.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestEvaluateZeroLimit.java index eff882fd8884..0d43ee717989 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestEvaluateZeroLimit.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestEvaluateZeroLimit.java @@ -13,45 +13,27 @@ */ package com.facebook.presto.sql.planner.iterative.rule; -import com.facebook.presto.sql.planner.iterative.rule.test.RuleTester; +import com.facebook.presto.sql.planner.iterative.rule.test.BaseRuleTest; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; -import org.testng.annotations.AfterClass; -import org.testng.annotations.BeforeClass; import org.testng.annotations.Test; -import static com.facebook.presto.spi.type.BigintType.BIGINT; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.values; import static com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder.expression; import static com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder.expressions; -import static io.airlift.testing.Closeables.closeAllRuntimeException; public class TestEvaluateZeroLimit + extends BaseRuleTest { - private RuleTester tester; - - @BeforeClass - public void setUp() - { - tester = new RuleTester(); - } - - @AfterClass(alwaysRun = true) - public void tearDown() - { - closeAllRuntimeException(tester); - tester = null; - } - @Test public void testDoesNotFire() throws Exception { - tester.assertThat(new EvaluateZeroLimit()) + tester().assertThat(new EvaluateZeroLimit()) .on(p -> p.limit( 1, - p.values(p.symbol("a", BIGINT)))) + p.values(p.symbol("a")))) .doesNotFire(); } @@ -59,14 +41,14 @@ public void testDoesNotFire() public void test() throws Exception { - tester.assertThat(new EvaluateZeroLimit()) + tester().assertThat(new EvaluateZeroLimit()) .on(p -> p.limit( 0, p.filter( expression("b > 5"), p.values( - ImmutableList.of(p.symbol("a", BIGINT), p.symbol("b", BIGINT)), + ImmutableList.of(p.symbol("a"), p.symbol("b")), ImmutableList.of( expressions("1", "10"), expressions("2", "11")))))) diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestEvaluateZeroSample.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestEvaluateZeroSample.java index 09b7fe272422..7a27936ee446 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestEvaluateZeroSample.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestEvaluateZeroSample.java @@ -13,47 +13,29 @@ */ package com.facebook.presto.sql.planner.iterative.rule; -import com.facebook.presto.sql.planner.iterative.rule.test.RuleTester; +import com.facebook.presto.sql.planner.iterative.rule.test.BaseRuleTest; import com.facebook.presto.sql.planner.plan.SampleNode.Type; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; -import org.testng.annotations.AfterClass; -import org.testng.annotations.BeforeClass; import org.testng.annotations.Test; -import static com.facebook.presto.spi.type.BigintType.BIGINT; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.values; import static com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder.expression; import static com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder.expressions; -import static io.airlift.testing.Closeables.closeAllRuntimeException; public class TestEvaluateZeroSample + extends BaseRuleTest { - private RuleTester tester; - - @BeforeClass - public void setUp() - { - tester = new RuleTester(); - } - - @AfterClass(alwaysRun = true) - public void tearDown() - { - closeAllRuntimeException(tester); - tester = null; - } - @Test public void testDoesNotFire() throws Exception { - tester.assertThat(new EvaluateZeroSample()) + tester().assertThat(new EvaluateZeroSample()) .on(p -> p.sample( 0.15, Type.BERNOULLI, - p.values(p.symbol("a", BIGINT)))) + p.values(p.symbol("a")))) .doesNotFire(); } @@ -61,7 +43,7 @@ public void testDoesNotFire() public void test() throws Exception { - tester.assertThat(new EvaluateZeroSample()) + tester().assertThat(new EvaluateZeroSample()) .on(p -> p.sample( 0, @@ -69,7 +51,7 @@ public void test() p.filter( expression("b > 5"), p.values( - ImmutableList.of(p.symbol("a", BIGINT), p.symbol("b", BIGINT)), + ImmutableList.of(p.symbol("a"), p.symbol("b")), ImmutableList.of( expressions("1", "10"), expressions("2", "11")))))) diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestInlineProjections.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestInlineProjections.java index c394c450c3e6..b7a64e6c7d64 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestInlineProjections.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestInlineProjections.java @@ -15,58 +15,40 @@ import com.facebook.presto.sql.planner.assertions.ExpressionMatcher; import com.facebook.presto.sql.planner.assertions.PlanMatchPattern; -import com.facebook.presto.sql.planner.iterative.rule.test.RuleTester; +import com.facebook.presto.sql.planner.iterative.rule.test.BaseRuleTest; import com.facebook.presto.sql.planner.plan.Assignments; import com.google.common.collect.ImmutableMap; -import org.testng.annotations.AfterClass; -import org.testng.annotations.BeforeClass; import org.testng.annotations.Test; -import static com.facebook.presto.spi.type.BigintType.BIGINT; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.project; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.values; import static com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder.expression; -import static io.airlift.testing.Closeables.closeAllRuntimeException; public class TestInlineProjections + extends BaseRuleTest { - private RuleTester tester; - - @BeforeClass - public void setUp() - { - tester = new RuleTester(); - } - - @AfterClass(alwaysRun = true) - public void tearDown() - { - closeAllRuntimeException(tester); - tester = null; - } - @Test public void test() { - tester.assertThat(new InlineProjections()) + tester().assertThat(new InlineProjections()) .on(p -> p.project( Assignments.builder() - .put(p.symbol("identity", BIGINT), expression("symbol")) // identity - .put(p.symbol("multi_complex_1", BIGINT), expression("complex + 1")) // complex expression referenced multiple times - .put(p.symbol("multi_complex_2", BIGINT), expression("complex + 2")) // complex expression referenced multiple times - .put(p.symbol("multi_literal_1", BIGINT), expression("literal + 1")) // literal referenced multiple times - .put(p.symbol("multi_literal_2", BIGINT), expression("literal + 2")) // literal referenced multiple times - .put(p.symbol("single_complex", BIGINT), expression("complex_2 + 2")) // complex expression reference only once - .put(p.symbol("try", BIGINT), expression("try(complex / literal)")) + .put(p.symbol("identity"), expression("symbol")) // identity + .put(p.symbol("multi_complex_1"), expression("complex + 1")) // complex expression referenced multiple times + .put(p.symbol("multi_complex_2"), expression("complex + 2")) // complex expression referenced multiple times + .put(p.symbol("multi_literal_1"), expression("literal + 1")) // literal referenced multiple times + .put(p.symbol("multi_literal_2"), expression("literal + 2")) // literal referenced multiple times + .put(p.symbol("single_complex"), expression("complex_2 + 2")) // complex expression reference only once + .put(p.symbol("try"), expression("try(complex / literal)")) .build(), p.project(Assignments.builder() - .put(p.symbol("symbol", BIGINT), expression("x")) - .put(p.symbol("complex", BIGINT), expression("x * 2")) - .put(p.symbol("literal", BIGINT), expression("1")) - .put(p.symbol("complex_2", BIGINT), expression("x - 1")) + .put(p.symbol("symbol"), expression("x")) + .put(p.symbol("complex"), expression("x * 2")) + .put(p.symbol("literal"), expression("1")) + .put(p.symbol("complex_2"), expression("x - 1")) .build(), - p.values(p.symbol("x", BIGINT))))) + p.values(p.symbol("x"))))) .matches( project( ImmutableMap.builder() @@ -89,13 +71,27 @@ public void test() public void testIdentityProjections() throws Exception { - tester.assertThat(new InlineProjections()) + tester().assertThat(new InlineProjections()) + .on(p -> + p.project( + Assignments.of(p.symbol("output"), expression("value")), + p.project( + Assignments.identity(p.symbol("value")), + p.values(p.symbol("value"))))) + .doesNotFire(); + } + + @Test + public void testSubqueryProjections() + throws Exception + { + tester().assertThat(new InlineProjections()) .on(p -> p.project( - Assignments.of(p.symbol("output", BIGINT), expression("value")), + Assignments.identity(p.symbol("fromOuterScope"), p.symbol("value")), p.project( - Assignments.identity(p.symbol("value", BIGINT)), - p.values(p.symbol("value", BIGINT))))) + Assignments.identity(p.symbol("value")), + p.values(p.symbol("value"))))) .doesNotFire(); } } diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestMergeAdjacentWindows.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestMergeAdjacentWindows.java index 0476393b7c72..d117fe516be6 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestMergeAdjacentWindows.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestMergeAdjacentWindows.java @@ -16,8 +16,8 @@ import com.facebook.presto.metadata.FunctionKind; import com.facebook.presto.metadata.Signature; import com.facebook.presto.sql.planner.assertions.ExpectedValueProvider; +import com.facebook.presto.sql.planner.iterative.rule.test.BaseRuleTest; import com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder; -import com.facebook.presto.sql.planner.iterative.rule.test.RuleTester; import com.facebook.presto.sql.planner.plan.Assignments; import com.facebook.presto.sql.planner.plan.WindowNode; import com.facebook.presto.sql.tree.FunctionCall; @@ -27,8 +27,6 @@ import com.facebook.presto.sql.tree.WindowFrame; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; -import org.testng.annotations.AfterClass; -import org.testng.annotations.BeforeClass; import org.testng.annotations.Test; import java.util.Arrays; @@ -43,25 +41,10 @@ import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.window; import static com.facebook.presto.sql.tree.FrameBound.Type.CURRENT_ROW; import static com.facebook.presto.sql.tree.FrameBound.Type.UNBOUNDED_PRECEDING; -import static io.airlift.testing.Closeables.closeAllRuntimeException; public class TestMergeAdjacentWindows + extends BaseRuleTest { - private RuleTester tester; - - @BeforeClass - public void setUp() - { - tester = new RuleTester(); - } - - @AfterClass(alwaysRun = true) - public void tearDown() - { - closeAllRuntimeException(tester); - tester = null; - } - private static final WindowNode.Frame frame = new WindowNode.Frame(WindowFrame.Type.RANGE, UNBOUNDED_PRECEDING, Optional.empty(), CURRENT_ROW, Optional.empty()); private static final Signature signature = new Signature( @@ -77,8 +60,8 @@ public void tearDown() public void testPlanWithoutWindowNode() throws Exception { - tester.assertThat(new MergeAdjacentWindows()) - .on(p -> p.values(p.symbol("a", BIGINT))) + tester().assertThat(new MergeAdjacentWindows()) + .on(p -> p.values(p.symbol("a"))) .doesNotFire(); } @@ -86,27 +69,27 @@ public void testPlanWithoutWindowNode() public void testPlanWithSingleWindowNode() throws Exception { - tester.assertThat(new MergeAdjacentWindows()) + tester().assertThat(new MergeAdjacentWindows()) .on(p -> p.window( newWindowNodeSpecification(p, "a"), - ImmutableMap.of(p.symbol("avg_1", BIGINT), newWindowNodeFunction("avg", "a")), - p.values(p.symbol("a", BIGINT)))) + ImmutableMap.of(p.symbol("avg_1"), newWindowNodeFunction("avg", "a")), + p.values(p.symbol("a")))) .doesNotFire(); } @Test public void testDistinctAdjacentWindowSpecifications() { - tester.assertThat(new MergeAdjacentWindows()) + tester().assertThat(new MergeAdjacentWindows()) .on(p -> p.window( newWindowNodeSpecification(p, "a"), - ImmutableMap.of(p.symbol("avg_1", BIGINT), newWindowNodeFunction("avg", "a")), + ImmutableMap.of(p.symbol("avg_1"), newWindowNodeFunction("avg", "a")), p.window( newWindowNodeSpecification(p, "b"), - ImmutableMap.of(p.symbol("sum_1", BIGINT), newWindowNodeFunction("sum", "b")), - p.values(p.symbol("b", BIGINT)) + ImmutableMap.of(p.symbol("sum_1"), newWindowNodeFunction("sum", "b")), + p.values(p.symbol("b")) ) )) .doesNotFire(); @@ -115,17 +98,17 @@ public void testDistinctAdjacentWindowSpecifications() @Test public void testNonWindowIntermediateNode() { - tester.assertThat(new MergeAdjacentWindows()) + tester().assertThat(new MergeAdjacentWindows()) .on(p -> p.window( newWindowNodeSpecification(p, "a"), - ImmutableMap.of(p.symbol("lag_1", BIGINT), newWindowNodeFunction("lag", "a", "ONE")), + ImmutableMap.of(p.symbol("lag_1"), newWindowNodeFunction("lag", "a", "ONE")), p.project( - Assignments.copyOf(ImmutableMap.of(p.symbol("ONE", BIGINT), p.expression("CAST(1 AS bigint)"))), + Assignments.copyOf(ImmutableMap.of(p.symbol("ONE"), p.expression("CAST(1 AS bigint)"))), p.window( newWindowNodeSpecification(p, "a"), - ImmutableMap.of(p.symbol("avg_1", BIGINT), newWindowNodeFunction("avg", "a")), - p.values(p.symbol("a", BIGINT)) + ImmutableMap.of(p.symbol("avg_1"), newWindowNodeFunction("avg", "a")), + p.values(p.symbol("a")) ) ) )) @@ -138,15 +121,15 @@ public void testDependentAdjacentWindowsIdenticalSpecifications() { Optional windowA = Optional.of(new Window(ImmutableList.of(new SymbolReference("a")), Optional.empty(), Optional.empty())); - tester.assertThat(new MergeAdjacentWindows()) + tester().assertThat(new MergeAdjacentWindows()) .on(p -> p.window( newWindowNodeSpecification(p, "a"), - ImmutableMap.of(p.symbol("avg_1", BIGINT), newWindowNodeFunction("avg", windowA, "avg_2")), + ImmutableMap.of(p.symbol("avg_1"), newWindowNodeFunction("avg", windowA, "avg_2")), p.window( newWindowNodeSpecification(p, "a"), - ImmutableMap.of(p.symbol("avg_2", BIGINT), newWindowNodeFunction("avg", windowA, "a")), - p.values(p.symbol("a", BIGINT)) + ImmutableMap.of(p.symbol("avg_2"), newWindowNodeFunction("avg", windowA, "a")), + p.values(p.symbol("a")) ) )) .doesNotFire(); @@ -158,15 +141,15 @@ public void testDependentAdjacentWindowsDistinctSpecifications() { Optional windowA = Optional.of(new Window(ImmutableList.of(new SymbolReference("a")), Optional.empty(), Optional.empty())); - tester.assertThat(new MergeAdjacentWindows()) + tester().assertThat(new MergeAdjacentWindows()) .on(p -> p.window( newWindowNodeSpecification(p, "a"), - ImmutableMap.of(p.symbol("avg_1", BIGINT), newWindowNodeFunction("avg", windowA, "avg_2")), + ImmutableMap.of(p.symbol("avg_1"), newWindowNodeFunction("avg", windowA, "avg_2")), p.window( newWindowNodeSpecification(p, "b"), - ImmutableMap.of(p.symbol("avg_2", BIGINT), newWindowNodeFunction("avg", windowA, "a")), - p.values(p.symbol("a", BIGINT), p.symbol("b", BIGINT)) + ImmutableMap.of(p.symbol("avg_2"), newWindowNodeFunction("avg", windowA, "a")), + p.values(p.symbol("a"), p.symbol("b")) ) )) .doesNotFire(); @@ -182,23 +165,23 @@ public void testIdenticalAdjacentWindowSpecifications() Optional windowA = Optional.of(new Window(ImmutableList.of(new SymbolReference("a")), Optional.empty(), Optional.empty())); - tester.assertThat(new MergeAdjacentWindows()) + tester().assertThat(new MergeAdjacentWindows()) .on(p -> p.window( newWindowNodeSpecification(p, "a"), - ImmutableMap.of(p.symbol("avg_1", BIGINT), newWindowNodeFunction("avg", windowA, "a")), + ImmutableMap.of(p.symbol("avg_1"), newWindowNodeFunction("avg", windowA, "a")), p.window( newWindowNodeSpecification(p, "a"), - ImmutableMap.of(p.symbol("sum_1", BIGINT), newWindowNodeFunction("sum", windowA, "a")), - p.values(p.symbol("a", BIGINT)) + ImmutableMap.of(p.symbol("sum_1"), newWindowNodeFunction("sum", windowA, "a")), + p.values(p.symbol("a")) ) )) - .matches(window( - specificationA, - ImmutableList.of( - functionCall("avg", Optional.empty(), ImmutableList.of(columnAAlias)), - functionCall("sum", Optional.empty(), ImmutableList.of(columnAAlias))), - values(ImmutableMap.of(columnAAlias, 0)))); + .matches( + window(windowMatcherBuilder -> windowMatcherBuilder + .specification(specificationA) + .addFunction(functionCall("avg", Optional.empty(), ImmutableList.of(columnAAlias))) + .addFunction(functionCall("sum", Optional.empty(), ImmutableList.of(columnAAlias))), + values(ImmutableMap.of(columnAAlias, 0)))); } private static WindowNode.Specification newWindowNodeSpecification(PlanBuilder planBuilder, String symbolName) diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestMergeFilters.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestMergeFilters.java index 6f9ff8414fb4..9ce71f2c884f 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestMergeFilters.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestMergeFilters.java @@ -13,43 +13,25 @@ */ package com.facebook.presto.sql.planner.iterative.rule; -import com.facebook.presto.sql.planner.iterative.rule.test.RuleTester; +import com.facebook.presto.sql.planner.iterative.rule.test.BaseRuleTest; import com.google.common.collect.ImmutableMap; -import org.testng.annotations.AfterClass; -import org.testng.annotations.BeforeClass; import org.testng.annotations.Test; -import static com.facebook.presto.spi.type.BigintType.BIGINT; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.filter; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.values; import static com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder.expression; -import static io.airlift.testing.Closeables.closeAllRuntimeException; public class TestMergeFilters + extends BaseRuleTest { - private RuleTester tester; - - @BeforeClass - public void setUp() - { - tester = new RuleTester(); - } - - @AfterClass(alwaysRun = true) - public void tearDown() - { - closeAllRuntimeException(tester); - tester = null; - } - @Test public void test() { - tester.assertThat(new MergeFilters()) + tester().assertThat(new MergeFilters()) .on(p -> p.filter(expression("b > 44"), p.filter(expression("a < 42"), - p.values(p.symbol("a", BIGINT), p.symbol("b", BIGINT))))) + p.values(p.symbol("a"), p.symbol("b"))))) .matches(filter("(a < 42) AND (b > 44)", values(ImmutableMap.of("a", 0, "b", 1)))); } } diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneCountAggregationOverScalar.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneCountAggregationOverScalar.java new file mode 100644 index 000000000000..4c5e17c6b9ee --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneCountAggregationOverScalar.java @@ -0,0 +1,174 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.iterative.rule; + +import com.facebook.presto.connector.ConnectorId; +import com.facebook.presto.metadata.TableHandle; +import com.facebook.presto.spi.type.BigintType; +import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.sql.planner.assertions.BasePlanTest; +import com.facebook.presto.sql.planner.iterative.rule.test.RuleTester; +import com.facebook.presto.sql.planner.plan.AggregationNode; +import com.facebook.presto.sql.planner.plan.Assignments; +import com.facebook.presto.sql.tree.FunctionCall; +import com.facebook.presto.sql.tree.QualifiedName; +import com.facebook.presto.sql.tree.SymbolReference; +import com.facebook.presto.tpch.TpchColumnHandle; +import com.facebook.presto.tpch.TpchTableHandle; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import org.testng.annotations.AfterClass; +import org.testng.annotations.BeforeClass; +import org.testng.annotations.Test; + +import static com.facebook.presto.spi.type.DoubleType.DOUBLE; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.values; +import static com.facebook.presto.tpch.TpchMetadata.TINY_SCALE_FACTOR; +import static io.airlift.testing.Closeables.closeAllRuntimeException; + +public class TestPruneCountAggregationOverScalar + extends BasePlanTest +{ + private RuleTester tester; + + @BeforeClass + public void setUp() + { + tester = new RuleTester(); + } + + @AfterClass(alwaysRun = true) + public void tearDown() + { + closeAllRuntimeException(tester); + tester = null; + } + + @Test + public void testDoesNotFireOnNonNestedAggregate() + { + tester.assertThat(new PruneCountAggregationOverScalar()) + .on(p -> + p.aggregation((a) -> a + .globalGrouping() + .addAggregation( + p.symbol("count_1", BigintType.BIGINT), + new FunctionCall(QualifiedName.of("count"), ImmutableList.of()), + ImmutableList.of(BigintType.BIGINT)) + .source( + p.tableScan(ImmutableList.of(), ImmutableMap.of()))) + ).doesNotFire(); + } + + @Test + public void testFiresOnNestedCountAggregate() + { + tester.assertThat(new PruneCountAggregationOverScalar()) + .on(p -> + p.aggregation((a) -> a + .addAggregation( + p.symbol("count_1", BigintType.BIGINT), + new FunctionCall(QualifiedName.of("count"), ImmutableList.of()), ImmutableList.of(BigintType.BIGINT)) + .globalGrouping() + .step(AggregationNode.Step.SINGLE) + .source( + p.aggregation((aggregationBuilder) -> aggregationBuilder + .source(p.tableScan(ImmutableList.of(), ImmutableMap.of())) + .globalGrouping() + .step(AggregationNode.Step.SINGLE))))) + .matches(values(ImmutableMap.of("count_1", 0))); + } + + @Test + public void testFiresOnCountAggregateOverValues() + { + tester.assertThat(new PruneCountAggregationOverScalar()) + .on(p -> + p.aggregation((a) -> a + .addAggregation( + p.symbol("count_1", BigintType.BIGINT), + new FunctionCall(QualifiedName.of("count"), ImmutableList.of()), + ImmutableList.of(BigintType.BIGINT)) + .step(AggregationNode.Step.SINGLE) + .globalGrouping() + .source(p.values(ImmutableList.of(p.symbol("orderkey")), ImmutableList.of(p.expressions("1")))))) + .matches(values(ImmutableMap.of("count_1", 0))); + } + + @Test + public void testFiresOnCountAggregateOverEnforceSingleRow() + { + tester.assertThat(new PruneCountAggregationOverScalar()) + .on(p -> + p.aggregation((a) -> a + .addAggregation( + p.symbol("count_1", BigintType.BIGINT), + new FunctionCall(QualifiedName.of("count"), ImmutableList.of()), + ImmutableList.of(BigintType.BIGINT)) + .step(AggregationNode.Step.SINGLE) + .globalGrouping() + .source(p.enforceSingleRow(p.tableScan(ImmutableList.of(), ImmutableMap.of()))))) + .matches(values(ImmutableMap.of("count_1", 0))); + } + + @Test + public void testDoesNotFireOnNestedCountAggregateWithNonEmptyGroupBy() + { + tester.assertThat(new PruneCountAggregationOverScalar()) + .on(p -> + p.aggregation((a) -> a + .addAggregation( + p.symbol("count_1", BigintType.BIGINT), + new FunctionCall(QualifiedName.of("count"), ImmutableList.of()), + ImmutableList.of(BigintType.BIGINT)) + .step(AggregationNode.Step.SINGLE) + .globalGrouping() + .source( + p.aggregation(aggregationBuilder -> aggregationBuilder + .source(p.tableScan(ImmutableList.of(), ImmutableMap.of())) + .groupingSets(ImmutableList.of(ImmutableList.of(p.symbol("orderkey")))))))) + .doesNotFire(); + } + + @Test + public void testDoesNotFireOnNestedNonCountAggregate() + { + tester.assertThat(new PruneCountAggregationOverScalar()) + .on(p -> { + Symbol totalPrice = p.symbol("total_price", DOUBLE); + AggregationNode inner = p.aggregation((a) -> a + .addAggregation(totalPrice, + new FunctionCall(QualifiedName.of("sum"), ImmutableList.of(new SymbolReference("totalprice"))), + ImmutableList.of(DOUBLE)) + .globalGrouping() + .source( + p.project( + Assignments.of(totalPrice, totalPrice.toSymbolReference()), + p.tableScan( + new TableHandle( + new ConnectorId("local"), + new TpchTableHandle("local", "orders", TINY_SCALE_FACTOR)), + ImmutableList.of(totalPrice), + ImmutableMap.of(totalPrice, new TpchColumnHandle(totalPrice.getName(), DOUBLE)))))); + + return p.aggregation((a) -> a + .addAggregation( + p.symbol("sum_outer", DOUBLE), + new FunctionCall(QualifiedName.of("sum"), ImmutableList.of(new SymbolReference("sum_inner"))), + ImmutableList.of(DOUBLE)) + .globalGrouping() + .source(inner)); + }).doesNotFire(); + } +} diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneCrossJoinColumns.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneCrossJoinColumns.java new file mode 100644 index 000000000000..32816c19b113 --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneCrossJoinColumns.java @@ -0,0 +1,106 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.iterative.rule; + +import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.sql.planner.assertions.PlanMatchPattern; +import com.facebook.presto.sql.planner.iterative.rule.test.BaseRuleTest; +import com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder; +import com.facebook.presto.sql.planner.plan.Assignments; +import com.facebook.presto.sql.planner.plan.JoinNode; +import com.facebook.presto.sql.planner.plan.PlanNode; +import com.google.common.base.Predicates; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import org.testng.annotations.Test; + +import java.util.List; +import java.util.Optional; +import java.util.function.Predicate; + +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.join; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.strictProject; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.values; +import static com.google.common.collect.ImmutableList.toImmutableList; + +public class TestPruneCrossJoinColumns + extends BaseRuleTest +{ + @Test + public void testLeftInputNotReferenced() + { + tester().assertThat(new PruneCrossJoinColumns()) + .on(p -> buildProjectedCrossJoin(p, symbol -> symbol.getName().equals("rightValue"))) + .matches( + strictProject( + ImmutableMap.of("rightValue", PlanMatchPattern.expression("rightValue")), + join( + JoinNode.Type.INNER, + ImmutableList.of(), + Optional.empty(), + strictProject( + ImmutableMap.of(), + values(ImmutableList.of("leftValue"))), + values(ImmutableList.of("rightValue"))) + .withExactOutputs("rightValue"))); + } + + @Test + public void testRightInputNotReferenced() + { + tester().assertThat(new PruneCrossJoinColumns()) + .on(p -> buildProjectedCrossJoin(p, symbol -> symbol.getName().equals("leftValue"))) + .matches( + strictProject( + ImmutableMap.of("leftValue", PlanMatchPattern.expression("leftValue")), + join( + JoinNode.Type.INNER, + ImmutableList.of(), + Optional.empty(), + values(ImmutableList.of("leftValue")), + strictProject( + ImmutableMap.of(), + values(ImmutableList.of("rightValue")))) + .withExactOutputs("leftValue"))); + } + + @Test + public void testAllInputsReferenced() + { + tester().assertThat(new PruneCrossJoinColumns()) + .on(p -> buildProjectedCrossJoin(p, Predicates.alwaysTrue())) + .doesNotFire(); + } + + private static PlanNode buildProjectedCrossJoin(PlanBuilder p, Predicate projectionFilter) + { + Symbol leftValue = p.symbol("leftValue"); + Symbol rightValue = p.symbol("rightValue"); + List outputs = ImmutableList.of(leftValue, rightValue); + return p.project( + Assignments.identity( + outputs.stream() + .filter(projectionFilter) + .collect(toImmutableList())), + p.join( + JoinNode.Type.INNER, + p.values(leftValue), + p.values(rightValue), + ImmutableList.of(), + outputs, + Optional.empty(), + Optional.empty(), + Optional.empty())); + } +} diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneIndexSourceColumns.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneIndexSourceColumns.java new file mode 100644 index 000000000000..f52cdab58eb6 --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneIndexSourceColumns.java @@ -0,0 +1,97 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.iterative.rule; + +import com.facebook.presto.connector.ConnectorId; +import com.facebook.presto.metadata.TableHandle; +import com.facebook.presto.spi.ColumnHandle; +import com.facebook.presto.spi.predicate.Domain; +import com.facebook.presto.spi.predicate.TupleDomain; +import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.sql.planner.iterative.rule.test.BaseRuleTest; +import com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder; +import com.facebook.presto.sql.planner.plan.Assignments; +import com.facebook.presto.sql.planner.plan.PlanNode; +import com.facebook.presto.tpch.TpchColumnHandle; +import com.facebook.presto.tpch.TpchTableHandle; +import com.google.common.base.Predicates; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; +import org.testng.annotations.Test; + +import java.util.function.Predicate; + +import static com.facebook.presto.spi.predicate.NullableValue.asNull; +import static com.facebook.presto.spi.type.DoubleType.DOUBLE; +import static com.facebook.presto.spi.type.IntegerType.INTEGER; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.constrainedIndexSource; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.expression; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.strictProject; +import static com.facebook.presto.tpch.TpchMetadata.TINY_SCALE_FACTOR; +import static com.google.common.collect.ImmutableList.toImmutableList; + +public class TestPruneIndexSourceColumns + extends BaseRuleTest +{ + @Test + public void testNotAllOutputsReferenced() + { + tester().assertThat(new PruneIndexSourceColumns()) + .on(p -> buildProjectedIndexSource(p, symbol -> symbol.getName().equals("orderkey"))) + .matches( + strictProject( + ImmutableMap.of("x", expression("orderkey")), + constrainedIndexSource( + "orders", + ImmutableMap.of("totalprice", Domain.onlyNull(DOUBLE)), + ImmutableMap.of( + "orderkey", "orderkey", + "totalprice", "totalprice")))); + } + + @Test + public void testAllOutputsReferenced() + { + tester().assertThat(new PruneIndexSourceColumns()) + .on(p -> buildProjectedIndexSource(p, Predicates.alwaysTrue())) + .doesNotFire(); + } + + private static PlanNode buildProjectedIndexSource(PlanBuilder p, Predicate projectionFilter) + { + Symbol orderkey = p.symbol("orderkey", INTEGER); + Symbol custkey = p.symbol("custkey", INTEGER); + Symbol totalprice = p.symbol("totalprice", DOUBLE); + ColumnHandle orderkeyHandle = new TpchColumnHandle(orderkey.getName(), INTEGER); + ColumnHandle custkeyHandle = new TpchColumnHandle(custkey.getName(), INTEGER); + ColumnHandle totalpriceHandle = new TpchColumnHandle(totalprice.getName(), DOUBLE); + return p.project( + Assignments.identity( + ImmutableList.of(orderkey, custkey, totalprice).stream() + .filter(projectionFilter) + .collect(toImmutableList())), + p.indexSource( + new TableHandle( + new ConnectorId("local"), + new TpchTableHandle("local", "orders", TINY_SCALE_FACTOR)), + ImmutableSet.of(orderkey, custkey), + ImmutableList.of(orderkey, custkey, totalprice), + ImmutableMap.of( + orderkey, orderkeyHandle, + custkey, custkeyHandle, + totalprice, totalpriceHandle), + TupleDomain.fromFixedValues(ImmutableMap.of(totalpriceHandle, asNull(DOUBLE))))); + } +} diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneJoinChildrenColumns.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneJoinChildrenColumns.java new file mode 100644 index 000000000000..935ee820db41 --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneJoinChildrenColumns.java @@ -0,0 +1,108 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.iterative.rule; + +import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.sql.planner.assertions.PlanMatchPattern; +import com.facebook.presto.sql.planner.iterative.rule.test.BaseRuleTest; +import com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder; +import com.facebook.presto.sql.planner.plan.JoinNode; +import com.facebook.presto.sql.planner.plan.PlanNode; +import com.google.common.base.Predicates; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import org.testng.annotations.Test; + +import java.util.List; +import java.util.Optional; +import java.util.function.Predicate; + +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.equiJoinClause; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.join; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.strictProject; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.values; +import static com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder.expression; +import static com.google.common.collect.ImmutableList.toImmutableList; + +public class TestPruneJoinChildrenColumns + extends BaseRuleTest +{ + @Test + public void testNotAllInputsRereferenced() + { + tester().assertThat(new PruneJoinChildrenColumns()) + .on(p -> buildJoin(p, symbol -> symbol.getName().equals("leftValue"))) + .matches( + join( + JoinNode.Type.INNER, + ImmutableList.of(equiJoinClause("leftKey", "rightKey")), + Optional.of("leftValue > 5"), + values("leftKey", "leftKeyHash", "leftValue"), + strictProject( + ImmutableMap.of( + "rightKey", PlanMatchPattern.expression("rightKey"), + "rightKeyHash", PlanMatchPattern.expression("rightKeyHash")), + values("rightKey", "rightKeyHash", "rightValue")))); + } + + @Test + public void testAllInputsReferenced() + { + tester().assertThat(new PruneJoinChildrenColumns()) + .on(p -> buildJoin(p, Predicates.alwaysTrue())) + .doesNotFire(); + } + + @Test + public void testCrossJoinDoesNotFire() + { + tester().assertThat(new PruneJoinColumns()) + .on(p -> { + Symbol leftValue = p.symbol("leftValue"); + Symbol rightValue = p.symbol("rightValue"); + return p.join( + JoinNode.Type.INNER, + p.values(leftValue), + p.values(rightValue), + ImmutableList.of(), + ImmutableList.of(leftValue, rightValue), + Optional.empty(), + Optional.empty(), + Optional.empty()); + }) + .doesNotFire(); + } + + private static PlanNode buildJoin(PlanBuilder p, Predicate joinOutputFilter) + { + Symbol leftKey = p.symbol("leftKey"); + Symbol leftKeyHash = p.symbol("leftKeyHash"); + Symbol leftValue = p.symbol("leftValue"); + Symbol rightKey = p.symbol("rightKey"); + Symbol rightKeyHash = p.symbol("rightKeyHash"); + Symbol rightValue = p.symbol("rightValue"); + List outputs = ImmutableList.of(leftValue, rightValue); + return p.join( + JoinNode.Type.INNER, + p.values(leftKey, leftKeyHash, leftValue), + p.values(rightKey, rightKeyHash, rightValue), + ImmutableList.of(new JoinNode.EquiJoinClause(leftKey, rightKey)), + outputs.stream() + .filter(joinOutputFilter) + .collect(toImmutableList()), + Optional.of(expression("leftValue > 5")), + Optional.of(leftKeyHash), + Optional.of(rightKeyHash)); + } +} diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneJoinColumns.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneJoinColumns.java new file mode 100644 index 000000000000..7bfcc8337221 --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneJoinColumns.java @@ -0,0 +1,110 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.iterative.rule; + +import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.sql.planner.assertions.PlanMatchPattern; +import com.facebook.presto.sql.planner.iterative.rule.test.BaseRuleTest; +import com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder; +import com.facebook.presto.sql.planner.plan.Assignments; +import com.facebook.presto.sql.planner.plan.JoinNode; +import com.facebook.presto.sql.planner.plan.PlanNode; +import com.google.common.base.Predicates; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import org.testng.annotations.Test; + +import java.util.List; +import java.util.Optional; +import java.util.function.Predicate; + +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.equiJoinClause; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.join; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.strictProject; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.values; +import static com.google.common.collect.ImmutableList.toImmutableList; + +public class TestPruneJoinColumns + extends BaseRuleTest +{ + @Test + public void testNotAllOutputsReferenced() + { + tester().assertThat(new PruneJoinColumns()) + .on(p -> buildProjectedJoin(p, symbol -> symbol.getName().equals("rightValue"))) + .matches( + strictProject( + ImmutableMap.of("rightValue", PlanMatchPattern.expression("rightValue")), + join( + JoinNode.Type.INNER, + ImmutableList.of(equiJoinClause("leftKey", "rightKey")), + Optional.empty(), + values(ImmutableList.of("leftKey", "leftValue")), + values(ImmutableList.of("rightKey", "rightValue"))) + .withExactOutputs("rightValue"))); + } + + @Test + public void testAllInputsReferenced() + { + tester().assertThat(new PruneJoinColumns()) + .on(p -> buildProjectedJoin(p, Predicates.alwaysTrue())) + .doesNotFire(); + } + + @Test + public void testCrossJoinDoesNotFire() + { + tester().assertThat(new PruneJoinColumns()) + .on(p -> { + Symbol leftValue = p.symbol("leftValue"); + Symbol rightValue = p.symbol("rightValue"); + return p.project( + Assignments.of(), + p.join( + JoinNode.Type.INNER, + p.values(leftValue), + p.values(rightValue), + ImmutableList.of(), + ImmutableList.of(leftValue, rightValue), + Optional.empty(), + Optional.empty(), + Optional.empty())); + }) + .doesNotFire(); + } + + private static PlanNode buildProjectedJoin(PlanBuilder p, Predicate projectionFilter) + { + Symbol leftKey = p.symbol("leftKey"); + Symbol leftValue = p.symbol("leftValue"); + Symbol rightKey = p.symbol("rightKey"); + Symbol rightValue = p.symbol("rightValue"); + List outputs = ImmutableList.of(leftKey, leftValue, rightKey, rightValue); + return p.project( + Assignments.identity( + outputs.stream() + .filter(projectionFilter) + .collect(toImmutableList())), + p.join( + JoinNode.Type.INNER, + p.values(leftKey, leftValue), + p.values(rightKey, rightValue), + ImmutableList.of(new JoinNode.EquiJoinClause(leftKey, rightKey)), + outputs, + Optional.empty(), + Optional.empty(), + Optional.empty())); + } +} diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneMarkDistinctColumns.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneMarkDistinctColumns.java new file mode 100644 index 000000000000..3af9467e90a1 --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneMarkDistinctColumns.java @@ -0,0 +1,113 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.iterative.rule; + +import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.sql.planner.iterative.rule.test.BaseRuleTest; +import com.facebook.presto.sql.planner.plan.Assignments; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import org.testng.annotations.Test; + +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.expression; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.markDistinct; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.strictProject; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.values; + +public class TestPruneMarkDistinctColumns + extends BaseRuleTest +{ + @Test + public void testMarkerSymbolNotReferenced() + throws Exception + { + tester().assertThat(new PruneMarkDistinctColumns()) + .on(p -> + { + Symbol key = p.symbol("key"); + Symbol key2 = p.symbol("key2"); + Symbol mark = p.symbol("mark"); + Symbol unused = p.symbol("unused"); + return p.project( + Assignments.of(key2, key.toSymbolReference()), + p.markDistinct(mark, ImmutableList.of(key), p.values(key, unused))); + }) + .matches( + strictProject( + ImmutableMap.of("key2", expression("key")), + values(ImmutableList.of("key", "unused")))); + } + + @Test + public void testSourceSymbolNotReferenced() + throws Exception + { + tester().assertThat(new PruneMarkDistinctColumns()) + .on(p -> + { + Symbol key = p.symbol("key"); + Symbol mark = p.symbol("mark"); + Symbol hash = p.symbol("hash"); + Symbol unused = p.symbol("unused"); + return p.project( + Assignments.identity(mark), + p.markDistinct( + mark, + ImmutableList.of(key), + hash, + p.values(key, hash, unused))); + }) + .matches( + strictProject( + ImmutableMap.of("mark", expression("mark")), + markDistinct("mark", ImmutableList.of("key"), "hash", + strictProject( + ImmutableMap.of( + "key", expression("key"), + "hash", expression("hash")), + values(ImmutableList.of("key", "hash", "unused")))))); + } + + @Test + public void testKeySymbolNotReferenced() + throws Exception + { + tester().assertThat(new PruneMarkDistinctColumns()) + .on(p -> + { + Symbol key = p.symbol("key"); + Symbol mark = p.symbol("mark"); + return p.project( + Assignments.identity(mark), + p.markDistinct(mark, ImmutableList.of(key), p.values(key))); + }) + .doesNotFire(); + } + + @Test + public void testAllOutputsReferenced() + throws Exception + { + tester().assertThat(new PruneMarkDistinctColumns()) + .on(p -> + { + Symbol key = p.symbol("key"); + Symbol mark = p.symbol("mark"); + return p.project( + Assignments.identity(key, mark), + p.markDistinct(mark, ImmutableList.of(key), p.values(key))); + }) + .doesNotFire(); + } +} diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneOutputColumns.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneOutputColumns.java new file mode 100644 index 000000000000..a36c54eacffa --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneOutputColumns.java @@ -0,0 +1,68 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.iterative.rule; + +import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.sql.planner.iterative.rule.test.BaseRuleTest; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import org.testng.annotations.Test; + +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.expression; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.strictOutput; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.strictProject; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.values; + +public class TestPruneOutputColumns + extends BaseRuleTest +{ + @Test + public void testNotAllOutputsReferenced() + throws Exception + { + tester().assertThat(new PruneOutputColumns()) + .on(p -> + { + Symbol a = p.symbol("a"); + Symbol b = p.symbol("b"); + return p.output( + ImmutableList.of("B label"), + ImmutableList.of(b), + p.values(a, b)); + }) + .matches( + strictOutput( + ImmutableList.of("b"), + strictProject( + ImmutableMap.of("b", expression("b")), + values("a", "b")))); + } + + @Test + public void testAllOutputsReferenced() + throws Exception + { + tester().assertThat(new PruneOutputColumns()) + .on(p -> + { + Symbol a = p.symbol("a"); + Symbol b = p.symbol("b"); + return p.output( + ImmutableList.of("A label", "B label"), + ImmutableList.of(a, b), + p.values(a, b)); + }) + .doesNotFire(); + } +} diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneSemiJoinColumns.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneSemiJoinColumns.java new file mode 100644 index 000000000000..4dc7a637e27a --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneSemiJoinColumns.java @@ -0,0 +1,104 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.iterative.rule; + +import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.sql.planner.iterative.rule.test.BaseRuleTest; +import com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder; +import com.facebook.presto.sql.planner.plan.Assignments; +import com.facebook.presto.sql.planner.plan.PlanNode; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import org.testng.annotations.Test; + +import java.util.List; +import java.util.Optional; +import java.util.function.Predicate; + +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.expression; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.semiJoin; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.strictProject; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.values; +import static com.google.common.collect.ImmutableList.toImmutableList; + +public class TestPruneSemiJoinColumns + extends BaseRuleTest +{ + @Test + public void testSemiJoinNotNeeded() + { + tester().assertThat(new PruneSemiJoinColumns()) + .on(p -> buildProjectedSemiJoin(p, symbol -> symbol.getName().equals("leftValue"))) + .matches( + strictProject( + ImmutableMap.of("leftValue", expression("leftValue")), + values("leftKey", "leftKeyHash", "leftValue"))); + } + + @Test + public void testAllColumnsNeeded() + { + tester().assertThat(new PruneSemiJoinColumns()) + .on(p -> buildProjectedSemiJoin(p, symbol -> true)) + .doesNotFire(); + } + + @Test + public void testKeysNotNeeded() + { + tester().assertThat(new PruneSemiJoinColumns()) + .on(p -> buildProjectedSemiJoin(p, symbol -> (symbol.getName().equals("leftValue") || symbol.getName().equals("match")))) + .doesNotFire(); + } + + @Test + public void testValueNotNeeded() + { + tester().assertThat(new PruneSemiJoinColumns()) + .on(p -> buildProjectedSemiJoin(p, symbol -> symbol.getName().equals("match"))) + .matches( + strictProject( + ImmutableMap.of("match", expression("match")), + semiJoin("leftKey", "rightKey", "match", + strictProject( + ImmutableMap.of( + "leftKey", expression("leftKey"), + "leftKeyHash", expression("leftKeyHash")), + values("leftKey", "leftKeyHash", "leftValue")), + values("rightKey")))); + } + + private static PlanNode buildProjectedSemiJoin(PlanBuilder p, Predicate projectionFilter) + { + Symbol match = p.symbol("match"); + Symbol leftKey = p.symbol("leftKey"); + Symbol leftKeyHash = p.symbol("leftKeyHash"); + Symbol leftValue = p.symbol("leftValue"); + Symbol rightKey = p.symbol("rightKey"); + List outputs = ImmutableList.of(match, leftKey, leftKeyHash, leftValue); + return p.project( + Assignments.identity( + outputs.stream() + .filter(projectionFilter) + .collect(toImmutableList())), + p.semiJoin( + leftKey, + rightKey, + match, + Optional.of(leftKeyHash), + Optional.empty(), + p.values(leftKey, leftKeyHash, leftValue), + p.values(rightKey))); + } +} diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneSemiJoinFilteringSourceColumns.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneSemiJoinFilteringSourceColumns.java new file mode 100644 index 000000000000..d64c103f1e96 --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneSemiJoinFilteringSourceColumns.java @@ -0,0 +1,81 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.iterative.rule; + +import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.sql.planner.iterative.rule.test.BaseRuleTest; +import com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder; +import com.facebook.presto.sql.planner.plan.PlanNode; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import org.testng.annotations.Test; + +import java.util.List; +import java.util.Optional; +import java.util.function.Predicate; + +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.expression; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.semiJoin; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.strictProject; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.values; +import static com.google.common.collect.ImmutableList.toImmutableList; + +public class TestPruneSemiJoinFilteringSourceColumns + extends BaseRuleTest +{ + @Test + public void testNotAllColumnsReferenced() + { + tester().assertThat(new PruneSemiJoinFilteringSourceColumns()) + .on(p -> buildSemiJoin(p, symbol -> true)) + .matches( + semiJoin("leftKey", "rightKey", "match", + values("leftKey"), + strictProject( + ImmutableMap.of( + "rightKey", expression("rightKey"), + "rightKeyHash", expression("rightKeyHash")), + values("rightKey", "rightKeyHash", "rightValue")))); + } + + @Test + public void testAllColumnsNeeded() + { + tester().assertThat(new PruneSemiJoinFilteringSourceColumns()) + .on(p -> buildSemiJoin(p, symbol -> !symbol.getName().equals("rightValue"))) + .doesNotFire(); + } + + private static PlanNode buildSemiJoin(PlanBuilder p, Predicate filteringSourceSymbolFilter) + { + Symbol match = p.symbol("match"); + Symbol leftKey = p.symbol("leftKey"); + Symbol rightKey = p.symbol("rightKey"); + Symbol rightKeyHash = p.symbol("rightKeyHash"); + Symbol rightValue = p.symbol("rightValue"); + List filteringSourceSymbols = ImmutableList.of(rightKey, rightKeyHash, rightValue); + return p.semiJoin( + leftKey, + rightKey, + match, + Optional.empty(), + Optional.of(rightKeyHash), + p.values(leftKey), + p.values( + filteringSourceSymbols.stream() + .filter(filteringSourceSymbolFilter) + .collect(toImmutableList()), + ImmutableList.of())); + } +} diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneTableScanColumns.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneTableScanColumns.java index 0fa45489fae9..18957faf46bc 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneTableScanColumns.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneTableScanColumns.java @@ -17,53 +17,35 @@ import com.facebook.presto.metadata.TableHandle; import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.assertions.PlanMatchPattern; -import com.facebook.presto.sql.planner.iterative.rule.test.RuleTester; +import com.facebook.presto.sql.planner.iterative.rule.test.BaseRuleTest; import com.facebook.presto.sql.planner.plan.Assignments; import com.facebook.presto.testing.TestingMetadata.TestingColumnHandle; import com.facebook.presto.tpch.TpchColumnHandle; import com.facebook.presto.tpch.TpchTableHandle; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; -import org.testng.annotations.AfterClass; -import org.testng.annotations.BeforeClass; import org.testng.annotations.Test; -import static com.facebook.presto.spi.type.BigintType.BIGINT; import static com.facebook.presto.spi.type.DateType.DATE; import static com.facebook.presto.spi.type.DoubleType.DOUBLE; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.strictProject; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.strictTableScan; import static com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder.expression; import static com.facebook.presto.tpch.TpchMetadata.TINY_SCALE_FACTOR; -import static io.airlift.testing.Closeables.closeAllRuntimeException; public class TestPruneTableScanColumns + extends BaseRuleTest { - private RuleTester tester; - - @BeforeClass - public void setUp() - { - tester = new RuleTester(); - } - - @AfterClass(alwaysRun = true) - public void tearDown() - { - closeAllRuntimeException(tester); - tester = null; - } - @Test public void testNotAllOutputsReferenced() { - tester.assertThat(new PruneTableScanColumns()) + tester().assertThat(new PruneTableScanColumns()) .on(p -> { Symbol orderdate = p.symbol("orderdate", DATE); Symbol totalprice = p.symbol("totalprice", DOUBLE); return p.project( - Assignments.of(p.symbol("x", BIGINT), totalprice.toSymbolReference()), + Assignments.of(p.symbol("x"), totalprice.toSymbolReference()), p.tableScan( new TableHandle( new ConnectorId("local"), @@ -82,13 +64,13 @@ orderdate, new TpchColumnHandle(orderdate.getName(), DATE), @Test public void testAllOutputsReferenced() { - tester.assertThat(new PruneTableScanColumns()) + tester().assertThat(new PruneTableScanColumns()) .on(p -> p.project( - Assignments.of(p.symbol("y", BIGINT), expression("x")), + Assignments.of(p.symbol("y"), expression("x")), p.tableScan( - ImmutableList.of(p.symbol("x", BIGINT)), - ImmutableMap.of(p.symbol("x", BIGINT), new TestingColumnHandle("x"))))) + ImmutableList.of(p.symbol("x")), + ImmutableMap.of(p.symbol("x"), new TestingColumnHandle("x"))))) .doesNotFire(); } } diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneValuesColumns.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneValuesColumns.java index 191b7b0f324c..3f78247ffe5f 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneValuesColumns.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneValuesColumns.java @@ -14,47 +14,29 @@ package com.facebook.presto.sql.planner.iterative.rule; import com.facebook.presto.sql.planner.assertions.PlanMatchPattern; -import com.facebook.presto.sql.planner.iterative.rule.test.RuleTester; +import com.facebook.presto.sql.planner.iterative.rule.test.BaseRuleTest; import com.facebook.presto.sql.planner.plan.Assignments; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; -import org.testng.annotations.AfterClass; -import org.testng.annotations.BeforeClass; import org.testng.annotations.Test; -import static com.facebook.presto.spi.type.BigintType.BIGINT; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.project; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.values; import static com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder.expression; -import static io.airlift.testing.Closeables.closeAllRuntimeException; public class TestPruneValuesColumns + extends BaseRuleTest { - private RuleTester tester; - - @BeforeClass - public void setUp() - { - tester = new RuleTester(); - } - - @AfterClass(alwaysRun = true) - public void tearDown() - { - closeAllRuntimeException(tester); - tester = null; - } - @Test public void testNotAllOutputsReferenced() throws Exception { - tester.assertThat(new PruneValuesColumns()) + tester().assertThat(new PruneValuesColumns()) .on(p -> p.project( - Assignments.of(p.symbol("y", BIGINT), expression("x")), + Assignments.of(p.symbol("y"), expression("x")), p.values( - ImmutableList.of(p.symbol("unused", BIGINT), p.symbol("x", BIGINT)), + ImmutableList.of(p.symbol("unused"), p.symbol("x")), ImmutableList.of( ImmutableList.of(expression("1"), expression("2")), ImmutableList.of(expression("3"), expression("4")))))) @@ -72,11 +54,11 @@ public void testNotAllOutputsReferenced() public void testAllOutputsReferenced() throws Exception { - tester.assertThat(new PruneValuesColumns()) + tester().assertThat(new PruneValuesColumns()) .on(p -> p.project( - Assignments.of(p.symbol("y", BIGINT), expression("x")), - p.values(p.symbol("x", BIGINT)))) + Assignments.of(p.symbol("y"), expression("x")), + p.values(p.symbol("x")))) .doesNotFire(); } } diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPushAggregationThroughOuterJoin.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPushAggregationThroughOuterJoin.java index 51a90827c3eb..4d51671328e5 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPushAggregationThroughOuterJoin.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPushAggregationThroughOuterJoin.java @@ -15,8 +15,8 @@ package com.facebook.presto.sql.planner.iterative.rule; import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.sql.planner.iterative.rule.test.BaseRuleTest; import com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder; -import com.facebook.presto.sql.planner.iterative.rule.test.RuleTester; import com.facebook.presto.sql.planner.plan.JoinNode; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; @@ -24,7 +24,6 @@ import java.util.Optional; -import static com.facebook.presto.spi.type.BigintType.BIGINT; import static com.facebook.presto.spi.type.DoubleType.DOUBLE; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.aggregation; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.equiJoinClause; @@ -37,25 +36,26 @@ import static com.facebook.presto.sql.planner.plan.AggregationNode.Step.SINGLE; public class TestPushAggregationThroughOuterJoin + extends BaseRuleTest { @Test public void testPushesAggregationThroughLeftJoin() { - new RuleTester().assertThat(new PushAggregationThroughOuterJoin()) + tester().assertThat(new PushAggregationThroughOuterJoin()) .on(p -> p.aggregation(ab -> ab .source( p.join( JoinNode.Type.LEFT, - p.values(ImmutableList.of(p.symbol("COL1", BIGINT)), ImmutableList.of(expressions("10"))), - p.values(p.symbol("COL2", BIGINT)), - ImmutableList.of(new JoinNode.EquiJoinClause(p.symbol("COL1", BIGINT), p.symbol("COL2", BIGINT))), - ImmutableList.of(p.symbol("COL1", BIGINT), p.symbol("COL2", BIGINT)), + p.values(ImmutableList.of(p.symbol("COL1")), ImmutableList.of(expressions("10"))), + p.values(p.symbol("COL2")), + ImmutableList.of(new JoinNode.EquiJoinClause(p.symbol("COL1"), p.symbol("COL2"))), + ImmutableList.of(p.symbol("COL1"), p.symbol("COL2")), Optional.empty(), Optional.empty(), Optional.empty() )) .addAggregation(p.symbol("AVG", DOUBLE), PlanBuilder.expression("avg(COL2)"), ImmutableList.of(DOUBLE)) - .addGroupingSet(p.symbol("COL1", BIGINT)))) + .addGroupingSet(p.symbol("COL1")))) .matches( project(ImmutableMap.of( "COL1", expression("COL1"), @@ -82,19 +82,19 @@ public void testPushesAggregationThroughLeftJoin() @Test public void testPushesAggregationThroughRightJoin() { - new RuleTester().assertThat(new PushAggregationThroughOuterJoin()) + tester().assertThat(new PushAggregationThroughOuterJoin()) .on(p -> p.aggregation(ab -> ab .source(p.join( JoinNode.Type.RIGHT, - p.values(p.symbol("COL2", BIGINT)), - p.values(ImmutableList.of(p.symbol("COL1", BIGINT)), ImmutableList.of(expressions("10"))), - ImmutableList.of(new JoinNode.EquiJoinClause(p.symbol("COL2", BIGINT), p.symbol("COL1", BIGINT))), - ImmutableList.of(p.symbol("COL2", BIGINT), p.symbol("COL1", BIGINT)), + p.values(p.symbol("COL2")), + p.values(ImmutableList.of(p.symbol("COL1")), ImmutableList.of(expressions("10"))), + ImmutableList.of(new JoinNode.EquiJoinClause(p.symbol("COL2"), p.symbol("COL1"))), + ImmutableList.of(p.symbol("COL2"), p.symbol("COL1")), Optional.empty(), Optional.empty(), Optional.empty())) .addAggregation(p.symbol("AVG", DOUBLE), PlanBuilder.expression("avg(COL2)"), ImmutableList.of(DOUBLE)) - .addGroupingSet(p.symbol("COL1", BIGINT)))) + .addGroupingSet(p.symbol("COL1")))) .matches( project(ImmutableMap.of( "COALESCE", expression("coalesce(AVG, AVG_NULL)"), @@ -122,11 +122,11 @@ public void testPushesAggregationThroughRightJoin() @Test public void testDoesNotFireWhenNotDistinct() { - new RuleTester().assertThat(new PushAggregationThroughOuterJoin()) + tester().assertThat(new PushAggregationThroughOuterJoin()) .on(p -> p.aggregation(ab -> ab .source(p.join( JoinNode.Type.LEFT, - p.values(ImmutableList.of(p.symbol("COL1", BIGINT)), ImmutableList.of(expressions("10"), expressions("11"))), + p.values(ImmutableList.of(p.symbol("COL1")), ImmutableList.of(expressions("10"), expressions("11"))), p.values(new Symbol("COL2")), ImmutableList.of(new JoinNode.EquiJoinClause(new Symbol("COL1"), new Symbol("COL2"))), ImmutableList.of(new Symbol("COL1"), new Symbol("COL2")), @@ -141,10 +141,10 @@ public void testDoesNotFireWhenNotDistinct() @Test public void testDoesNotFireWhenGroupingOnInner() { - new RuleTester().assertThat(new PushAggregationThroughOuterJoin()) + tester().assertThat(new PushAggregationThroughOuterJoin()) .on(p -> p.aggregation(ab -> ab .source(p.join(JoinNode.Type.LEFT, - p.values(ImmutableList.of(p.symbol("COL1", BIGINT)), ImmutableList.of(expressions("10"))), + p.values(ImmutableList.of(p.symbol("COL1")), ImmutableList.of(expressions("10"))), p.values(new Symbol("COL2"), new Symbol("COL3")), ImmutableList.of(new JoinNode.EquiJoinClause(new Symbol("COL1"), new Symbol("COL2"))), ImmutableList.of(new Symbol("COL1"), new Symbol("COL2")), diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPushLimitThroughMarkDistinct.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPushLimitThroughMarkDistinct.java new file mode 100644 index 000000000000..898da02bfea5 --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPushLimitThroughMarkDistinct.java @@ -0,0 +1,58 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.iterative.rule; + +import com.facebook.presto.sql.planner.iterative.rule.test.BaseRuleTest; +import com.facebook.presto.sql.planner.plan.LimitNode; +import com.facebook.presto.sql.planner.plan.MarkDistinctNode; +import com.facebook.presto.sql.planner.plan.ValuesNode; +import org.testng.annotations.Test; + +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.node; +import static java.util.Collections.emptyList; + +public class TestPushLimitThroughMarkDistinct + extends BaseRuleTest +{ + @Test + public void test() + throws Exception + { + tester().assertThat(new PushLimitThroughMarkDistinct()) + .on(p -> + p.limit( + 1, + p.markDistinct( + p.values(), p.symbol("foo"), emptyList()))) + .matches( + node(MarkDistinctNode.class, + node(LimitNode.class, + node(ValuesNode.class)))); + } + + @Test + public void testDoesNotFire() + throws Exception + { + tester().assertThat(new PushLimitThroughMarkDistinct()) + .on(p -> + p.markDistinct( + p.limit( + 1, + p.values()), + p.symbol("foo"), + emptyList())) + .doesNotFire(); + } +} diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPushProjectionThroughExchange.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPushProjectionThroughExchange.java index 799181cb4424..e66e450e6b6b 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPushProjectionThroughExchange.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPushProjectionThroughExchange.java @@ -14,49 +14,31 @@ package com.facebook.presto.sql.planner.iterative.rule; import com.facebook.presto.sql.planner.Symbol; -import com.facebook.presto.sql.planner.iterative.rule.test.RuleTester; +import com.facebook.presto.sql.planner.iterative.rule.test.BaseRuleTest; import com.facebook.presto.sql.planner.plan.Assignments; import com.facebook.presto.sql.tree.ArithmeticBinaryExpression; import com.facebook.presto.sql.tree.LongLiteral; import com.facebook.presto.sql.tree.SymbolReference; import com.google.common.collect.ImmutableList; -import org.testng.annotations.AfterClass; -import org.testng.annotations.BeforeClass; import org.testng.annotations.Test; -import static com.facebook.presto.spi.type.BigintType.BIGINT; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.exchange; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.expression; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.project; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.values; -import static io.airlift.testing.Closeables.closeAllRuntimeException; public class TestPushProjectionThroughExchange + extends BaseRuleTest { - private RuleTester tester; - - @BeforeClass - public void setUp() - { - tester = new RuleTester(); - } - - @AfterClass(alwaysRun = true) - public void tearDown() - { - closeAllRuntimeException(tester); - tester = null; - } - @Test public void testDoesNotFireNoExchange() throws Exception { - tester.assertThat(new PushProjectionThroughExchange()) + tester().assertThat(new PushProjectionThroughExchange()) .on(p -> p.project( - Assignments.of(p.symbol("x", BIGINT), new LongLiteral("3")), - p.values(p.symbol("a", BIGINT)))) + Assignments.of(p.symbol("x"), new LongLiteral("3")), + p.values(p.symbol("a")))) .doesNotFire(); } @@ -64,11 +46,11 @@ public void testDoesNotFireNoExchange() public void testDoesNotFireNarrowingProjection() throws Exception { - tester.assertThat(new PushProjectionThroughExchange()) + tester().assertThat(new PushProjectionThroughExchange()) .on(p -> { - Symbol a = p.symbol("a", BIGINT); - Symbol b = p.symbol("b", BIGINT); - Symbol c = p.symbol("c", BIGINT); + Symbol a = p.symbol("a"); + Symbol b = p.symbol("b"); + Symbol c = p.symbol("c"); return p.project( Assignments.builder() @@ -87,13 +69,13 @@ public void testDoesNotFireNarrowingProjection() public void testSimpleMultipleInputs() throws Exception { - tester.assertThat(new PushProjectionThroughExchange()) + tester().assertThat(new PushProjectionThroughExchange()) .on(p -> { - Symbol a = p.symbol("a", BIGINT); - Symbol b = p.symbol("b", BIGINT); - Symbol c = p.symbol("c", BIGINT); - Symbol c2 = p.symbol("c2", BIGINT); - Symbol x = p.symbol("x", BIGINT); + Symbol a = p.symbol("a"); + Symbol b = p.symbol("b"); + Symbol c = p.symbol("c"); + Symbol c2 = p.symbol("c2"); + Symbol x = p.symbol("x"); return p.project( Assignments.of( x, new LongLiteral("3"), @@ -128,14 +110,14 @@ c2, new SymbolReference("c") public void testPartitioningColumnAndHashWithoutIdentityMappingInProjection() throws Exception { - tester.assertThat(new PushProjectionThroughExchange()) + tester().assertThat(new PushProjectionThroughExchange()) .on(p -> { - Symbol a = p.symbol("a", BIGINT); - Symbol b = p.symbol("b", BIGINT); - Symbol h = p.symbol("h", BIGINT); - Symbol aTimes5 = p.symbol("a_times_5", BIGINT); - Symbol bTimes5 = p.symbol("b_times_5", BIGINT); - Symbol hTimes5 = p.symbol("h_times_5", BIGINT); + Symbol a = p.symbol("a"); + Symbol b = p.symbol("b"); + Symbol h = p.symbol("h"); + Symbol aTimes5 = p.symbol("a_times_5"); + Symbol bTimes5 = p.symbol("b_times_5"); + Symbol hTimes5 = p.symbol("h_times_5"); return p.project( Assignments.builder() .put(aTimes5, new ArithmeticBinaryExpression(ArithmeticBinaryExpression.Type.MULTIPLY, new SymbolReference("a"), new LongLiteral("5"))) diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPushProjectionThroughUnion.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPushProjectionThroughUnion.java index ad17d765a071..817fc481b023 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPushProjectionThroughUnion.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPushProjectionThroughUnion.java @@ -14,50 +14,32 @@ package com.facebook.presto.sql.planner.iterative.rule; import com.facebook.presto.sql.planner.Symbol; -import com.facebook.presto.sql.planner.iterative.rule.test.RuleTester; +import com.facebook.presto.sql.planner.iterative.rule.test.BaseRuleTest; import com.facebook.presto.sql.planner.plan.Assignments; import com.facebook.presto.sql.tree.ArithmeticBinaryExpression; import com.facebook.presto.sql.tree.LongLiteral; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableListMultimap; import com.google.common.collect.ImmutableMap; -import org.testng.annotations.AfterClass; -import org.testng.annotations.BeforeClass; import org.testng.annotations.Test; -import static com.facebook.presto.spi.type.BigintType.BIGINT; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.expression; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.project; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.union; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.values; -import static io.airlift.testing.Closeables.closeAllRuntimeException; public class TestPushProjectionThroughUnion + extends BaseRuleTest { - private RuleTester tester; - - @BeforeClass - public void setUp() - { - tester = new RuleTester(); - } - - @AfterClass(alwaysRun = true) - public void tearDown() - { - closeAllRuntimeException(tester); - tester = null; - } - @Test public void testDoesNotFire() throws Exception { - tester.assertThat(new PushProjectionThroughUnion()) + tester().assertThat(new PushProjectionThroughUnion()) .on(p -> p.project( - Assignments.of(p.symbol("x", BIGINT), new LongLiteral("3")), - p.values(p.symbol("a", BIGINT)))) + Assignments.of(p.symbol("x"), new LongLiteral("3")), + p.values(p.symbol("a")))) .doesNotFire(); } @@ -65,23 +47,23 @@ public void testDoesNotFire() public void test() throws Exception { - tester.assertThat(new PushProjectionThroughUnion()) + tester().assertThat(new PushProjectionThroughUnion()) .on(p -> { - Symbol a = p.symbol("a", BIGINT); - Symbol b = p.symbol("b", BIGINT); - Symbol c = p.symbol("c", BIGINT); - Symbol cTimes3 = p.symbol("c_times_3", BIGINT); + Symbol a = p.symbol("a"); + Symbol b = p.symbol("b"); + Symbol c = p.symbol("c"); + Symbol cTimes3 = p.symbol("c_times_3"); return p.project( Assignments.of(cTimes3, new ArithmeticBinaryExpression(ArithmeticBinaryExpression.Type.MULTIPLY, c.toSymbolReference(), new LongLiteral("3"))), p.union( - ImmutableList.of( - p.values(a), - p.values(b)), ImmutableListMultimap.builder() .put(c, a) .put(c, b) .build(), - ImmutableList.of(c))); + ImmutableList.of( + p.values(a), + p.values(b)) + )); }) .matches( union( diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPushTableWriteThroughUnion.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPushTableWriteThroughUnion.java new file mode 100644 index 000000000000..ae48bdd1560c --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPushTableWriteThroughUnion.java @@ -0,0 +1,49 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.iterative.rule; + +import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.sql.planner.iterative.rule.test.RuleTester; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableListMultimap; +import com.google.common.collect.ImmutableMap; +import org.testng.annotations.Test; + +import static com.facebook.presto.spi.type.BigintType.BIGINT; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.tableWriter; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.union; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.values; + +public class TestPushTableWriteThroughUnion +{ + @Test + public void testPushThroughUnion() + { + new RuleTester().assertThat(new PushTableWriteThroughUnion()) + .on(p -> + p.tableWriter( + ImmutableList.of(p.symbol("A", BIGINT), p.symbol("B", BIGINT)), ImmutableList.of("a", "b"), + p.union( + ImmutableListMultimap.builder() + .putAll(p.symbol("A", BIGINT), p.symbol("A1", BIGINT), p.symbol("B2", BIGINT)) + .putAll(p.symbol("B", BIGINT), p.symbol("B1", BIGINT), p.symbol("A2", BIGINT)) + .build(), + ImmutableList.of( + p.values(p.symbol("A1", BIGINT), p.symbol("B1", BIGINT)), + p.values(p.symbol("A2", BIGINT), p.symbol("B2", BIGINT)))))) + .matches(union( + tableWriter(ImmutableList.of("A1", "B1"), ImmutableList.of("a", "b"), values(ImmutableMap.of("A1", 0, "B1", 1))), + tableWriter(ImmutableList.of("B2", "A2"), ImmutableList.of("a", "b"), values(ImmutableMap.of("A2", 0, "B2", 1))))); + } +} diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestRemoveEmptyDelete.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestRemoveEmptyDelete.java index 0e15e514d58a..25f9064ca17f 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestRemoveEmptyDelete.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestRemoveEmptyDelete.java @@ -13,43 +13,33 @@ */ package com.facebook.presto.sql.planner.iterative.rule; +import com.facebook.presto.metadata.TableHandle; import com.facebook.presto.spi.SchemaTableName; import com.facebook.presto.spi.type.BigintType; import com.facebook.presto.sql.planner.assertions.PlanMatchPattern; -import com.facebook.presto.sql.planner.iterative.rule.test.RuleTester; +import com.facebook.presto.sql.planner.iterative.rule.test.BaseRuleTest; +import com.facebook.presto.tpch.TpchTableHandle; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; -import org.testng.annotations.AfterClass; -import org.testng.annotations.BeforeClass; import org.testng.annotations.Test; -import static io.airlift.testing.Closeables.closeAllRuntimeException; +import static com.facebook.presto.sql.planner.iterative.rule.test.RuleTester.CATALOG_ID; +import static com.facebook.presto.sql.planner.iterative.rule.test.RuleTester.CONNECTOR_ID; public class TestRemoveEmptyDelete + extends BaseRuleTest { - private RuleTester tester; - - @BeforeClass - public void setUp() - { - tester = new RuleTester(); - } - - @AfterClass(alwaysRun = true) - public void tearDown() - { - closeAllRuntimeException(tester); - tester = null; - } - @Test public void testDoesNotFire() throws Exception { - tester.assertThat(new RemoveEmptyDelete()) + tester().assertThat(new RemoveEmptyDelete()) .on(p -> p.tableDelete( new SchemaTableName("sch", "tab"), - p.tableScan(ImmutableList.of(), ImmutableMap.of()), + p.tableScan( + new TableHandle(CONNECTOR_ID, new TpchTableHandle(CATALOG_ID, "nation", 1.0)), + ImmutableList.of(), + ImmutableMap.of()), p.symbol("a", BigintType.BIGINT)) ) .doesNotFire(); @@ -58,7 +48,7 @@ public void testDoesNotFire() @Test public void test() { - tester.assertThat(new RemoveEmptyDelete()) + tester().assertThat(new RemoveEmptyDelete()) .on(p -> p.tableDelete( new SchemaTableName("sch", "tab"), p.values(), diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestRemoveFullSample.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestRemoveFullSample.java index 72f20993a085..c64ffd062b6c 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestRemoveFullSample.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestRemoveFullSample.java @@ -13,48 +13,30 @@ */ package com.facebook.presto.sql.planner.iterative.rule; -import com.facebook.presto.sql.planner.iterative.rule.test.RuleTester; +import com.facebook.presto.sql.planner.iterative.rule.test.BaseRuleTest; import com.facebook.presto.sql.planner.plan.SampleNode.Type; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; -import org.testng.annotations.AfterClass; -import org.testng.annotations.BeforeClass; import org.testng.annotations.Test; -import static com.facebook.presto.spi.type.BigintType.BIGINT; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.filter; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.values; import static com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder.expression; import static com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder.expressions; -import static io.airlift.testing.Closeables.closeAllRuntimeException; public class TestRemoveFullSample + extends BaseRuleTest { - private RuleTester tester; - - @BeforeClass - public void setUp() - { - tester = new RuleTester(); - } - - @AfterClass(alwaysRun = true) - public void tearDown() - { - closeAllRuntimeException(tester); - tester = null; - } - @Test public void testDoesNotFire() throws Exception { - tester.assertThat(new RemoveFullSample()) + tester().assertThat(new RemoveFullSample()) .on(p -> p.sample( 0.15, Type.BERNOULLI, - p.values(p.symbol("a", BIGINT)))) + p.values(p.symbol("a")))) .doesNotFire(); } @@ -62,7 +44,7 @@ public void testDoesNotFire() public void test() throws Exception { - tester.assertThat(new RemoveFullSample()) + tester().assertThat(new RemoveFullSample()) .on(p -> p.sample( 1.0, @@ -70,7 +52,7 @@ public void test() p.filter( expression("b > 5"), p.values( - ImmutableList.of(p.symbol("a", BIGINT), p.symbol("b", BIGINT)), + ImmutableList.of(p.symbol("a"), p.symbol("b")), ImmutableList.of( expressions("1", "10"), expressions("2", "11")))))) diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestRemoveTrivialFilters.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestRemoveTrivialFilters.java new file mode 100644 index 000000000000..341f1c22c1e8 --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestRemoveTrivialFilters.java @@ -0,0 +1,52 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.iterative.rule; + +import com.facebook.presto.sql.planner.iterative.rule.test.BaseRuleTest; +import com.google.common.collect.ImmutableList; +import org.testng.annotations.Test; + +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.values; + +public class TestRemoveTrivialFilters + extends BaseRuleTest +{ + @Test + public void testDoesNotFire() + { + tester().assertThat(new RemoveTrivialFilters()) + .on(p -> p.filter(p.expression("1 = 1"), p.values())) + .doesNotFire(); + } + + @Test + public void testRemovesTrueFilter() + { + tester().assertThat(new RemoveTrivialFilters()) + .on(p -> p.filter(p.expression("TRUE"), p.values())) + .matches(values()); + } + + @Test + public void testRemovesFalseFilter() + { + tester().assertThat(new RemoveTrivialFilters()) + .on(p -> p.filter( + p.expression("FALSE"), + p.values( + ImmutableList.of(p.symbol("a")), + ImmutableList.of(p.expressions("1"))))) + .matches(values("a")); + } +} diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestRemoveUnreferencedScalarApplyNodes.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestRemoveUnreferencedScalarApplyNodes.java new file mode 100644 index 000000000000..1532466c06cb --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestRemoveUnreferencedScalarApplyNodes.java @@ -0,0 +1,50 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.facebook.presto.sql.planner.iterative.rule; + +import com.facebook.presto.sql.planner.iterative.rule.test.BaseRuleTest; +import com.facebook.presto.sql.planner.plan.Assignments; +import com.google.common.collect.ImmutableList; +import org.testng.annotations.Test; + +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.values; + +public class TestRemoveUnreferencedScalarApplyNodes + extends BaseRuleTest +{ + @Test + public void testDoesNotFire() + { + tester().assertThat(new RemoveUnreferencedScalarApplyNodes()) + .on(p -> p.apply( + Assignments.of(p.symbol("z"), p.expression("x IN (y)")), + ImmutableList.of(), + p.values(p.symbol("x")), + p.values(p.symbol("y")))) + .doesNotFire(); + } + + @Test + public void testEmptyAssignments() + { + tester().assertThat(new RemoveUnreferencedScalarApplyNodes()) + .on(p -> p.apply( + Assignments.of(), + ImmutableList.of(), + p.values(p.symbol("x")), + p.values(p.symbol("y")))) + .matches(values("x")); + } +} diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestRemoveUnreferencedScalarLateralNodes.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestRemoveUnreferencedScalarLateralNodes.java new file mode 100644 index 000000000000..e7ef0fd1dc2c --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestRemoveUnreferencedScalarLateralNodes.java @@ -0,0 +1,59 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.iterative.rule; + +import com.facebook.presto.spi.type.BigintType; +import com.facebook.presto.sql.planner.iterative.rule.test.BaseRuleTest; +import com.google.common.collect.ImmutableList; +import org.testng.annotations.Test; + +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.values; +import static java.util.Collections.emptyList; + +public class TestRemoveUnreferencedScalarLateralNodes + extends BaseRuleTest +{ + @Test + public void testRemoveUnreferencedInput() + { + tester().assertThat(new RemoveUnreferencedScalarLateralNodes()) + .on(p -> p.lateral( + emptyList(), + p.values(p.symbol("x", BigintType.BIGINT)), + p.values(emptyList(), ImmutableList.of(emptyList())))) + .matches(values("x")); + } + + @Test + public void testRemoveUnreferencedSubquery() + { + tester().assertThat(new RemoveUnreferencedScalarLateralNodes()) + .on(p -> p.lateral( + emptyList(), + p.values(emptyList(), ImmutableList.of(emptyList())), + p.values(p.symbol("x", BigintType.BIGINT)))) + .matches(values("x")); + } + + @Test + public void testDoesNotFire() + { + tester().assertThat(new RemoveUnreferencedScalarLateralNodes()) + .on(p -> p.lateral( + emptyList(), + p.values(), + p.values())) + .doesNotFire(); + } +} diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestSwapAdjacentWindowsBySpecifications.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestSwapAdjacentWindowsBySpecifications.java index f058b7972484..96dbdd47fa3f 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestSwapAdjacentWindowsBySpecifications.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestSwapAdjacentWindowsBySpecifications.java @@ -16,7 +16,7 @@ import com.facebook.presto.metadata.FunctionKind; import com.facebook.presto.metadata.Signature; import com.facebook.presto.sql.planner.assertions.ExpectedValueProvider; -import com.facebook.presto.sql.planner.iterative.rule.test.RuleTester; +import com.facebook.presto.sql.planner.iterative.rule.test.BaseRuleTest; import com.facebook.presto.sql.planner.plan.WindowNode; import com.facebook.presto.sql.tree.FunctionCall; import com.facebook.presto.sql.tree.QualifiedName; @@ -25,8 +25,6 @@ import com.facebook.presto.sql.tree.WindowFrame; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; -import org.testng.annotations.AfterClass; -import org.testng.annotations.BeforeClass; import org.testng.annotations.Test; import java.util.Optional; @@ -39,25 +37,10 @@ import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.window; import static com.facebook.presto.sql.tree.FrameBound.Type.CURRENT_ROW; import static com.facebook.presto.sql.tree.FrameBound.Type.UNBOUNDED_PRECEDING; -import static io.airlift.testing.Closeables.closeAllRuntimeException; public class TestSwapAdjacentWindowsBySpecifications + extends BaseRuleTest { - private RuleTester tester; - - @BeforeClass - public void setUp() - { - tester = new RuleTester(); - } - - @AfterClass(alwaysRun = true) - public void tearDown() - { - closeAllRuntimeException(tester); - tester = null; - } - private WindowNode.Frame frame; private Signature signature; @@ -79,8 +62,8 @@ public TestSwapAdjacentWindowsBySpecifications() public void doesNotFireOnPlanWithoutWindowFunctions() throws Exception { - tester.assertThat(new SwapAdjacentWindowsBySpecifications()) - .on(p -> p.values(p.symbol("a", BIGINT))) + tester().assertThat(new SwapAdjacentWindowsBySpecifications()) + .on(p -> p.values(p.symbol("a"))) .doesNotFire(); } @@ -88,14 +71,14 @@ public void doesNotFireOnPlanWithoutWindowFunctions() public void doesNotFireOnPlanWithSingleWindowNode() throws Exception { - tester.assertThat(new SwapAdjacentWindowsBySpecifications()) + tester().assertThat(new SwapAdjacentWindowsBySpecifications()) .on(p -> p.window(new WindowNode.Specification( - ImmutableList.of(p.symbol("a", BIGINT)), + ImmutableList.of(p.symbol("a")), ImmutableList.of(), ImmutableMap.of()), - ImmutableMap.of(p.symbol("avg_1", BIGINT), + ImmutableMap.of(p.symbol("avg_1"), new WindowNode.Function(new FunctionCall(QualifiedName.of("avg"), ImmutableList.of()), signature, frame)), - p.values(p.symbol("a", BIGINT)))) + p.values(p.symbol("a")))) .doesNotFire(); } @@ -112,26 +95,29 @@ public void subsetComesFirst() Optional windowAB = Optional.of(new Window(ImmutableList.of(new SymbolReference("a"), new SymbolReference("b")), Optional.empty(), Optional.empty())); Optional windowA = Optional.of(new Window(ImmutableList.of(new SymbolReference("a")), Optional.empty(), Optional.empty())); - tester.assertThat(new SwapAdjacentWindowsBySpecifications()) + tester().assertThat(new SwapAdjacentWindowsBySpecifications()) .on(p -> p.window(new WindowNode.Specification( - ImmutableList.of(p.symbol("a", BIGINT)), + ImmutableList.of(p.symbol("a")), ImmutableList.of(), ImmutableMap.of()), ImmutableMap.of(p.symbol("avg_1", DOUBLE), new WindowNode.Function(new FunctionCall(QualifiedName.of("avg"), windowA, false, ImmutableList.of(new SymbolReference("a"))), signature, frame)), p.window(new WindowNode.Specification( - ImmutableList.of(p.symbol("a", BIGINT), p.symbol("b", BIGINT)), + ImmutableList.of(p.symbol("a"), p.symbol("b")), ImmutableList.of(), ImmutableMap.of()), ImmutableMap.of(p.symbol("avg_2", DOUBLE), new WindowNode.Function(new FunctionCall(QualifiedName.of("avg"), windowAB, false, ImmutableList.of(new SymbolReference("b"))), signature, frame)), - p.values(p.symbol("a", BIGINT), p.symbol("b", BIGINT))))) - .matches(window(specificationAB, - ImmutableList.of(functionCall("avg", Optional.empty(), ImmutableList.of(columnBAlias))), - window(specificationA, - ImmutableList.of(functionCall("avg", Optional.empty(), ImmutableList.of(columnAAlias))), - values(ImmutableMap.of(columnAAlias, 0, columnBAlias, 1))))); + p.values(p.symbol("a"), p.symbol("b"))))) + .matches( + window(windowMatcherBuilder -> windowMatcherBuilder + .specification(specificationAB) + .addFunction(functionCall("avg", Optional.empty(), ImmutableList.of(columnBAlias))), + window(windowMatcherBuilder -> windowMatcherBuilder + .specification(specificationA) + .addFunction(functionCall("avg", Optional.empty(), ImmutableList.of(columnAAlias))), + values(ImmutableMap.of(columnAAlias, 0, columnBAlias, 1))))); } @Test @@ -140,21 +126,21 @@ public void dependentWindowsAreNotReordered() { Optional windowA = Optional.of(new Window(ImmutableList.of(new SymbolReference("a")), Optional.empty(), Optional.empty())); - tester.assertThat(new SwapAdjacentWindowsBySpecifications()) + tester().assertThat(new SwapAdjacentWindowsBySpecifications()) .on(p -> p.window(new WindowNode.Specification( - ImmutableList.of(p.symbol("a", BIGINT)), + ImmutableList.of(p.symbol("a")), ImmutableList.of(), ImmutableMap.of()), - ImmutableMap.of(p.symbol("avg_1", BIGINT), + ImmutableMap.of(p.symbol("avg_1"), new WindowNode.Function(new FunctionCall(QualifiedName.of("avg"), windowA, false, ImmutableList.of(new SymbolReference("avg_2"))), signature, frame)), p.window(new WindowNode.Specification( - ImmutableList.of(p.symbol("a", BIGINT), p.symbol("b", BIGINT)), + ImmutableList.of(p.symbol("a"), p.symbol("b")), ImmutableList.of(), ImmutableMap.of()), - ImmutableMap.of(p.symbol("avg_2", BIGINT), + ImmutableMap.of(p.symbol("avg_2"), new WindowNode.Function(new FunctionCall(QualifiedName.of("avg"), windowA, false, ImmutableList.of(new SymbolReference("a"))), signature, frame)), - p.values(p.symbol("a", BIGINT), p.symbol("b", BIGINT))))) + p.values(p.symbol("a"), p.symbol("b"))))) .doesNotFire(); } } diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestTransformCorrelatedScalarAggregationToJoin.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestTransformCorrelatedScalarAggregationToJoin.java index 75e5f7c4df3d..af5e77648d97 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestTransformCorrelatedScalarAggregationToJoin.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestTransformCorrelatedScalarAggregationToJoin.java @@ -55,7 +55,7 @@ public void setUp() public void doesNotFireOnPlanWithoutApplyNode() { tester.assertThat(rule) - .on(p -> p.values(p.symbol("a", BIGINT))) + .on(p -> p.values(p.symbol("a"))) .doesNotFire(); } @@ -64,9 +64,9 @@ public void doesNotFireOnCorrelatedWithoutAggregation() { tester.assertThat(rule) .on(p -> p.lateral( - ImmutableList.of(p.symbol("corr", BIGINT)), - p.values(p.symbol("corr", BIGINT)), - p.values(p.symbol("a", BIGINT)))) + ImmutableList.of(p.symbol("corr")), + p.values(p.symbol("corr")), + p.values(p.symbol("a")))) .doesNotFire(); } @@ -76,8 +76,8 @@ public void doesNotFireOnUncorrelated() tester.assertThat(rule) .on(p -> p.lateral( ImmutableList.of(), - p.values(p.symbol("a", BIGINT)), - p.values(p.symbol("b", BIGINT)))) + p.values(p.symbol("a")), + p.values(p.symbol("b")))) .doesNotFire(); } @@ -86,12 +86,12 @@ public void doesNotFireOnCorrelatedWithNonScalarAggregation() { tester.assertThat(rule) .on(p -> p.lateral( - ImmutableList.of(p.symbol("corr", BIGINT)), - p.values(p.symbol("corr", BIGINT)), + ImmutableList.of(p.symbol("corr")), + p.values(p.symbol("corr")), p.aggregation(ab -> ab - .source(p.values(p.symbol("a", BIGINT), p.symbol("b", BIGINT))) - .addAggregation(p.symbol("sum", BIGINT), PlanBuilder.expression("sum(a)"), ImmutableList.of(BIGINT)) - .addGroupingSet(p.symbol("b", BIGINT))))) + .source(p.values(p.symbol("a"), p.symbol("b"))) + .addAggregation(p.symbol("sum"), PlanBuilder.expression("sum(a)"), ImmutableList.of(BIGINT)) + .addGroupingSet(p.symbol("b"))))) .doesNotFire(); } @@ -100,11 +100,11 @@ public void rewritesOnSubqueryWithoutProjection() { tester.assertThat(rule) .on(p -> p.lateral( - ImmutableList.of(p.symbol("corr", BIGINT)), - p.values(p.symbol("corr", BIGINT)), + ImmutableList.of(p.symbol("corr")), + p.values(p.symbol("corr")), p.aggregation(ab -> ab - .source(p.values(p.symbol("a", BIGINT), p.symbol("b", BIGINT))) - .addAggregation(p.symbol("sum", BIGINT), PlanBuilder.expression("sum(a)"), ImmutableList.of(BIGINT)) + .source(p.values(p.symbol("a"), p.symbol("b"))) + .addAggregation(p.symbol("sum"), PlanBuilder.expression("sum(a)"), ImmutableList.of(BIGINT)) .globalGrouping()))) .matches( project(ImmutableMap.of("sum_1", expression("sum_1"), "corr", expression("corr")), @@ -122,12 +122,12 @@ public void rewritesOnSubqueryWithProjection() { tester.assertThat(rule) .on(p -> p.lateral( - ImmutableList.of(p.symbol("corr", BIGINT)), - p.values(p.symbol("corr", BIGINT)), - p.project(Assignments.of(p.symbol("expr", BIGINT), p.expression("sum + 1")), + ImmutableList.of(p.symbol("corr")), + p.values(p.symbol("corr")), + p.project(Assignments.of(p.symbol("expr"), p.expression("sum + 1")), p.aggregation(ab -> ab - .source(p.values(p.symbol("a", BIGINT), p.symbol("b", BIGINT))) - .addAggregation(p.symbol("sum", BIGINT), PlanBuilder.expression("sum(a)"), ImmutableList.of(BIGINT)) + .source(p.values(p.symbol("a"), p.symbol("b"))) + .addAggregation(p.symbol("sum"), PlanBuilder.expression("sum(a)"), ImmutableList.of(BIGINT)) .globalGrouping())))) .matches( project(ImmutableMap.of("corr", expression("corr"), "expr", expression("(\"sum_1\" + 1)")), diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestTransformExistsApplyToScalarLateralJoin.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestTransformExistsApplyToScalarLateralJoin.java index 705ef9391d72..d4c34a0c913f 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestTransformExistsApplyToScalarLateralJoin.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestTransformExistsApplyToScalarLateralJoin.java @@ -27,7 +27,6 @@ import org.testng.annotations.BeforeClass; import org.testng.annotations.Test; -import static com.facebook.presto.spi.type.BigintType.BIGINT; import static com.facebook.presto.spi.type.BooleanType.BOOLEAN; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.aggregation; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.functionCall; @@ -63,15 +62,15 @@ public void tearDown() public void testDoesNotFire() { tester.assertThat(transformExistsApplyToScalarApply) - .on(p -> p.values(p.symbol("a", BIGINT))) + .on(p -> p.values(p.symbol("a"))) .doesNotFire(); tester.assertThat(transformExistsApplyToScalarApply) .on(p -> p.lateral( - ImmutableList.of(p.symbol("a", BIGINT)), - p.values(p.symbol("a", BIGINT)), - p.values(p.symbol("a", BIGINT))) + ImmutableList.of(p.symbol("a")), + p.values(p.symbol("a")), + p.values(p.symbol("a"))) ) .doesNotFire(); } @@ -86,7 +85,7 @@ public void testRewrite() Assignments.of(p.symbol("b", BOOLEAN), expression("EXISTS(SELECT \"a\")")), ImmutableList.of(), p.values(), - p.values(p.symbol("a", BIGINT))) + p.values(p.symbol("a"))) ) .matches(lateral( ImmutableList.of(), diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestTransformUncorrelatedInPredicateSubqueryToSemiJoin.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestTransformUncorrelatedInPredicateSubqueryToSemiJoin.java new file mode 100644 index 000000000000..47255b4fc74b --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestTransformUncorrelatedInPredicateSubqueryToSemiJoin.java @@ -0,0 +1,71 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.iterative.rule; + +import com.facebook.presto.sql.planner.iterative.rule.test.BaseRuleTest; +import com.facebook.presto.sql.planner.plan.Assignments; +import com.facebook.presto.sql.planner.plan.SemiJoinNode; +import com.facebook.presto.sql.tree.ExistsPredicate; +import com.facebook.presto.sql.tree.InPredicate; +import com.facebook.presto.sql.tree.LongLiteral; +import com.facebook.presto.sql.tree.SymbolReference; +import org.testng.annotations.Test; + +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.node; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.values; +import static java.util.Collections.emptyList; + +public class TestTransformUncorrelatedInPredicateSubqueryToSemiJoin + extends BaseRuleTest +{ + @Test + public void testDoesNotFireOnNoCorrelation() + { + tester().assertThat(new TransformUncorrelatedInPredicateSubqueryToSemiJoin()) + .on(p -> p.apply( + Assignments.of(), + emptyList(), + p.values(), + p.values())) + .doesNotFire(); + } + + @Test + public void testDoesNotFireOnNonInPredicateSubquery() + { + tester().assertThat(new TransformUncorrelatedInPredicateSubqueryToSemiJoin()) + .on(p -> p.apply( + Assignments.of(p.symbol("x"), new ExistsPredicate(new LongLiteral("1"))), + emptyList(), + p.values(), + p.values())) + .doesNotFire(); + } + + @Test + public void testFiresForInPredicate() + { + tester().assertThat(new TransformUncorrelatedInPredicateSubqueryToSemiJoin()) + .on(p -> p.apply( + Assignments.of( + p.symbol("x"), + new InPredicate( + new SymbolReference("y"), + new SymbolReference("z"))), + emptyList(), + p.values(p.symbol("y")), + p.values(p.symbol("z")))) + .matches(node(SemiJoinNode.class, values("y"), values("z"))); + } +} diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestTransformUncorrelatedLateralToJoin.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestTransformUncorrelatedLateralToJoin.java new file mode 100644 index 000000000000..9ea4319eb045 --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestTransformUncorrelatedLateralToJoin.java @@ -0,0 +1,47 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.iterative.rule; + +import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.sql.planner.iterative.rule.test.BaseRuleTest; +import com.facebook.presto.sql.planner.plan.JoinNode; +import com.google.common.collect.ImmutableList; +import org.testng.annotations.Test; + +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.join; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.values; +import static java.util.Collections.emptyList; + +public class TestTransformUncorrelatedLateralToJoin + extends BaseRuleTest +{ + @Test + public void test() + { + tester() + .assertThat(new TransformUncorrelatedLateralToJoin()) + .on(p -> p.lateral(emptyList(), p.values(), p.values())) + .matches(join(JoinNode.Type.INNER, emptyList(), values(), values())); + } + + @Test + public void testDoesNotFire() + { + Symbol symbol = new Symbol("x"); + tester() + .assertThat(new TransformUncorrelatedLateralToJoin()) + .on(p -> p.lateral(ImmutableList.of(symbol), p.values(symbol), p.values())) + .doesNotFire(); + } +} diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/BaseRuleTest.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/BaseRuleTest.java new file mode 100644 index 000000000000..0a0f320b2563 --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/BaseRuleTest.java @@ -0,0 +1,42 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.iterative.rule.test; + +import org.testng.annotations.AfterClass; +import org.testng.annotations.BeforeClass; + +import static io.airlift.testing.Closeables.closeAllRuntimeException; + +public abstract class BaseRuleTest +{ + private RuleTester tester; + + @BeforeClass + public final void setUp() + { + tester = new RuleTester(); + } + + @AfterClass(alwaysRun = true) + public final void tearDown() + { + closeAllRuntimeException(tester); + tester = null; + } + + protected RuleTester tester() + { + return tester; + } +} diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/PlanBuilder.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/PlanBuilder.java index 3f746329f7bc..1601737dfb28 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/PlanBuilder.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/PlanBuilder.java @@ -14,6 +14,7 @@ package com.facebook.presto.sql.planner.iterative.rule.test; import com.facebook.presto.connector.ConnectorId; +import com.facebook.presto.metadata.IndexHandle; import com.facebook.presto.metadata.Metadata; import com.facebook.presto.metadata.Signature; import com.facebook.presto.metadata.TableHandle; @@ -28,20 +29,28 @@ import com.facebook.presto.sql.planner.PartitioningScheme; import com.facebook.presto.sql.planner.PlanNodeIdAllocator; import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.sql.planner.TestingConnectorIndexHandle; +import com.facebook.presto.sql.planner.TestingConnectorTransactionHandle; +import com.facebook.presto.sql.planner.TestingWriterTarget; import com.facebook.presto.sql.planner.plan.AggregationNode; import com.facebook.presto.sql.planner.plan.AggregationNode.Aggregation; import com.facebook.presto.sql.planner.plan.AggregationNode.Step; import com.facebook.presto.sql.planner.plan.ApplyNode; import com.facebook.presto.sql.planner.plan.Assignments; import com.facebook.presto.sql.planner.plan.DeleteNode; +import com.facebook.presto.sql.planner.plan.EnforceSingleRowNode; import com.facebook.presto.sql.planner.plan.ExchangeNode; import com.facebook.presto.sql.planner.plan.FilterNode; +import com.facebook.presto.sql.planner.plan.IndexSourceNode; import com.facebook.presto.sql.planner.plan.JoinNode; import com.facebook.presto.sql.planner.plan.LateralJoinNode; import com.facebook.presto.sql.planner.plan.LimitNode; +import com.facebook.presto.sql.planner.plan.MarkDistinctNode; +import com.facebook.presto.sql.planner.plan.OutputNode; import com.facebook.presto.sql.planner.plan.PlanNode; import com.facebook.presto.sql.planner.plan.ProjectNode; import com.facebook.presto.sql.planner.plan.SampleNode; +import com.facebook.presto.sql.planner.plan.SemiJoinNode; import com.facebook.presto.sql.planner.plan.TableFinishNode; import com.facebook.presto.sql.planner.plan.TableScanNode; import com.facebook.presto.sql.planner.plan.TableWriterNode; @@ -63,9 +72,12 @@ import java.util.List; import java.util.Map; import java.util.Optional; +import java.util.Set; import java.util.function.Consumer; import java.util.stream.Stream; +import static com.facebook.presto.spi.type.BigintType.BIGINT; +import static com.facebook.presto.spi.type.VarbinaryType.VARBINARY; import static com.facebook.presto.sql.planner.SystemPartitioningHandle.FIXED_HASH_DISTRIBUTION; import static com.facebook.presto.sql.planner.SystemPartitioningHandle.SINGLE_DISTRIBUTION; import static com.google.common.base.Preconditions.checkArgument; @@ -85,6 +97,15 @@ public PlanBuilder(PlanNodeIdAllocator idAllocator, Metadata metadata) this.metadata = metadata; } + public OutputNode output(List columnNames, List outputs, PlanNode source) + { + return new OutputNode( + idAllocator.getNextId(), + source, + columnNames, + outputs); + } + public ValuesNode values(Symbol... columns) { return new ValuesNode( @@ -98,11 +119,21 @@ public ValuesNode values(List columns, List> rows) return new ValuesNode(idAllocator.getNextId(), columns, rows); } + public EnforceSingleRowNode enforceSingleRow(PlanNode source) + { + return new EnforceSingleRowNode(idAllocator.getNextId(), source); + } + public LimitNode limit(long limit, PlanNode source) { return new LimitNode(idAllocator.getNextId(), source, limit, false); } + public MarkDistinctNode markDistinct(PlanNode source, Symbol markerSymbol, List distinctSymbols) + { + return new MarkDistinctNode(idAllocator.getNextId(), source, markerSymbol, distinctSymbols, Optional.empty()); + } + public SampleNode sample(double sampleRatio, SampleNode.Type type, PlanNode source) { return new SampleNode(idAllocator.getNextId(), source, sampleRatio, type); @@ -113,6 +144,16 @@ public ProjectNode project(Assignments assignments, PlanNode source) return new ProjectNode(idAllocator.getNextId(), source, assignments); } + public MarkDistinctNode markDistinct(Symbol markerSymbol, List distinctSymbols, PlanNode source) + { + return new MarkDistinctNode(idAllocator.getNextId(), source, markerSymbol, distinctSymbols, Optional.empty()); + } + + public MarkDistinctNode markDistinct(Symbol markerSymbol, List distinctSymbols, Symbol hashSymbol, PlanNode source) + { + return new MarkDistinctNode(idAllocator.getNextId(), source, markerSymbol, distinctSymbols, Optional.of(hashSymbol)); + } + public FilterNode filter(Expression predicate, PlanNode source) { return new FilterNode(idAllocator.getNextId(), source, predicate); @@ -220,14 +261,23 @@ public LateralJoinNode lateral(List correlation, PlanNode input, PlanNod } public TableScanNode tableScan(List symbols, Map assignments) + { + return tableScan(symbols, assignments, null); + } + + public TableScanNode tableScan(List symbols, Map assignments, Expression originalConstraint) { TableHandle tableHandle = new TableHandle(new ConnectorId("testConnector"), new TestingTableHandle()); - return tableScan(tableHandle, symbols, assignments); + return tableScan(tableHandle, symbols, assignments, originalConstraint); } public TableScanNode tableScan(TableHandle tableHandle, List symbols, Map assignments) { - Expression originalConstraint = null; + return tableScan(tableHandle, symbols, assignments, null); + } + + public TableScanNode tableScan(TableHandle tableHandle, List symbols, Map assignments, Expression originalConstraint) + { return new TableScanNode( idAllocator.getNextId(), tableHandle, @@ -274,6 +324,47 @@ public ExchangeNode gatheringExchange(ExchangeNode.Scope scope, PlanNode child) .addInputsSet(child.getOutputSymbols())); } + public SemiJoinNode semiJoin( + Symbol sourceJoinSymbol, + Symbol filteringSourceJoinSymbol, + Symbol semiJoinOutput, + Optional sourceHashSymbol, + Optional filteringSourceHashSymbol, + PlanNode source, + PlanNode filteringSource) + { + return new SemiJoinNode(idAllocator.getNextId(), + source, + filteringSource, + sourceJoinSymbol, + filteringSourceJoinSymbol, + semiJoinOutput, + sourceHashSymbol, + filteringSourceHashSymbol, + Optional.empty()); + } + + public IndexSourceNode indexSource( + TableHandle tableHandle, + Set lookupSymbols, + List outputSymbols, + Map assignments, + TupleDomain effectiveTupleDomain) + { + return new IndexSourceNode( + idAllocator.getNextId(), + new IndexHandle( + tableHandle.getConnectorId(), + TestingConnectorTransactionHandle.INSTANCE, + TestingConnectorIndexHandle.INSTANCE), + tableHandle, + Optional.empty(), + lookupSymbols, + outputSymbols, + assignments, + effectiveTupleDomain); + } + public ExchangeNode exchange(Consumer exchangeBuilderConsumer) { ExchangeBuilder exchangeBuilder = new ExchangeBuilder(); @@ -352,7 +443,17 @@ protected ExchangeNode build() public JoinNode join(JoinNode.Type joinType, PlanNode left, PlanNode right, JoinNode.EquiJoinClause... criteria) { - return new JoinNode(idAllocator.getNextId(), + return join(joinType, left, right, Optional.empty(), criteria); + } + + public JoinNode join(JoinNode.Type joinType, PlanNode left, PlanNode right, Expression filter, JoinNode.EquiJoinClause... criteria) + { + return join(joinType, left, right, Optional.of(filter), criteria); + } + + private JoinNode join(JoinNode.Type joinType, PlanNode left, PlanNode right, Optional filter, JoinNode.EquiJoinClause... criteria) + { + return join( joinType, left, right, @@ -361,11 +462,9 @@ public JoinNode join(JoinNode.Type joinType, PlanNode left, PlanNode right, Join .addAll(left.getOutputSymbols()) .addAll(right.getOutputSymbols()) .build(), + filter, Optional.empty(), - Optional.empty(), - Optional.empty(), - Optional.empty() - ); + Optional.empty()); } public JoinNode join( @@ -381,9 +480,27 @@ public JoinNode join( return new JoinNode(idAllocator.getNextId(), type, left, right, criteria, outputSymbols, filter, leftHashSymbol, rightHashSymbol, Optional.empty()); } - public UnionNode union(List sources, ListMultimap outputsToInputs, List outputs) + public UnionNode union(ListMultimap outputsToInputs, List sources) + { + ImmutableList outputs = outputsToInputs.keySet().stream().collect(toImmutableList()); + return new UnionNode(idAllocator.getNextId(), sources, outputsToInputs, outputs); + } + + public TableWriterNode tableWriter(List columns, List columnNames, PlanNode source) + { + return new TableWriterNode( + idAllocator.getNextId(), + source, + new TestingWriterTarget(), + columns, + columnNames, + ImmutableList.of(symbol("partialrows", BIGINT), symbol("fragment", VARBINARY)), + Optional.empty()); + } + + public Symbol symbol(String name) { - return new UnionNode(idAllocator.getNextId(), (List) sources, outputsToInputs, outputs); + return symbol(name, BIGINT); } public Symbol symbol(String name, Type type) diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/RuleAssert.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/RuleAssert.java index a01f266a9277..82c024393a6d 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/RuleAssert.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/RuleAssert.java @@ -104,80 +104,110 @@ public void doesNotFire() public void matches(PlanMatchPattern pattern) { - RuleApplication ruleApplication = applyRule(); - Map types = ruleApplication.types; - - if (!ruleApplication.wasRuleApplied()) { - fail(String.format( - "%s did not fire for:\n%s", - rule.getClass().getName(), - formatPlan(plan, types))); - } - - PlanNode actual = ruleApplication.getResult(); - - if (actual == plan) { // plans are not comparable, so we can only ensure they are not the same instance - fail(String.format( - "%s: rule fired but return the original plan:\n%s", - rule.getClass().getName(), - formatPlan(plan, types))); - } - - if (!ImmutableSet.copyOf(plan.getOutputSymbols()).equals(ImmutableSet.copyOf(actual.getOutputSymbols()))) { - fail(String.format( - "%s: output schema of transformed and original plans are not equivalent\n" + - "\texpected: %s\n" + - "\tactual: %s", - rule.getClass().getName(), - plan.getOutputSymbols(), - actual.getOutputSymbols())); - } - - inTransaction(session -> { - Map planNodeCosts = costCalculator.calculateCostForPlan(session, types, actual); - assertPlan(session, metadata, costCalculator, new Plan(actual, types, planNodeCosts), ruleApplication.lookup, pattern); - return null; - }); - } - - private RuleApplication applyRule() - { - SymbolAllocator symbolAllocator = new SymbolAllocator(symbols); - Memo memo = new Memo(idAllocator, plan); - Lookup lookup = Lookup.from(memo::resolve); - - if (!rule.getPattern().matches(plan)) { - return new RuleApplication(lookup, symbolAllocator.getTypes(), Optional.empty()); - } - - Optional result = inTransaction(session -> rule.apply(memo.getNode(memo.getRootGroup()), lookup, idAllocator, symbolAllocator, session)); - - return new RuleApplication(lookup, symbolAllocator.getTypes(), result); - } - - private String formatPlan(PlanNode plan, Map types) - { - return inTransaction(session -> PlanPrinter.textLogicalPlan(plan, types, metadata, costCalculator, session, 2)); - } - - private T inTransaction(Function transactionSessionConsumer) - { - return transaction(transactionManager, accessControl) - .singleStatement() - .execute(session, session -> { - // metadata.getCatalogHandle() registers the catalog for the transaction - session.getCatalog().ifPresent(catalog -> metadata.getCatalogHandle(session, catalog)); - return transactionSessionConsumer.apply(session); - }); - } - - private static class RuleApplication - { - private final Lookup lookup; - private final Map types; - private final Optional result; - - public RuleApplication(Lookup lookup, Map types, Optional result) + RuleApplication ruleApplication = applyRule(); + Map types = ruleApplication.types; + + if (!ruleApplication.wasRuleApplied()) { + fail(String.format( + "%s did not fire for:\n%s", + rule.getClass().getName(), + formatPlan(plan, types))); + } + + PlanNode actual = ruleApplication.getResult(); + + if (actual == plan) { // plans are not comparable, so we can only ensure they are not the same instance + fail(String.format( + "%s: rule fired but return the original plan:\n%s", + rule.getClass().getName(), + formatPlan(plan, types))); + } + + if (!ImmutableSet.copyOf(plan.getOutputSymbols()).equals(ImmutableSet.copyOf(actual.getOutputSymbols()))) { + fail(String.format( + "%s: output schema of transformed and original plans are not equivalent\n" + + "\texpected: %s\n" + + "\tactual: %s", + rule.getClass().getName(), + plan.getOutputSymbols(), + actual.getOutputSymbols())); + } + + inTransaction(session -> { + Map planNodeCosts = costCalculator.calculateCostForPlan(session, types, actual); + assertPlan(session, metadata, costCalculator, new Plan(actual, types, planNodeCosts), ruleApplication.lookup, pattern); + return null; + }); + } + + private RuleApplication applyRule() + { + SymbolAllocator symbolAllocator = new SymbolAllocator(symbols); + Memo memo = new Memo(idAllocator, plan); + Lookup lookup = Lookup.from(memo::resolve); + + if (!rule.getPattern().matches(plan)) { + return new RuleApplication(lookup, symbolAllocator.getTypes(), Optional.empty()); + } + + Optional result = inTransaction(session -> rule.apply(memo.getNode(memo.getRootGroup()), ruleContext(symbolAllocator, lookup, session))); + + return new RuleApplication(lookup, symbolAllocator.getTypes(), result); + } + + private String formatPlan(PlanNode plan, Map types) + { + return inTransaction(session -> PlanPrinter.textLogicalPlan(plan, types, metadata, costCalculator, session, 2)); + } + + private T inTransaction(Function transactionSessionConsumer) + { + return transaction(transactionManager, accessControl) + .singleStatement() + .execute(session, session -> { + // metadata.getCatalogHandle() registers the catalog for the transaction + session.getCatalog().ifPresent(catalog -> metadata.getCatalogHandle(session, catalog)); + return transactionSessionConsumer.apply(session); + }); + } + + private Rule.Context ruleContext(SymbolAllocator symbolAllocator, Lookup lookup, Session session) + { + return new Rule.Context() + { + @Override + public Lookup getLookup() + { + return lookup; + } + + @Override + public PlanNodeIdAllocator getIdAllocator() + { + return idAllocator; + } + + @Override + public SymbolAllocator getSymbolAllocator() + { + return symbolAllocator; + } + + @Override + public Session getSession() + { + return session; + } + }; + } + + private static class RuleApplication + { + private final Lookup lookup; + private final Map types; + private final Optional result; + + public RuleApplication(Lookup lookup, Map types, Optional result) { this.lookup = requireNonNull(lookup, "lookup is null"); this.types = requireNonNull(types, "types is null"); diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/RuleTester.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/RuleTester.java index 0a682ae64a4f..aa41b2ed8c15 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/RuleTester.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/RuleTester.java @@ -14,6 +14,7 @@ package com.facebook.presto.sql.planner.iterative.rule.test; import com.facebook.presto.Session; +import com.facebook.presto.connector.ConnectorId; import com.facebook.presto.cost.CostCalculator; import com.facebook.presto.metadata.Metadata; import com.facebook.presto.security.AccessControl; @@ -30,6 +31,9 @@ public class RuleTester implements Closeable { + public static final String CATALOG_ID = "local"; + public static final ConnectorId CONNECTOR_ID = new ConnectorId(CATALOG_ID); + private final Metadata metadata; private final CostCalculator costCalculator; private final Session session; @@ -40,7 +44,7 @@ public class RuleTester public RuleTester() { session = testSessionBuilder() - .setCatalog("local") + .setCatalog(CATALOG_ID) .setSchema("tiny") .setSystemProperty("task_concurrency", "1") // these tests don't handle exchanges from local parallel .build(); diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/optimizations/TestCardinalityExtractorPlanVisitor.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/optimizations/TestCardinalityExtractorPlanVisitor.java new file mode 100644 index 000000000000..8cdc6b33c2c5 --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/optimizations/TestCardinalityExtractorPlanVisitor.java @@ -0,0 +1,42 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.optimizations; + +import com.facebook.presto.metadata.DummyMetadata; +import com.facebook.presto.sql.planner.PlanNodeIdAllocator; +import com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.Range; +import org.testng.annotations.Test; + +import static com.facebook.presto.sql.planner.optimizations.QueryCardinalityUtil.extractCardinality; +import static java.util.Collections.emptyList; +import static org.testng.Assert.assertEquals; + +public class TestCardinalityExtractorPlanVisitor +{ + @Test + public void testLimitOnTopOfValues() + { + PlanBuilder planBuilder = new PlanBuilder(new PlanNodeIdAllocator(), new DummyMetadata()); + + assertEquals( + extractCardinality(planBuilder.limit(3, planBuilder.values(emptyList(), ImmutableList.of(emptyList())))), + Range.singleton(1L)); + + assertEquals( + extractCardinality(planBuilder.limit(3, planBuilder.values(emptyList(), ImmutableList.of(emptyList(), emptyList(), emptyList(), emptyList())))), + Range.singleton(3L)); + } +} diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/optimizations/TestEliminateSorts.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/optimizations/TestEliminateSorts.java index 3db490c3443b..e959f4156b38 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/optimizations/TestEliminateSorts.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/optimizations/TestEliminateSorts.java @@ -63,8 +63,9 @@ public void testEliminateSorts() PlanMatchPattern pattern = output( - window(windowSpec, - ImmutableList.of(functionCall("row_number", Optional.empty(), ImmutableList.of())), + window(windowMatcherBuilder -> windowMatcherBuilder + .specification(windowSpec) + .addFunction(functionCall("row_number", Optional.empty(), ImmutableList.of())), anyTree(LINEITEM_TABLESCAN_Q))); assertUnitPlan(sql, pattern); @@ -78,8 +79,9 @@ public void testNotEliminateSorts() PlanMatchPattern pattern = anyTree( sort( - window(windowSpec, - ImmutableList.of(functionCall("row_number", Optional.empty(), ImmutableList.of())), + window(windowMatcherBuilder -> windowMatcherBuilder + .specification(windowSpec) + .addFunction(functionCall("row_number", Optional.empty(), ImmutableList.of())), anyTree(LINEITEM_TABLESCAN_Q)))); assertUnitPlan(sql, pattern); diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/optimizations/TestExpressionEquivalence.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/optimizations/TestExpressionEquivalence.java index a6915ac4410e..aad65cf5cf2d 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/optimizations/TestExpressionEquivalence.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/optimizations/TestExpressionEquivalence.java @@ -29,7 +29,7 @@ import static com.facebook.presto.SessionTestUtils.TEST_SESSION; import static com.facebook.presto.sql.ExpressionUtils.rewriteIdentifiersToSymbolReferences; -import static com.facebook.presto.sql.planner.DependencyExtractor.extractUnique; +import static com.facebook.presto.sql.planner.SymbolsExtractor.extractUnique; import static java.util.function.Function.identity; import static java.util.stream.Collectors.toMap; import static org.testng.Assert.assertFalse; diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/optimizations/TestMergeWindows.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/optimizations/TestMergeWindows.java index 3150ac96759d..c9a58c91e1d3 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/optimizations/TestMergeWindows.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/optimizations/TestMergeWindows.java @@ -85,9 +85,9 @@ public class TestMergeWindows EXTENDEDPRICE_ALIAS, "extendedprice")); private static final Optional COMMON_FRAME = Optional.of(new WindowFrame( - WindowFrame.Type.ROWS, - new FrameBound(FrameBound.Type.UNBOUNDED_PRECEDING), - Optional.of(new FrameBound(FrameBound.Type.CURRENT_ROW)))); + WindowFrame.Type.ROWS, + new FrameBound(FrameBound.Type.UNBOUNDED_PRECEDING), + Optional.of(new FrameBound(FrameBound.Type.CURRENT_ROW)))); private static final Optional UNSPECIFIED_FRAME = Optional.empty(); @@ -96,7 +96,7 @@ public class TestMergeWindows public TestMergeWindows() { - this(ImmutableMap.of()); + this(ImmutableMap.of()); } public TestMergeWindows(Map sessionProperties) @@ -151,14 +151,14 @@ public void testMergeableWindowsAllOptimizers() PlanMatchPattern pattern = anyTree( - window(specificationA, - ImmutableList.of( - functionCall("sum", COMMON_FRAME, ImmutableList.of(QUANTITY_ALIAS)), - functionCall("sum", COMMON_FRAME, ImmutableList.of(DISCOUNT_ALIAS))), + window(windowMatcherBuilder -> windowMatcherBuilder + .specification(specificationA) + .addFunction(functionCall("sum", COMMON_FRAME, ImmutableList.of(QUANTITY_ALIAS))) + .addFunction(functionCall("sum", COMMON_FRAME, ImmutableList.of(DISCOUNT_ALIAS))), anyTree( - window(specificationB, - ImmutableList.of( - functionCall("sum", COMMON_FRAME, ImmutableList.of(QUANTITY_ALIAS))), + window(windowMatcherBuilder -> windowMatcherBuilder + .specification(specificationB) + .addFunction(functionCall("sum", COMMON_FRAME, ImmutableList.of(QUANTITY_ALIAS))), anyNot(WindowNode.class, LINEITEM_TABLESCAN_DOQSS))))); // should be anyTree(LINEITEM_TABLESCAN_DOQSS) but anyTree does not handle zero nodes case correctly @@ -176,13 +176,13 @@ public void testIdenticalWindowSpecificationsABA() assertUnitPlan(sql, anyTree( - window(specificationA, - ImmutableList.of( - functionCall("sum", COMMON_FRAME, ImmutableList.of(QUANTITY_ALIAS)), - functionCall("sum", COMMON_FRAME, ImmutableList.of(DISCOUNT_ALIAS))), - window(specificationB, - ImmutableList.of( - functionCall("sum", COMMON_FRAME, ImmutableList.of(QUANTITY_ALIAS))), + window(windowMatcherBuilder -> windowMatcherBuilder + .specification(specificationA) + .addFunction(functionCall("sum", COMMON_FRAME, ImmutableList.of(QUANTITY_ALIAS))) + .addFunction(functionCall("sum", COMMON_FRAME, ImmutableList.of(DISCOUNT_ALIAS))), + window(windowMatcherBuilder -> windowMatcherBuilder + .specification(specificationB) + .addFunction(functionCall("sum", COMMON_FRAME, ImmutableList.of(QUANTITY_ALIAS))), LINEITEM_TABLESCAN_DOQSS)))); } @@ -197,14 +197,16 @@ public void testIdenticalWindowSpecificationsABcpA() assertUnitPlan(sql, anyTree( - window(specificationA, - ImmutableList.of(functionCall("sum", COMMON_FRAME, ImmutableList.of(DISCOUNT_ALIAS))), - window(specificationB, - ImmutableList.of(functionCall("lag", COMMON_FRAME, ImmutableList.of(QUANTITY_ALIAS, "ONE", "ZERO"))), + window(windowMatcherBuilder -> windowMatcherBuilder + .specification(specificationA) + .addFunction(functionCall("sum", COMMON_FRAME, ImmutableList.of(DISCOUNT_ALIAS))), + window(windowMatcherBuilder -> windowMatcherBuilder + .specification(specificationB) + .addFunction(functionCall("lag", COMMON_FRAME, ImmutableList.of(QUANTITY_ALIAS, "ONE", "ZERO"))), project(ImmutableMap.of("ONE", expression("CAST(1 AS bigint)"), "ZERO", expression("0.0")), - window(specificationA, - ImmutableList.of( - functionCall("sum", COMMON_FRAME, ImmutableList.of(QUANTITY_ALIAS))), + window(windowMatcherBuilder -> windowMatcherBuilder + .specification(specificationA) + .addFunction(functionCall("sum", COMMON_FRAME, ImmutableList.of(QUANTITY_ALIAS))), LINEITEM_TABLESCAN_DOQSS)))))); } @@ -219,14 +221,14 @@ public void testIdenticalWindowSpecificationsAAcpA() assertUnitPlan(sql, anyTree( - window(specificationA, - ImmutableList.of( - functionCall("sum", COMMON_FRAME, ImmutableList.of(DISCOUNT_ALIAS)), - functionCall("lag", COMMON_FRAME, ImmutableList.of(QUANTITY_ALIAS, "ONE", "ZERO"))), + window(windowMatcherBuilder -> windowMatcherBuilder + .specification(specificationA) + .addFunction(functionCall("sum", COMMON_FRAME, ImmutableList.of(DISCOUNT_ALIAS))) + .addFunction(functionCall("lag", COMMON_FRAME, ImmutableList.of(QUANTITY_ALIAS, "ONE", "ZERO"))), project(ImmutableMap.of("ONE", expression("CAST(1 AS bigint)"), "ZERO", expression("0.0")), - window(specificationA, - ImmutableList.of( - functionCall("sum", COMMON_FRAME, ImmutableList.of(QUANTITY_ALIAS))), + window(windowMatcherBuilder -> windowMatcherBuilder + .specification(specificationA) + .addFunction(functionCall("sum", COMMON_FRAME, ImmutableList.of(QUANTITY_ALIAS))), LINEITEM_TABLESCAN_DOQS))))); } @@ -251,12 +253,13 @@ public void testIdenticalWindowSpecificationsDefaultFrame() assertUnitPlan(sql, anyTree( - window(specificationC, - ImmutableList.of( - functionCall("sum", UNSPECIFIED_FRAME, ImmutableList.of(QUANTITY_ALIAS)), - functionCall("sum", UNSPECIFIED_FRAME, ImmutableList.of(DISCOUNT_ALIAS))), - window(specificationD, - ImmutableList.of(functionCall("sum", UNSPECIFIED_FRAME, ImmutableList.of(QUANTITY_ALIAS))), + window(windowMatcherBuilder -> windowMatcherBuilder + .specification(specificationC) + .addFunction(functionCall("sum", UNSPECIFIED_FRAME, ImmutableList.of(QUANTITY_ALIAS))) + .addFunction(functionCall("sum", UNSPECIFIED_FRAME, ImmutableList.of(DISCOUNT_ALIAS))), + window(windowMatcherBuilder -> windowMatcherBuilder + .specification(specificationD) + .addFunction(functionCall("sum", UNSPECIFIED_FRAME, ImmutableList.of(QUANTITY_ALIAS))), LINEITEM_TABLESCAN_DOQSS)))); } @@ -286,11 +289,11 @@ public void testMergeDifferentFrames() assertUnitPlan(sql, anyTree( - window(specificationC, - ImmutableList.of( - functionCall("avg", frameD, ImmutableList.of(QUANTITY_ALIAS)), - functionCall("sum", frameC, ImmutableList.of(DISCOUNT_ALIAS)), - functionCall("sum", frameC, ImmutableList.of(QUANTITY_ALIAS))), + window(windowMatcherBuilder -> windowMatcherBuilder + .specification(specificationC) + .addFunction(functionCall("avg", frameD, ImmutableList.of(QUANTITY_ALIAS))) + .addFunction(functionCall("sum", frameC, ImmutableList.of(DISCOUNT_ALIAS))) + .addFunction(functionCall("sum", frameC, ImmutableList.of(QUANTITY_ALIAS))), LINEITEM_TABLESCAN_DOQS))); } @@ -315,11 +318,11 @@ public void testMergeDifferentFramesWithDefault() assertUnitPlan(sql, anyTree( - window(specificationD, - ImmutableList.of( - functionCall("avg", frameD, ImmutableList.of(QUANTITY_ALIAS)), - functionCall("sum", UNSPECIFIED_FRAME, ImmutableList.of(DISCOUNT_ALIAS)), - functionCall("sum", UNSPECIFIED_FRAME, ImmutableList.of(QUANTITY_ALIAS))), + window(windowMatcherBuilder -> windowMatcherBuilder + .specification(specificationD) + .addFunction(functionCall("avg", frameD, ImmutableList.of(QUANTITY_ALIAS))) + .addFunction(functionCall("sum", UNSPECIFIED_FRAME, ImmutableList.of(DISCOUNT_ALIAS))) + .addFunction(functionCall("sum", UNSPECIFIED_FRAME, ImmutableList.of(QUANTITY_ALIAS))), LINEITEM_TABLESCAN_DOQS))); } @@ -379,11 +382,15 @@ public void testNotMergeAcrossJoinBranches() filter("SUM = AVG", join(JoinNode.Type.INNER, ImmutableList.of(), any( - window(leftSpecification, ImmutableMap.of("SUM", functionCall("sum", COMMON_FRAME, ImmutableList.of(DISCOUNT_ALIAS))), + window(windowMatcherBuilder -> windowMatcherBuilder + .specification(leftSpecification) + .addFunction("SUM", functionCall("sum", COMMON_FRAME, ImmutableList.of(DISCOUNT_ALIAS))), any( leftTableScan))), any( - window(rightSpecification, ImmutableMap.of("AVG", functionCall("avg", COMMON_FRAME, ImmutableList.of(rQuantityAlias))), + window(windowMatcherBuilder -> windowMatcherBuilder + .specification(rightSpecification) + .addFunction("AVG", functionCall("avg", COMMON_FRAME, ImmutableList.of(rQuantityAlias))), any( rightTableScan))))))); } @@ -403,10 +410,12 @@ public void testNotMergeDifferentPartition() assertUnitPlan(sql, anyTree( - window(specificationA, ImmutableList.of( - functionCall("sum", COMMON_FRAME, ImmutableList.of(DISCOUNT_ALIAS))), - window(specificationC, ImmutableList.of( - functionCall("sum", COMMON_FRAME, ImmutableList.of(QUANTITY_ALIAS))), + window(windowMatcherBuilder -> windowMatcherBuilder + .specification(specificationA) + .addFunction(functionCall("sum", COMMON_FRAME, ImmutableList.of(DISCOUNT_ALIAS))), + window(windowMatcherBuilder -> windowMatcherBuilder + .specification(specificationC) + .addFunction(functionCall("sum", COMMON_FRAME, ImmutableList.of(QUANTITY_ALIAS))), LINEITEM_TABLESCAN_DOQS)))); } @@ -418,17 +427,19 @@ public void testNotMergeDifferentOrderBy() "SUM(quantity) OVER (PARTITION BY suppkey ORDER BY quantity ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) sum_quantity_C " + "FROM lineitem"; - ExpectedValueProvider specificationC = specification( + ExpectedValueProvider specificationC = specification( ImmutableList.of(SUPPKEY_ALIAS), ImmutableList.of(QUANTITY_ALIAS), ImmutableMap.of(QUANTITY_ALIAS, SortOrder.ASC_NULLS_LAST)); assertUnitPlan(sql, anyTree( - window(specificationC, ImmutableList.of( - functionCall("sum", COMMON_FRAME, ImmutableList.of(QUANTITY_ALIAS))), - window(specificationA, ImmutableList.of( - functionCall("sum", COMMON_FRAME, ImmutableList.of(DISCOUNT_ALIAS))), + window(windowMatcherBuilder -> windowMatcherBuilder + .specification(specificationC) + .addFunction(functionCall("sum", COMMON_FRAME, ImmutableList.of(QUANTITY_ALIAS))), + window(windowMatcherBuilder -> windowMatcherBuilder + .specification(specificationA) + .addFunction(functionCall("sum", COMMON_FRAME, ImmutableList.of(DISCOUNT_ALIAS))), LINEITEM_TABLESCAN_DOQS)))); } @@ -441,18 +452,20 @@ public void testNotMergeDifferentOrdering() "SUM(discount) over (PARTITION BY suppkey ORDER BY orderkey ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) sum_discount_A " + "FROM lineitem"; - ExpectedValueProvider specificationC = specification( + ExpectedValueProvider specificationC = specification( ImmutableList.of(SUPPKEY_ALIAS), ImmutableList.of(ORDERKEY_ALIAS), ImmutableMap.of(ORDERKEY_ALIAS, SortOrder.DESC_NULLS_LAST)); assertUnitPlan(sql, anyTree( - window(specificationC, ImmutableList.of( - functionCall("sum", COMMON_FRAME, ImmutableList.of(QUANTITY_ALIAS))), - window(specificationA, ImmutableList.of( - functionCall("sum", COMMON_FRAME, ImmutableList.of(EXTENDEDPRICE_ALIAS)), - functionCall("sum", COMMON_FRAME, ImmutableList.of(DISCOUNT_ALIAS))), + window(windowMatcherBuilder -> windowMatcherBuilder + .specification(specificationC) + .addFunction(functionCall("sum", COMMON_FRAME, ImmutableList.of(QUANTITY_ALIAS))), + window(windowMatcherBuilder -> windowMatcherBuilder + .specification(specificationA) + .addFunction(functionCall("sum", COMMON_FRAME, ImmutableList.of(EXTENDEDPRICE_ALIAS))) + .addFunction(functionCall("sum", COMMON_FRAME, ImmutableList.of(DISCOUNT_ALIAS))), LINEITEM_TABLESCAN_DEOQS)))); } @@ -465,18 +478,20 @@ public void testNotMergeDifferentNullOrdering() "SUM(discount) OVER (PARTITION BY suppkey ORDER BY orderkey ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) sum_discount_A " + "FROM lineitem"; - ExpectedValueProvider specificationC = specification( + ExpectedValueProvider specificationC = specification( ImmutableList.of(SUPPKEY_ALIAS), ImmutableList.of(ORDERKEY_ALIAS), ImmutableMap.of(ORDERKEY_ALIAS, SortOrder.ASC_NULLS_FIRST)); assertUnitPlan(sql, anyTree( - window(specificationA, ImmutableList.of( - functionCall("sum", COMMON_FRAME, ImmutableList.of(EXTENDEDPRICE_ALIAS)), - functionCall("sum", COMMON_FRAME, ImmutableList.of(DISCOUNT_ALIAS))), - window(specificationC, ImmutableList.of( - functionCall("sum", COMMON_FRAME, ImmutableList.of(QUANTITY_ALIAS))), + window(windowMatcherBuilder -> windowMatcherBuilder + .specification(specificationA) + .addFunction(functionCall("sum", COMMON_FRAME, ImmutableList.of(EXTENDEDPRICE_ALIAS))) + .addFunction(functionCall("sum", COMMON_FRAME, ImmutableList.of(DISCOUNT_ALIAS))), + window(windowMatcherBuilder -> windowMatcherBuilder + .specification(specificationC) + .addFunction(functionCall("sum", COMMON_FRAME, ImmutableList.of(QUANTITY_ALIAS))), LINEITEM_TABLESCAN_DEOQS)))); } diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/optimizations/TestReorderWindows.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/optimizations/TestReorderWindows.java index 543bc29c34df..1fa64932bb6c 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/optimizations/TestReorderWindows.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/optimizations/TestReorderWindows.java @@ -138,17 +138,17 @@ public void testNonMergeableABAReordersToAABAllOptimizers() PlanMatchPattern pattern = anyTree( - window(windowAp, - ImmutableList.of( - functionCall("min", commonFrame, ImmutableList.of(TAX_ALIAS))), - window(windowA, - ImmutableList.of( - functionCall("sum", commonFrame, ImmutableList.of(QUANTITY_ALIAS))), + window(windowMatcherBuilder -> windowMatcherBuilder + .specification(windowAp) + .addFunction(functionCall("min", commonFrame, ImmutableList.of(TAX_ALIAS))), + window(windowMatcherBuilder -> windowMatcherBuilder + .specification(windowA) + .addFunction(functionCall("sum", commonFrame, ImmutableList.of(QUANTITY_ALIAS))), anyTree( - window(windowB, - ImmutableList.of( - functionCall("avg", commonFrame, ImmutableList.of(DISCOUNT_ALIAS))), - anyTree(LINEITEM_TABLESCAN_DOQPRSST)))))); + window(windowMatcherBuilder -> windowMatcherBuilder + .specification(windowB) + .addFunction(functionCall("avg", commonFrame, ImmutableList.of(DISCOUNT_ALIAS))), + anyTree(LINEITEM_TABLESCAN_DOQPRSST)))))); assertPlan(sql, pattern); } @@ -164,15 +164,15 @@ public void testNonMergeableABAReordersToAAB() assertUnitPlan(sql, anyTree( - window(windowAp, - ImmutableList.of( - functionCall("min", commonFrame, ImmutableList.of(TAX_ALIAS))), - window(windowA, - ImmutableList.of( - functionCall("sum", commonFrame, ImmutableList.of(QUANTITY_ALIAS))), - window(windowB, - ImmutableList.of( - functionCall("avg", commonFrame, ImmutableList.of(DISCOUNT_ALIAS))), + window(windowMatcherBuilder -> windowMatcherBuilder + .specification(windowAp) + .addFunction(functionCall("min", commonFrame, ImmutableList.of(TAX_ALIAS))), + window(windowMatcherBuilder -> windowMatcherBuilder + .specification(windowA) + .addFunction(functionCall("sum", commonFrame, ImmutableList.of(QUANTITY_ALIAS))), + window(windowMatcherBuilder -> windowMatcherBuilder + .specification(windowB) + .addFunction(functionCall("avg", commonFrame, ImmutableList.of(DISCOUNT_ALIAS))), LINEITEM_TABLESCAN_DOQPRSST))))); // should be anyTree(LINEITEM_TABLESCANE_DOQPRSST) but anyTree does not handle zero nodes case correctly } @@ -186,13 +186,14 @@ public void testPrefixOfPartitionComesFirstRegardlessOfTheirOrderInSQL() "from lineitem"; assertUnitPlan(sql, - anyTree(window(windowApp, - ImmutableList.of( - functionCall("avg", commonFrame, ImmutableList.of(DISCOUNT_ALIAS))), - window(windowA, - ImmutableList.of( - functionCall("sum", commonFrame, ImmutableList.of(QUANTITY_ALIAS))), - LINEITEM_TABLESCAN_DOQRST)))); // should be anyTree(LINEITEM_TABLESCAN_DOQRST) but anyTree does not handle zero nodes case correctly + anyTree( + window(windowMatcherBuilder -> windowMatcherBuilder + .specification(windowApp) + .addFunction(functionCall("avg", commonFrame, ImmutableList.of(DISCOUNT_ALIAS))), + window(windowMatcherBuilder -> windowMatcherBuilder + .specification(windowA) + .addFunction(functionCall("sum", commonFrame, ImmutableList.of(QUANTITY_ALIAS))), + LINEITEM_TABLESCAN_DOQRST)))); // should be anyTree(LINEITEM_TABLESCAN_DOQRST) but anyTree does not handle zero nodes case correctly } { @@ -202,13 +203,14 @@ public void testPrefixOfPartitionComesFirstRegardlessOfTheirOrderInSQL() "from lineitem"; assertUnitPlan(sql, - anyTree(window(windowApp, - ImmutableList.of( - functionCall("avg", commonFrame, ImmutableList.of(DISCOUNT_ALIAS))), - window(windowA, - ImmutableList.of( - functionCall("sum", commonFrame, ImmutableList.of(QUANTITY_ALIAS))), - LINEITEM_TABLESCAN_DOQRST)))); // should be anyTree(LINEITEM_TABLESCAN_DOQRST) but anyTree does not handle zero nodes case correctly + anyTree( + window(windowMatcherBuilder -> windowMatcherBuilder + .specification(windowApp) + .addFunction(functionCall("avg", commonFrame, ImmutableList.of(DISCOUNT_ALIAS))), + window(windowMatcherBuilder -> windowMatcherBuilder + .specification(windowA) + .addFunction(functionCall("sum", commonFrame, ImmutableList.of(QUANTITY_ALIAS))), + LINEITEM_TABLESCAN_DOQRST)))); // should be anyTree(LINEITEM_TABLESCAN_DOQRST) but anyTree does not handle zero nodes case correctly } } @@ -221,14 +223,15 @@ public void testNotReorderAcrossNonWindowNodes() "from lineitem"; assertUnitPlan(sql, - anyTree(window(windowA, - ImmutableList.of( - functionCall("lag", commonFrame, ImmutableList.of(QUANTITY_ALIAS, "ONE"))), - project(ImmutableMap.of("ONE", expression("CAST(1 AS bigint)")), - window(windowApp, - ImmutableList.of( - functionCall("avg", commonFrame, ImmutableList.of(DISCOUNT_ALIAS))), - LINEITEM_TABLESCAN_DOQRST))))); // should be anyTree(LINEITEM_TABLESCAN_DOQRST) but anyTree does not handle zero nodes case correctly + anyTree( + window(windowMatcherBuilder -> windowMatcherBuilder + .specification(windowA) + .addFunction(functionCall("lag", commonFrame, ImmutableList.of(QUANTITY_ALIAS, "ONE"))), + project(ImmutableMap.of("ONE", expression("CAST(1 AS bigint)")), + window(windowMatcherBuilder -> windowMatcherBuilder + .specification(windowApp) + .addFunction(functionCall("avg", commonFrame, ImmutableList.of(DISCOUNT_ALIAS))), + LINEITEM_TABLESCAN_DOQRST))))); // should be anyTree(LINEITEM_TABLESCAN_DOQRST) but anyTree does not handle zero nodes case correctly } @Test @@ -253,19 +256,20 @@ public void testReorderBDAC() "from lineitem"; assertUnitPlan(sql, - anyTree(window(windowD, - ImmutableList.of( - functionCall("avg", commonFrame, ImmutableList.of(QUANTITY_ALIAS))), - window(windowA, - ImmutableList.of( - functionCall("avg", commonFrame, ImmutableList.of(DISCOUNT_ALIAS))), - window(windowC, - ImmutableList.of( - functionCall("sum", commonFrame, ImmutableList.of(DISCOUNT_ALIAS))), - window(windowE, - ImmutableList.of( - functionCall("sum", commonFrame, ImmutableList.of(TAX_ALIAS))), - LINEITEM_TABLESCAN_DOQRST)))))); // should be anyTree(LINEITEM_TABLESCAN_DOQRST) but anyTree does not handle zero nodes case correctly + anyTree( + window(windowMatcherBuilder -> windowMatcherBuilder + .specification(windowD) + .addFunction(functionCall("avg", commonFrame, ImmutableList.of(QUANTITY_ALIAS))), + window(windowMatcherBuilder -> windowMatcherBuilder + .specification(windowA) + .addFunction(functionCall("avg", commonFrame, ImmutableList.of(DISCOUNT_ALIAS))), + window(windowMatcherBuilder -> windowMatcherBuilder + .specification(windowC) + .addFunction(functionCall("sum", commonFrame, ImmutableList.of(DISCOUNT_ALIAS))), + window(windowMatcherBuilder -> windowMatcherBuilder + .specification(windowE) + .addFunction(functionCall("sum", commonFrame, ImmutableList.of(TAX_ALIAS))), + LINEITEM_TABLESCAN_DOQRST)))))); // should be anyTree(LINEITEM_TABLESCAN_DOQRST) but anyTree does not handle zero nodes case correctly } private void assertUnitPlan(@Language("SQL") String sql, PlanMatchPattern pattern) diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/optimizations/TestSimplifyExpressions.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/optimizations/TestSimplifyExpressions.java index 30e3e58704ed..6ed6702705fc 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/optimizations/TestSimplifyExpressions.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/optimizations/TestSimplifyExpressions.java @@ -15,10 +15,10 @@ import com.facebook.presto.spi.type.Type; import com.facebook.presto.sql.parser.SqlParser; -import com.facebook.presto.sql.planner.DependencyExtractor; import com.facebook.presto.sql.planner.PlanNodeIdAllocator; import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.SymbolAllocator; +import com.facebook.presto.sql.planner.SymbolsExtractor; import com.facebook.presto.sql.planner.plan.FilterNode; import com.facebook.presto.sql.planner.plan.ValuesNode; import com.facebook.presto.sql.tree.Expression; @@ -139,7 +139,7 @@ private static Expression simplifyExpressions(Expression expression) private static Map booleanSymbolTypeMapFor(Expression expression) { - return DependencyExtractor.extractUnique(expression).stream() + return SymbolsExtractor.extractUnique(expression).stream() .collect(Collectors.toMap(symbol -> symbol, symbol -> BOOLEAN)); } diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/optimizations/TestUnion.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/optimizations/TestUnion.java index bd805e17d91b..56775a9acd3b 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/optimizations/TestUnion.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/optimizations/TestUnion.java @@ -164,7 +164,7 @@ private void assertPlanIsFullyDistributed(Plan plan) { assertTrue( searchFrom(plan.getRoot()) - .skipOnlyWhen(TestUnion::isNotRemoteGatheringExchange) + .recurseOnlyWhen(TestUnion::isNotRemoteGatheringExchange) .findAll() .stream() .noneMatch(this::shouldBeDistributed), @@ -206,7 +206,7 @@ private static void assertAtMostOneAggregationBetweenRemoteExchanges(Plan plan) for (PlanNode fragment : fragments) { List aggregations = searchFrom(fragment) .where(AggregationNode.class::isInstance) - .skipOnlyWhen(TestUnion::isNotRemoteExchange) + .recurseOnlyWhen(TestUnion::isNotRemoteExchange) .findAll(); assertFalse(aggregations.size() > 1, "More than a single AggregationNode between remote exchanges"); diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/plan/TestAssingments.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/plan/TestAssingments.java new file mode 100644 index 000000000000..4f121265984e --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/plan/TestAssingments.java @@ -0,0 +1,38 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.plan; + +import com.facebook.presto.sql.planner.Symbol; +import com.google.common.collect.ImmutableCollection; +import org.testng.annotations.Test; + +import static com.facebook.presto.sql.tree.BooleanLiteral.TRUE_LITERAL; +import static org.testng.Assert.assertTrue; + +public class TestAssingments +{ + private final Assignments assignments = Assignments.of(new Symbol("test"), TRUE_LITERAL); + + @Test + public void testOutputsImmutable() + { + assertTrue(assignments.getOutputs() instanceof ImmutableCollection); + } + + @Test + public void testOutputsMemoized() + { + assertTrue(assignments.getOutputs() == assignments.getOutputs()); + } +} diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/sanity/TestVerifyOnlyOneOutputNode.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/sanity/TestVerifyOnlyOneOutputNode.java index 8f42c8dcfd5b..2c22f6181df8 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/sanity/TestVerifyOnlyOneOutputNode.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/sanity/TestVerifyOnlyOneOutputNode.java @@ -60,7 +60,8 @@ public void testValidateFailed() ), Assignments.of() ), ImmutableList.of(), ImmutableList.of() - ), new Symbol("a") + ), new Symbol("a"), + false ), ImmutableList.of(), ImmutableList.of() ); diff --git a/presto-main/src/test/java/com/facebook/presto/type/AbstractTestType.java b/presto-main/src/test/java/com/facebook/presto/type/AbstractTestType.java index db2dfa64d124..0b4d76118b38 100644 --- a/presto-main/src/test/java/com/facebook/presto/type/AbstractTestType.java +++ b/presto-main/src/test/java/com/facebook/presto/type/AbstractTestType.java @@ -16,6 +16,9 @@ import com.facebook.presto.spi.block.Block; import com.facebook.presto.spi.block.BlockBuilder; import com.facebook.presto.spi.block.BlockBuilderStatus; +import com.facebook.presto.spi.type.ArrayType; +import com.facebook.presto.spi.type.MapType; +import com.facebook.presto.spi.type.RowType; import com.facebook.presto.spi.type.Type; import com.google.common.collect.ImmutableMap; import io.airlift.slice.DynamicSliceOutput; diff --git a/presto-main/src/test/java/com/facebook/presto/type/TestArrayOperators.java b/presto-main/src/test/java/com/facebook/presto/type/TestArrayOperators.java index bbe2f97072f8..f7232f851531 100644 --- a/presto-main/src/test/java/com/facebook/presto/type/TestArrayOperators.java +++ b/presto-main/src/test/java/com/facebook/presto/type/TestArrayOperators.java @@ -20,7 +20,9 @@ import com.facebook.presto.spi.block.BlockBuilder; import com.facebook.presto.spi.block.BlockBuilderStatus; import com.facebook.presto.spi.block.InterleavedBlockBuilder; +import com.facebook.presto.spi.type.ArrayType; import com.facebook.presto.spi.type.BooleanType; +import com.facebook.presto.spi.type.MapType; import com.facebook.presto.spi.type.SqlTimestamp; import com.facebook.presto.spi.type.Type; import com.facebook.presto.sql.analyzer.SemanticErrorCode; @@ -282,6 +284,8 @@ public void testArrayToArrayConcat() catch (RuntimeException e) { // Expected } + + assertCachedInstanceHasBoundedRetainedSize("ARRAY [1, NULL] || ARRAY [3]"); } @Test @@ -313,6 +317,9 @@ public void testElementArrayConcat() catch (RuntimeException e) { // Expected } + + assertCachedInstanceHasBoundedRetainedSize("ARRAY [1, NULL] || 3"); + assertCachedInstanceHasBoundedRetainedSize("3 || ARRAY [1, NULL]"); } @Test @@ -540,6 +547,14 @@ public void testElementAt() assertFunction("ELEMENT_AT(ARRAY [sqrt(-1)], -1)", DOUBLE, NaN); } + @Test + public void testShuffle() + { + // More tests can be found in AbstractTestQueries.testArrayShuffle + + assertCachedInstanceHasBoundedRetainedSize("SHUFFLE(ARRAY[2, 3, 4, 1])"); + } + @Test public void testSort() throws Exception @@ -562,6 +577,8 @@ public void testSort() assertFunction("ARRAY_SORT(ARRAY[1, null, null, -1, 0])", new ArrayType(INTEGER), expected); assertInvalidFunction("ARRAY_SORT(ARRAY[color('red'), color('blue')])", FUNCTION_NOT_FOUND); + + assertCachedInstanceHasBoundedRetainedSize("ARRAY_SORT(ARRAY[2, 3, 4, 1])"); } @Test @@ -575,6 +592,8 @@ public void testReverse() assertFunction("REVERSE(ARRAY['a', 'b', 'c', 'd'])", new ArrayType(createVarcharType(1)), ImmutableList.of("d", "c", "b", "a")); assertFunction("REVERSE(ARRAY[TRUE, FALSE])", new ArrayType(BOOLEAN), ImmutableList.of(false, true)); assertFunction("REVERSE(ARRAY[1.1, 2.2, 3.3, 4.4])", new ArrayType(DOUBLE), ImmutableList.of(4.4, 3.3, 2.2, 1.1)); + + assertCachedInstanceHasBoundedRetainedSize("REVERSE(ARRAY[1.1, 2.2, 3.3, 4.4])"); } @Test @@ -710,6 +729,8 @@ public void testArrayIntersect() assertFunction("ARRAY_INTERSECT(ARRAY [8.3, 1.6, 4.1, 5.2], ARRAY [4.0, 5.2, 8.3, 9.7, 3.5])", new ArrayType(DOUBLE), ImmutableList.of(5.2, 8.3)); assertFunction("ARRAY_INTERSECT(ARRAY [5.1, 7, 3.0, 4.8, 10], ARRAY [6.5, 10.0, 1.9, 5.1, 3.9, 4.8])", new ArrayType(DOUBLE), ImmutableList.of(4.8, 5.1, 10.0)); assertFunction("ARRAY_INTERSECT(ARRAY [ARRAY [4, 5], ARRAY [6, 7]], ARRAY [ARRAY [4, 5], ARRAY [6, 8]])", new ArrayType(new ArrayType(INTEGER)), ImmutableList.of(ImmutableList.of(4, 5))); + + assertCachedInstanceHasBoundedRetainedSize("ARRAY_INTERSECT(ARRAY ['foo', 'bar', 'baz'], ARRAY ['foo', 'test', 'bar'])"); } @Test diff --git a/presto-main/src/test/java/com/facebook/presto/type/TestMapOperators.java b/presto-main/src/test/java/com/facebook/presto/type/TestMapOperators.java index 50c20355c469..61899a092f07 100644 --- a/presto-main/src/test/java/com/facebook/presto/type/TestMapOperators.java +++ b/presto-main/src/test/java/com/facebook/presto/type/TestMapOperators.java @@ -21,6 +21,8 @@ import com.facebook.presto.spi.function.LiteralParameters; import com.facebook.presto.spi.function.ScalarFunction; import com.facebook.presto.spi.function.SqlType; +import com.facebook.presto.spi.type.ArrayType; +import com.facebook.presto.spi.type.MapType; import com.facebook.presto.spi.type.SqlDecimal; import com.facebook.presto.spi.type.SqlTimestamp; import com.facebook.presto.spi.type.SqlVarbinary; @@ -142,6 +144,8 @@ public void testConstructor() 100.0)); assertInvalidFunction("MAP(ARRAY [1], ARRAY [2, 4])", "Key and value arrays must be the same length"); + + assertCachedInstanceHasBoundedRetainedSize("MAP(ARRAY ['1','3'], ARRAY [2,4])"); } @Test @@ -570,6 +574,8 @@ public void testMapConcat() assertFunction("MAP_CONCAT(MAP(), MAP(), MAP())", mapType(UNKNOWN, UNKNOWN), ImmutableMap.of()); assertFunction("MAP_CONCAT(MAP(ARRAY[1], ARRAY[-1]), MAP(), MAP(ARRAY[3], ARRAY[-3]))", mapType(INTEGER, INTEGER), ImmutableMap.of(1, -1, 3, -3)); assertFunction("MAP_CONCAT(MAP(ARRAY[TRUE], ARRAY[1]), MAP(ARRAY[TRUE, FALSE], ARRAY[10, 20]), MAP(ARRAY[FALSE], ARRAY[0]))", mapType(BOOLEAN, INTEGER), ImmutableMap.of(true, 10, false, 0)); + + assertCachedInstanceHasBoundedRetainedSize("MAP_CONCAT(MAP (ARRAY ['1', '2', '3'], ARRAY [1, 2, 3]), MAP (ARRAY ['1', '2', '3', '4'], ARRAY [10, 20, 30, 40]))"); } @Test diff --git a/presto-main/src/test/java/com/facebook/presto/type/TestRowOperators.java b/presto-main/src/test/java/com/facebook/presto/type/TestRowOperators.java index 703f0e88eeea..114e31108403 100644 --- a/presto-main/src/test/java/com/facebook/presto/type/TestRowOperators.java +++ b/presto-main/src/test/java/com/facebook/presto/type/TestRowOperators.java @@ -18,6 +18,8 @@ import com.facebook.presto.spi.block.BlockBuilder; import com.facebook.presto.spi.block.BlockBuilderStatus; import com.facebook.presto.spi.block.InterleavedBlockBuilder; +import com.facebook.presto.spi.type.ArrayType; +import com.facebook.presto.spi.type.RowType; import com.facebook.presto.spi.type.SqlTimestamp; import com.facebook.presto.spi.type.Type; import com.facebook.presto.sql.analyzer.SemanticErrorCode; diff --git a/presto-main/src/test/java/com/facebook/presto/util/StructuralTestUtil.java b/presto-main/src/test/java/com/facebook/presto/util/StructuralTestUtil.java index a126ffe81b8b..00d0ed3bdb21 100644 --- a/presto-main/src/test/java/com/facebook/presto/util/StructuralTestUtil.java +++ b/presto-main/src/test/java/com/facebook/presto/util/StructuralTestUtil.java @@ -19,12 +19,12 @@ import com.facebook.presto.spi.block.BlockBuilder; import com.facebook.presto.spi.block.BlockBuilderStatus; import com.facebook.presto.spi.block.InterleavedBlockBuilder; +import com.facebook.presto.spi.type.MapType; import com.facebook.presto.spi.type.StandardTypes; import com.facebook.presto.spi.type.Type; import com.facebook.presto.spi.type.TypeManager; import com.facebook.presto.spi.type.TypeSignatureParameter; import com.facebook.presto.sql.analyzer.FeaturesConfig; -import com.facebook.presto.type.MapType; import com.facebook.presto.type.TypeRegistry; import com.google.common.collect.ImmutableList; diff --git a/presto-memory/pom.xml b/presto-memory/pom.xml index 0db8f238a553..33a14399fc04 100644 --- a/presto-memory/pom.xml +++ b/presto-memory/pom.xml @@ -5,7 +5,7 @@ com.facebook.presto presto-root - 0.179-tw-0.36 + 0.181-tw-0.37 presto-memory diff --git a/presto-ml/pom.xml b/presto-ml/pom.xml index a55ff848513d..9d1f11240be4 100644 --- a/presto-ml/pom.xml +++ b/presto-ml/pom.xml @@ -4,7 +4,7 @@ com.facebook.presto presto-root - 0.179-tw-0.36 + 0.181-tw-0.37 presto-ml diff --git a/presto-ml/src/main/java/com/facebook/presto/ml/MLFeaturesFunctions.java b/presto-ml/src/main/java/com/facebook/presto/ml/MLFeaturesFunctions.java new file mode 100644 index 000000000000..aaad89c54dbf --- /dev/null +++ b/presto-ml/src/main/java/com/facebook/presto/ml/MLFeaturesFunctions.java @@ -0,0 +1,131 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.ml; + +import com.facebook.presto.spi.PageBuilder; +import com.facebook.presto.spi.block.Block; +import com.facebook.presto.spi.block.BlockBuilder; +import com.facebook.presto.spi.function.ScalarFunction; +import com.facebook.presto.spi.function.SqlType; +import com.facebook.presto.spi.function.TypeParameter; +import com.facebook.presto.spi.type.BigintType; +import com.facebook.presto.spi.type.DoubleType; +import com.facebook.presto.spi.type.StandardTypes; +import com.facebook.presto.spi.type.Type; +import com.google.common.cache.Cache; +import com.google.common.cache.CacheBuilder; +import com.google.common.collect.ImmutableList; +import com.google.common.hash.HashCode; + +public final class MLFeaturesFunctions +{ + private static final Cache MODEL_CACHE = CacheBuilder.newBuilder().maximumSize(5).build(); + private static final String MAP_BIGINT_DOUBLE = "map(bigint,double)"; + + private final PageBuilder pageBuilder; + + public MLFeaturesFunctions(@TypeParameter("map(bigint,double)") Type mapType) + { + pageBuilder = new PageBuilder(ImmutableList.of(mapType)); + } + + @ScalarFunction + @SqlType(MAP_BIGINT_DOUBLE) + public Block features(@SqlType(StandardTypes.DOUBLE) double f1) + { + return featuresHelper(f1); + } + + @ScalarFunction + @SqlType(MAP_BIGINT_DOUBLE) + public Block features(@SqlType(StandardTypes.DOUBLE) double f1, @SqlType(StandardTypes.DOUBLE) double f2) + { + return featuresHelper(f1, f2); + } + + @ScalarFunction + @SqlType(MAP_BIGINT_DOUBLE) + public Block features(@SqlType(StandardTypes.DOUBLE) double f1, @SqlType(StandardTypes.DOUBLE) double f2, @SqlType(StandardTypes.DOUBLE) double f3) + { + return featuresHelper(f1, f2, f3); + } + + @ScalarFunction + @SqlType(MAP_BIGINT_DOUBLE) + public Block features(@SqlType(StandardTypes.DOUBLE) double f1, @SqlType(StandardTypes.DOUBLE) double f2, @SqlType(StandardTypes.DOUBLE) double f3, @SqlType(StandardTypes.DOUBLE) double f4) + { + return featuresHelper(f1, f2, f3, f4); + } + + @ScalarFunction + @SqlType(MAP_BIGINT_DOUBLE) + public Block features(@SqlType(StandardTypes.DOUBLE) double f1, @SqlType(StandardTypes.DOUBLE) double f2, @SqlType(StandardTypes.DOUBLE) double f3, @SqlType(StandardTypes.DOUBLE) double f4, @SqlType(StandardTypes.DOUBLE) double f5) + { + return featuresHelper(f1, f2, f3, f4, f5); + } + + @ScalarFunction + @SqlType(MAP_BIGINT_DOUBLE) + public Block features(@SqlType(StandardTypes.DOUBLE) double f1, @SqlType(StandardTypes.DOUBLE) double f2, @SqlType(StandardTypes.DOUBLE) double f3, @SqlType(StandardTypes.DOUBLE) double f4, @SqlType(StandardTypes.DOUBLE) double f5, @SqlType(StandardTypes.DOUBLE) double f6) + { + return featuresHelper(f1, f2, f3, f4, f5, f6); + } + + @ScalarFunction + @SqlType(MAP_BIGINT_DOUBLE) + public Block features(@SqlType(StandardTypes.DOUBLE) double f1, @SqlType(StandardTypes.DOUBLE) double f2, @SqlType(StandardTypes.DOUBLE) double f3, @SqlType(StandardTypes.DOUBLE) double f4, @SqlType(StandardTypes.DOUBLE) double f5, @SqlType(StandardTypes.DOUBLE) double f6, @SqlType(StandardTypes.DOUBLE) double f7) + { + return featuresHelper(f1, f2, f3, f4, f5, f6, f7); + } + + @ScalarFunction + @SqlType(MAP_BIGINT_DOUBLE) + public Block features(@SqlType(StandardTypes.DOUBLE) double f1, @SqlType(StandardTypes.DOUBLE) double f2, @SqlType(StandardTypes.DOUBLE) double f3, @SqlType(StandardTypes.DOUBLE) double f4, @SqlType(StandardTypes.DOUBLE) double f5, @SqlType(StandardTypes.DOUBLE) double f6, @SqlType(StandardTypes.DOUBLE) double f7, @SqlType(StandardTypes.DOUBLE) double f8) + { + return featuresHelper(f1, f2, f3, f4, f5, f6, f7, f8); + } + + @ScalarFunction + @SqlType(MAP_BIGINT_DOUBLE) + public Block features(@SqlType(StandardTypes.DOUBLE) double f1, @SqlType(StandardTypes.DOUBLE) double f2, @SqlType(StandardTypes.DOUBLE) double f3, @SqlType(StandardTypes.DOUBLE) double f4, @SqlType(StandardTypes.DOUBLE) double f5, @SqlType(StandardTypes.DOUBLE) double f6, @SqlType(StandardTypes.DOUBLE) double f7, @SqlType(StandardTypes.DOUBLE) double f8, @SqlType(StandardTypes.DOUBLE) double f9) + { + return featuresHelper(f1, f2, f3, f4, f5, f6, f7, f8, f9); + } + + @ScalarFunction + @SqlType(MAP_BIGINT_DOUBLE) + public Block features(@SqlType(StandardTypes.DOUBLE) double f1, @SqlType(StandardTypes.DOUBLE) double f2, @SqlType(StandardTypes.DOUBLE) double f3, @SqlType(StandardTypes.DOUBLE) double f4, @SqlType(StandardTypes.DOUBLE) double f5, @SqlType(StandardTypes.DOUBLE) double f6, @SqlType(StandardTypes.DOUBLE) double f7, @SqlType(StandardTypes.DOUBLE) double f8, @SqlType(StandardTypes.DOUBLE) double f9, @SqlType(StandardTypes.DOUBLE) double f10) + { + return featuresHelper(f1, f2, f3, f4, f5, f6, f7, f8, f9, f10); + } + + private Block featuresHelper(double... features) + { + if (pageBuilder.isFull()) { + pageBuilder.reset(); + } + + BlockBuilder mapBlockBuilder = pageBuilder.getBlockBuilder(0); + BlockBuilder blockBuilder = mapBlockBuilder.beginBlockEntry(); + + for (int i = 0; i < features.length; i++) { + BigintType.BIGINT.writeLong(blockBuilder, i); + DoubleType.DOUBLE.writeDouble(blockBuilder, features[i]); + } + + mapBlockBuilder.closeEntry(); + pageBuilder.declarePosition(); + return mapBlockBuilder.getObject(mapBlockBuilder.getPositionCount() - 1, Block.class); + } +} diff --git a/presto-ml/src/main/java/com/facebook/presto/ml/MLFunctions.java b/presto-ml/src/main/java/com/facebook/presto/ml/MLFunctions.java index 518534b6b1fd..e14f6dad3ba5 100644 --- a/presto-ml/src/main/java/com/facebook/presto/ml/MLFunctions.java +++ b/presto-ml/src/main/java/com/facebook/presto/ml/MLFunctions.java @@ -15,17 +15,11 @@ import com.facebook.presto.ml.type.RegressorType; import com.facebook.presto.spi.block.Block; -import com.facebook.presto.spi.block.BlockBuilder; -import com.facebook.presto.spi.block.BlockBuilderStatus; -import com.facebook.presto.spi.block.InterleavedBlockBuilder; import com.facebook.presto.spi.function.ScalarFunction; import com.facebook.presto.spi.function.SqlType; -import com.facebook.presto.spi.type.BigintType; -import com.facebook.presto.spi.type.DoubleType; import com.facebook.presto.spi.type.StandardTypes; import com.google.common.cache.Cache; import com.google.common.cache.CacheBuilder; -import com.google.common.collect.ImmutableList; import com.google.common.hash.HashCode; import io.airlift.slice.Slice; import io.airlift.slice.Slices; @@ -89,86 +83,4 @@ private static Model getOrLoadModel(Slice slice) return model; } - - @ScalarFunction - @SqlType(MAP_BIGINT_DOUBLE) - public static Block features(@SqlType(StandardTypes.DOUBLE) double f1) - { - return featuresHelper(f1); - } - - @ScalarFunction - @SqlType(MAP_BIGINT_DOUBLE) - public static Block features(@SqlType(StandardTypes.DOUBLE) double f1, @SqlType(StandardTypes.DOUBLE) double f2) - { - return featuresHelper(f1, f2); - } - - @ScalarFunction - @SqlType(MAP_BIGINT_DOUBLE) - public static Block features(@SqlType(StandardTypes.DOUBLE) double f1, @SqlType(StandardTypes.DOUBLE) double f2, @SqlType(StandardTypes.DOUBLE) double f3) - { - return featuresHelper(f1, f2, f3); - } - - @ScalarFunction - @SqlType(MAP_BIGINT_DOUBLE) - public static Block features(@SqlType(StandardTypes.DOUBLE) double f1, @SqlType(StandardTypes.DOUBLE) double f2, @SqlType(StandardTypes.DOUBLE) double f3, @SqlType(StandardTypes.DOUBLE) double f4) - { - return featuresHelper(f1, f2, f3, f4); - } - - @ScalarFunction - @SqlType(MAP_BIGINT_DOUBLE) - public static Block features(@SqlType(StandardTypes.DOUBLE) double f1, @SqlType(StandardTypes.DOUBLE) double f2, @SqlType(StandardTypes.DOUBLE) double f3, @SqlType(StandardTypes.DOUBLE) double f4, @SqlType(StandardTypes.DOUBLE) double f5) - { - return featuresHelper(f1, f2, f3, f4, f5); - } - - @ScalarFunction - @SqlType(MAP_BIGINT_DOUBLE) - public static Block features(@SqlType(StandardTypes.DOUBLE) double f1, @SqlType(StandardTypes.DOUBLE) double f2, @SqlType(StandardTypes.DOUBLE) double f3, @SqlType(StandardTypes.DOUBLE) double f4, @SqlType(StandardTypes.DOUBLE) double f5, @SqlType(StandardTypes.DOUBLE) double f6) - { - return featuresHelper(f1, f2, f3, f4, f5, f6); - } - - @ScalarFunction - @SqlType(MAP_BIGINT_DOUBLE) - public static Block features(@SqlType(StandardTypes.DOUBLE) double f1, @SqlType(StandardTypes.DOUBLE) double f2, @SqlType(StandardTypes.DOUBLE) double f3, @SqlType(StandardTypes.DOUBLE) double f4, @SqlType(StandardTypes.DOUBLE) double f5, @SqlType(StandardTypes.DOUBLE) double f6, @SqlType(StandardTypes.DOUBLE) double f7) - { - return featuresHelper(f1, f2, f3, f4, f5, f6, f7); - } - - @ScalarFunction - @SqlType(MAP_BIGINT_DOUBLE) - public static Block features(@SqlType(StandardTypes.DOUBLE) double f1, @SqlType(StandardTypes.DOUBLE) double f2, @SqlType(StandardTypes.DOUBLE) double f3, @SqlType(StandardTypes.DOUBLE) double f4, @SqlType(StandardTypes.DOUBLE) double f5, @SqlType(StandardTypes.DOUBLE) double f6, @SqlType(StandardTypes.DOUBLE) double f7, @SqlType(StandardTypes.DOUBLE) double f8) - { - return featuresHelper(f1, f2, f3, f4, f5, f6, f7, f8); - } - - @ScalarFunction - @SqlType(MAP_BIGINT_DOUBLE) - public static Block features(@SqlType(StandardTypes.DOUBLE) double f1, @SqlType(StandardTypes.DOUBLE) double f2, @SqlType(StandardTypes.DOUBLE) double f3, @SqlType(StandardTypes.DOUBLE) double f4, @SqlType(StandardTypes.DOUBLE) double f5, @SqlType(StandardTypes.DOUBLE) double f6, @SqlType(StandardTypes.DOUBLE) double f7, @SqlType(StandardTypes.DOUBLE) double f8, @SqlType(StandardTypes.DOUBLE) double f9) - { - return featuresHelper(f1, f2, f3, f4, f5, f6, f7, f8, f9); - } - - @ScalarFunction - @SqlType(MAP_BIGINT_DOUBLE) - public static Block features(@SqlType(StandardTypes.DOUBLE) double f1, @SqlType(StandardTypes.DOUBLE) double f2, @SqlType(StandardTypes.DOUBLE) double f3, @SqlType(StandardTypes.DOUBLE) double f4, @SqlType(StandardTypes.DOUBLE) double f5, @SqlType(StandardTypes.DOUBLE) double f6, @SqlType(StandardTypes.DOUBLE) double f7, @SqlType(StandardTypes.DOUBLE) double f8, @SqlType(StandardTypes.DOUBLE) double f9, @SqlType(StandardTypes.DOUBLE) double f10) - { - return featuresHelper(f1, f2, f3, f4, f5, f6, f7, f8, f9, f10); - } - - private static Block featuresHelper(double... features) - { - BlockBuilder blockBuilder = new InterleavedBlockBuilder(ImmutableList.of(BigintType.BIGINT, DoubleType.DOUBLE), new BlockBuilderStatus(), features.length); - - for (int i = 0; i < features.length; i++) { - BigintType.BIGINT.writeLong(blockBuilder, i); - DoubleType.DOUBLE.writeDouble(blockBuilder, features[i]); - } - - return blockBuilder.build(); - } } diff --git a/presto-ml/src/main/java/com/facebook/presto/ml/MLPlugin.java b/presto-ml/src/main/java/com/facebook/presto/ml/MLPlugin.java index 7cbfa86aef7c..c7d96bcf374c 100644 --- a/presto-ml/src/main/java/com/facebook/presto/ml/MLPlugin.java +++ b/presto-ml/src/main/java/com/facebook/presto/ml/MLPlugin.java @@ -52,6 +52,7 @@ public Set> getFunctions() .add(LearnLibSvmRegressorAggregation.class) .add(EvaluateClassifierPredictionsAggregation.class) .add(MLFunctions.class) + .add(MLFeaturesFunctions.class) .build(); } } diff --git a/presto-ml/src/test/java/com/facebook/presto/ml/AbstractTestMLFunctions.java b/presto-ml/src/test/java/com/facebook/presto/ml/AbstractTestMLFunctions.java new file mode 100644 index 000000000000..1206bd70f9da --- /dev/null +++ b/presto-ml/src/test/java/com/facebook/presto/ml/AbstractTestMLFunctions.java @@ -0,0 +1,31 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.facebook.presto.ml; + +import com.facebook.presto.operator.scalar.AbstractTestFunctions; +import org.testng.annotations.BeforeClass; + +import static com.facebook.presto.metadata.FunctionExtractor.extractFunctions; + +abstract class AbstractTestMLFunctions + extends AbstractTestFunctions +{ + @BeforeClass + protected void registerFunctions() + { + functionAssertions.getMetadata().addFunctions( + extractFunctions(new MLPlugin().getFunctions())); + } +} diff --git a/presto-ml/src/test/java/com/facebook/presto/ml/TestMLFeaturesFunctions.java b/presto-ml/src/test/java/com/facebook/presto/ml/TestMLFeaturesFunctions.java new file mode 100644 index 000000000000..c2d15c9adccf --- /dev/null +++ b/presto-ml/src/test/java/com/facebook/presto/ml/TestMLFeaturesFunctions.java @@ -0,0 +1,28 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.ml; + +import org.testng.annotations.Test; + +public class TestMLFeaturesFunctions + extends AbstractTestMLFunctions +{ + @Test + public void testFeatures() + throws Exception + { + // More tests related to `features` function can be found in TestMLQueries + assertCachedInstanceHasBoundedRetainedSize("features(1, 2)"); + } +} diff --git a/presto-mongodb/pom.xml b/presto-mongodb/pom.xml index 9cafda4f787e..29cca6c69338 100644 --- a/presto-mongodb/pom.xml +++ b/presto-mongodb/pom.xml @@ -4,7 +4,7 @@ com.facebook.presto presto-root - 0.179-tw-0.36 + 0.181-tw-0.37 presto-mongodb diff --git a/presto-mongodb/src/main/java/com/facebook/presto/mongodb/MongoPageSource.java b/presto-mongodb/src/main/java/com/facebook/presto/mongodb/MongoPageSource.java index 43ca7cbe9766..f29891953d21 100644 --- a/presto-mongodb/src/main/java/com/facebook/presto/mongodb/MongoPageSource.java +++ b/presto-mongodb/src/main/java/com/facebook/presto/mongodb/MongoPageSource.java @@ -214,18 +214,18 @@ private void writeBlock(BlockBuilder output, Type type, Object value) { if (isArrayType(type)) { if (value instanceof List) { - BlockBuilder builder = createParametersBlockBuilder(type, ((List) value).size()); + BlockBuilder builder = output.beginBlockEntry(); ((List) value).forEach(element -> appendTo(type.getTypeParameters().get(0), element, builder)); - type.writeObject(output, builder.build()); + output.closeEntry(); return; } } else if (isMapType(type)) { if (value instanceof List) { - BlockBuilder builder = createParametersBlockBuilder(type, ((List) value).size()); + BlockBuilder builder = output.beginBlockEntry(); for (Object element : (List) value) { if (!(element instanceof Map)) { continue; @@ -238,14 +238,14 @@ else if (isMapType(type)) { } } - type.writeObject(output, builder.build()); + output.closeEntry(); return; } } else if (isRowType(type)) { if (value instanceof Map) { Map mapValue = (Map) value; - BlockBuilder builder = createParametersBlockBuilder(type, mapValue.size()); + BlockBuilder builder = output.beginBlockEntry(); List fieldNames = type.getTypeSignature().getParameters().stream() .map(TypeSignatureParameter::getNamedTypeSignature) .map(NamedTypeSignature::getName) @@ -254,12 +254,12 @@ else if (isRowType(type)) { for (int index = 0; index < type.getTypeParameters().size(); index++) { appendTo(type.getTypeParameters().get(index), mapValue.get(fieldNames.get(index).toString()), builder); } - type.writeObject(output, builder.build()); + output.closeEntry(); return; } else if (value instanceof List) { List listValue = (List) value; - BlockBuilder builder = createParametersBlockBuilder(type, listValue.size()); + BlockBuilder builder = output.beginBlockEntry(); for (int index = 0; index < type.getTypeParameters().size(); index++) { if (index < listValue.size()) { appendTo(type.getTypeParameters().get(index), listValue.get(index), builder); @@ -268,7 +268,7 @@ else if (value instanceof List) { builder.appendNull(); } } - type.writeObject(output, builder.build()); + output.closeEntry(); return; } } diff --git a/presto-mongodb/src/main/java/com/facebook/presto/mongodb/MongoSession.java b/presto-mongodb/src/main/java/com/facebook/presto/mongodb/MongoSession.java index 0a91a95367ed..6fed21164668 100644 --- a/presto-mongodb/src/main/java/com/facebook/presto/mongodb/MongoSession.java +++ b/presto-mongodb/src/main/java/com/facebook/presto/mongodb/MongoSession.java @@ -65,6 +65,7 @@ import static com.facebook.presto.spi.type.TimestampType.TIMESTAMP; import static com.facebook.presto.spi.type.VarcharType.createUnboundedVarcharType; import static com.google.common.base.Preconditions.checkState; +import static com.google.common.base.Throwables.throwIfInstanceOf; import static com.google.common.base.Verify.verify; import static java.util.Objects.requireNonNull; import static java.util.concurrent.TimeUnit.HOURS; @@ -232,7 +233,9 @@ private static V getCacheValue(LoadingCache ca } catch (ExecutionException | UncheckedExecutionException e) { Throwable t = e.getCause(); - Throwables.propagateIfInstanceOf(t, exceptionClass); + if (t != null) { + throwIfInstanceOf(t, exceptionClass); + } throw Throwables.propagate(t); } } diff --git a/presto-mysql/pom.xml b/presto-mysql/pom.xml index 9682652322c6..1e188a5e822a 100644 --- a/presto-mysql/pom.xml +++ b/presto-mysql/pom.xml @@ -5,7 +5,7 @@ com.facebook.presto presto-root - 0.179-tw-0.36 + 0.181-tw-0.37 presto-mysql diff --git a/presto-orc/pom.xml b/presto-orc/pom.xml index ac1b2eb9a068..8f5fa1f55d2d 100644 --- a/presto-orc/pom.xml +++ b/presto-orc/pom.xml @@ -5,7 +5,7 @@ com.facebook.presto presto-root - 0.179-tw-0.36 + 0.181-tw-0.37 presto-orc diff --git a/presto-orc/src/main/java/com/facebook/presto/orc/OrcReader.java b/presto-orc/src/main/java/com/facebook/presto/orc/OrcReader.java index 1693d4578665..8adba7b45e43 100644 --- a/presto-orc/src/main/java/com/facebook/presto/orc/OrcReader.java +++ b/presto-orc/src/main/java/com/facebook/presto/orc/OrcReader.java @@ -56,6 +56,7 @@ public class OrcReader private final ExceptionWrappingMetadataReader metadataReader; private final DataSize maxMergeDistance; private final DataSize maxReadSize; + private final DataSize maxBlockSize; private final HiveWriterVersion hiveWriterVersion; private final int bufferSize; private final Footer footer; @@ -63,14 +64,15 @@ public class OrcReader private Optional decompressor = Optional.empty(); // This is based on the Apache Hive ORC code - public OrcReader(OrcDataSource orcDataSource, MetadataReader metadataReader, DataSize maxMergeDistance, DataSize maxReadSize) + public OrcReader(OrcDataSource orcDataSource, MetadataReader delegate, DataSize maxMergeDistance, DataSize maxReadSize, DataSize maxBlockSize) throws IOException { orcDataSource = wrapWithCacheIfTiny(requireNonNull(orcDataSource, "orcDataSource is null"), maxMergeDistance); this.orcDataSource = orcDataSource; - this.metadataReader = new ExceptionWrappingMetadataReader(orcDataSource.getId(), requireNonNull(metadataReader, "metadataReader is null")); + this.metadataReader = new ExceptionWrappingMetadataReader(orcDataSource.getId(), requireNonNull(delegate, "delegate is null")); this.maxMergeDistance = requireNonNull(maxMergeDistance, "maxMergeDistance is null"); this.maxReadSize = requireNonNull(maxReadSize, "maxReadSize is null"); + this.maxBlockSize = requireNonNull(maxBlockSize, "maxBlockSize is null"); // // Read the file tail: @@ -213,6 +215,7 @@ public OrcRecordReader createRecordReader( metadataReader, maxMergeDistance, maxReadSize, + maxBlockSize, footer.getUserMetadata(), systemMemoryUsage); } diff --git a/presto-orc/src/main/java/com/facebook/presto/orc/OrcRecordReader.java b/presto-orc/src/main/java/com/facebook/presto/orc/OrcRecordReader.java index b324d083d684..ab924c00d95f 100644 --- a/presto-orc/src/main/java/com/facebook/presto/orc/OrcRecordReader.java +++ b/presto-orc/src/main/java/com/facebook/presto/orc/OrcRecordReader.java @@ -53,6 +53,7 @@ import static com.facebook.presto.orc.OrcReader.MAX_BATCH_SIZE; import static com.facebook.presto.orc.OrcRecordReader.LinearProbeRangeFinder.createTinyStripesRangeFinder; import static com.google.common.base.Preconditions.checkArgument; +import static java.lang.Math.max; import static java.lang.Math.min; import static java.lang.Math.toIntExact; import static java.util.Comparator.comparingLong; @@ -64,13 +65,17 @@ public class OrcRecordReader private final OrcDataSource orcDataSource; private final StreamReader[] streamReaders; + private final long[] maxBytesPerCell; + private long maxCombinedBytesPerRow; private final long totalRowCount; private final long splitLength; private final Set presentColumns; + private final long maxBlockBytes; private long currentPosition; private long currentStripePosition; private int currentBatchSize; + private int maxBatchSize = MAX_BATCH_SIZE; private final List stripes; private final StripeReader stripeReader; @@ -107,6 +112,7 @@ public OrcRecordReader( MetadataReader metadataReader, DataSize maxMergeDistance, DataSize maxReadSize, + DataSize maxBlockSize, Map userMetadata, AbstractAggregatedMemoryContext systemMemoryUsage) throws IOException @@ -135,6 +141,8 @@ public OrcRecordReader( } this.presentColumns = presentColumns.build(); + this.maxBlockBytes = requireNonNull(maxBlockSize, "maxBlockSize is null").toBytes(); + // it is possible that old versions of orc use 0 to mean there are no row groups checkArgument(rowsInRowGroup > 0, "rowsInRowGroup must be greater than zero"); @@ -195,6 +203,7 @@ public OrcRecordReader( metadataReader); streamReaders = createStreamReaders(orcDataSource, types, hiveStorageTimeZone, presentColumnsAndTypes.build()); + maxBytesPerCell = new long[streamReaders.length]; } private static boolean splitContainsStripe(long splitOffset, long splitLength, StripeInformation stripe) @@ -279,6 +288,14 @@ public long getSplitLength() return splitLength; } + /** + * Returns the sum of the largest cells in size from each column + */ + public long getMaxCombinedBytesPerRow() + { + return maxCombinedBytesPerRow; + } + @Override public void close() throws IOException @@ -308,7 +325,7 @@ public int nextBatch() } } - currentBatchSize = toIntExact(min(MAX_BATCH_SIZE, currentGroupRowCount - nextRowInGroup)); + currentBatchSize = toIntExact(min(maxBatchSize, currentGroupRowCount - nextRowInGroup)); for (StreamReader column : streamReaders) { if (column != null) { @@ -322,7 +339,16 @@ public int nextBatch() public Block readBlock(Type type, int columnIndex) throws IOException { - return streamReaders[columnIndex].readBlock(type); + Block block = streamReaders[columnIndex].readBlock(type); + if (block.getPositionCount() > 0) { + long bytesPerCell = block.getSizeInBytes() / block.getPositionCount(); + if (maxBytesPerCell[columnIndex] < bytesPerCell) { + maxCombinedBytesPerRow = maxCombinedBytesPerRow - maxBytesPerCell[columnIndex] + bytesPerCell; + maxBytesPerCell[columnIndex] = bytesPerCell; + maxBatchSize = toIntExact(min(maxBatchSize, max(1, maxBlockBytes / maxCombinedBytesPerRow))); + } + } + return block; } public StreamReader getStreamReader(int index) diff --git a/presto-orc/src/main/java/com/facebook/presto/orc/reader/MapStreamReader.java b/presto-orc/src/main/java/com/facebook/presto/orc/reader/MapStreamReader.java index e476d8a27654..3514ede45dbd 100644 --- a/presto-orc/src/main/java/com/facebook/presto/orc/reader/MapStreamReader.java +++ b/presto-orc/src/main/java/com/facebook/presto/orc/reader/MapStreamReader.java @@ -20,10 +20,9 @@ import com.facebook.presto.orc.stream.InputStreamSource; import com.facebook.presto.orc.stream.InputStreamSources; import com.facebook.presto.orc.stream.LongInputStream; -import com.facebook.presto.spi.block.ArrayBlock; import com.facebook.presto.spi.block.Block; import com.facebook.presto.spi.block.BlockBuilderStatus; -import com.facebook.presto.spi.block.InterleavedBlock; +import com.facebook.presto.spi.type.MapType; import com.facebook.presto.spi.type.Type; import it.unimi.dsi.fastutil.ints.IntArrayList; import org.joda.time.DateTimeZone; @@ -125,8 +124,9 @@ public Block readBlock(Type type) } } - Type keyType = type.getTypeParameters().get(0); - Type valueType = type.getTypeParameters().get(1); + MapType mapType = (MapType) type; + Type keyType = mapType.getKeyType(); + Type valueType = mapType.getValueType(); int entryCount = 0; for (int length : lengthVector) { @@ -146,26 +146,25 @@ public Block readBlock(Type type) values = valueType.createBlockBuilder(new BlockBuilderStatus(), 1).build(); } - InterleavedBlock keyValueBlock = createKeyValueBlock(nextBatchSize, keys, values, lengthVector); + Block[] keyValueBlock = createKeyValueBlock(nextBatchSize, keys, values, lengthVector); // convert lengths into offsets into the keyValueBlock (e.g., two positions per entry) int[] offsets = new int[nextBatchSize + 1]; for (int i = 1; i < offsets.length; i++) { - int length = lengthVector[i - 1] * 2; + int length = lengthVector[i - 1]; offsets[i] = offsets[i - 1] + length; } - ArrayBlock arrayBlock = new ArrayBlock(nextBatchSize, nullVector, offsets, keyValueBlock); readOffset = 0; nextBatchSize = 0; - return arrayBlock; + return mapType.createBlockFromKeyValue(nullVector, offsets, keyValueBlock[0], keyValueBlock[1]); } - private static InterleavedBlock createKeyValueBlock(int positionCount, Block keys, Block values, int[] lengths) + private static Block[] createKeyValueBlock(int positionCount, Block keys, Block values, int[] lengths) { if (!hasNull(keys)) { - return new InterleavedBlock(new Block[] {keys, values}); + return new Block[] {keys, values}; } // @@ -191,7 +190,7 @@ private static InterleavedBlock createKeyValueBlock(int positionCount, Block key Block newKeys = keys.copyPositions(nonNullPositions); Block newValues = values.copyPositions(nonNullPositions); - return new InterleavedBlock(new Block[] {newKeys, newValues}); + return new Block[] {newKeys, newValues}; } private static boolean hasNull(Block keys) diff --git a/presto-orc/src/main/java/com/facebook/presto/orc/reader/SliceDictionaryStreamReader.java b/presto-orc/src/main/java/com/facebook/presto/orc/reader/SliceDictionaryStreamReader.java index d02660d13d68..09b3709252d4 100644 --- a/presto-orc/src/main/java/com/facebook/presto/orc/reader/SliceDictionaryStreamReader.java +++ b/presto-orc/src/main/java/com/facebook/presto/orc/reader/SliceDictionaryStreamReader.java @@ -184,7 +184,6 @@ else if (inDictionary[i]) { } } - // copy ids into a private array for this block since data vector is reused Block block = new DictionaryBlock(nextBatchSize, dictionaryBlock, dataVector); readOffset = 0; diff --git a/presto-orc/src/main/java/com/facebook/presto/orc/reader/SliceDirectStreamReader.java b/presto-orc/src/main/java/com/facebook/presto/orc/reader/SliceDirectStreamReader.java index 3f43b45e65ed..5db980cb3043 100644 --- a/presto-orc/src/main/java/com/facebook/presto/orc/reader/SliceDirectStreamReader.java +++ b/presto-orc/src/main/java/com/facebook/presto/orc/reader/SliceDirectStreamReader.java @@ -26,6 +26,7 @@ import com.facebook.presto.spi.type.Type; import io.airlift.slice.Slice; import io.airlift.slice.Slices; +import io.airlift.units.DataSize; import javax.annotation.Nonnull; import javax.annotation.Nullable; @@ -43,12 +44,15 @@ import static com.facebook.presto.spi.type.Varchars.isVarcharType; import static com.facebook.presto.spi.type.Varchars.truncateToLength; import static com.google.common.base.MoreObjects.toStringHelper; +import static io.airlift.units.DataSize.Unit.GIGABYTE; +import static java.lang.Math.toIntExact; import static java.util.Objects.requireNonNull; public class SliceDirectStreamReader implements StreamReader { private static final byte[] EMPTY_BYTE_ARRAY = new byte[0]; + private static final int ONE_GIGABYTE = toIntExact(new DataSize(1, GIGABYTE).toBytes()); private final StreamDescriptor streamDescriptor; @@ -137,19 +141,22 @@ public Block readBlock(Type type) } } - int totalLength = 0; + long totalLength = 0; for (int i = 0; i < nextBatchSize; i++) { if (!isNullVector[i]) { totalLength += lengthVector[i]; } } + if (totalLength > ONE_GIGABYTE) { + throw new OrcCorruptionException(streamDescriptor.getOrcDataSourceId(), "Column values too large to process in Presto. %s column values larger than 1GB", nextBatchSize); + } byte[] data = EMPTY_BYTE_ARRAY; if (totalLength > 0) { if (dataStream == null) { throw new OrcCorruptionException(streamDescriptor.getOrcDataSourceId(), "Value is not null but data stream is not present"); } - data = dataStream.next(totalLength); + data = dataStream.next(toIntExact(totalLength)); } Slice[] sliceVector = new Slice[nextBatchSize]; diff --git a/presto-orc/src/test/java/com/facebook/presto/orc/BenchmarkOrcDecimalReader.java b/presto-orc/src/test/java/com/facebook/presto/orc/BenchmarkOrcDecimalReader.java index e35300987ac0..739a1482ad88 100644 --- a/presto-orc/src/test/java/com/facebook/presto/orc/BenchmarkOrcDecimalReader.java +++ b/presto-orc/src/test/java/com/facebook/presto/orc/BenchmarkOrcDecimalReader.java @@ -117,7 +117,7 @@ private OrcRecordReader createRecordReader() { OrcDataSource dataSource = new FileOrcDataSource(dataPath, new DataSize(1, MEGABYTE), new DataSize(1, MEGABYTE), new DataSize(1, MEGABYTE)); MetadataReader metadataReader = new OrcMetadataReader(); - OrcReader orcReader = new OrcReader(dataSource, metadataReader, new DataSize(1, MEGABYTE), new DataSize(1, MEGABYTE)); + OrcReader orcReader = new OrcReader(dataSource, metadataReader, new DataSize(1, MEGABYTE), new DataSize(1, MEGABYTE), new DataSize(1, MEGABYTE)); return orcReader.createRecordReader( ImmutableMap.of(0, DECIMAL_TYPE), OrcPredicate.TRUE, diff --git a/presto-orc/src/test/java/com/facebook/presto/orc/OrcTester.java b/presto-orc/src/test/java/com/facebook/presto/orc/OrcTester.java index fb0b84599cdb..e5e4fbf387ed 100644 --- a/presto-orc/src/test/java/com/facebook/presto/orc/OrcTester.java +++ b/presto-orc/src/test/java/com/facebook/presto/orc/OrcTester.java @@ -559,7 +559,7 @@ static OrcRecordReader createCustomOrcRecordReader(TempFile tempFile, MetadataRe throws IOException { OrcDataSource orcDataSource = new FileOrcDataSource(tempFile.getFile(), new DataSize(1, Unit.MEGABYTE), new DataSize(1, Unit.MEGABYTE), new DataSize(1, Unit.MEGABYTE)); - OrcReader orcReader = new OrcReader(orcDataSource, metadataReader, new DataSize(1, Unit.MEGABYTE), new DataSize(1, Unit.MEGABYTE)); + OrcReader orcReader = new OrcReader(orcDataSource, metadataReader, new DataSize(1, Unit.MEGABYTE), new DataSize(1, Unit.MEGABYTE), new DataSize(1, Unit.MEGABYTE)); assertEquals(orcReader.getColumnNames(), ImmutableList.of("test")); assertEquals(orcReader.getFooter().getRowsInRowGroup(), 10_000); diff --git a/presto-orc/src/test/java/com/facebook/presto/orc/TestCachingOrcDataSource.java b/presto-orc/src/test/java/com/facebook/presto/orc/TestCachingOrcDataSource.java index cccaf9d91f0a..d9cf0a5d1b6f 100644 --- a/presto-orc/src/test/java/com/facebook/presto/orc/TestCachingOrcDataSource.java +++ b/presto-orc/src/test/java/com/facebook/presto/orc/TestCachingOrcDataSource.java @@ -193,7 +193,7 @@ public void testIntegration() public void doIntegration(TestingOrcDataSource orcDataSource, DataSize maxMergeDistance, DataSize maxReadSize) throws IOException { - OrcReader orcReader = new OrcReader(orcDataSource, new OrcMetadataReader(), maxMergeDistance, maxReadSize); + OrcReader orcReader = new OrcReader(orcDataSource, new OrcMetadataReader(), maxMergeDistance, maxReadSize, new DataSize(1, Unit.MEGABYTE)); // 1 for reading file footer assertEquals(orcDataSource.getReadCount(), 1); List stripes = orcReader.getFooter().getStripes(); diff --git a/presto-orc/src/test/java/com/facebook/presto/orc/TestOrcReaderPositions.java b/presto-orc/src/test/java/com/facebook/presto/orc/TestOrcReaderPositions.java index 839b618d7c0a..46eefb7a3a96 100644 --- a/presto-orc/src/test/java/com/facebook/presto/orc/TestOrcReaderPositions.java +++ b/presto-orc/src/test/java/com/facebook/presto/orc/TestOrcReaderPositions.java @@ -183,7 +183,7 @@ public void testReadUserMetadata() createFileWithOnlyUserMetadata(tempFile.getFile(), metadata); OrcDataSource orcDataSource = new FileOrcDataSource(tempFile.getFile(), new DataSize(1, DataSize.Unit.MEGABYTE), new DataSize(1, DataSize.Unit.MEGABYTE), new DataSize(1, DataSize.Unit.MEGABYTE)); - OrcReader orcReader = new OrcReader(orcDataSource, new OrcMetadataReader(), new DataSize(1, DataSize.Unit.MEGABYTE), new DataSize(1, DataSize.Unit.MEGABYTE)); + OrcReader orcReader = new OrcReader(orcDataSource, new OrcMetadataReader(), new DataSize(1, DataSize.Unit.MEGABYTE), new DataSize(1, DataSize.Unit.MEGABYTE), new DataSize(1, DataSize.Unit.MEGABYTE)); Footer footer = orcReader.getFooter(); Map readMetadata = Maps.transformValues(footer.getUserMetadata(), Slice::toStringAscii); assertEquals(readMetadata, metadata); diff --git a/presto-parser/pom.xml b/presto-parser/pom.xml index 0efeef5da3e7..27d1e819e309 100644 --- a/presto-parser/pom.xml +++ b/presto-parser/pom.xml @@ -5,7 +5,7 @@ com.facebook.presto presto-root - 0.179-tw-0.36 + 0.181-tw-0.37 presto-parser diff --git a/presto-parser/src/main/antlr4/com/facebook/presto/sql/parser/SqlBase.g4 b/presto-parser/src/main/antlr4/com/facebook/presto/sql/parser/SqlBase.g4 index b82a091641f9..3f3bded218ed 100644 --- a/presto-parser/src/main/antlr4/com/facebook/presto/sql/parser/SqlBase.g4 +++ b/presto-parser/src/main/antlr4/com/facebook/presto/sql/parser/SqlBase.g4 @@ -48,6 +48,8 @@ statement | ALTER TABLE from=qualifiedName RENAME TO to=qualifiedName #renameTable | ALTER TABLE tableName=qualifiedName RENAME COLUMN from=identifier TO to=identifier #renameColumn + | ALTER TABLE tableName=qualifiedName + DROP COLUMN column=qualifiedName #dropColumn | ALTER TABLE tableName=qualifiedName ADD COLUMN column=columnDefinition #addColumn | CREATE (OR REPLACE)? VIEW qualifiedName AS query #createView @@ -63,7 +65,7 @@ statement ON TABLE? qualifiedName FROM grantee=identifier #revoke | SHOW GRANTS (ON TABLE? qualifiedName)? #showGrants - | EXPLAIN ANALYZE? + | EXPLAIN ANALYZE? VERBOSE? ('(' explainOption (',' explainOption)* ')')? statement #explain | SHOW CREATE TABLE qualifiedName #showCreateTable | SHOW CREATE VIEW qualifiedName #showCreateView @@ -219,7 +221,6 @@ sampledRelation sampleType : BERNOULLI | SYSTEM - | POISSONIZED ; aliasedRelation @@ -234,6 +235,7 @@ relationPrimary : qualifiedName #tableName | '(' query ')' #subqueryRelation | UNNEST '(' expression (',' expression)* ')' (WITH ORDINALITY)? #unnest + | LATERAL '(' query ')' #lateral | '(' relation ')' #parenthesizedRelation ; @@ -448,214 +450,211 @@ number nonReserved // IMPORTANT: this rule must only contain tokens. Nested rules are not supported. See SqlParser.exitNonReserved - : SHOW | TABLES | COLUMNS | COLUMN | PARTITIONS | FUNCTIONS | SCHEMAS | CATALOGS | SESSION | STATS - | ADD - | FILTER - | AT - | OVER | PARTITION | RANGE | ROWS | PRECEDING | FOLLOWING | CURRENT | ROW | MAP | ARRAY - | TINYINT | SMALLINT | INTEGER | DATE | TIME | TIMESTAMP | INTERVAL | ZONE - | YEAR | MONTH | DAY | HOUR | MINUTE | SECOND - | EXPLAIN | ANALYZE | FORMAT | TYPE | TEXT | GRAPHVIZ | LOGICAL | DISTRIBUTED | VALIDATE - | TABLESAMPLE | SYSTEM | BERNOULLI | POISSONIZED | USE | TO - | SET | RESET - | VIEW | REPLACE - | IF | NULLIF | COALESCE - | NFD | NFC | NFKD | NFKC - | POSITION - | NO | DATA - | START | TRANSACTION | COMMIT | ROLLBACK | WORK | ISOLATION | LEVEL - | SERIALIZABLE | REPEATABLE | COMMITTED | UNCOMMITTED | READ | WRITE | ONLY - | COMMENT - | CALL - | GRANT | REVOKE | PRIVILEGES | PUBLIC | OPTION | GRANTS - | SUBSTRING - | SCHEMA | CASCADE | RESTRICT - | INPUT | OUTPUT - | INCLUDING | EXCLUDING | PROPERTIES - | ALL | SOME | ANY + : ADD | ALL | ANALYZE | ANY | ARRAY | ASC | AT + | BERNOULLI + | CALL | CASCADE | CATALOGS | COALESCE | COLUMN | COLUMNS | COMMENT | COMMIT | COMMITTED | CURRENT + | DATA | DATE | DAY | DESC | DISTRIBUTED + | EXCLUDING | EXPLAIN + | FILTER | FIRST | FOLLOWING | FORMAT | FUNCTIONS + | GRANT | GRANTS | GRAPHVIZ + | HOUR + | IF | INCLUDING | INPUT | INTEGER | INTERVAL | ISOLATION + | LAST | LATERAL | LEVEL | LIMIT | LOGICAL + | MAP | MINUTE | MONTH + | NFC | NFD | NFKC | NFKD | NO | NULLIF | NULLS + | ONLY | OPTION | ORDINALITY | OUTPUT | OVER + | PARTITION | PARTITIONS | POSITION | PRECEDING | PRIVILEGES | PROPERTIES | PUBLIC + | RANGE | READ | RENAME | REPEATABLE | REPLACE | RESET | RESTRICT | REVOKE | ROLLBACK | ROW | ROWS + | SCHEMA | SCHEMAS | SECOND | SERIALIZABLE | SESSION | SET | SETS + | SHOW | SMALLINT | SOME | START | STATS | SUBSTRING | SYSTEM + | TABLES | TABLESAMPLE | TEXT | TIME | TIMESTAMP | TINYINT | TO | TRANSACTION | TRY_CAST | TYPE + | UNBOUNDED | UNCOMMITTED | USE + | VALIDATE | VERBOSE | VIEW + | WORK | WRITE + | YEAR + | ZONE ; -SELECT: 'SELECT'; -FROM: 'FROM'; ADD: 'ADD'; -AS: 'AS'; ALL: 'ALL'; -SOME: 'SOME'; +ALTER: 'ALTER'; +ANALYZE: 'ANALYZE'; +AND: 'AND'; ANY: 'ANY'; -DISTINCT: 'DISTINCT'; -WHERE: 'WHERE'; -GROUP: 'GROUP'; +ARRAY: 'ARRAY'; +AS: 'AS'; +ASC: 'ASC'; +AT: 'AT'; +BERNOULLI: 'BERNOULLI'; +BETWEEN: 'BETWEEN'; BY: 'BY'; -GROUPING: 'GROUPING'; -SETS: 'SETS'; +CALL: 'CALL'; +CASCADE: 'CASCADE'; +CASE: 'CASE'; +CAST: 'CAST'; +CATALOGS: 'CATALOGS'; +COALESCE: 'COALESCE'; +COLUMN: 'COLUMN'; +COLUMNS: 'COLUMNS'; +COMMENT: 'COMMENT'; +COMMIT: 'COMMIT'; +COMMITTED: 'COMMITTED'; +CONSTRAINT: 'CONSTRAINT'; +CREATE: 'CREATE'; +CROSS: 'CROSS'; CUBE: 'CUBE'; -ROLLUP: 'ROLLUP'; -ORDER: 'ORDER'; -HAVING: 'HAVING'; -LIMIT: 'LIMIT'; -AT: 'AT'; -OR: 'OR'; -AND: 'AND'; -IN: 'IN'; -NOT: 'NOT'; -NO: 'NO'; +CURRENT: 'CURRENT'; +CURRENT_DATE: 'CURRENT_DATE'; +CURRENT_TIME: 'CURRENT_TIME'; +CURRENT_TIMESTAMP: 'CURRENT_TIMESTAMP'; +DATA: 'DATA'; +DATE: 'DATE'; +DAY: 'DAY'; +DEALLOCATE: 'DEALLOCATE'; +DELETE: 'DELETE'; +DESC: 'DESC'; +DESCRIBE: 'DESCRIBE'; +DISTINCT: 'DISTINCT'; +DISTRIBUTED: 'DISTRIBUTED'; +DROP: 'DROP'; +ELSE: 'ELSE'; +END: 'END'; +ESCAPE: 'ESCAPE'; +EXCEPT: 'EXCEPT'; +EXCLUDING: 'EXCLUDING'; +EXECUTE: 'EXECUTE'; EXISTS: 'EXISTS'; -BETWEEN: 'BETWEEN'; -LIKE: 'LIKE'; -IS: 'IS'; -NULL: 'NULL'; -TRUE: 'TRUE'; +EXPLAIN: 'EXPLAIN'; +EXTRACT: 'EXTRACT'; FALSE: 'FALSE'; -NULLS: 'NULLS'; +FILTER: 'FILTER'; FIRST: 'FIRST'; -LAST: 'LAST'; -ESCAPE: 'ESCAPE'; -ASC: 'ASC'; -DESC: 'DESC'; -SUBSTRING: 'SUBSTRING'; -POSITION: 'POSITION'; +FOLLOWING: 'FOLLOWING'; FOR: 'FOR'; -TINYINT: 'TINYINT'; -SMALLINT: 'SMALLINT'; +FORMAT: 'FORMAT'; +FROM: 'FROM'; +FULL: 'FULL'; +FUNCTIONS: 'FUNCTIONS'; +GRANT: 'GRANT'; +GRANTS: 'GRANTS'; +GRAPHVIZ: 'GRAPHVIZ'; +GROUP: 'GROUP'; +GROUPING: 'GROUPING'; +HAVING: 'HAVING'; +HOUR: 'HOUR'; +IF: 'IF'; +IN: 'IN'; +INCLUDING: 'INCLUDING'; +INNER: 'INNER'; +INPUT: 'INPUT'; +INSERT: 'INSERT'; INTEGER: 'INTEGER'; -DATE: 'DATE'; -TIME: 'TIME'; -TIMESTAMP: 'TIMESTAMP'; +INTERSECT: 'INTERSECT'; INTERVAL: 'INTERVAL'; -YEAR: 'YEAR'; -MONTH: 'MONTH'; -DAY: 'DAY'; -HOUR: 'HOUR'; -MINUTE: 'MINUTE'; -SECOND: 'SECOND'; -ZONE: 'ZONE'; -CURRENT_DATE: 'CURRENT_DATE'; -CURRENT_TIME: 'CURRENT_TIME'; -CURRENT_TIMESTAMP: 'CURRENT_TIMESTAMP'; -LOCALTIME: 'LOCALTIME'; -LOCALTIMESTAMP: 'LOCALTIMESTAMP'; -EXTRACT: 'EXTRACT'; -CASE: 'CASE'; -WHEN: 'WHEN'; -THEN: 'THEN'; -ELSE: 'ELSE'; -END: 'END'; +INTO: 'INTO'; +IS: 'IS'; +ISOLATION: 'ISOLATION'; JOIN: 'JOIN'; -CROSS: 'CROSS'; -OUTER: 'OUTER'; -INNER: 'INNER'; +LAST: 'LAST'; +LATERAL: 'LATERAL'; LEFT: 'LEFT'; -RIGHT: 'RIGHT'; -FULL: 'FULL'; +LEVEL: 'LEVEL'; +LIKE: 'LIKE'; +LIMIT: 'LIMIT'; +LOCALTIME: 'LOCALTIME'; +LOCALTIMESTAMP: 'LOCALTIMESTAMP'; +LOGICAL: 'LOGICAL'; +MAP: 'MAP'; +MINUTE: 'MINUTE'; +MONTH: 'MONTH'; NATURAL: 'NATURAL'; -USING: 'USING'; +NFC : 'NFC'; +NFD : 'NFD'; +NFKC : 'NFKC'; +NFKD : 'NFKD'; +NO: 'NO'; +NORMALIZE: 'NORMALIZE'; +NOT: 'NOT'; +NULL: 'NULL'; +NULLIF: 'NULLIF'; +NULLS: 'NULLS'; ON: 'ON'; -FILTER: 'FILTER'; +ONLY: 'ONLY'; +OPTION: 'OPTION'; +OR: 'OR'; +ORDER: 'ORDER'; +ORDINALITY: 'ORDINALITY'; +OUTER: 'OUTER'; +OUTPUT: 'OUTPUT'; OVER: 'OVER'; PARTITION: 'PARTITION'; -RANGE: 'RANGE'; -ROWS: 'ROWS'; -UNBOUNDED: 'UNBOUNDED'; +PARTITIONS: 'PARTITIONS'; +POSITION: 'POSITION'; PRECEDING: 'PRECEDING'; -FOLLOWING: 'FOLLOWING'; -CURRENT: 'CURRENT'; -ROW: 'ROW'; -WITH: 'WITH'; -RECURSIVE: 'RECURSIVE'; -VALUES: 'VALUES'; -CREATE: 'CREATE'; -SCHEMA: 'SCHEMA'; -TABLE: 'TABLE'; -COMMENT: 'COMMENT'; -VIEW: 'VIEW'; -REPLACE: 'REPLACE'; -INSERT: 'INSERT'; -DELETE: 'DELETE'; -INTO: 'INTO'; -CONSTRAINT: 'CONSTRAINT'; -DESCRIBE: 'DESCRIBE'; -GRANT: 'GRANT'; -REVOKE: 'REVOKE'; +PREPARE: 'PREPARE'; PRIVILEGES: 'PRIVILEGES'; +PROPERTIES: 'PROPERTIES'; PUBLIC: 'PUBLIC'; -OPTION: 'OPTION'; -GRANTS: 'GRANTS'; -EXPLAIN: 'EXPLAIN'; -ANALYZE: 'ANALYZE'; -FORMAT: 'FORMAT'; -TYPE: 'TYPE'; -TEXT: 'TEXT'; -GRAPHVIZ: 'GRAPHVIZ'; -LOGICAL: 'LOGICAL'; -DISTRIBUTED: 'DISTRIBUTED'; -VALIDATE: 'VALIDATE'; -CAST: 'CAST'; -TRY_CAST: 'TRY_CAST'; -SHOW: 'SHOW'; -TABLES: 'TABLES'; -SCHEMAS: 'SCHEMAS'; -CATALOGS: 'CATALOGS'; -COLUMNS: 'COLUMNS'; -COLUMN: 'COLUMN'; -USE: 'USE'; -PARTITIONS: 'PARTITIONS'; -FUNCTIONS: 'FUNCTIONS'; -DROP: 'DROP'; -UNION: 'UNION'; -EXCEPT: 'EXCEPT'; -INTERSECT: 'INTERSECT'; -TO: 'TO'; -SYSTEM: 'SYSTEM'; -BERNOULLI: 'BERNOULLI'; -POISSONIZED: 'POISSONIZED'; -TABLESAMPLE: 'TABLESAMPLE'; -ALTER: 'ALTER'; +RANGE: 'RANGE'; +READ: 'READ'; +RECURSIVE: 'RECURSIVE'; RENAME: 'RENAME'; -UNNEST: 'UNNEST'; -ORDINALITY: 'ORDINALITY'; -ARRAY: 'ARRAY'; -MAP: 'MAP'; -SET: 'SET'; +REPEATABLE: 'REPEATABLE'; +REPLACE: 'REPLACE'; RESET: 'RESET'; +RESTRICT: 'RESTRICT'; +REVOKE: 'REVOKE'; +RIGHT: 'RIGHT'; +ROLLBACK: 'ROLLBACK'; +ROLLUP: 'ROLLUP'; +ROW: 'ROW'; +ROWS: 'ROWS'; +SCHEMA: 'SCHEMA'; +SCHEMAS: 'SCHEMAS'; +SECOND: 'SECOND'; +SELECT: 'SELECT'; +SERIALIZABLE: 'SERIALIZABLE'; SESSION: 'SESSION'; -DATA: 'DATA'; +SET: 'SET'; +SETS: 'SETS'; +SHOW: 'SHOW'; +SMALLINT: 'SMALLINT'; +SOME: 'SOME'; START: 'START'; +STATS: 'STATS'; +SUBSTRING: 'SUBSTRING'; +SYSTEM: 'SYSTEM'; +TABLE: 'TABLE'; +TABLES: 'TABLES'; +TABLESAMPLE: 'TABLESAMPLE'; +TEXT: 'TEXT'; +THEN: 'THEN'; +TIME: 'TIME'; +TIMESTAMP: 'TIMESTAMP'; +TINYINT: 'TINYINT'; +TO: 'TO'; TRANSACTION: 'TRANSACTION'; -COMMIT: 'COMMIT'; -ROLLBACK: 'ROLLBACK'; -WORK: 'WORK'; -ISOLATION: 'ISOLATION'; -LEVEL: 'LEVEL'; -SERIALIZABLE: 'SERIALIZABLE'; -REPEATABLE: 'REPEATABLE'; -COMMITTED: 'COMMITTED'; +TRUE: 'TRUE'; +TRY_CAST: 'TRY_CAST'; +TYPE: 'TYPE'; +UESCAPE: 'UESCAPE'; +UNBOUNDED: 'UNBOUNDED'; UNCOMMITTED: 'UNCOMMITTED'; -READ: 'READ'; +UNION: 'UNION'; +UNNEST: 'UNNEST'; +USE: 'USE'; +USING: 'USING'; +VALIDATE: 'VALIDATE'; +VALUES: 'VALUES'; +VERBOSE: 'VERBOSE'; +VIEW: 'VIEW'; +WHEN: 'WHEN'; +WHERE: 'WHERE'; +WITH: 'WITH'; +WORK: 'WORK'; WRITE: 'WRITE'; -ONLY: 'ONLY'; -CALL: 'CALL'; -PREPARE: 'PREPARE'; -DEALLOCATE: 'DEALLOCATE'; -EXECUTE: 'EXECUTE'; -INPUT: 'INPUT'; -OUTPUT: 'OUTPUT'; -CASCADE: 'CASCADE'; -RESTRICT: 'RESTRICT'; -INCLUDING: 'INCLUDING'; -EXCLUDING: 'EXCLUDING'; -PROPERTIES: 'PROPERTIES'; -UESCAPE: 'UESCAPE'; -STATS: 'STATS'; - -NORMALIZE: 'NORMALIZE'; -NFD : 'NFD'; -NFC : 'NFC'; -NFKD : 'NFKD'; -NFKC : 'NFKC'; - -IF: 'IF'; -NULLIF: 'NULLIF'; -COALESCE: 'COALESCE'; +YEAR: 'YEAR'; +ZONE: 'ZONE'; EQ : '='; NEQ : '<>' | '!='; diff --git a/presto-parser/src/main/java/com/facebook/presto/sql/ReservedIdentifiers.java b/presto-parser/src/main/java/com/facebook/presto/sql/ReservedIdentifiers.java new file mode 100644 index 000000000000..76add99963ac --- /dev/null +++ b/presto-parser/src/main/java/com/facebook/presto/sql/ReservedIdentifiers.java @@ -0,0 +1,153 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql; + +import com.facebook.presto.sql.parser.ParsingException; +import com.facebook.presto.sql.parser.SqlBaseLexer; +import com.facebook.presto.sql.parser.SqlParser; +import com.facebook.presto.sql.tree.Identifier; +import com.google.common.collect.ImmutableSet; +import org.antlr.v4.runtime.Vocabulary; + +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.util.HashSet; +import java.util.Iterator; +import java.util.List; +import java.util.Set; +import java.util.regex.Matcher; +import java.util.regex.Pattern; + +import static com.google.common.base.Strings.nullToEmpty; +import static com.google.common.collect.ImmutableSet.toImmutableSet; +import static java.lang.String.format; + +public final class ReservedIdentifiers +{ + private static final Pattern IDENTIFIER = Pattern.compile("'([A-Z_]+)'"); + private static final Pattern TABLE_ROW = Pattern.compile("``([A-Z_]+)``.*"); + private static final String TABLE_PREFIX = "============================== "; + + private static final SqlParser PARSER = new SqlParser(); + + private ReservedIdentifiers() {} + + @SuppressWarnings("CallToPrintStackTrace") + public static void main(String[] args) + throws IOException + { + if ((args.length == 2) && args[0].equals("validateDocs")) { + try { + validateDocs(Paths.get(args[1])); + } + catch (Throwable t) { + t.printStackTrace(); + System.exit(100); + } + } + else { + for (String name : reservedIdentifiers()) { + System.out.println(name); + } + } + } + + private static void validateDocs(Path path) + throws IOException + { + System.out.println("Validating " + path); + List lines = Files.readAllLines(path); + + if (lines.stream().filter(s -> s.startsWith(TABLE_PREFIX)).count() != 3) { + throw new RuntimeException("Failed to find exactly one table"); + } + + Iterator iterator = lines.iterator(); + + // find table and skip header + while (!iterator.next().startsWith(TABLE_PREFIX)) { + // skip + } + if (iterator.next().startsWith(TABLE_PREFIX)) { + throw new RuntimeException("Expected to find a header line"); + } + if (!iterator.next().startsWith(TABLE_PREFIX)) { + throw new RuntimeException("Found multiple header lines"); + } + + Set reserved = reservedIdentifiers(); + Set found = new HashSet<>(); + while (true) { + String line = iterator.next(); + if (line.startsWith(TABLE_PREFIX)) { + break; + } + + Matcher matcher = TABLE_ROW.matcher(line); + if (!matcher.matches()) { + throw new RuntimeException("Invalid table line: " + line); + } + String name = matcher.group(1); + + if (!reserved.contains(name)) { + throw new RuntimeException("Documented identifier is not reserved: " + name); + } + if (!found.add(name)) { + throw new RuntimeException("Duplicate documented identifier: " + name); + } + } + + for (String name : reserved) { + if (!found.contains(name)) { + throw new RuntimeException("Reserved identifier is not documented: " + name); + } + } + + System.out.println(format("Validated %s reserved identifiers", reserved.size())); + } + + public static Set reservedIdentifiers() + { + return possibleIdentifiers().stream() + .filter(ReservedIdentifiers::reserved) + .sorted() + .collect(toImmutableSet()); + } + + private static Set possibleIdentifiers() + { + ImmutableSet.Builder names = ImmutableSet.builder(); + Vocabulary vocabulary = SqlBaseLexer.VOCABULARY; + for (int i = 0; i <= vocabulary.getMaxTokenType(); i++) { + String name = nullToEmpty(vocabulary.getLiteralName(i)); + Matcher matcher = IDENTIFIER.matcher(name); + if (matcher.matches()) { + names.add(matcher.group(1)); + } + } + return names.build(); + } + + private static boolean reserved(String name) + { + try { + return !(PARSER.createExpression(name) instanceof Identifier); + } + catch (ParsingException ignored) { + return true; + } + } +} diff --git a/presto-parser/src/main/java/com/facebook/presto/sql/SqlFormatter.java b/presto-parser/src/main/java/com/facebook/presto/sql/SqlFormatter.java index 23475b70857f..e1621e25b0df 100644 --- a/presto-parser/src/main/java/com/facebook/presto/sql/SqlFormatter.java +++ b/presto-parser/src/main/java/com/facebook/presto/sql/SqlFormatter.java @@ -29,6 +29,7 @@ import com.facebook.presto.sql.tree.Delete; import com.facebook.presto.sql.tree.DescribeInput; import com.facebook.presto.sql.tree.DescribeOutput; +import com.facebook.presto.sql.tree.DropColumn; import com.facebook.presto.sql.tree.DropSchema; import com.facebook.presto.sql.tree.DropTable; import com.facebook.presto.sql.tree.DropView; @@ -47,6 +48,7 @@ import com.facebook.presto.sql.tree.JoinCriteria; import com.facebook.presto.sql.tree.JoinOn; import com.facebook.presto.sql.tree.JoinUsing; +import com.facebook.presto.sql.tree.Lateral; import com.facebook.presto.sql.tree.LikeClause; import com.facebook.presto.sql.tree.NaturalJoin; import com.facebook.presto.sql.tree.Node; @@ -155,6 +157,15 @@ protected Void visitUnnest(Unnest node, Integer indent) return null; } + @Override + protected Void visitLateral(Lateral node, Integer indent) + { + append(indent, "LATERAL ("); + process(node.getQuery(), indent + 1); + append(indent, ")"); + return null; + } + @Override protected Void visitPrepare(Prepare node, Integer indent) { @@ -876,6 +887,17 @@ protected Void visitRenameColumn(RenameColumn node, Integer context) return null; } + @Override + protected Void visitDropColumn(DropColumn node, Integer context) + { + builder.append("ALTER TABLE ") + .append(formatName(node.getTable())) + .append(" DROP COLUMN ") + .append(formatName(node.getColumn())); + + return null; + } + @Override protected Void visitAddColumn(AddColumn node, Integer indent) { diff --git a/presto-parser/src/main/java/com/facebook/presto/sql/parser/AstBuilder.java b/presto-parser/src/main/java/com/facebook/presto/sql/parser/AstBuilder.java index f1b5d4bd5c73..3b06656c0f87 100644 --- a/presto-parser/src/main/java/com/facebook/presto/sql/parser/AstBuilder.java +++ b/presto-parser/src/main/java/com/facebook/presto/sql/parser/AstBuilder.java @@ -48,6 +48,7 @@ import com.facebook.presto.sql.tree.DescribeInput; import com.facebook.presto.sql.tree.DescribeOutput; import com.facebook.presto.sql.tree.DoubleLiteral; +import com.facebook.presto.sql.tree.DropColumn; import com.facebook.presto.sql.tree.DropSchema; import com.facebook.presto.sql.tree.DropTable; import com.facebook.presto.sql.tree.DropView; @@ -84,6 +85,7 @@ import com.facebook.presto.sql.tree.JoinUsing; import com.facebook.presto.sql.tree.LambdaArgumentDeclaration; import com.facebook.presto.sql.tree.LambdaExpression; +import com.facebook.presto.sql.tree.Lateral; import com.facebook.presto.sql.tree.LikeClause; import com.facebook.presto.sql.tree.LikePredicate; import com.facebook.presto.sql.tree.LogicalBinaryExpression; @@ -311,6 +313,12 @@ public Node visitAddColumn(SqlBaseParser.AddColumnContext context) return new AddColumn(getLocation(context), getQualifiedName(context.qualifiedName()), (ColumnDefinition) visit(context.columnDefinition())); } + @Override + public Node visitDropColumn(SqlBaseParser.DropColumnContext context) + { + return new DropColumn(getLocation(context), getQualifiedName(context.tableName), context.column.getText()); + } + @Override public Node visitCreateView(SqlBaseParser.CreateViewContext context) { @@ -615,7 +623,7 @@ public Node visitInlineTable(SqlBaseParser.InlineTableContext context) @Override public Node visitExplain(SqlBaseParser.ExplainContext context) { - return new Explain(getLocation(context), context.ANALYZE() != null, (Statement) visit(context.statement()), visit(context.explainOption(), ExplainOption.class)); + return new Explain(getLocation(context), context.ANALYZE() != null, context.VERBOSE() != null, (Statement) visit(context.statement()), visit(context.explainOption(), ExplainOption.class)); } @Override @@ -912,6 +920,12 @@ public Node visitUnnest(SqlBaseParser.UnnestContext context) return new Unnest(getLocation(context), visit(context.expression(), Expression.class), context.ORDINALITY() != null); } + @Override + public Node visitLateral(SqlBaseParser.LateralContext context) + { + return new Lateral(getLocation(context), (Query) visit(context.query())); + } + @Override public Node visitParenthesizedRelation(SqlBaseParser.ParenthesizedRelationContext context) { diff --git a/presto-parser/src/main/java/com/facebook/presto/sql/tree/AstVisitor.java b/presto-parser/src/main/java/com/facebook/presto/sql/tree/AstVisitor.java index a74b619ff728..e833ec553934 100644 --- a/presto-parser/src/main/java/com/facebook/presto/sql/tree/AstVisitor.java +++ b/presto-parser/src/main/java/com/facebook/presto/sql/tree/AstVisitor.java @@ -427,6 +427,11 @@ protected R visitUnnest(Unnest node, C context) return visitRelation(node, context); } + protected R visitLateral(Lateral node, C context) + { + return visitRelation(node, context); + } + protected R visitValues(Values node, C context) { return visitQueryBody(node, context); @@ -552,6 +557,11 @@ protected R visitRenameColumn(RenameColumn node, C context) return visitStatement(node, context); } + protected R visitDropColumn(DropColumn node, C context) + { + return visitStatement(node, context); + } + protected R visitAddColumn(AddColumn node, C context) { return visitStatement(node, context); diff --git a/presto-parser/src/main/java/com/facebook/presto/sql/tree/DefaultTraversalVisitor.java b/presto-parser/src/main/java/com/facebook/presto/sql/tree/DefaultTraversalVisitor.java index 1ae5ce8a57e2..02990e504de6 100644 --- a/presto-parser/src/main/java/com/facebook/presto/sql/tree/DefaultTraversalVisitor.java +++ b/presto-parser/src/main/java/com/facebook/presto/sql/tree/DefaultTraversalVisitor.java @@ -606,4 +606,12 @@ protected R visitExists(ExistsPredicate node, C context) return null; } + + @Override + protected R visitLateral(Lateral node, C context) + { + process(node.getQuery(), context); + + return super.visitLateral(node, context); + } } diff --git a/presto-parser/src/main/java/com/facebook/presto/sql/tree/DropColumn.java b/presto-parser/src/main/java/com/facebook/presto/sql/tree/DropColumn.java new file mode 100644 index 000000000000..80c6f8017fa8 --- /dev/null +++ b/presto-parser/src/main/java/com/facebook/presto/sql/tree/DropColumn.java @@ -0,0 +1,98 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.tree; + +import com.google.common.collect.ImmutableList; + +import java.util.List; +import java.util.Objects; +import java.util.Optional; + +import static com.google.common.base.MoreObjects.toStringHelper; +import static java.util.Objects.requireNonNull; + +public class DropColumn + extends Statement +{ + private final QualifiedName table; + private final String column; + + public DropColumn(QualifiedName table, String column) + { + this(Optional.empty(), table, column); + } + + public DropColumn(NodeLocation location, QualifiedName table, String column) + { + this(Optional.of(location), table, column); + } + + private DropColumn(Optional location, QualifiedName table, String column) + { + super(location); + this.table = requireNonNull(table, "table is null"); + this.column = requireNonNull(column, "column is null"); + } + + public QualifiedName getTable() + { + return table; + } + + public String getColumn() + { + return column; + } + + @Override + public R accept(AstVisitor visitor, C context) + { + return visitor.visitDropColumn(this, context); + } + + @Override + public List getChildren() + { + return ImmutableList.of(); + } + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + DropColumn that = (DropColumn) o; + return Objects.equals(table, that.table) && + Objects.equals(column, that.column); + } + + @Override + public int hashCode() + { + return Objects.hash(table, column); + } + + @Override + public String toString() + { + return toStringHelper(this) + .add("table", table) + .add("column", column) + .toString(); + } +} diff --git a/presto-parser/src/main/java/com/facebook/presto/sql/tree/Explain.java b/presto-parser/src/main/java/com/facebook/presto/sql/tree/Explain.java index 520351d2b912..ef7274d1e0f0 100644 --- a/presto-parser/src/main/java/com/facebook/presto/sql/tree/Explain.java +++ b/presto-parser/src/main/java/com/facebook/presto/sql/tree/Explain.java @@ -27,23 +27,25 @@ public class Explain { private final Statement statement; private final boolean analyze; + private final boolean verbose; private final List options; - public Explain(Statement statement, boolean analyze, List options) + public Explain(Statement statement, boolean analyze, boolean verbose, List options) { - this(Optional.empty(), analyze, statement, options); + this(Optional.empty(), analyze, verbose, statement, options); } - public Explain(NodeLocation location, boolean analyze, Statement statement, List options) + public Explain(NodeLocation location, boolean analyze, boolean verbose, Statement statement, List options) { - this(Optional.of(location), analyze, statement, options); + this(Optional.of(location), analyze, verbose, statement, options); } - private Explain(Optional location, boolean analyze, Statement statement, List options) + private Explain(Optional location, boolean analyze, boolean verbose, Statement statement, List options) { super(location); this.statement = requireNonNull(statement, "statement is null"); this.analyze = analyze; + this.verbose = verbose; if (options == null) { this.options = ImmutableList.of(); } @@ -62,6 +64,11 @@ public boolean isAnalyze() return analyze; } + public boolean isVerbose() + { + return verbose; + } + public List getOptions() { return options; diff --git a/presto-parser/src/main/java/com/facebook/presto/sql/tree/Lateral.java b/presto-parser/src/main/java/com/facebook/presto/sql/tree/Lateral.java new file mode 100644 index 000000000000..289204083997 --- /dev/null +++ b/presto-parser/src/main/java/com/facebook/presto/sql/tree/Lateral.java @@ -0,0 +1,87 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.facebook.presto.sql.tree; + +import com.google.common.collect.ImmutableList; + +import java.util.List; +import java.util.Objects; +import java.util.Optional; + +import static java.util.Objects.requireNonNull; + +public final class Lateral + extends Relation +{ + private final Query query; + + public Lateral(Query query) + { + this(Optional.empty(), query); + } + + public Lateral(NodeLocation location, Query query) + { + this(Optional.of(location), query); + } + + private Lateral(Optional location, Query query) + { + super(location); + this.query = requireNonNull(query, "query is null"); + } + + public Query getQuery() + { + return query; + } + + @Override + public R accept(AstVisitor visitor, C context) + { + return visitor.visitLateral(this, context); + } + + @Override + public List getChildren() + { + return ImmutableList.of(query); + } + + @Override + public String toString() + { + return "LATERAL(" + query + ")"; + } + + @Override + public int hashCode() + { + return Objects.hash(query); + } + + @Override + public boolean equals(Object obj) + { + if (this == obj) { + return true; + } + if (obj == null || getClass() != obj.getClass()) { + return false; + } + Lateral other = (Lateral) obj; + return Objects.equals(this.query, other.query); + } +} diff --git a/presto-parser/src/main/java/com/facebook/presto/sql/util/AstUtils.java b/presto-parser/src/main/java/com/facebook/presto/sql/util/AstUtils.java index 510febf77b8b..b5b63cffa166 100644 --- a/presto-parser/src/main/java/com/facebook/presto/sql/util/AstUtils.java +++ b/presto-parser/src/main/java/com/facebook/presto/sql/util/AstUtils.java @@ -13,67 +13,30 @@ */ package com.facebook.presto.sql.util; -import com.facebook.presto.sql.tree.DefaultTraversalVisitor; import com.facebook.presto.sql.tree.Node; +import com.google.common.collect.TreeTraverser; -import java.util.ArrayDeque; -import java.util.Deque; -import java.util.Iterator; -import java.util.Spliterators; -import java.util.concurrent.atomic.AtomicBoolean; import java.util.stream.Stream; -import java.util.stream.StreamSupport; + +import static com.google.common.collect.Iterables.unmodifiableIterable; +import static java.util.Objects.requireNonNull; public class AstUtils { public static boolean nodeContains(Node node, Node subNode) { - return new DefaultTraversalVisitor() - { - @Override - public Boolean process(Node node, AtomicBoolean findResultHolder) - { - if (!findResultHolder.get()) { - if (node == subNode) { - findResultHolder.set(true); - } - else { - super.process(node, findResultHolder); - } - } - return findResultHolder.get(); - } - }.process(node, new AtomicBoolean(false)); - } + requireNonNull(node, "node is null"); + requireNonNull(subNode, "subNode is null"); - public static Stream preOrder(Node node) - { - return StreamSupport.stream(Spliterators.spliteratorUnknownSize(new PreOrderIterator(node), 0), false); + return preOrder(node) + .anyMatch(childNode -> childNode == subNode); } - private static final class PreOrderIterator - implements Iterator + public static Stream preOrder(Node node) { - private final Deque remaining = new ArrayDeque<>(); - - public PreOrderIterator(Node node) - { - remaining.push(node); - } - - @Override - public boolean hasNext() - { - return remaining.size() > 0; - } - - @Override - public Node next() - { - Node node = remaining.pop(); - node.getChildren().forEach(remaining::push); - return node; - } + return TreeTraverser.using((Node n) -> unmodifiableIterable(n.getChildren())) + .preOrderTraversal(requireNonNull(node, "node is null")) + .stream(); } private AstUtils() {} diff --git a/presto-parser/src/test/java/com/facebook/presto/sql/parser/TestSqlParser.java b/presto-parser/src/test/java/com/facebook/presto/sql/parser/TestSqlParser.java index 77a69e3956c9..a9a1596f2c98 100644 --- a/presto-parser/src/test/java/com/facebook/presto/sql/parser/TestSqlParser.java +++ b/presto-parser/src/test/java/com/facebook/presto/sql/parser/TestSqlParser.java @@ -14,6 +14,7 @@ package com.facebook.presto.sql.parser; import com.facebook.presto.sql.tree.AddColumn; +import com.facebook.presto.sql.tree.AliasedRelation; import com.facebook.presto.sql.tree.AllColumns; import com.facebook.presto.sql.tree.ArithmeticBinaryExpression; import com.facebook.presto.sql.tree.ArrayConstructor; @@ -42,6 +43,7 @@ import com.facebook.presto.sql.tree.DescribeInput; import com.facebook.presto.sql.tree.DescribeOutput; import com.facebook.presto.sql.tree.DoubleLiteral; +import com.facebook.presto.sql.tree.DropColumn; import com.facebook.presto.sql.tree.DropSchema; import com.facebook.presto.sql.tree.DropTable; import com.facebook.presto.sql.tree.DropView; @@ -68,6 +70,7 @@ import com.facebook.presto.sql.tree.JoinOn; import com.facebook.presto.sql.tree.LambdaArgumentDeclaration; import com.facebook.presto.sql.tree.LambdaExpression; +import com.facebook.presto.sql.tree.Lateral; import com.facebook.presto.sql.tree.LikeClause; import com.facebook.presto.sql.tree.LogicalBinaryExpression; import com.facebook.presto.sql.tree.LongLiteral; @@ -1454,6 +1457,14 @@ public void testAddColumn() assertStatement("ALTER TABLE foo.t ADD COLUMN c bigint", new AddColumn(QualifiedName.of("foo", "t"), new ColumnDefinition("c", "bigint", Optional.empty()))); } + @Test + public void testDropColumn() + throws Exception + { + assertStatement("ALTER TABLE foo.t DROP COLUMN c", new DropColumn(QualifiedName.of("foo", "t"), "c")); + assertStatement("ALTER TABLE \"t x\" DROP COLUMN \"c d\"", new DropColumn(QualifiedName.of("t x"), "c d")); + } + @Test public void testCreateView() throws Exception @@ -1545,27 +1556,69 @@ public void testExplain() throws Exception { assertStatement("EXPLAIN SELECT * FROM t", - new Explain(simpleQuery(selectList(new AllColumns()), table(QualifiedName.of("t"))), false, ImmutableList.of())); + new Explain(simpleQuery(selectList(new AllColumns()), table(QualifiedName.of("t"))), false, false, ImmutableList.of())); assertStatement("EXPLAIN (TYPE LOGICAL) SELECT * FROM t", new Explain( simpleQuery(selectList(new AllColumns()), table(QualifiedName.of("t"))), false, + false, ImmutableList.of(new ExplainType(ExplainType.Type.LOGICAL)))); assertStatement("EXPLAIN (TYPE LOGICAL, FORMAT TEXT) SELECT * FROM t", new Explain( simpleQuery(selectList(new AllColumns()), table(QualifiedName.of("t"))), false, + false, ImmutableList.of( new ExplainType(ExplainType.Type.LOGICAL), new ExplainFormat(ExplainFormat.Type.TEXT)))); } + @Test + public void testExplainVerbose() + throws Exception + { + assertStatement("EXPLAIN VERBOSE SELECT * FROM t", + new Explain(simpleQuery(selectList(new AllColumns()), table(QualifiedName.of("t"))), false, true, ImmutableList.of())); + } + + @Test + public void testExplainVerboseTypeLogical() + throws Exception + { + assertStatement("EXPLAIN VERBOSE (type LOGICAL) SELECT * FROM t", + new Explain(simpleQuery(selectList(new AllColumns()), table(QualifiedName.of("t"))), false, true, ImmutableList.of(new ExplainType(ExplainType.Type.LOGICAL)))); + } + @Test public void testExplainAnalyze() throws Exception { assertStatement("EXPLAIN ANALYZE SELECT * FROM t", - new Explain(simpleQuery(selectList(new AllColumns()), table(QualifiedName.of("t"))), true, ImmutableList.of())); + new Explain(simpleQuery(selectList(new AllColumns()), table(QualifiedName.of("t"))), true, false, ImmutableList.of())); + } + + @Test + public void testExplainAnalyzeTypeDistributed() + throws Exception + { + assertStatement("EXPLAIN ANALYZE (type DISTRIBUTED) SELECT * FROM t", + new Explain(simpleQuery(selectList(new AllColumns()), table(QualifiedName.of("t"))), true, false, ImmutableList.of(new ExplainType(ExplainType.Type.DISTRIBUTED)))); + } + + @Test + public void testExplainAnalyzeVerbose() + throws Exception + { + assertStatement("EXPLAIN ANALYZE VERBOSE SELECT * FROM t", + new Explain(simpleQuery(selectList(new AllColumns()), table(QualifiedName.of("t"))), true, true, ImmutableList.of())); + } + + @Test + public void testExplainAnalyzeVerboseTypeDistributed() + throws Exception + { + assertStatement("EXPLAIN ANALYZE VERBOSE (type DISTRIBUTED) SELECT * FROM t", + new Explain(simpleQuery(selectList(new AllColumns()), table(QualifiedName.of("t"))), true, true, ImmutableList.of(new ExplainType(ExplainType.Type.DISTRIBUTED)))); } @Test @@ -1628,6 +1681,52 @@ public void testUnnest() new Table(QualifiedName.of("t")), new Unnest(ImmutableList.of(new Identifier("a")), true), Optional.empty()))); + assertStatement("SELECT * FROM t FULL JOIN UNNEST(a) ON true", + simpleQuery( + selectList(new AllColumns()), + new Join( + Join.Type.FULL, + new Table(QualifiedName.of("t")), + new Unnest(ImmutableList.of(new Identifier("a")), true), + Optional.of(new JoinOn(BooleanLiteral.TRUE_LITERAL))))); + } + + @Test + public void testLateral() + throws Exception + { + Lateral lateralRelation = new Lateral(new Query( + Optional.empty(), + new Values(ImmutableList.of(new LongLiteral("1"))), + Optional.empty(), + Optional.empty())); + + assertStatement("SELECT * FROM t, LATERAL (VALUES 1) a(x)", + simpleQuery( + selectList(new AllColumns()), + new Join( + Join.Type.IMPLICIT, + new Table(QualifiedName.of("t")), + new AliasedRelation(lateralRelation, "a", ImmutableList.of("x")), + Optional.empty()))); + + assertStatement("SELECT * FROM t CROSS JOIN LATERAL (VALUES 1) ", + simpleQuery( + selectList(new AllColumns()), + new Join( + Join.Type.CROSS, + new Table(QualifiedName.of("t")), + lateralRelation, + Optional.empty()))); + + assertStatement("SELECT * FROM t FULL JOIN LATERAL (VALUES 1) ON true", + simpleQuery( + selectList(new AllColumns()), + new Join( + Join.Type.FULL, + new Table(QualifiedName.of("t")), + lateralRelation, + Optional.of(new JoinOn(BooleanLiteral.TRUE_LITERAL))))); } @Test diff --git a/presto-parser/src/test/java/com/facebook/presto/sql/parser/TestStatementBuilder.java b/presto-parser/src/test/java/com/facebook/presto/sql/parser/TestStatementBuilder.java index c2e941d9d06c..1161a7ef5da7 100644 --- a/presto-parser/src/test/java/com/facebook/presto/sql/parser/TestStatementBuilder.java +++ b/presto-parser/src/test/java/com/facebook/presto/sql/parser/TestStatementBuilder.java @@ -186,6 +186,8 @@ public void testStatementBuilder() printStatement("alter table a.b.c add column x bigint"); + printStatement("alter table a.b.c drop column x"); + printStatement("create schema test"); printStatement("create schema if not exists test"); printStatement("create schema test with (a = 'apple', b = 123)"); diff --git a/presto-plugin-toolkit/pom.xml b/presto-plugin-toolkit/pom.xml index 1a5feaed88e3..5ad3f5f8b728 100644 --- a/presto-plugin-toolkit/pom.xml +++ b/presto-plugin-toolkit/pom.xml @@ -5,7 +5,7 @@ com.facebook.presto presto-root - 0.179-tw-0.36 + 0.181-tw-0.37 presto-plugin-toolkit diff --git a/presto-plugin-toolkit/src/main/java/com/facebook/presto/plugin/base/security/AllowAllAccessControl.java b/presto-plugin-toolkit/src/main/java/com/facebook/presto/plugin/base/security/AllowAllAccessControl.java index d62156444afd..5b6d3a4ca0a7 100644 --- a/presto-plugin-toolkit/src/main/java/com/facebook/presto/plugin/base/security/AllowAllAccessControl.java +++ b/presto-plugin-toolkit/src/main/java/com/facebook/presto/plugin/base/security/AllowAllAccessControl.java @@ -66,6 +66,11 @@ public void checkCanAddColumn(ConnectorTransactionHandle transaction, Identity i { } + @Override + public void checkCanDropColumn(ConnectorTransactionHandle transactionHandle, Identity identity, SchemaTableName tableName) + { + } + @Override public void checkCanRenameColumn(ConnectorTransactionHandle transaction, Identity identity, SchemaTableName tableName) { diff --git a/presto-plugin-toolkit/src/main/java/com/facebook/presto/plugin/base/security/FileBasedAccessControl.java b/presto-plugin-toolkit/src/main/java/com/facebook/presto/plugin/base/security/FileBasedAccessControl.java index 18dcf8159543..856acd0ea9a6 100644 --- a/presto-plugin-toolkit/src/main/java/com/facebook/presto/plugin/base/security/FileBasedAccessControl.java +++ b/presto-plugin-toolkit/src/main/java/com/facebook/presto/plugin/base/security/FileBasedAccessControl.java @@ -41,6 +41,7 @@ import static com.facebook.presto.spi.security.AccessDeniedException.denyCreateTable; import static com.facebook.presto.spi.security.AccessDeniedException.denyCreateView; import static com.facebook.presto.spi.security.AccessDeniedException.denyDeleteTable; +import static com.facebook.presto.spi.security.AccessDeniedException.denyDropColumn; import static com.facebook.presto.spi.security.AccessDeniedException.denyDropTable; import static com.facebook.presto.spi.security.AccessDeniedException.denyDropView; import static com.facebook.presto.spi.security.AccessDeniedException.denyGrantTablePrivilege; @@ -125,6 +126,14 @@ public void checkCanAddColumn(ConnectorTransactionHandle transaction, Identity i } } + @Override + public void checkCanDropColumn(ConnectorTransactionHandle transactionHandle, Identity identity, SchemaTableName tableName) + { + if (!checkTablePermission(identity, tableName, OWNERSHIP)) { + denyDropColumn(tableName.toString()); + } + } + @Override public void checkCanRenameColumn(ConnectorTransactionHandle transaction, Identity identity, SchemaTableName tableName) { diff --git a/presto-plugin-toolkit/src/main/java/com/facebook/presto/plugin/base/security/ReadOnlyAccessControl.java b/presto-plugin-toolkit/src/main/java/com/facebook/presto/plugin/base/security/ReadOnlyAccessControl.java index da7791eeeed2..d80590eb3f40 100644 --- a/presto-plugin-toolkit/src/main/java/com/facebook/presto/plugin/base/security/ReadOnlyAccessControl.java +++ b/presto-plugin-toolkit/src/main/java/com/facebook/presto/plugin/base/security/ReadOnlyAccessControl.java @@ -25,6 +25,7 @@ import static com.facebook.presto.spi.security.AccessDeniedException.denyCreateTable; import static com.facebook.presto.spi.security.AccessDeniedException.denyCreateView; import static com.facebook.presto.spi.security.AccessDeniedException.denyDeleteTable; +import static com.facebook.presto.spi.security.AccessDeniedException.denyDropColumn; import static com.facebook.presto.spi.security.AccessDeniedException.denyDropTable; import static com.facebook.presto.spi.security.AccessDeniedException.denyDropView; import static com.facebook.presto.spi.security.AccessDeniedException.denyGrantTablePrivilege; @@ -53,6 +54,12 @@ public void checkCanAddColumn(ConnectorTransactionHandle transaction, Identity i denyAddColumn(tableName.toString()); } + @Override + public void checkCanDropColumn(ConnectorTransactionHandle transactionHandle, Identity identity, SchemaTableName tableName) + { + denyDropColumn(tableName.toString()); + } + @Override public void checkCanCreateTable(ConnectorTransactionHandle transaction, Identity identity, SchemaTableName tableName) { diff --git a/presto-postgresql/pom.xml b/presto-postgresql/pom.xml index 794ead1bf3c7..a9ee0664bec0 100644 --- a/presto-postgresql/pom.xml +++ b/presto-postgresql/pom.xml @@ -5,7 +5,7 @@ com.facebook.presto presto-root - 0.179-tw-0.36 + 0.181-tw-0.37 presto-postgresql diff --git a/presto-postgresql/src/main/java/com/facebook/presto/plugin/postgresql/PostgreSqlClient.java b/presto-postgresql/src/main/java/com/facebook/presto/plugin/postgresql/PostgreSqlClient.java index 8cd63055cbd3..4d8f6896dfec 100644 --- a/presto-postgresql/src/main/java/com/facebook/presto/plugin/postgresql/PostgreSqlClient.java +++ b/presto-postgresql/src/main/java/com/facebook/presto/plugin/postgresql/PostgreSqlClient.java @@ -76,6 +76,6 @@ protected ResultSet getTables(Connection connection, String schemaName, String t connection.getCatalog(), escapeNamePattern(schemaName, escape), escapeNamePattern(tableName, escape), - new String[] {"TABLE", "VIEW", "MATERIALIZED VIEW"}); + new String[] {"TABLE", "VIEW", "MATERIALIZED VIEW", "FOREIGN TABLE"}); } } diff --git a/presto-postgresql/src/test/java/com/facebook/presto/plugin/postgresql/TestPostgreSqlIntegrationSmokeTest.java b/presto-postgresql/src/test/java/com/facebook/presto/plugin/postgresql/TestPostgreSqlIntegrationSmokeTest.java index b01848af7c91..ea17b3244629 100644 --- a/presto-postgresql/src/test/java/com/facebook/presto/plugin/postgresql/TestPostgreSqlIntegrationSmokeTest.java +++ b/presto-postgresql/src/test/java/com/facebook/presto/plugin/postgresql/TestPostgreSqlIntegrationSmokeTest.java @@ -44,6 +44,7 @@ public TestPostgreSqlIntegrationSmokeTest(TestingPostgreSqlServer postgreSqlServ { super(() -> PostgreSqlQueryRunner.createPostgreSqlQueryRunner(postgreSqlServer, ORDERS)); this.postgreSqlServer = postgreSqlServer; + execute("CREATE EXTENSION file_fdw"); } @AfterClass(alwaysRun = true) @@ -73,6 +74,18 @@ public void testMaterializedView() execute("DROP MATERIALIZED VIEW tpch.test_mv"); } + @Test + public void testForeignTable() + throws Exception + { + execute("CREATE SERVER devnull FOREIGN DATA WRAPPER file_fdw"); + execute("CREATE FOREIGN TABLE tpch.test_ft (x bigint) SERVER devnull OPTIONS (filename '/dev/null')"); + assertTrue(getQueryRunner().tableExists(getSession(), "test_ft")); + computeActual("SELECT * FROM test_ft"); + execute("DROP FOREIGN TABLE tpch.test_ft"); + execute("DROP SERVER devnull"); + } + private void execute(String sql) throws SQLException { diff --git a/presto-product-tests/README.md b/presto-product-tests/README.md index 1fc4a832306d..32fbc4cb96b6 100644 --- a/presto-product-tests/README.md +++ b/presto-product-tests/README.md @@ -139,16 +139,17 @@ groups run the following command: presto-product-tests/bin/run_on_docker.sh -x quarantine,big_query,profile_specific_tests ``` -where [profile](#profile) is one of either: +where profile is one of either: +#### Profiles - **multinode** - pseudo-distributed Hadoop installation running on a single Docker container and a distributed Presto installation running on multiple Docker containers. For multinode the default configuration is 1 coordinator and 1 worker. -- **[singlenode](#singlenode)** - pseudo-distributed Hadoop installation running on a +- **singlenode** - pseudo-distributed Hadoop installation running on a single Docker container and a single node installation of Presto also running on a single Docker container. - **singlenode-hdfs-impersonation** - HDFS impersonation enabled on top of the - environment in [singlenode](#singlenode) profile. Presto impersonates the user + environment in singlenode profile. Presto impersonates the user who is running the query when accessing HDFS. - **singlenode-kerberos-hdfs-impersonation** - pseudo-distributed kerberized Hadoop installation running on a single Docker container and a single node @@ -221,6 +222,7 @@ groups. | HDFS impersonation | ``hdfs_impersonation`` | ``singlenode-hdfs-impersonation``, ``singlenode-kerberos-hdfs-impersonation`` | | No HDFS impersonation | ``hdfs_no_impersonation`` | ``singlenode``, ``singlenode-kerberos-hdfs-no_impersonation`` | | LDAP | ``ldap`` | ``singlenode-ldap`` | +| SQL Server | ``sqlserver`` | ``singlenode-sqlserver`` | Below is a list of commands that explain how to run these profile specific tests and also the entire test suite: @@ -247,6 +249,12 @@ and also the entire test suite: ``` presto-product-tests/bin/run_on_docker.sh singlenode-ldap -g ldap ``` +* Run **SQL Server** tests: + + ``` + presto-product-tests/bin/run_on_docker.sh singlenode-sqlserver -g sqlserver + ``` + * Run the **entire test suite** excluding all profile specific tests, where <profile> can be any one of the available profiles: @@ -425,7 +433,7 @@ running the debugger. Use the `docker-compose` (probably using a [wrapper](#use-the-docker-compose-wrappers)) and `docker` utilities to control and troubleshoot containers. -In the following examples ```` is [profile](#profile). +In the following examples ```` is [profiles](#profiles). 1. Use the following command to view output from running containers: diff --git a/presto-product-tests/bin/run_on_docker.sh b/presto-product-tests/bin/run_on_docker.sh index 2035651a5256..fcfbfc83d2c9 100755 --- a/presto-product-tests/bin/run_on_docker.sh +++ b/presto-product-tests/bin/run_on_docker.sh @@ -45,7 +45,7 @@ function run_in_application_runner_container() { function check_presto() { run_in_application_runner_container \ java -jar "/docker/volumes/presto-cli/presto-cli-executable.jar" \ - --server presto-master:8080 \ + ${CLI_ARGUMENTS} \ --execute "SHOW CATALOGS" | grep -i hive } @@ -180,6 +180,13 @@ shift 1 PRESTO_SERVICES="presto-master" if [[ "$ENVIRONMENT" == "multinode" ]]; then PRESTO_SERVICES="${PRESTO_SERVICES} presto-worker" +elif [[ "$ENVIRONMENT" == "multinode-tls" ]]; then + PRESTO_SERVICES="${PRESTO_SERVICES} presto-worker-1 presto-worker-2" +fi + +CLI_ARGUMENTS="--server presto-master:8080" +if [[ "$ENVIRONMENT" == "multinode-tls" ]]; then + CLI_ARGUMENTS="--server https://presto-master.docker.cluster:7778 --keystore-path /docker/volumes/conf/presto/etc/docker.cluster.jks --keystore-password 123456" fi # check docker and docker compose installation diff --git a/presto-product-tests/conf/docker/common/compose-commons.sh b/presto-product-tests/conf/docker/common/compose-commons.sh index 9e8db105c9e0..1e4eae2b129d 100644 --- a/presto-product-tests/conf/docker/common/compose-commons.sh +++ b/presto-product-tests/conf/docker/common/compose-commons.sh @@ -23,7 +23,7 @@ function export_canonical_path() { source ${BASH_SOURCE%/*}/../../../bin/locations.sh -export DOCKER_IMAGES_VERSION=${DOCKER_IMAGES_VERSION:-16} +export DOCKER_IMAGES_VERSION=${DOCKER_IMAGES_VERSION:-20} export HADOOP_MASTER_IMAGE=${HADOOP_MASTER_IMAGE:-"teradatalabs/hdp2.5-hive:${DOCKER_IMAGES_VERSION}"} # The following variables are defined to enable running product tests with arbitrary/downloaded jars diff --git a/presto-product-tests/conf/docker/files/presto-launcher-wrapper.sh b/presto-product-tests/conf/docker/files/presto-launcher-wrapper.sh index 3538338dba6b..9134bb678fe3 100755 --- a/presto-product-tests/conf/docker/files/presto-launcher-wrapper.sh +++ b/presto-product-tests/conf/docker/files/presto-launcher-wrapper.sh @@ -4,13 +4,14 @@ set -e CONFIG="$1" -if [[ "$CONFIG" != "singlenode" && "$CONFIG" != "multinode-master" && "$CONFIG" != "multinode-worker" && "$CONFIG" != "singlenode-kerberized" && "$CONFIG" != "singlenode-ldap" ]]; then - echo "Usage: launcher-wrapper " +PRESTO_CONFIG_DIRECTORY="/docker/volumes/conf/presto/etc" +CONFIG_PROPERTIES_LOCATION="${PRESTO_CONFIG_DIRECTORY}/${CONFIG}.properties" + +if [[ ! -e ${CONFIG_PROPERTIES_LOCATION} ]]; then + echo "${CONFIG_PROPERTIES_LOCATION} does not exist" exit 1 fi -PRESTO_CONFIG_DIRECTORY="/docker/volumes/conf/presto/etc" - shift 1 /docker/volumes/presto-server/bin/launcher \ diff --git a/presto-product-tests/conf/docker/multinode-tls/compose.sh b/presto-product-tests/conf/docker/multinode-tls/compose.sh new file mode 100755 index 000000000000..512e1c57d75a --- /dev/null +++ b/presto-product-tests/conf/docker/multinode-tls/compose.sh @@ -0,0 +1,12 @@ +#!/usr/bin/env bash + +SCRIPT_DIRECTORY=${BASH_SOURCE%/*} + +source ${SCRIPT_DIRECTORY}/../common/compose-commons.sh + +docker-compose \ +-f ${SCRIPT_DIRECTORY}/../common/standard.yml \ +-f ${SCRIPT_DIRECTORY}/../common/jdbc_db.yml \ +-f ${BASH_SOURCE%/*}/../common/cassandra.yml \ +-f ${SCRIPT_DIRECTORY}/docker-compose.yml \ +"$@" diff --git a/presto-product-tests/conf/docker/multinode-tls/docker-compose.yml b/presto-product-tests/conf/docker/multinode-tls/docker-compose.yml new file mode 100644 index 000000000000..21c2ceb45731 --- /dev/null +++ b/presto-product-tests/conf/docker/multinode-tls/docker-compose.yml @@ -0,0 +1,49 @@ +version: '2' +services: + + presto-master: + domainname: docker.cluster + hostname: presto-master + command: /docker/volumes/conf/docker/files/presto-launcher-wrapper.sh multinode-tls-master run + ports: + - '7778:7778' + networks: + default: + aliases: + - presto-master.docker.cluster + + presto-worker-1: + domainname: docker.cluster + hostname: presto-worker-1 + extends: + file: ../common/standard.yml + service: java-8-base + command: /docker/volumes/conf/docker/files/presto-launcher-wrapper.sh multinode-tls-worker run + networks: + default: + aliases: + - presto-worker-1.docker.cluster + depends_on: + - presto-master + volumes_from: + - presto-master + + presto-worker-2: + domainname: docker.cluster + hostname: presto-worker-2 + extends: + file: ../common/standard.yml + service: java-8-base + command: /docker/volumes/conf/docker/files/presto-launcher-wrapper.sh multinode-tls-worker run + networks: + default: + aliases: + - presto-worker-2.docker.cluster + depends_on: + - presto-master + volumes_from: + - presto-master + + application-runner: + volumes: + - ../../../conf/tempto/tempto-configuration-for-docker-tls.yaml:/docker/volumes/tempto/tempto-configuration-local.yaml diff --git a/presto-product-tests/conf/presto/etc/catalog/hive.properties b/presto-product-tests/conf/presto/etc/catalog/hive.properties index e633c3888a81..926185f10511 100644 --- a/presto-product-tests/conf/presto/etc/catalog/hive.properties +++ b/presto-product-tests/conf/presto/etc/catalog/hive.properties @@ -9,6 +9,7 @@ connector.name=hive-hadoop2 hive.metastore.uri=thrift://hadoop-master:9083 hive.metastore.thrift.client.socks-proxy=hadoop-master:1180 hive.allow-add-column=true +hive.allow-drop-column=true hive.allow-rename-column=true hive.allow-drop-table=true hive.allow-rename-table=true diff --git a/presto-product-tests/conf/presto/etc/environment-specific-catalogs/singlenode-hdfs-impersonation/hive.properties b/presto-product-tests/conf/presto/etc/environment-specific-catalogs/singlenode-hdfs-impersonation/hive.properties index bcbe4108ffb0..6dba6fd3a847 100644 --- a/presto-product-tests/conf/presto/etc/environment-specific-catalogs/singlenode-hdfs-impersonation/hive.properties +++ b/presto-product-tests/conf/presto/etc/environment-specific-catalogs/singlenode-hdfs-impersonation/hive.properties @@ -10,6 +10,7 @@ hive.metastore.uri=thrift://hadoop-master:9083 hive.metastore.thrift.client.socks-proxy=hadoop-master:1180 hive.allow-drop-table=true hive.allow-add-column=true +hive.allow-drop-column=true hive.allow-rename-table=true hive.allow-rename-column=true hive.metastore-cache-ttl=0s diff --git a/presto-product-tests/conf/presto/etc/environment-specific-catalogs/singlenode-kerberos-hdfs-no-impersonation/hive.properties b/presto-product-tests/conf/presto/etc/environment-specific-catalogs/singlenode-kerberos-hdfs-no-impersonation/hive.properties index bb710526599d..f060dfcd5f7f 100644 --- a/presto-product-tests/conf/presto/etc/environment-specific-catalogs/singlenode-kerberos-hdfs-no-impersonation/hive.properties +++ b/presto-product-tests/conf/presto/etc/environment-specific-catalogs/singlenode-kerberos-hdfs-no-impersonation/hive.properties @@ -12,6 +12,7 @@ hive.allow-drop-table=true hive.allow-rename-table=true hive.metastore-cache-ttl=0s hive.allow-add-column=true +hive.allow-drop-column=true hive.allow-rename-column=true hive.metastore.authentication.type=KERBEROS diff --git a/presto-product-tests/conf/presto/etc/multinode-tls-master.properties b/presto-product-tests/conf/presto/etc/multinode-tls-master.properties new file mode 100644 index 000000000000..47a45a7e9ad3 --- /dev/null +++ b/presto-product-tests/conf/presto/etc/multinode-tls-master.properties @@ -0,0 +1,29 @@ +# +# WARNING +# ^^^^^^^ +# This configuration file is for development only and should NOT be used be +# used in production. For example configuration, see the Presto documentation. +# + +node.id=will-be-overwritten +node.environment=test +node.internal-address-source=FQDN + +coordinator=true +node-scheduler.include-coordinator=true +discovery-server.enabled=true +discovery.uri=https://presto-master.docker.cluster:7778 + +query.max-memory=1GB +query.max-memory-per-node=512MB + +http-server.http.enabled=false +http-server.http.port=8080 +http-server.https.enabled=true +http-server.https.port=7778 +http-server.https.keystore.path=/docker/volumes/conf/presto/etc/docker.cluster.jks +http-server.https.keystore.key=123456 + +internal-communication.https.required=true +internal-communication.https.keystore.path=/docker/volumes/conf/presto/etc/docker.cluster.jks +internal-communication.https.keystore.key=123456 diff --git a/presto-product-tests/conf/presto/etc/multinode-tls-worker.properties b/presto-product-tests/conf/presto/etc/multinode-tls-worker.properties new file mode 100644 index 000000000000..dbe128f62bd9 --- /dev/null +++ b/presto-product-tests/conf/presto/etc/multinode-tls-worker.properties @@ -0,0 +1,28 @@ +# +# WARNING +# ^^^^^^^ +# This configuration file is for development only and should NOT be used be +# used in production. For example configuration, see the Presto documentation. +# + +node.id=will-be-overwritten +node.environment=test +node.internal-address-source=FQDN + +coordinator=false +discovery-server.enabled=false +discovery.uri=https://presto-master.docker.cluster:7778 + +query.max-memory=1GB +query.max-memory-per-node=512MB + +http-server.http.enabled=false +http-server.http.port=8080 +http-server.https.enabled=true +http-server.https.port=7778 +http-server.https.keystore.path=/docker/volumes/conf/presto/etc/docker.cluster.jks +http-server.https.keystore.key=123456 + +internal-communication.https.required=true +internal-communication.https.keystore.path=/docker/volumes/conf/presto/etc/docker.cluster.jks +internal-communication.https.keystore.key=123456 diff --git a/presto-product-tests/conf/tempto/tempto-configuration-for-docker-kerberos.yaml b/presto-product-tests/conf/tempto/tempto-configuration-for-docker-kerberos.yaml index b5a8d0c825a2..8abcf5cb2c77 100644 --- a/presto-product-tests/conf/tempto/tempto-configuration-for-docker-kerberos.yaml +++ b/presto-product-tests/conf/tempto/tempto-configuration-for-docker-kerberos.yaml @@ -22,8 +22,6 @@ databases: host: presto-master.docker.cluster port: 7778 server_address: https://${databases.presto.host}:${databases.presto.port} - # Use the HTTP interface in JDBC, as Kerberos authentication is not yet supported in there. - jdbc_url: jdbc:presto://${databases.presto.host}:8080/hive/${databases.hive.schema} # jdbc_user in here should satisfy two requirements in order to pass SQL standard access control checks in Presto: # 1) It should belong to the "admin" role in hive @@ -39,3 +37,13 @@ databases: cli_kerberos_service_name: presto-server cli_kerberos_use_canonical_hostname: false configured_hdfs_user: hdfs + + jdbc_url: "jdbc:presto://${databases.presto.host}:${databases.presto.port}/hive/${databases.hive.schema}?\ + SSL=true&\ + SSLTrustStorePath=${databases.presto.https_keystore_path}&\ + SSLTrustStorePassword=${databases.presto.https_keystore_password}&\ + KerberosRemoteServiceName=${databases.presto.cli_kerberos_service_name}&\ + KerberosPrincipal=${databases.presto.cli_kerberos_principal}&\ + KerberosUseCanonicalHostname=${databases.presto.cli_kerberos_use_canonical_hostname}&\ + KerberosConfigPath=${databases.presto.cli_kerberos_config_path}&\ + KerberosKeytabPath=${databases.presto.cli_kerberos_keytab}" diff --git a/presto-product-tests/conf/tempto/tempto-configuration-for-docker-tls.yaml b/presto-product-tests/conf/tempto/tempto-configuration-for-docker-tls.yaml new file mode 100644 index 000000000000..a86a12f00efa --- /dev/null +++ b/presto-product-tests/conf/tempto/tempto-configuration-for-docker-tls.yaml @@ -0,0 +1,16 @@ +databases: + hive: + host: hadoop-master + presto: + host: presto-master.docker.cluster + port: 7778 + http_port: 8080 + https_port: ${databases.presto.port} + server_address: https://${databases.presto.host}:${databases.presto.port} + jdbc_url: "jdbc:presto://${databases.presto.host}:${databases.presto.port}/hive/${databases.hive.schema}?\ + SSL=true&\ + SSLTrustStorePath=${databases.presto.https_keystore_path}&\ + SSLTrustStorePassword=${databases.presto.https_keystore_password}" + configured_hdfs_user: hive + https_keystore_path: /docker/volumes/conf/presto/etc/docker.cluster.jks + https_keystore_password: '123456' diff --git a/presto-product-tests/pom.xml b/presto-product-tests/pom.xml index a4e6e93d2ee5..7d55006a1c6f 100644 --- a/presto-product-tests/pom.xml +++ b/presto-product-tests/pom.xml @@ -5,7 +5,7 @@ presto-root com.facebook.presto - 0.179-tw-0.36 + 0.181-tw-0.37 presto-product-tests diff --git a/presto-product-tests/src/main/java/com/facebook/presto/tests/AlterTableTests.java b/presto-product-tests/src/main/java/com/facebook/presto/tests/AlterTableTests.java index 6b386b9d2c3b..5888e41a7139 100644 --- a/presto-product-tests/src/main/java/com/facebook/presto/tests/AlterTableTests.java +++ b/presto-product-tests/src/main/java/com/facebook/presto/tests/AlterTableTests.java @@ -96,4 +96,17 @@ public void addColumn() assertThat(() -> query(format("ALTER TABLE %s ADD COLUMN n_naTioNkEy BIGINT", TABLE_NAME))) .failsWithMessage("Column 'n_naTioNkEy' already exists"); } + + @Test(groups = {ALTER_TABLE, SMOKE}) + public void dropColumn() + { + query(format("CREATE TABLE %s AS SELECT n_nationkey, n_regionkey FROM nation", TABLE_NAME)); + + assertThat(query(format("SELECT count(n_nationkey) FROM %s", TABLE_NAME))) + .containsExactly(row(25)); + assertThat(query(format("ALTER TABLE %s DROP COLUMN n_nationkey", TABLE_NAME))) + .hasRowsCount(1); + assertThat(() -> query(format("ALTER TABLE %s DROP COLUMN n_regionkey", TABLE_NAME))) + .failsWithMessage("Cannot drop the only column in a table"); + } } diff --git a/presto-product-tests/src/main/java/com/facebook/presto/tests/TestGroups.java b/presto-product-tests/src/main/java/com/facebook/presto/tests/TestGroups.java index 26e5908fd375..bfa83472587f 100644 --- a/presto-product-tests/src/main/java/com/facebook/presto/tests/TestGroups.java +++ b/presto-product-tests/src/main/java/com/facebook/presto/tests/TestGroups.java @@ -29,6 +29,7 @@ public final class TestGroups public static final String SMOKE = "smoke"; public static final String JDBC = "jdbc"; public static final String MYSQL = "mysql"; + public static final String PRESTO_JDBC = "presto_jdbc"; public static final String SIMBA_JDBC = "simba_jdbc"; public static final String QUERY_ENGINE = "qe"; public static final String COMPARISON = "comparison"; @@ -57,6 +58,7 @@ public final class TestGroups public static final String LDAP = "ldap"; public static final String LDAP_CLI = "ldap_cli"; public static final String SKIP_ON_CDH = "skip_on_cdh"; + public static final String TLS = "tls"; private TestGroups() {} } diff --git a/presto-product-tests/src/main/java/com/facebook/presto/tests/TlsTests.java b/presto-product-tests/src/main/java/com/facebook/presto/tests/TlsTests.java new file mode 100644 index 000000000000..09e9eecfe13b --- /dev/null +++ b/presto-product-tests/src/main/java/com/facebook/presto/tests/TlsTests.java @@ -0,0 +1,121 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.tests; + +import com.google.common.base.Throwables; +import com.google.inject.Inject; +import com.google.inject.name.Named; +import com.teradata.tempto.query.QueryResult; +import org.testng.annotations.Test; + +import java.io.IOException; +import java.net.ConnectException; +import java.net.InetAddress; +import java.net.InetSocketAddress; +import java.net.Socket; +import java.net.SocketTimeoutException; +import java.net.URI; +import java.util.List; + +import static com.facebook.presto.tests.TestGroups.PROFILE_SPECIFIC_TESTS; +import static com.facebook.presto.tests.TestGroups.TLS; +import static com.facebook.presto.tests.utils.QueryExecutors.onPresto; +import static java.lang.String.format; +import static java.util.concurrent.TimeUnit.MINUTES; +import static java.util.stream.Collectors.toList; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.fail; + +public class TlsTests +{ + @Inject(optional = true) + @Named("databases.presto.http_port") + private Integer httpPort; + + @Inject(optional = true) + @Named("databases.presto.https_port") + private Integer httpsPort; + + @Test(groups = {TLS, PROFILE_SPECIFIC_TESTS}) + public void testHttpPortIsClosed() + throws Exception + { + assertThat(httpPort).isNotNull(); + assertThat(httpsPort).isNotNull(); + + waitForNodeRefresh(); + List activeNodesUrls = getActiveNodesUrls(); + assertThat(activeNodesUrls).hasSize(3); + + List hosts = activeNodesUrls.stream() + .map((uri) -> URI.create(uri).getHost()) + .collect(toList()); + + for (String host : hosts) { + assertPortIsOpen(host, httpsPort); + assertPortIsClosed(host, httpPort); + } + } + + private void waitForNodeRefresh() + throws InterruptedException + { + long deadline = System.currentTimeMillis() + MINUTES.toMillis(1); + while (System.currentTimeMillis() < deadline) { + if (getActiveNodesUrls().size() == 3) { + return; + } + Thread.sleep(100); + } + fail("Worker nodes haven't been discovered in 1 minutes."); + } + + private List getActiveNodesUrls() + { + QueryResult queryResult = onPresto() + .executeQuery("SELECT http_uri FROM system.runtime.nodes"); + return queryResult.rows() + .stream() + .map((row) -> row.get(0).toString()) + .collect(toList()); + } + + private static void assertPortIsClosed(String host, Integer port) + { + if (isPortOpen(host, port)) { + fail(format("Port %d at %s is expected to be closed", port, host)); + } + } + + private static void assertPortIsOpen(String host, Integer port) + { + if (!isPortOpen(host, port)) { + fail(format("Port %d at %s is expected to be open", port, host)); + } + } + + private static boolean isPortOpen(String host, Integer port) + { + try (Socket socket = new Socket()) { + socket.connect(new InetSocketAddress(InetAddress.getByName(host), port), 1000); + return true; + } + catch (ConnectException | SocketTimeoutException e) { + return false; + } + catch (IOException e) { + throw Throwables.propagate(e); + } + } +} diff --git a/presto-product-tests/src/main/java/com/facebook/presto/tests/cassandra/TestInsertIntoCassandraTable.java b/presto-product-tests/src/main/java/com/facebook/presto/tests/cassandra/TestInsertIntoCassandraTable.java new file mode 100644 index 000000000000..d32890adfb52 --- /dev/null +++ b/presto-product-tests/src/main/java/com/facebook/presto/tests/cassandra/TestInsertIntoCassandraTable.java @@ -0,0 +1,128 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.tests.cassandra; + +import com.teradata.tempto.ProductTest; +import com.teradata.tempto.Requirement; +import com.teradata.tempto.RequirementsProvider; +import com.teradata.tempto.configuration.Configuration; +import com.teradata.tempto.internal.fulfillment.table.TableName; +import com.teradata.tempto.query.QueryResult; +import io.airlift.units.Duration; +import org.testng.annotations.Test; + +import static com.facebook.presto.tests.TestGroups.CASSANDRA; +import static com.facebook.presto.tests.cassandra.DataTypesTableDefinition.CASSANDRA_ALL_TYPES; +import static com.facebook.presto.tests.cassandra.TestConstants.CONNECTOR_NAME; +import static com.facebook.presto.tests.cassandra.TestConstants.KEY_SPACE; +import static com.facebook.presto.tests.utils.QueryAssertions.assertContainsEventually; +import static com.teradata.tempto.assertions.QueryAssert.Row.row; +import static com.teradata.tempto.assertions.QueryAssert.assertThat; +import static com.teradata.tempto.fulfillment.table.MutableTableRequirement.State.CREATED; +import static com.teradata.tempto.fulfillment.table.MutableTablesState.mutableTablesState; +import static com.teradata.tempto.fulfillment.table.TableRequirements.mutableTable; +import static com.teradata.tempto.query.QueryExecutor.query; +import static com.teradata.tempto.util.DateTimeUtils.parseTimestampInUTC; +import static java.lang.String.format; +import static java.util.concurrent.TimeUnit.MINUTES; + +public class TestInsertIntoCassandraTable + extends ProductTest + implements RequirementsProvider +{ + private static final String CASSANDRA_INSERT_TABLE = "Insert_All_Types"; + + @Override + public Requirement getRequirements(Configuration configuration) + { + return mutableTable(CASSANDRA_ALL_TYPES, CASSANDRA_INSERT_TABLE, CREATED); + } + + @Test(groups = CASSANDRA) + public void testInsertIntoValuesToCassandraTableAllSimpleTypes() + throws Exception + { + TableName table = mutableTablesState().get(CASSANDRA_INSERT_TABLE).getTableName(); + String tableNameInDatabase = String.format("%s.%s", CONNECTOR_NAME, table.getNameInDatabase()); + + assertContainsEventually(() -> query(format("SHOW TABLES FROM %s.%s", CONNECTOR_NAME, KEY_SPACE)), + query(format("SELECT '%s'", table.getSchemalessNameInDatabase())), + new Duration(1, MINUTES)); + + QueryResult queryResult = query("SELECT * FROM " + tableNameInDatabase); + assertThat(queryResult).hasNoRows(); + + // TODO Following types are not supported now. We need to change null into the value after fixing it + // blob, frozen>, inet, list, map, set, timeuuid, decimal, uuid, varint + query("INSERT INTO " + tableNameInDatabase + + "(a, b, bl, bo, d, do, f, fr, i, integer, l, m, s, t, ti, tu, u, v, vari) VALUES (" + + "'ascii value', " + + "BIGINT '99999', " + + "null, " + + "true, " + + "null, " + + "123.456789, " + + "REAL '123.45678', " + + "null, " + + "null, " + + "123, " + + "null, " + + "null, " + + "null, " + + "'text value', " + + "timestamp '9999-12-31 23:59:59'," + + "null, " + + "null, " + + "'varchar value'," + + "null)"); + + assertThat(query("SELECT * FROM " + tableNameInDatabase)).containsOnly( + row( + "ascii value", + 99999, + null, + true, + null, + 123.456789, + 123.45678, + null, + null, + 123, + null, + null, + null, + "text value", + parseTimestampInUTC("9999-12-31 23:59:59"), + null, + null, + "varchar value", + null)); + + // insert null for all datatypes + query("INSERT INTO " + tableNameInDatabase + + "(a, b, bl, bo, d, do, f, fr, i, integer, l, m, s, t, ti, tu, u, v, vari) VALUES (" + + "'key 1', null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null) "); + assertThat(query(format("SELECT * FROM %s WHERE a = 'key 1'", tableNameInDatabase))).containsOnly( + row("key 1", null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null)); + + // insert into only a subset of columns + query(format("INSERT INTO %s (a, bo, integer, t) VALUES ('key 2', false, 999, 'text 2')", tableNameInDatabase)); + assertThat(query(format("SELECT * FROM %s WHERE a = 'key 2'", tableNameInDatabase))).containsOnly( + row("key 2", null, null, false, null, null, null, null, null, 999, null, null, null, "text 2", null, null, null, null, null)); + + // negative test: failed to insert null to primary key + assertThat(() -> query(format("INSERT INTO %s (a) VALUES (null) ", tableNameInDatabase))) + .failsWithMessage("Invalid null value in condition for column a"); + } +} diff --git a/presto-product-tests/src/main/java/com/facebook/presto/tests/cli/PrestoLdapCliTests.java b/presto-product-tests/src/main/java/com/facebook/presto/tests/cli/PrestoLdapCliTests.java index 1ef7f8cf3ebc..fe2c8694e7a3 100644 --- a/presto-product-tests/src/main/java/com/facebook/presto/tests/cli/PrestoLdapCliTests.java +++ b/presto-product-tests/src/main/java/com/facebook/presto/tests/cli/PrestoLdapCliTests.java @@ -202,7 +202,7 @@ public void shouldFailQueryForLdapWithoutPassword() "--truststore-password", ldapTruststorePassword, "--user", ldapUserName, "--execute", "select * from hive.default.nation;"); - assertTrue(trimLines(presto.readRemainingErrorLines()).stream().anyMatch(str -> str.contains("statusMessage=Unauthorized"))); + assertTrue(trimLines(presto.readRemainingErrorLines()).stream().anyMatch(str -> str.contains("Authentication failed: Unauthorized"))); } @Test(groups = {LDAP, LDAP_CLI, PROFILE_SPECIFIC_TESTS}, timeOut = TIMEOUT) diff --git a/presto-product-tests/src/main/java/com/facebook/presto/tests/jdbc/LdapJdbcTests.java b/presto-product-tests/src/main/java/com/facebook/presto/tests/jdbc/LdapJdbcTests.java new file mode 100644 index 000000000000..d5a1a3bf8f20 --- /dev/null +++ b/presto-product-tests/src/main/java/com/facebook/presto/tests/jdbc/LdapJdbcTests.java @@ -0,0 +1,125 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.tests.jdbc; + +import com.google.inject.Inject; +import com.google.inject.name.Named; +import com.teradata.tempto.ProductTest; +import com.teradata.tempto.Requirement; +import com.teradata.tempto.RequirementsProvider; +import com.teradata.tempto.configuration.Configuration; +import com.teradata.tempto.fulfillment.ldap.LdapObjectRequirement; +import com.teradata.tempto.query.QueryResult; + +import java.sql.Connection; +import java.sql.DriverManager; +import java.sql.ResultSet; +import java.sql.SQLException; +import java.sql.Statement; +import java.util.Arrays; + +import static com.facebook.presto.tests.ImmutableLdapObjectDefinitions.AMERICA_ORG; +import static com.facebook.presto.tests.ImmutableLdapObjectDefinitions.ASIA_ORG; +import static com.facebook.presto.tests.ImmutableLdapObjectDefinitions.CHILD_GROUP; +import static com.facebook.presto.tests.ImmutableLdapObjectDefinitions.CHILD_GROUP_USER; +import static com.facebook.presto.tests.ImmutableLdapObjectDefinitions.DEFAULT_GROUP; +import static com.facebook.presto.tests.ImmutableLdapObjectDefinitions.DEFAULT_GROUP_USER; +import static com.facebook.presto.tests.ImmutableLdapObjectDefinitions.ORPHAN_USER; +import static com.facebook.presto.tests.ImmutableLdapObjectDefinitions.PARENT_GROUP; +import static com.facebook.presto.tests.ImmutableLdapObjectDefinitions.PARENT_GROUP_USER; +import static com.google.common.base.Preconditions.checkState; +import static java.lang.String.format; +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.fail; + +public abstract class LdapJdbcTests + extends ProductTest + implements RequirementsProvider +{ + protected static final long TIMEOUT = 30 * 1000; // seconds per test + + protected static final String NATION_SELECT_ALL_QUERY = "select * from tpch.tiny.nation"; + + @Inject + @Named("databases.presto.cli_ldap_truststore_path") + protected String ldapTruststorePath; + + @Inject + @Named("databases.presto.cli_ldap_truststore_password") + protected String ldapTruststorePassword; + + @Inject + @Named("databases.presto.cli_ldap_user_name") + protected String ldapUserName; + + @Inject + @Named("databases.presto.cli_ldap_user_password") + protected String ldapUserPassword; + + @Inject + @Named("databases.presto.cli_ldap_server_address") + private String prestoServer; + + @Override + public Requirement getRequirements(Configuration configuration) + { + return new LdapObjectRequirement( + Arrays.asList( + AMERICA_ORG, ASIA_ORG, + DEFAULT_GROUP, PARENT_GROUP, CHILD_GROUP, + DEFAULT_GROUP_USER, PARENT_GROUP_USER, CHILD_GROUP_USER, ORPHAN_USER + )); + } + + protected void expectQueryToFail(String user, String password, String message) + { + try { + executeLdapQuery(NATION_SELECT_ALL_QUERY, user, password); + fail(); + } + catch (SQLException exception) { + assertEquals(exception.getMessage(), message); + } + } + + protected QueryResult executeLdapQuery(String query, String name, String password) + throws SQLException + { + try (Connection connection = getLdapConnection(name, password)) { + Statement statement = connection.createStatement(); + ResultSet rs = statement.executeQuery(query); + return QueryResult.forResultSet(rs); + } + } + + private Connection getLdapConnection(String name, String password) + throws SQLException + { + return DriverManager.getConnection(getLdapUrl(), name, password); + } + + protected String prestoServer() + { + String prefix = "https://"; + checkState(prestoServer.startsWith(prefix), "invalid server address: %s", prestoServer); + return prestoServer.substring(prefix.length()); + } + + protected String getLdapUrl() + { + return format(getLdapUrlFormat(), prestoServer(), ldapTruststorePath, ldapTruststorePassword); + } + + protected abstract String getLdapUrlFormat(); +} diff --git a/presto-product-tests/src/main/java/com/facebook/presto/tests/jdbc/LdapPrestoJdbcTests.java b/presto-product-tests/src/main/java/com/facebook/presto/tests/jdbc/LdapPrestoJdbcTests.java new file mode 100644 index 000000000000..6b9664c70f4c --- /dev/null +++ b/presto-product-tests/src/main/java/com/facebook/presto/tests/jdbc/LdapPrestoJdbcTests.java @@ -0,0 +1,136 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.tests.jdbc; + +import com.teradata.tempto.Requires; +import com.teradata.tempto.fulfillment.table.hive.tpch.ImmutableTpchTablesRequirements.ImmutableNationTable; +import org.testng.annotations.Test; + +import java.sql.Connection; +import java.sql.DriverManager; +import java.sql.SQLException; +import java.sql.Statement; + +import static com.facebook.presto.tests.ImmutableLdapObjectDefinitions.CHILD_GROUP_USER; +import static com.facebook.presto.tests.ImmutableLdapObjectDefinitions.ORPHAN_USER; +import static com.facebook.presto.tests.ImmutableLdapObjectDefinitions.PARENT_GROUP_USER; +import static com.facebook.presto.tests.TestGroups.LDAP; +import static com.facebook.presto.tests.TestGroups.PRESTO_JDBC; +import static com.facebook.presto.tests.TestGroups.PROFILE_SPECIFIC_TESTS; +import static com.facebook.presto.tests.TpchTableResults.PRESTO_NATION_RESULT; +import static com.teradata.tempto.assertions.QueryAssert.assertThat; +import static java.lang.String.format; +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.fail; + +public class LdapPrestoJdbcTests + extends LdapJdbcTests +{ + @Override + protected String getLdapUrlFormat() + { + return "jdbc:presto://%s?SSL=true&SSLTrustStorePath=%s&SSLTrustStorePassword=%s"; + } + + @Requires(ImmutableNationTable.class) + @Test(groups = {LDAP, PRESTO_JDBC, PROFILE_SPECIFIC_TESTS}, timeOut = TIMEOUT) + public void shouldRunQueryWithLdap() + throws SQLException + { + assertThat(executeLdapQuery(NATION_SELECT_ALL_QUERY, ldapUserName, ldapUserPassword)).matches(PRESTO_NATION_RESULT); + } + + @Test(groups = {LDAP, PRESTO_JDBC, PROFILE_SPECIFIC_TESTS}, timeOut = TIMEOUT) + public void shouldFailQueryForLdapUserInChildGroup() + { + String name = CHILD_GROUP_USER.getAttributes().get("cn"); + expectQueryToFailForUserNotInGroup(name); + } + + @Test(groups = {LDAP, PRESTO_JDBC, PROFILE_SPECIFIC_TESTS}, timeOut = TIMEOUT) + public void shouldFailQueryForLdapUserInParentGroup() + { + String name = PARENT_GROUP_USER.getAttributes().get("cn"); + expectQueryToFailForUserNotInGroup(name); + } + + @Test(groups = {LDAP, PRESTO_JDBC, PROFILE_SPECIFIC_TESTS}, timeOut = TIMEOUT) + public void shouldFailQueryForOrphanLdapUser() + { + String name = ORPHAN_USER.getAttributes().get("cn"); + expectQueryToFailForUserNotInGroup(name); + } + + @Test(groups = {LDAP, PRESTO_JDBC, PROFILE_SPECIFIC_TESTS}, timeOut = TIMEOUT) + public void shouldFailQueryForWrongLdapPassword() + { + expectQueryToFail(ldapUserName, "wrong_password", "Authentication failed: Invalid credentials: [LDAP: error code 49 - Invalid Credentials]"); + } + + @Test(groups = {LDAP, PRESTO_JDBC, PROFILE_SPECIFIC_TESTS}, timeOut = TIMEOUT) + public void shouldFailQueryForWrongLdapUser() + { + expectQueryToFail("invalid_user", ldapUserPassword, "Authentication failed: Invalid credentials: [LDAP: error code 49 - Invalid Credentials]"); + } + + @Test(groups = {LDAP, PRESTO_JDBC, PROFILE_SPECIFIC_TESTS}, timeOut = TIMEOUT) + public void shouldFailQueryForEmptyUser() + { + expectQueryToFail("", ldapUserPassword, "Connection property 'user' value is empty"); + } + + @Test(groups = {LDAP, PRESTO_JDBC, PROFILE_SPECIFIC_TESTS}, timeOut = TIMEOUT) + public void shouldFailQueryForLdapWithoutPassword() + { + expectQueryToFail(ldapUserName, null, "Authentication failed: Unauthorized"); + } + + @Test(groups = {LDAP, PRESTO_JDBC, PROFILE_SPECIFIC_TESTS}, timeOut = TIMEOUT) + public void shouldFailQueryForLdapWithoutSsl() + { + try { + DriverManager.getConnection("jdbc:presto://" + prestoServer(), ldapUserName, ldapUserPassword); + fail(); + } + catch (SQLException exception) { + assertEquals(exception.getMessage(), "Authentication using username/password requires SSL to be enabled"); + } + } + + @Test(groups = {LDAP, PRESTO_JDBC, PROFILE_SPECIFIC_TESTS}, timeOut = TIMEOUT) + public void shouldFailForIncorrectTrustStore() + { + try { + String url = format("jdbc:presto://%s?SSL=true&SSLTrustStorePath=%s&SSLTrustStorePassword=%s", prestoServer(), ldapTruststorePath, "wrong_password"); + Connection connection = DriverManager.getConnection(url, ldapUserName, ldapUserPassword); + Statement statement = connection.createStatement(); + statement.executeQuery(NATION_SELECT_ALL_QUERY); + fail(); + } + catch (SQLException exception) { + assertEquals(exception.getMessage(), "Error setting up SSL: Keystore was tampered with, or password was incorrect"); + } + } + + @Test(groups = {LDAP, PRESTO_JDBC, PROFILE_SPECIFIC_TESTS}, timeOut = TIMEOUT) + public void shouldFailForUserWithColon() + { + expectQueryToFail("UserWith:Colon", ldapUserPassword, "Illegal character ':' found in username"); + } + + private void expectQueryToFailForUserNotInGroup(String user) + { + expectQueryToFail(user, ldapUserPassword, format("Authentication failed: Unauthorized user: User %s not a member of the authorized group", user)); + } +} diff --git a/presto-product-tests/src/main/java/com/facebook/presto/tests/jdbc/LdapTests.java b/presto-product-tests/src/main/java/com/facebook/presto/tests/jdbc/LdapSimbaJdbcTests.java similarity index 65% rename from presto-product-tests/src/main/java/com/facebook/presto/tests/jdbc/LdapTests.java rename to presto-product-tests/src/main/java/com/facebook/presto/tests/jdbc/LdapSimbaJdbcTests.java index ac9018e6cd2f..29bdfd3bc915 100644 --- a/presto-product-tests/src/main/java/com/facebook/presto/tests/jdbc/LdapTests.java +++ b/presto-product-tests/src/main/java/com/facebook/presto/tests/jdbc/LdapSimbaJdbcTests.java @@ -13,35 +13,18 @@ */ package com.facebook.presto.tests.jdbc; -import com.google.inject.Inject; -import com.google.inject.name.Named; -import com.teradata.tempto.BeforeTestWithContext; -import com.teradata.tempto.ProductTest; -import com.teradata.tempto.Requirement; -import com.teradata.tempto.RequirementsProvider; import com.teradata.tempto.Requires; -import com.teradata.tempto.configuration.Configuration; -import com.teradata.tempto.fulfillment.ldap.LdapObjectRequirement; import com.teradata.tempto.fulfillment.table.hive.tpch.ImmutableTpchTablesRequirements.ImmutableNationTable; -import com.teradata.tempto.query.QueryResult; import org.testng.annotations.Test; import java.io.IOException; import java.sql.Connection; import java.sql.DriverManager; -import java.sql.ResultSet; import java.sql.SQLException; import java.sql.Statement; -import java.util.Arrays; -import static com.facebook.presto.tests.ImmutableLdapObjectDefinitions.AMERICA_ORG; -import static com.facebook.presto.tests.ImmutableLdapObjectDefinitions.ASIA_ORG; -import static com.facebook.presto.tests.ImmutableLdapObjectDefinitions.CHILD_GROUP; import static com.facebook.presto.tests.ImmutableLdapObjectDefinitions.CHILD_GROUP_USER; -import static com.facebook.presto.tests.ImmutableLdapObjectDefinitions.DEFAULT_GROUP; -import static com.facebook.presto.tests.ImmutableLdapObjectDefinitions.DEFAULT_GROUP_USER; import static com.facebook.presto.tests.ImmutableLdapObjectDefinitions.ORPHAN_USER; -import static com.facebook.presto.tests.ImmutableLdapObjectDefinitions.PARENT_GROUP; import static com.facebook.presto.tests.ImmutableLdapObjectDefinitions.PARENT_GROUP_USER; import static com.facebook.presto.tests.TestGroups.LDAP; import static com.facebook.presto.tests.TestGroups.PROFILE_SPECIFIC_TESTS; @@ -51,14 +34,9 @@ import static org.testng.Assert.assertEquals; import static org.testng.Assert.fail; -public class LdapTests - extends ProductTest - implements RequirementsProvider - +public class LdapSimbaJdbcTests + extends LdapJdbcTests { - private static final long TIMEOUT = 300 * 1000; // 30 secs per test - - private static final String NATION_SELECT_ALL_QUERY = "select * from tpch.tiny.nation"; private static final String JDBC_URL_FORMAT = "jdbc:presto://%s;AuthenticationType=LDAP Authentication;" + "SSLTrustStorePath=%s;SSLTrustStorePwd=%s;AllowSelfSignedServerCert=1;AllowHostNameCNMismatch=1"; private static final String SSL_CERTIFICATE_ERROR = @@ -72,42 +50,10 @@ public class LdapTests private static final String INVALID_SSL_PROPERTY = "[Teradata][Presto](100200) Connection string is invalid: SSL value is not valid for given AuthenticationType."; - @Inject - @Named("databases.presto.cli_ldap_truststore_path") - private String ldapTruststorePath; - - @Inject - @Named("databases.presto.cli_ldap_truststore_password") - private String ldapTruststorePassword; - - @Inject - @Named("databases.presto.cli_ldap_user_name") - private String ldapUserName; - - @Inject - @Named("databases.presto.cli_ldap_user_password") - private String ldapUserPassword; - - @Inject - @Named("databases.presto.cli_ldap_server_address") - private String prestoServer; - - @BeforeTestWithContext - public void setup() - throws SQLException - { - prestoServer = prestoServer.substring(8); - } - @Override - public Requirement getRequirements(Configuration configuration) + protected String getLdapUrlFormat() { - return new LdapObjectRequirement( - Arrays.asList( - AMERICA_ORG, ASIA_ORG, - DEFAULT_GROUP, PARENT_GROUP, CHILD_GROUP, - DEFAULT_GROUP_USER, PARENT_GROUP_USER, CHILD_GROUP_USER, ORPHAN_USER - )); + return JDBC_URL_FORMAT; } @Requires(ImmutableNationTable.class) @@ -188,10 +134,10 @@ public void shouldFailForIncorrectTrustStore() throws IOException, InterruptedException { try { - String url = String.format(JDBC_URL_FORMAT, prestoServer, ldapTruststorePath, "wrong_password"); + String url = String.format(JDBC_URL_FORMAT, prestoServer(), ldapTruststorePath, "wrong_password"); Connection connection = DriverManager.getConnection(url, ldapUserName, ldapUserPassword); Statement statement = connection.createStatement(); - ResultSet rs = statement.executeQuery(NATION_SELECT_ALL_QUERY); + statement.executeQuery(NATION_SELECT_ALL_QUERY); fail(); } catch (SQLException exception) { @@ -201,7 +147,7 @@ public void shouldFailForIncorrectTrustStore() @Test(groups = {LDAP, SIMBA_JDBC, PROFILE_SPECIFIC_TESTS}, timeOut = TIMEOUT) public void shouldFailForUserWithColon() - throws SQLException, InterruptedException + throws SQLException, InterruptedException { expectQueryToFail("UserWith:Colon", ldapUserPassword, MALFORMED_CREDENTIALS_ERROR); } @@ -210,36 +156,4 @@ private void expectQueryToFailForUserNotInGroup(String user) { expectQueryToFail(user, ldapUserPassword, UNAUTHORIZED_USER_ERROR); } - - private void expectQueryToFail(String user, String password, String message) - { - try { - executeLdapQuery(NATION_SELECT_ALL_QUERY, user, password); - fail(); - } - catch (SQLException exception) { - assertEquals(exception.getMessage(), message); - } - } - - private QueryResult executeLdapQuery(String query, String name, String password) - throws SQLException - { - try (Connection connection = getLdapConnection(name, password)) { - Statement statement = connection.createStatement(); - ResultSet rs = statement.executeQuery(query); - return QueryResult.forResultSet(rs); - } - } - - private Connection getLdapConnection(String name, String password) - throws SQLException - { - return DriverManager.getConnection(getLdapUrl(), name, password); - } - - private String getLdapUrl() - { - return String.format(JDBC_URL_FORMAT, prestoServer, ldapTruststorePath, ldapTruststorePassword); - } } diff --git a/presto-product-tests/src/main/java/com/facebook/presto/tests/sqlserver/Select.java b/presto-product-tests/src/main/java/com/facebook/presto/tests/sqlserver/Select.java index b2ddd1f061e0..383407a99e85 100644 --- a/presto-product-tests/src/main/java/com/facebook/presto/tests/sqlserver/Select.java +++ b/presto-product-tests/src/main/java/com/facebook/presto/tests/sqlserver/Select.java @@ -156,7 +156,7 @@ public void testAllDatatypes() Timestamp.valueOf("9999-12-31 23:59:59.999"), Timestamp.valueOf("2079-06-06 00:00:00"), Double.valueOf("12345678912.3456756"), Float.valueOf("12345678.6557") ), - row(null, null, null, null, null, null, null, null, null, null, null, null, null, null) + row(null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null) ); } diff --git a/presto-product-tests/src/main/java/com/facebook/presto/tests/utils/QueryAssertions.java b/presto-product-tests/src/main/java/com/facebook/presto/tests/utils/QueryAssertions.java new file mode 100644 index 000000000000..9b941dce3516 --- /dev/null +++ b/presto-product-tests/src/main/java/com/facebook/presto/tests/utils/QueryAssertions.java @@ -0,0 +1,63 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.tests.utils; + +import com.google.common.base.Joiner; +import com.google.common.collect.Iterables; +import com.teradata.tempto.query.QueryResult; +import io.airlift.units.Duration; + +import java.util.function.Supplier; + +import static com.google.common.util.concurrent.Uninterruptibles.sleepUninterruptibly; +import static io.airlift.units.Duration.nanosSince; +import static java.lang.String.format; +import static java.util.concurrent.TimeUnit.MILLISECONDS; +import static org.testng.Assert.fail; + +public class QueryAssertions +{ + public static void assertContainsEventually(Supplier all, QueryResult expectedSubset, Duration timeout) + { + long start = System.nanoTime(); + while (!Thread.currentThread().isInterrupted()) { + try { + assertContains(all.get(), expectedSubset); + return; + } + catch (AssertionError e) { + if (nanosSince(start).compareTo(timeout) > 0) { + throw e; + } + } + sleepUninterruptibly(50, MILLISECONDS); + } + } + + public static void assertContains(QueryResult all, QueryResult expectedSubset) + { + for (Object row : expectedSubset.rows()) { + if (!all.rows().contains(row)) { + fail(format("expected row missing: %s\nAll %s rows:\n %s\nExpected subset %s rows:\n %s\n", + row, + all.getRowsCount(), + Joiner.on("\n ").join(Iterables.limit(all.rows(), 100)), + expectedSubset.getRowsCount(), + Joiner.on("\n ").join(Iterables.limit(expectedSubset.rows(), 100)))); + } + } + } + + private QueryAssertions() {} +} diff --git a/presto-product-tests/src/main/resources/sql-tests/testcases/ml_connector/prediction.sql b/presto-product-tests/src/main/resources/sql-tests/testcases/ml_functions/prediction.sql similarity index 58% rename from presto-product-tests/src/main/resources/sql-tests/testcases/ml_connector/prediction.sql rename to presto-product-tests/src/main/resources/sql-tests/testcases/ml_functions/prediction.sql index 63d7355386cb..0381cf009b44 100644 --- a/presto-product-tests/src/main/resources/sql-tests/testcases/ml_connector/prediction.sql +++ b/presto-product-tests/src/main/resources/sql-tests/testcases/ml_functions/prediction.sql @@ -1,6 +1,6 @@ --- database: presto; groups: ml_connector +-- database: presto; groups: ml_functions --! -SELECT classify(features(1, 2), model) +SELECT classify(features(1, 2 + random(1)), model) FROM ( SELECT learn_classifier(labels, features) AS model FROM (VALUES (1, features(1, 2))) t (labels, features) diff --git a/presto-product-tests/src/main/resources/sql-tests/testcases/ml_connector/varcharPrediction.sql b/presto-product-tests/src/main/resources/sql-tests/testcases/ml_functions/varcharPrediction.sql similarity index 59% rename from presto-product-tests/src/main/resources/sql-tests/testcases/ml_connector/varcharPrediction.sql rename to presto-product-tests/src/main/resources/sql-tests/testcases/ml_functions/varcharPrediction.sql index a143500712f4..2edbc87f4573 100644 --- a/presto-product-tests/src/main/resources/sql-tests/testcases/ml_connector/varcharPrediction.sql +++ b/presto-product-tests/src/main/resources/sql-tests/testcases/ml_functions/varcharPrediction.sql @@ -1,6 +1,6 @@ --- database: presto; groups: ml_connector +-- database: presto; groups: ml_functions --! -SELECT classify(features(1, 2), model) +SELECT classify(features(1, 2 + random(1)), model) FROM ( SELECT learn_classifier(labels, features) AS model FROM (VALUES ('cat', features(1, 2))) t (labels, features) diff --git a/presto-product-tests/src/main/resources/tempto-configuration.yaml b/presto-product-tests/src/main/resources/tempto-configuration.yaml index d6ea5880700f..0eb0898971dd 100644 --- a/presto-product-tests/src/main/resources/tempto-configuration.yaml +++ b/presto-product-tests/src/main/resources/tempto-configuration.yaml @@ -29,7 +29,7 @@ databases: jdbc_driver_class: com.facebook.presto.jdbc.PrestoDriver jdbc_url: jdbc:presto://${databases.presto.host}:${databases.presto.port}/hive/${databases.hive.schema} jdbc_user: hdfs - jdbc_password: na + jdbc_password: "***empty***" jdbc_pooling: false presto_tpcds: @@ -37,7 +37,7 @@ databases: jdbc_driver_class: com.facebook.presto.jdbc.PrestoDriver jdbc_url: jdbc:presto://${databases.presto.host}:${databases.presto.port}/hive/tpcds jdbc_user: hdfs - jdbc_password: na + jdbc_password: "***empty***" jdbc_pooling: false alice@presto: @@ -47,7 +47,7 @@ databases: jdbc_driver_class: ${databases.presto.jdbc_driver_class} jdbc_url: ${databases.presto.jdbc_url} jdbc_user: alice - jdbc_password: na + jdbc_password: "***empty***" jdbc_pooling: false https_keystore_path: ${databases.presto.https_keystore_path} https_keystore_password: ${databases.presto.https_keystore_password} @@ -59,7 +59,7 @@ databases: jdbc_driver_class: ${databases.presto.jdbc_driver_class} jdbc_url: ${databases.presto.jdbc_url} jdbc_user: bob - jdbc_password: na + jdbc_password: "***empty***" jdbc_pooling: false https_keystore_path: ${databases.presto.https_keystore_path} https_keystore_password: ${databases.presto.https_keystore_password} diff --git a/presto-raptor/pom.xml b/presto-raptor/pom.xml index aa2b0244764c..bbebb06643fa 100644 --- a/presto-raptor/pom.xml +++ b/presto-raptor/pom.xml @@ -5,7 +5,7 @@ com.facebook.presto presto-root - 0.179-tw-0.36 + 0.181-tw-0.37 presto-raptor diff --git a/presto-raptor/src/main/java/com/facebook/presto/raptor/RaptorMetadata.java b/presto-raptor/src/main/java/com/facebook/presto/raptor/RaptorMetadata.java index 47e4e3e01b21..9ea9d64ee871 100644 --- a/presto-raptor/src/main/java/com/facebook/presto/raptor/RaptorMetadata.java +++ b/presto-raptor/src/main/java/com/facebook/presto/raptor/RaptorMetadata.java @@ -471,6 +471,43 @@ public void renameColumn(ConnectorSession session, ConnectorTableHandle tableHan }); } + @Override + public void dropColumn(ConnectorSession session, ConnectorTableHandle tableHandle, ColumnHandle column) + { + RaptorTableHandle table = (RaptorTableHandle) tableHandle; + RaptorColumnHandle raptorColumn = (RaptorColumnHandle) column; + + List existingColumns = dao.listTableColumns(table.getSchemaName(), table.getTableName()); + if (existingColumns.size() <= 1) { + throw new PrestoException(NOT_SUPPORTED, "Cannot drop the only column in a table"); + } + long maxColumnId = existingColumns.stream().mapToLong(TableColumn::getColumnId).max().getAsLong(); + if (raptorColumn.getColumnId() == maxColumnId) { + throw new PrestoException(NOT_SUPPORTED, "Cannot drop the column which has the largest column ID in the table"); + } + + if (getBucketColumnHandles(table.getTableId()).contains(column)) { + throw new PrestoException(NOT_SUPPORTED, "Cannot drop bucket columns"); + } + + Optional.ofNullable(dao.getTemporalColumnId(table.getTableId())).ifPresent(tempColumnId -> { + if (raptorColumn.getColumnId() == tempColumnId) { + throw new PrestoException(NOT_SUPPORTED, "Cannot drop the temporal column"); + } + }); + + if (getSortColumnHandles(table.getTableId()).contains(raptorColumn)) { + throw new PrestoException(NOT_SUPPORTED, "Cannot drop sort columns"); + } + + daoTransaction(dbi, MetadataDao.class, dao -> { + dao.dropColumn(table.getTableId(), raptorColumn.getColumnId()); + dao.updateTableVersion(table.getTableId(), session.getStartTime()); + }); + + // TODO: drop column from index table + } + @Override public ConnectorOutputTableHandle beginCreateTable(ConnectorSession session, ConnectorTableMetadata tableMetadata, Optional layout) { @@ -696,7 +733,9 @@ public Optional finishInsert(ConnectorSession session, List columns = handle.getColumnHandles().stream().map(ColumnInfo::fromHandle).collect(toList()); long updateTime = session.getStartTime(); - shardManager.commitShards(transactionId, tableId, columns, parseFragments(fragments), externalBatchId, updateTime); + Collection shards = parseFragments(fragments); + log.info("Committing insert into tableId %s (queryId: %s, shards: %s, columns: %s)", handle.getTableId(), session.getQueryId(), shards.size(), columns.size()); + shardManager.commitShards(transactionId, tableId, columns, shards, externalBatchId, updateTime); clearRollback(); diff --git a/presto-raptor/src/main/java/com/facebook/presto/raptor/metadata/DatabaseShardManager.java b/presto-raptor/src/main/java/com/facebook/presto/raptor/metadata/DatabaseShardManager.java index 0cbdd4e83a3d..017c7c336e34 100644 --- a/presto-raptor/src/main/java/com/facebook/presto/raptor/metadata/DatabaseShardManager.java +++ b/presto-raptor/src/main/java/com/facebook/presto/raptor/metadata/DatabaseShardManager.java @@ -72,6 +72,7 @@ import static com.facebook.presto.raptor.util.ArrayUtil.intArrayToBytes; import static com.facebook.presto.raptor.util.DatabaseUtil.bindOptionalInt; import static com.facebook.presto.raptor.util.DatabaseUtil.isSyntaxOrAccessError; +import static com.facebook.presto.raptor.util.DatabaseUtil.isTransactionCacheFullError; import static com.facebook.presto.raptor.util.DatabaseUtil.metadataError; import static com.facebook.presto.raptor.util.DatabaseUtil.runIgnoringConstraintViolation; import static com.facebook.presto.raptor.util.DatabaseUtil.runTransaction; @@ -82,7 +83,7 @@ import static com.facebook.presto.spi.StandardErrorCode.SERVER_STARTING_UP; import static com.facebook.presto.spi.StandardErrorCode.TRANSACTION_CONFLICT; import static com.google.common.base.Preconditions.checkArgument; -import static com.google.common.base.Throwables.propagateIfInstanceOf; +import static com.google.common.base.Throwables.throwIfInstanceOf; import static com.google.common.collect.Iterables.partition; import static java.lang.Boolean.TRUE; import static java.lang.Math.multiplyExact; @@ -365,7 +366,14 @@ private void runCommit(long transactionId, HandleConsumer callback) return; } catch (DBIException e) { - propagateIfInstanceOf(e.getCause(), PrestoException.class); + if (isTransactionCacheFullError(e)) { + throw metadataError(e, "Transaction too large"); + } + + if (e.getCause() != null) { + throwIfInstanceOf(e.getCause(), PrestoException.class); + } + if (attempt == maxAttempts) { throw metadataError(e); } diff --git a/presto-raptor/src/main/java/com/facebook/presto/raptor/metadata/MetadataDao.java b/presto-raptor/src/main/java/com/facebook/presto/raptor/metadata/MetadataDao.java index 2c63d85ed1b4..4cfd88947e0c 100644 --- a/presto-raptor/src/main/java/com/facebook/presto/raptor/metadata/MetadataDao.java +++ b/presto-raptor/src/main/java/com/facebook/presto/raptor/metadata/MetadataDao.java @@ -178,6 +178,13 @@ void renameColumn( @Bind("columnId") long columnId, @Bind("target") String target); + @SqlUpdate("DELETE FROM columns\n" + + " WHERE table_id = :tableId\n" + + " AND column_id = :columnId") + void dropColumn( + @Bind("tableId") long tableId, + @Bind("columnId") long column); + @SqlUpdate("INSERT INTO views (schema_name, table_name, data)\n" + "VALUES (:schemaName, :tableName, :data)") void insertView( diff --git a/presto-raptor/src/main/java/com/facebook/presto/raptor/storage/OrcStorageManager.java b/presto-raptor/src/main/java/com/facebook/presto/raptor/storage/OrcStorageManager.java index 139f92d5fae9..a064a92ada9e 100644 --- a/presto-raptor/src/main/java/com/facebook/presto/raptor/storage/OrcStorageManager.java +++ b/presto-raptor/src/main/java/com/facebook/presto/raptor/storage/OrcStorageManager.java @@ -104,11 +104,12 @@ import static com.facebook.presto.spi.type.VarcharType.createUnboundedVarcharType; import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkState; -import static com.google.common.base.Throwables.propagateIfInstanceOf; +import static com.google.common.base.Throwables.throwIfInstanceOf; import static io.airlift.concurrent.MoreFutures.allAsList; import static io.airlift.concurrent.MoreFutures.getFutureValue; import static io.airlift.concurrent.Threads.daemonThreadsNamed; import static io.airlift.json.JsonCodec.jsonCodec; +import static io.airlift.units.DataSize.Unit.PETABYTE; import static java.lang.Math.min; import static java.nio.file.StandardCopyOption.ATOMIC_MOVE; import static java.util.Objects.requireNonNull; @@ -125,6 +126,8 @@ public class OrcStorageManager private static final JsonCodec SHARD_DELTA_CODEC = jsonCodec(ShardDelta.class); private static final long MAX_ROWS = 1_000_000_000; + // TODO: do not limit the max size of blocks to read for now; enable the limit when the Hive connector is ready + private static final DataSize HUGE_MAX_READ_BLOCK_SIZE = new DataSize(1, PETABYTE); private static final JsonCodec METADATA_CODEC = jsonCodec(OrcFileMetadata.class); private final String nodeId; @@ -228,7 +231,7 @@ public ConnectorPageSource getPageSource( AggregatedMemoryContext systemMemoryUsage = new AggregatedMemoryContext(); try { - OrcReader reader = new OrcReader(dataSource, new OrcMetadataReader(), readerAttributes.getMaxMergeDistance(), readerAttributes.getMaxReadSize()); + OrcReader reader = new OrcReader(dataSource, new OrcMetadataReader(), readerAttributes.getMaxMergeDistance(), readerAttributes.getMaxReadSize(), HUGE_MAX_READ_BLOCK_SIZE); Map indexMap = columnIdIndex(reader.getColumnNames()); ImmutableMap.Builder includedColumns = ImmutableMap.builder(); @@ -338,7 +341,9 @@ OrcDataSource openShard(UUID shardUuid, ReaderAttributes readerAttributes) throw Throwables.propagate(e); } catch (ExecutionException e) { - propagateIfInstanceOf(e.getCause(), PrestoException.class); + if (e.getCause() != null) { + throwIfInstanceOf(e.getCause(), PrestoException.class); + } throw new PrestoException(RAPTOR_RECOVERY_ERROR, "Error recovering shard " + shardUuid, e.getCause()); } catch (TimeoutException e) { @@ -368,7 +373,7 @@ private ShardInfo createShardInfo(UUID shardUuid, OptionalInt bucketNumber, File private List computeShardStats(File file) { try (OrcDataSource dataSource = fileOrcDataSource(defaultReaderAttributes, file)) { - OrcReader reader = new OrcReader(dataSource, new OrcMetadataReader(), defaultReaderAttributes.getMaxMergeDistance(), defaultReaderAttributes.getMaxReadSize()); + OrcReader reader = new OrcReader(dataSource, new OrcMetadataReader(), defaultReaderAttributes.getMaxMergeDistance(), defaultReaderAttributes.getMaxReadSize(), HUGE_MAX_READ_BLOCK_SIZE); ImmutableList.Builder list = ImmutableList.builder(); for (ColumnInfo info : getColumnInfo(reader)) { diff --git a/presto-raptor/src/main/java/com/facebook/presto/raptor/storage/Row.java b/presto-raptor/src/main/java/com/facebook/presto/raptor/storage/Row.java index b21029d28ff0..b8ab117f7476 100644 --- a/presto-raptor/src/main/java/com/facebook/presto/raptor/storage/Row.java +++ b/presto-raptor/src/main/java/com/facebook/presto/raptor/storage/Row.java @@ -42,9 +42,9 @@ public class Row { private final List columns; - private final int sizeInBytes; + private final long sizeInBytes; - public Row(List columns, int sizeInBytes) + public Row(List columns, long sizeInBytes) { this.columns = requireNonNull(columns, "columns is null"); checkArgument(sizeInBytes >= 0, "sizeInBytes must be >= 0"); @@ -56,7 +56,7 @@ public List getColumns() return columns; } - public int getSizeInBytes() + public long getSizeInBytes() { return sizeInBytes; } @@ -70,7 +70,7 @@ public static Row extractRow(Page page, int position, List types) for (int channel = 0; channel < page.getChannelCount(); channel++) { Block block = page.getBlock(channel); Type type = types.get(channel); - int size; + long size; Object value = getNativeContainerValue(type, block, position); if (value == null) { size = SIZE_OF_BYTE; @@ -180,7 +180,7 @@ private static Object nativeContainerToOrcValue(Type type, Object nativeValue) private static class RowBuilder { - private int rowSize; + private long rowSize; private final List columns; public RowBuilder(int columnCount) @@ -188,7 +188,7 @@ public RowBuilder(int columnCount) this.columns = new ArrayList<>(columnCount); } - public void add(Object value, int size) + public void add(Object value, long size) { columns.add(value); rowSize += size; diff --git a/presto-raptor/src/main/java/com/facebook/presto/raptor/util/DatabaseUtil.java b/presto-raptor/src/main/java/com/facebook/presto/raptor/util/DatabaseUtil.java index 6c11ee08af8d..fbbf7aab4e0b 100644 --- a/presto-raptor/src/main/java/com/facebook/presto/raptor/util/DatabaseUtil.java +++ b/presto-raptor/src/main/java/com/facebook/presto/raptor/util/DatabaseUtil.java @@ -25,13 +25,16 @@ import java.sql.ResultSet; import java.sql.SQLException; import java.sql.Statement; +import java.util.Arrays; import java.util.OptionalInt; import java.util.OptionalLong; import java.util.function.Consumer; +import java.util.function.Predicate; import static com.facebook.presto.raptor.RaptorErrorCode.RAPTOR_METADATA_ERROR; -import static com.google.common.base.Throwables.propagateIfInstanceOf; +import static com.google.common.base.Throwables.throwIfInstanceOf; import static com.google.common.reflect.Reflection.newProxy; +import static com.mysql.jdbc.MysqlErrorNumbers.ER_TRANS_CACHE_FULL; import static java.sql.Types.INTEGER; import static java.util.Objects.requireNonNull; @@ -62,7 +65,9 @@ public static T runTransaction(IDBI dbi, TransactionCallback callback) return dbi.inTransaction(callback); } catch (DBIException e) { - propagateIfInstanceOf(e.getCause(), PrestoException.class); + if (e.getCause() != null) { + throwIfInstanceOf(e.getCause(), PrestoException.class); + } throw metadataError(e); } } @@ -75,9 +80,14 @@ public static void daoTransaction(IDBI dbi, Class daoType, Consumer ca }); } + public static PrestoException metadataError(Throwable cause, String message) + { + return new PrestoException(RAPTOR_METADATA_ERROR, message, cause); + } + public static PrestoException metadataError(Throwable cause) { - return new PrestoException(RAPTOR_METADATA_ERROR, "Failed to perform metadata operation", cause); + return metadataError(cause, "Failed to perform metadata operation"); } /** @@ -134,6 +144,32 @@ public static boolean isSyntaxOrAccessError(Exception e) return sqlCodeStartsWith(e, "42"); } + public static boolean isTransactionCacheFullError(Exception e) + { + return mySqlErrorCodeMatches(e, ER_TRANS_CACHE_FULL); + } + + /** + * Check if an exception is caused by a MySQL exception of certain error code + */ + private static boolean mySqlErrorCodeMatches(Exception e, int errorCode) + { + return Throwables.getCausalChain(e).stream() + .filter(SQLException.class::isInstance) + .map(SQLException.class::cast) + .filter(t -> t.getErrorCode() == errorCode) + .map(Throwable::getStackTrace) + .anyMatch(isMySQLException()); + } + + private static Predicate isMySQLException() + { + // check if the exception is a mysql exception by matching the package name in the stack trace + return s -> Arrays.stream(s) + .map(StackTraceElement::getClassName) + .anyMatch(t -> t.startsWith("com.mysql.jdbc.")); + } + private static boolean sqlCodeStartsWith(Exception e, String code) { for (Throwable throwable : Throwables.getCausalChain(e)) { diff --git a/presto-raptor/src/test/java/com/facebook/presto/raptor/TestRaptorIntegrationSmokeTest.java b/presto-raptor/src/test/java/com/facebook/presto/raptor/TestRaptorIntegrationSmokeTest.java index 87c769550e0a..773460ea33b7 100644 --- a/presto-raptor/src/test/java/com/facebook/presto/raptor/TestRaptorIntegrationSmokeTest.java +++ b/presto-raptor/src/test/java/com/facebook/presto/raptor/TestRaptorIntegrationSmokeTest.java @@ -13,10 +13,10 @@ */ package com.facebook.presto.raptor; +import com.facebook.presto.spi.type.ArrayType; import com.facebook.presto.testing.MaterializedResult; import com.facebook.presto.testing.MaterializedRow; import com.facebook.presto.tests.AbstractTestIntegrationSmokeTest; -import com.facebook.presto.type.ArrayType; import com.google.common.collect.HashMultimap; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; diff --git a/presto-raptor/src/test/java/com/facebook/presto/raptor/metadata/TestRaptorMetadata.java b/presto-raptor/src/test/java/com/facebook/presto/raptor/metadata/TestRaptorMetadata.java index 5f60fe0dfdb1..eb106400a104 100644 --- a/presto-raptor/src/test/java/com/facebook/presto/raptor/metadata/TestRaptorMetadata.java +++ b/presto-raptor/src/test/java/com/facebook/presto/raptor/metadata/TestRaptorMetadata.java @@ -50,6 +50,7 @@ import org.skife.jdbi.v2.Handle; import org.skife.jdbi.v2.util.BooleanMapper; import org.skife.jdbi.v2.util.LongMapper; +import org.testng.Assert.ThrowingRunnable; import org.testng.annotations.AfterMethod; import org.testng.annotations.BeforeMethod; import org.testng.annotations.Test; @@ -140,6 +141,66 @@ public void testRenameColumn() assertNotNull(metadata.getColumnHandles(SESSION, tableHandle).get("orderkey_renamed")); } + @Test + public void testDropColumn() + throws Exception + { + assertNull(metadata.getTableHandle(SESSION, DEFAULT_TEST_ORDERS)); + metadata.createTable(SESSION, buildTable(ImmutableMap.of(), tableMetadataBuilder(DEFAULT_TEST_ORDERS) + .column("orderkey", BIGINT) + .column("price", BIGINT))); + ConnectorTableHandle tableHandle = metadata.getTableHandle(SESSION, DEFAULT_TEST_ORDERS); + assertInstanceOf(tableHandle, RaptorTableHandle.class); + + RaptorTableHandle raptorTableHandle = (RaptorTableHandle) tableHandle; + + ColumnHandle lastColumn = metadata.getColumnHandles(SESSION, tableHandle).get("orderkey"); + metadata.dropColumn(SESSION, raptorTableHandle, lastColumn); + assertNull(metadata.getColumnHandles(SESSION, tableHandle).get("orderkey")); + } + + @Test + public void testDropColumnDisallowed() + throws Exception + { + assertNull(metadata.getTableHandle(SESSION, DEFAULT_TEST_ORDERS)); + Map properties = ImmutableMap.of( + BUCKET_COUNT_PROPERTY, 16, + BUCKETED_ON_PROPERTY, ImmutableList.of("orderkey"), + ORDERING_PROPERTY, ImmutableList.of("totalprice"), + TEMPORAL_COLUMN_PROPERTY, "orderdate"); + ConnectorTableMetadata ordersTable = buildTable(properties, tableMetadataBuilder(DEFAULT_TEST_ORDERS) + .column("orderkey", BIGINT) + .column("totalprice", DOUBLE) + .column("orderdate", DATE) + .column("highestid", BIGINT)); + metadata.createTable(SESSION, ordersTable); + + ConnectorTableHandle ordersTableHandle = metadata.getTableHandle(SESSION, DEFAULT_TEST_ORDERS); + assertInstanceOf(ordersTableHandle, RaptorTableHandle.class); + RaptorTableHandle ordersRaptorTableHandle = (RaptorTableHandle) ordersTableHandle; + assertEquals(ordersRaptorTableHandle.getTableId(), 1); + + assertInstanceOf(ordersRaptorTableHandle, RaptorTableHandle.class); + + // disallow dropping bucket, sort, temporal and highest-id columns + ColumnHandle bucketColumn = metadata.getColumnHandles(SESSION, ordersRaptorTableHandle).get("orderkey"); + assertThrows("Cannot drop bucket columns", () -> + metadata.dropColumn(SESSION, ordersTableHandle, bucketColumn)); + + ColumnHandle sortColumn = metadata.getColumnHandles(SESSION, ordersRaptorTableHandle).get("totalprice"); + assertThrows("Cannot drop sort columns", () -> + metadata.dropColumn(SESSION, ordersTableHandle, sortColumn)); + + ColumnHandle temporalColumn = metadata.getColumnHandles(SESSION, ordersRaptorTableHandle).get("orderdate"); + assertThrows("Cannot drop the temporal column", () -> + metadata.dropColumn(SESSION, ordersTableHandle, temporalColumn)); + + ColumnHandle highestColumn = metadata.getColumnHandles(SESSION, ordersRaptorTableHandle).get("highestid"); + assertThrows("Cannot drop the column which has the largest column ID in the table", () -> + metadata.dropColumn(SESSION, ordersTableHandle, highestColumn)); + } + @Test public void testRenameTable() throws Exception @@ -836,4 +897,15 @@ private static void assertTableColumnsEqual(List actual, List columnIds, List types) throws IOException { - OrcReader orcReader = new OrcReader(dataSource, new OrcMetadataReader(), new DataSize(1, Unit.MEGABYTE), new DataSize(1, Unit.MEGABYTE)); + OrcReader orcReader = new OrcReader(dataSource, new OrcMetadataReader(), new DataSize(1, Unit.MEGABYTE), new DataSize(1, Unit.MEGABYTE), new DataSize(1, Unit.MEGABYTE)); List columnNames = orcReader.getColumnNames(); assertEquals(columnNames.size(), columnIds.size()); @@ -69,7 +69,7 @@ public static OrcRecordReader createReader(OrcDataSource dataSource, List public static OrcRecordReader createReaderNoRows(OrcDataSource dataSource) throws IOException { - OrcReader orcReader = new OrcReader(dataSource, new OrcMetadataReader(), new DataSize(1, Unit.MEGABYTE), new DataSize(1, Unit.MEGABYTE)); + OrcReader orcReader = new OrcReader(dataSource, new OrcMetadataReader(), new DataSize(1, Unit.MEGABYTE), new DataSize(1, Unit.MEGABYTE), new DataSize(1, Unit.MEGABYTE)); assertEquals(orcReader.getColumnNames().size(), 0); diff --git a/presto-raptor/src/test/java/com/facebook/presto/raptor/storage/TestOrcFileRewriter.java b/presto-raptor/src/test/java/com/facebook/presto/raptor/storage/TestOrcFileRewriter.java index 428fc5d7c89d..cf22ef820fe8 100644 --- a/presto-raptor/src/test/java/com/facebook/presto/raptor/storage/TestOrcFileRewriter.java +++ b/presto-raptor/src/test/java/com/facebook/presto/raptor/storage/TestOrcFileRewriter.java @@ -20,13 +20,13 @@ import com.facebook.presto.raptor.storage.OrcFileRewriter.OrcFileInfo; import com.facebook.presto.spi.Page; import com.facebook.presto.spi.block.Block; +import com.facebook.presto.spi.type.ArrayType; import com.facebook.presto.spi.type.StandardTypes; import com.facebook.presto.spi.type.Type; import com.facebook.presto.spi.type.TypeManager; import com.facebook.presto.spi.type.TypeSignature; import com.facebook.presto.spi.type.TypeSignatureParameter; import com.facebook.presto.sql.analyzer.FeaturesConfig; -import com.facebook.presto.type.ArrayType; import com.facebook.presto.type.TypeRegistry; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; diff --git a/presto-raptor/src/test/java/com/facebook/presto/raptor/storage/TestShardWriter.java b/presto-raptor/src/test/java/com/facebook/presto/raptor/storage/TestShardWriter.java index 4247421b38bd..45a257521232 100644 --- a/presto-raptor/src/test/java/com/facebook/presto/raptor/storage/TestShardWriter.java +++ b/presto-raptor/src/test/java/com/facebook/presto/raptor/storage/TestShardWriter.java @@ -20,13 +20,13 @@ import com.facebook.presto.orc.OrcRecordReader; import com.facebook.presto.spi.block.Block; import com.facebook.presto.spi.classloader.ThreadContextClassLoader; +import com.facebook.presto.spi.type.ArrayType; import com.facebook.presto.spi.type.StandardTypes; import com.facebook.presto.spi.type.Type; import com.facebook.presto.spi.type.TypeManager; import com.facebook.presto.spi.type.TypeSignature; import com.facebook.presto.spi.type.TypeSignatureParameter; import com.facebook.presto.sql.analyzer.FeaturesConfig; -import com.facebook.presto.type.ArrayType; import com.facebook.presto.type.TypeRegistry; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; diff --git a/presto-rcfile/pom.xml b/presto-rcfile/pom.xml index 2dcb1e07775b..af113f8ae402 100644 --- a/presto-rcfile/pom.xml +++ b/presto-rcfile/pom.xml @@ -5,7 +5,7 @@ com.facebook.presto presto-root - 0.179-tw-0.36 + 0.181-tw-0.37 presto-rcfile diff --git a/presto-rcfile/src/main/java/com/facebook/presto/rcfile/AircompressorCodecFactory.java b/presto-rcfile/src/main/java/com/facebook/presto/rcfile/AircompressorCodecFactory.java index 8c0f56c16f90..d30c9f2de38c 100644 --- a/presto-rcfile/src/main/java/com/facebook/presto/rcfile/AircompressorCodecFactory.java +++ b/presto-rcfile/src/main/java/com/facebook/presto/rcfile/AircompressorCodecFactory.java @@ -13,6 +13,7 @@ */ package com.facebook.presto.rcfile; +import io.airlift.compress.gzip.JdkGzipCodec; import io.airlift.compress.lz4.Lz4Codec; import io.airlift.compress.lzo.LzoCodec; import io.airlift.compress.snappy.SnappyCodec; @@ -27,6 +28,7 @@ public class AircompressorCodecFactory private static final String LZO_CODEC_NAME_DEPRECATED = "org.apache.hadoop.io.compress.LzoCodec"; private static final String LZ4_CODEC_NAME = "org.apache.hadoop.io.compress.Lz4Codec"; private static final String LZ4_HC_CODEC_NAME = "org.apache.hadoop.io.compress.Lz4Codec"; + private static final String GZIP_CODEC_NAME = "org.apache.hadoop.io.compress.GzipCodec"; private final RcFileCodecFactory delegate; @@ -47,6 +49,9 @@ public RcFileCompressor createCompressor(String codecName) if (LZ4_CODEC_NAME.equals(codecName)) { return new AircompressorCompressor(new Lz4Codec()); } + if (GZIP_CODEC_NAME.equals(codecName)) { + return new AircompressorCompressor(new JdkGzipCodec()); + } return delegate.createCompressor(codecName); } @@ -62,6 +67,9 @@ public RcFileDecompressor createDecompressor(String codecName) if (LZ4_CODEC_NAME.equals(codecName) || LZ4_HC_CODEC_NAME.equals(codecName)) { return new AircompressorDecompressor(new Lz4Codec()); } + if (GZIP_CODEC_NAME.equals(codecName)) { + return new AircompressorDecompressor(new JdkGzipCodec()); + } return delegate.createDecompressor(codecName); } } diff --git a/presto-rcfile/src/main/java/com/facebook/presto/rcfile/BufferedOutputStreamSliceOutput.java b/presto-rcfile/src/main/java/com/facebook/presto/rcfile/BufferedOutputStreamSliceOutput.java index 8a5eb537615c..da0eb91f36a4 100644 --- a/presto-rcfile/src/main/java/com/facebook/presto/rcfile/BufferedOutputStreamSliceOutput.java +++ b/presto-rcfile/src/main/java/com/facebook/presto/rcfile/BufferedOutputStreamSliceOutput.java @@ -76,8 +76,12 @@ public void flush() public void close() throws IOException { - flushBufferToOutputStream(); - outputStream.close(); + try { + flushBufferToOutputStream(); + } + finally { + outputStream.close(); + } } @Override diff --git a/presto-rcfile/src/main/java/com/facebook/presto/rcfile/RcFileReader.java b/presto-rcfile/src/main/java/com/facebook/presto/rcfile/RcFileReader.java index 684f67a3ace6..20625063b236 100644 --- a/presto-rcfile/src/main/java/com/facebook/presto/rcfile/RcFileReader.java +++ b/presto-rcfile/src/main/java/com/facebook/presto/rcfile/RcFileReader.java @@ -137,21 +137,19 @@ private RcFileReader( this.writeValidation = requireNonNull(writeValidation, "writeValidation is null"); this.writeChecksumBuilder = writeValidation.map(validation -> createWriteChecksumBuilder(readColumns)); - checkArgument(offset >= 0, "offset is negative"); - checkArgument(offset < dataSource.getSize(), "offset is greater than data size"); - checkArgument(length >= 1, "length must be at least 1"); + verify(offset >= 0, "offset is negative"); + verify(offset < dataSource.getSize(), "offset is greater than data size"); + verify(length >= 1, "length must be at least 1"); this.length = length; this.end = offset + length; - checkArgument(end <= dataSource.getSize(), "offset plus length is greater than data size"); + verify(end <= dataSource.getSize(), "offset plus length is greater than data size"); // read header Slice magic = input.readSlice(RCFILE_MAGIC.length()); boolean compressed; if (RCFILE_MAGIC.equals(magic)) { version = input.readByte(); - if (version > CURRENT_VERSION) { - throw corrupt("RCFile version %s not supported: %s", version, dataSource); - } + verify(version <= CURRENT_VERSION, "RCFile version %s not supported: %s", version, dataSource); validateWrite(validation -> validation.getVersion() == version, "Unexpected file version"); compressed = input.readBoolean(); } @@ -160,19 +158,14 @@ else if (SEQUENCE_FILE_MAGIC.equals(magic)) { // first version of RCFile used magic SEQ with version 6 byte sequenceFileVersion = input.readByte(); - if (sequenceFileVersion != SEQUENCE_FILE_VERSION) { - throw corrupt("File %s is a SequenceFile not an RCFile", dataSource); - } + verify(sequenceFileVersion == SEQUENCE_FILE_VERSION, "File %s is a SequenceFile not an RCFile", dataSource); // this is the first version of RCFile this.version = FIRST_VERSION; Slice keyClassName = readLengthPrefixedString(input); Slice valueClassName = readLengthPrefixedString(input); - if (!RCFILE_KEY_BUFFER_NAME.equals(keyClassName) || !RCFILE_VALUE_BUFFER_NAME.equals(valueClassName)) { - throw corrupt("File %s is a SequenceFile not an RCFile", dataSource); - } - + verify(RCFILE_KEY_BUFFER_NAME.equals(keyClassName) && RCFILE_VALUE_BUFFER_NAME.equals(valueClassName), "File %s is a SequenceFile not an RCFile", dataSource); compressed = input.readBoolean(); // RC file is never block compressed @@ -197,12 +190,8 @@ else if (SEQUENCE_FILE_MAGIC.equals(magic)) { // read metadata int metadataEntries = Integer.reverseBytes(input.readInt()); - if (metadataEntries < 0) { - throw corrupt("Invalid metadata entry count %s in RCFile %s", metadataEntries, dataSource); - } - if (metadataEntries > MAX_METADATA_ENTRIES) { - throw corrupt("Too many metadata entries (%s) in RCFile %s", metadataEntries, dataSource); - } + verify(metadataEntries >= 0, "Invalid metadata entry count %s in RCFile %s", metadataEntries, dataSource); + verify(metadataEntries <= MAX_METADATA_ENTRIES, "Too many metadata entries (%s) in RCFile %s", metadataEntries, dataSource); ImmutableMap.Builder metadataBuilder = ImmutableMap.builder(); for (int i = 0; i < metadataEntries; i++) { metadataBuilder.put(readLengthPrefixedString(input).toStringUtf8(), readLengthPrefixedString(input).toStringUtf8()); @@ -220,9 +209,7 @@ else if (SEQUENCE_FILE_MAGIC.equals(magic)) { } // initialize columns - if (columnCount > MAX_COLUMN_COUNT) { - throw corrupt("Too many columns (%s) in RCFile %s", columnCountString, dataSource); - } + verify(columnCount <= MAX_COLUMN_COUNT, "Too many columns (%s) in RCFile %s", columnCountString, dataSource); columns = new Column[columnCount]; for (Entry entry : readColumns.entrySet()) { if (entry.getKey() < columnCount) { @@ -339,16 +326,12 @@ public int advance() } // read uncompressed size of row group (which is useless information) - if (input.remaining() < SIZE_OF_INT) { - throw corrupt("RCFile truncated %s", dataSource); - } + verify(input.remaining() >= SIZE_OF_INT, "RCFile truncated %s", dataSource); int unusedRowGroupSize = Integer.reverseBytes(input.readInt()); // read sequence sync if present if (unusedRowGroupSize == -1) { - if (input.remaining() < SIZE_OF_LONG + SIZE_OF_LONG + SIZE_OF_INT) { - throw corrupt("RCFile truncated %s", dataSource); - } + verify(input.remaining() >= SIZE_OF_LONG + SIZE_OF_LONG + SIZE_OF_INT, "RCFile truncated %s", dataSource); // The full sync sequence is "0xFFFFFFFF syncFirst syncSecond". If // this sequence begins in our segment, we must continue process until the @@ -361,9 +344,7 @@ public int advance() return -1; } - if (syncFirst != input.readLong() || syncSecond != input.readLong()) { - throw corrupt("Invalid sync in RCFile %s", dataSource); - } + verify(syncFirst == input.readLong() && syncSecond == input.readLong(), "Invalid sync in RCFile %s", dataSource); // read the useless uncompressed length unusedRowGroupSize = Integer.reverseBytes(input.readInt()); @@ -371,9 +352,7 @@ public int advance() else if (rowsRead > 0) { validateWrite(writeValidation -> false, "Expected sync sequence for every row group except the first one"); } - if (unusedRowGroupSize <= 0) { - throw corrupt("Invalid uncompressed row group length %s", unusedRowGroupSize); - } + verify(unusedRowGroupSize > 0, "Invalid uncompressed row group length %s", unusedRowGroupSize); // read row group header int uncompressedHeaderSize = Integer.reverseBytes(input.readInt()); @@ -396,9 +375,7 @@ else if (rowsRead > 0) { header = buffer; } else { - if (compressedHeaderSize != uncompressedHeaderSize) { - throw corrupt("Invalid RCFile %s", dataSource); - } + verify(compressedHeaderSize == uncompressedHeaderSize, "Invalid RCFile %s", dataSource); header = compressedHeaderBuffer; } BasicSliceInput headerInput = header.getInput(); @@ -433,9 +410,7 @@ else if (rowsRead > 0) { } // this value is not used but validate it is correct since it might signal corruption - if (unusedRowGroupSize != totalCompressedDataSize + uncompressedHeaderSize) { - throw corrupt("Invalid row group size"); - } + verify(unusedRowGroupSize == totalCompressedDataSize + uncompressedHeaderSize, "Invalid row group size"); validateWriteRowGroupChecksum(); validateWritePageChecksum(); @@ -481,13 +456,18 @@ private Slice readLengthPrefixedString(SliceInput in) throws RcFileCorruptionException { int length = toIntExact(readVInt(in)); - if (length > MAX_METADATA_STRING_LENGTH) { - throw corrupt("Metadata string value is too long (%s) in RCFile %s", length, in); - } - + verify(length <= MAX_METADATA_STRING_LENGTH, "Metadata string value is too long (%s) in RCFile %s", length, in); return in.readSlice(length); } + private void verify(boolean expression, String messageFormat, Object... args) + throws RcFileCorruptionException + { + if (!expression) { + throw corrupt(messageFormat, args); + } + } + private RcFileCorruptionException corrupt(String messageFormat, Object... args) { closeQuietly(); diff --git a/presto-rcfile/src/main/java/com/facebook/presto/rcfile/RcFileWriter.java b/presto-rcfile/src/main/java/com/facebook/presto/rcfile/RcFileWriter.java index 7b3e68e7d7f2..62c9040a8ab7 100644 --- a/presto-rcfile/src/main/java/com/facebook/presto/rcfile/RcFileWriter.java +++ b/presto-rcfile/src/main/java/com/facebook/presto/rcfile/RcFileWriter.java @@ -188,11 +188,15 @@ private void writeMetadataProperty(String key, String value) public void close() throws IOException { - writeRowGroup(); - output.close(); - keySectionOutput.destroy(); - for (ColumnEncoder columnEncoder : columnEncoders) { - columnEncoder.destroy(); + try { + writeRowGroup(); + output.close(); + } + finally { + keySectionOutput.destroy(); + for (ColumnEncoder columnEncoder : columnEncoders) { + columnEncoder.destroy(); + } } } diff --git a/presto-rcfile/src/test/java/com/facebook/presto/rcfile/RcFileTester.java b/presto-rcfile/src/test/java/com/facebook/presto/rcfile/RcFileTester.java index 228bb76aadfc..96ce0b58e328 100644 --- a/presto-rcfile/src/test/java/com/facebook/presto/rcfile/RcFileTester.java +++ b/presto-rcfile/src/test/java/com/facebook/presto/rcfile/RcFileTester.java @@ -22,8 +22,11 @@ import com.facebook.presto.spi.block.Block; import com.facebook.presto.spi.block.BlockBuilder; import com.facebook.presto.spi.block.BlockBuilderStatus; +import com.facebook.presto.spi.type.ArrayType; import com.facebook.presto.spi.type.DecimalType; import com.facebook.presto.spi.type.Decimals; +import com.facebook.presto.spi.type.MapType; +import com.facebook.presto.spi.type.RowType; import com.facebook.presto.spi.type.SqlDate; import com.facebook.presto.spi.type.SqlDecimal; import com.facebook.presto.spi.type.SqlTimestamp; @@ -34,9 +37,6 @@ import com.facebook.presto.spi.type.TypeSignatureParameter; import com.facebook.presto.spi.type.VarcharType; import com.facebook.presto.sql.analyzer.FeaturesConfig; -import com.facebook.presto.type.ArrayType; -import com.facebook.presto.type.MapType; -import com.facebook.presto.type.RowType; import com.facebook.presto.type.TypeRegistry; import com.google.common.base.Throwables; import com.google.common.collect.AbstractIterator; @@ -87,6 +87,7 @@ import org.apache.hadoop.io.LongWritable; import org.apache.hadoop.io.Text; import org.apache.hadoop.io.Writable; +import org.apache.hadoop.io.compress.BZip2Codec; import org.apache.hadoop.io.compress.GzipCodec; import org.apache.hadoop.io.compress.Lz4Codec; import org.apache.hadoop.io.compress.SnappyCodec; @@ -123,6 +124,7 @@ import java.util.concurrent.ThreadLocalRandom; import static com.facebook.presto.rcfile.RcFileDecoderUtils.findFirstSyncPosition; +import static com.facebook.presto.rcfile.RcFileTester.Compression.BZIP2; import static com.facebook.presto.rcfile.RcFileTester.Compression.LZ4; import static com.facebook.presto.rcfile.RcFileTester.Compression.NONE; import static com.facebook.presto.rcfile.RcFileTester.Compression.SNAPPY; @@ -246,6 +248,13 @@ public RcFileEncoding getVectorEncoding() public enum Compression { + BZIP2 { + @Override + Optional getCodecName() + { + return Optional.of(BZip2Codec.class.getName()); + } + }, ZLIB { @Override Optional getCodecName() @@ -318,7 +327,7 @@ public static RcFileTester fullTestRcFileReader() // These compression algorithms were chosen to cover the three different // cases: uncompressed, aircompressor, and hadoop compression // We assume that the compression algorithms generally work - rcFileTester.compressions = ImmutableSet.of(NONE, LZ4, ZLIB); + rcFileTester.compressions = ImmutableSet.of(NONE, LZ4, ZLIB, BZIP2); return rcFileTester; } diff --git a/presto-record-decoder/pom.xml b/presto-record-decoder/pom.xml index e7909a0562ef..4915d72cd769 100644 --- a/presto-record-decoder/pom.xml +++ b/presto-record-decoder/pom.xml @@ -5,7 +5,7 @@ com.facebook.presto presto-root - 0.179-tw-0.36 + 0.181-tw-0.37 presto-record-decoder diff --git a/presto-redis/pom.xml b/presto-redis/pom.xml index fb3eb4ab45ed..087b22d1164d 100644 --- a/presto-redis/pom.xml +++ b/presto-redis/pom.xml @@ -5,7 +5,7 @@ com.facebook.presto presto-root - 0.179-tw-0.36 + 0.181-tw-0.37 presto-redis diff --git a/presto-resource-group-managers/pom.xml b/presto-resource-group-managers/pom.xml index 22c92eb8f054..81722baea1d8 100644 --- a/presto-resource-group-managers/pom.xml +++ b/presto-resource-group-managers/pom.xml @@ -5,7 +5,7 @@ com.facebook.presto presto-root - 0.179-tw-0.36 + 0.181-tw-0.37 presto-resource-group-managers diff --git a/presto-resource-group-managers/src/main/java/com/facebook/presto/resourceGroups/AbstractResourceConfigurationManager.java b/presto-resource-group-managers/src/main/java/com/facebook/presto/resourceGroups/AbstractResourceConfigurationManager.java index df9093323eab..e493d4bf18b2 100644 --- a/presto-resource-group-managers/src/main/java/com/facebook/presto/resourceGroups/AbstractResourceConfigurationManager.java +++ b/presto-resource-group-managers/src/main/java/com/facebook/presto/resourceGroups/AbstractResourceConfigurationManager.java @@ -15,6 +15,7 @@ import com.facebook.presto.spi.memory.ClusterMemoryPoolManager; import com.facebook.presto.spi.memory.MemoryPoolId; +import com.facebook.presto.spi.resourceGroups.QueryType; import com.facebook.presto.spi.resourceGroups.ResourceGroup; import com.facebook.presto.spi.resourceGroups.ResourceGroupConfigurationManager; import com.facebook.presto.spi.resourceGroups.ResourceGroupSelector; @@ -83,14 +84,16 @@ protected List buildSelectors(ManagerSpec managerSpec) { ImmutableList.Builder selectors = ImmutableList.builder(); for (SelectorSpec spec : managerSpec.getSelectors()) { - validateSelectors(managerSpec.getRootGroups(), spec.getGroup().getSegments()); - selectors.add(new StaticSelector(spec.getUserRegex(), spec.getSourceRegex(), spec.getGroup())); + validateSelectors(managerSpec.getRootGroups(), spec); + selectors.add(new StaticSelector(spec.getUserRegex(), spec.getSourceRegex(), spec.getQueryType(), spec.getGroup())); } return selectors.build(); } - private void validateSelectors(List groups, List selectorGroups) + private void validateSelectors(List groups, SelectorSpec spec) { + spec.getQueryType().ifPresent(this::validateQueryType); + List selectorGroups = spec.getGroup().getSegments(); StringBuilder fullyQualifiedGroupName = new StringBuilder(); while (!selectorGroups.isEmpty()) { ResourceGroupNameTemplate groupName = selectorGroups.get(0); @@ -108,6 +111,16 @@ private void validateSelectors(List groups, List { diff --git a/presto-resource-group-managers/src/main/java/com/facebook/presto/resourceGroups/SelectorSpec.java b/presto-resource-group-managers/src/main/java/com/facebook/presto/resourceGroups/SelectorSpec.java index f60f08c0504c..e0d39057c77c 100644 --- a/presto-resource-group-managers/src/main/java/com/facebook/presto/resourceGroups/SelectorSpec.java +++ b/presto-resource-group-managers/src/main/java/com/facebook/presto/resourceGroups/SelectorSpec.java @@ -28,15 +28,18 @@ public class SelectorSpec private final Optional userRegex; private final Optional sourceRegex; private final ResourceGroupIdTemplate group; + private final Optional queryType; @JsonCreator public SelectorSpec( @JsonProperty("user") Optional userRegex, @JsonProperty("source") Optional sourceRegex, + @JsonProperty("queryType") Optional queryType, @JsonProperty("group") ResourceGroupIdTemplate group) { this.userRegex = requireNonNull(userRegex, "userRegex is null"); this.sourceRegex = requireNonNull(sourceRegex, "sourceRegex is null"); + this.queryType = requireNonNull(queryType, "queryType is null"); this.group = requireNonNull(group, "group is null"); } @@ -55,6 +58,11 @@ public ResourceGroupIdTemplate getGroup() return group; } + public Optional getQueryType() + { + return queryType; + } + @Override public boolean equals(Object other) { @@ -66,6 +74,7 @@ public boolean equals(Object other) } SelectorSpec that = (SelectorSpec) other; return (group.equals(that.group) && + queryType.equals(that.queryType) && userRegex.map(Pattern::pattern).equals(that.userRegex.map(Pattern::pattern)) && userRegex.map(Pattern::flags).equals(that.userRegex.map(Pattern::flags)) && sourceRegex.map(Pattern::pattern).equals(that.sourceRegex.map(Pattern::pattern))) && @@ -76,6 +85,7 @@ public int hashCode() { return Objects.hash( group, + queryType, userRegex.map(Pattern::pattern), userRegex.map(Pattern::flags), sourceRegex.map(Pattern::pattern), @@ -91,6 +101,7 @@ public String toString() .add("userFlags", userRegex.map(Pattern::flags)) .add("sourceRegex", sourceRegex) .add("sourceFlags", sourceRegex.map(Pattern::flags)) + .add("queryType", queryType) .toString(); } } diff --git a/presto-resource-group-managers/src/main/java/com/facebook/presto/resourceGroups/StaticSelector.java b/presto-resource-group-managers/src/main/java/com/facebook/presto/resourceGroups/StaticSelector.java index 743c73ce3aaf..f73840d0c7dc 100644 --- a/presto-resource-group-managers/src/main/java/com/facebook/presto/resourceGroups/StaticSelector.java +++ b/presto-resource-group-managers/src/main/java/com/facebook/presto/resourceGroups/StaticSelector.java @@ -27,12 +27,14 @@ public class StaticSelector { private final Optional userRegex; private final Optional sourceRegex; + private final Optional queryType; private final ResourceGroupIdTemplate group; - public StaticSelector(Optional userRegex, Optional sourceRegex, ResourceGroupIdTemplate group) + public StaticSelector(Optional userRegex, Optional sourceRegex, Optional queryType, ResourceGroupIdTemplate group) { this.userRegex = requireNonNull(userRegex, "userRegex is null"); this.sourceRegex = requireNonNull(sourceRegex, "sourceRegex is null"); + this.queryType = requireNonNull(queryType, "queryType is null"); this.group = requireNonNull(group, "group is null"); } @@ -49,6 +51,13 @@ public Optional match(SelectionContext context) } } + if (queryType.isPresent()) { + String contextQueryType = context.getQueryType().orElse(""); + if (!queryType.get().equalsIgnoreCase(contextQueryType)) { + return Optional.empty(); + } + } + return Optional.of(group.expandTemplate(context)); } } diff --git a/presto-resource-group-managers/src/main/java/com/facebook/presto/resourceGroups/db/DbResourceGroupConfigurationManager.java b/presto-resource-group-managers/src/main/java/com/facebook/presto/resourceGroups/db/DbResourceGroupConfigurationManager.java index 438e687816b0..992fca47cd76 100644 --- a/presto-resource-group-managers/src/main/java/com/facebook/presto/resourceGroups/db/DbResourceGroupConfigurationManager.java +++ b/presto-resource-group-managers/src/main/java/com/facebook/presto/resourceGroups/db/DbResourceGroupConfigurationManager.java @@ -23,6 +23,7 @@ import com.facebook.presto.spi.resourceGroups.ResourceGroupId; import com.facebook.presto.spi.resourceGroups.ResourceGroupSelector; import com.facebook.presto.spi.resourceGroups.SelectionContext; +import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; import com.google.common.collect.Sets; @@ -133,7 +134,8 @@ private synchronized Optional getCpuQuotaPeriodFromDb() return (!globalProperties.isEmpty()) ? globalProperties.get(0).getCpuQuotaPeriod() : Optional.empty(); } - private synchronized void load() + @VisibleForTesting + public synchronized void load() { Map.Entry> specsFromDb = buildSpecsFromDb(); ManagerSpec managerSpec = specsFromDb.getKey(); @@ -225,7 +227,10 @@ private synchronized Map.Entry rootGroups = rootGroupIds.stream().map(resourceGroupSpecMap::get).collect(Collectors.toList()); List selectors = dao.getSelectors().stream().map(selectorRecord -> - new SelectorSpec(selectorRecord.getUserRegex(), selectorRecord.getSourceRegex(), + new SelectorSpec( + selectorRecord.getUserRegex(), + selectorRecord.getSourceRegex(), + Optional.empty(), resourceGroupIdTemplateMap.get(selectorRecord.getResourceGroupId())) ).collect(Collectors.toList()); ManagerSpec managerSpec = new ManagerSpec(rootGroups, selectors, getCpuQuotaPeriodFromDb()); diff --git a/presto-resource-group-managers/src/test/java/com/facebook/presto/resourceGroups/TestFileResourceGroupConfigurationManager.java b/presto-resource-group-managers/src/test/java/com/facebook/presto/resourceGroups/TestFileResourceGroupConfigurationManager.java index b06cf2ff10e4..ba95cd972f34 100644 --- a/presto-resource-group-managers/src/test/java/com/facebook/presto/resourceGroups/TestFileResourceGroupConfigurationManager.java +++ b/presto-resource-group-managers/src/test/java/com/facebook/presto/resourceGroups/TestFileResourceGroupConfigurationManager.java @@ -16,18 +16,21 @@ import com.facebook.presto.spi.resourceGroups.ResourceGroup; import com.facebook.presto.spi.resourceGroups.ResourceGroupConfigurationManager; import com.facebook.presto.spi.resourceGroups.ResourceGroupId; +import com.facebook.presto.spi.resourceGroups.ResourceGroupSelector; import com.facebook.presto.spi.resourceGroups.SelectionContext; import com.fasterxml.jackson.databind.JsonMappingException; import io.airlift.units.DataSize; import io.airlift.units.Duration; import org.testng.annotations.Test; +import java.util.List; import java.util.Optional; import java.util.regex.Pattern; import static com.facebook.presto.spi.resourceGroups.SchedulingPolicy.WEIGHTED; import static io.airlift.json.JsonCodec.jsonCodec; import static io.airlift.units.DataSize.Unit.MEGABYTE; +import static java.lang.String.format; import static java.util.concurrent.TimeUnit.DAYS; import static java.util.concurrent.TimeUnit.HOURS; import static org.testng.Assert.assertEquals; @@ -51,7 +54,45 @@ public void testMissing() { ResourceGroupConfigurationManager manager = parse("resource_groups_config.json"); ResourceGroup missing = new TestingResourceGroup(new ResourceGroupId("missing")); - manager.configure(missing, new SelectionContext(true, "user", Optional.empty(), 1)); + manager.configure(missing, new SelectionContext(true, "user", Optional.empty(), 1, Optional.empty())); + } + + @Test + public void testQueryTypeConfiguration() + { + ResourceGroupConfigurationManager manager = parse("resource_groups_config_query_type.json"); + List selectors = manager.getSelectors(); + assertMatch(selectors, new SelectionContext(true, "test_user", Optional.empty(), 1, Optional.of("select")), "global.select"); + assertMatch(selectors, new SelectionContext(true, "test_user", Optional.empty(), 1, Optional.of("explain")), "global.explain"); + assertMatch(selectors, new SelectionContext(true, "test_user", Optional.empty(), 1, Optional.of("insert")), "global.insert"); + assertMatch(selectors, new SelectionContext(true, "test_user", Optional.empty(), 1, Optional.of("delete")), "global.delete"); + assertMatch(selectors, new SelectionContext(true, "test_user", Optional.empty(), 1, Optional.of("describe")), "global.describe"); + assertMatch(selectors, new SelectionContext(true, "test_user", Optional.empty(), 1, Optional.of("data_definition")), "global.data_definition"); + assertMatch(selectors, new SelectionContext(true, "test_user", Optional.empty(), 1, Optional.of("sth_else")), "global.other"); + } + + @Test(expectedExceptions = IllegalArgumentException.class, expectedExceptionsMessageRegExp = "Selector specifies an invalid query type: invalid_query_type") + public void testInvalidQueryTypeConfiguration() + { + parse("resource_groups_config_bad_query_type.json"); + } + + private void assertMatch(List selectors, SelectionContext context, String expectedResourceGroup) + { + Optional group = tryMatch(selectors, context); + assertTrue(group.isPresent(), "match expected"); + assertEquals(group.get().toString(), expectedResourceGroup, format("Expected: '%s' resource group, found: %s", expectedResourceGroup, group.get())); + } + + private Optional tryMatch(List selectors, SelectionContext context) + { + for (ResourceGroupSelector selector : selectors) { + Optional group = selector.match(context); + if (group.isPresent()) { + return group; + } + } + return Optional.empty(); } @Test @@ -59,7 +100,7 @@ public void testConfiguration() { ResourceGroupConfigurationManager manager = parse("resource_groups_config.json"); ResourceGroup global = new TestingResourceGroup(new ResourceGroupId("global")); - manager.configure(global, new SelectionContext(true, "user", Optional.empty(), 1)); + manager.configure(global, new SelectionContext(true, "user", Optional.empty(), 1, Optional.empty())); assertEquals(global.getSoftMemoryLimit(), new DataSize(1, MEGABYTE)); assertEquals(global.getSoftCpuLimit(), new Duration(1, HOURS)); assertEquals(global.getHardCpuLimit(), new Duration(1, DAYS)); @@ -73,7 +114,7 @@ public void testConfiguration() assertEquals(global.getRunningTimeLimit(), new Duration(1, HOURS)); ResourceGroup sub = new TestingResourceGroup(new ResourceGroupId(new ResourceGroupId("global"), "sub")); - manager.configure(sub, new SelectionContext(true, "user", Optional.empty(), 1)); + manager.configure(sub, new SelectionContext(true, "user", Optional.empty(), 1, Optional.empty())); assertEquals(sub.getSoftMemoryLimit(), new DataSize(2, MEGABYTE)); assertEquals(sub.getMaxRunningQueries(), 3); assertEquals(sub.getMaxQueuedQueries(), 4); diff --git a/presto-resource-group-managers/src/test/java/com/facebook/presto/resourceGroups/TestResourceGroupIdTemplate.java b/presto-resource-group-managers/src/test/java/com/facebook/presto/resourceGroups/TestResourceGroupIdTemplate.java index f0d94eb6e850..f3cb6659657f 100644 --- a/presto-resource-group-managers/src/test/java/com/facebook/presto/resourceGroups/TestResourceGroupIdTemplate.java +++ b/presto-resource-group-managers/src/test/java/com/facebook/presto/resourceGroups/TestResourceGroupIdTemplate.java @@ -28,9 +28,9 @@ public void testExpansion() { ResourceGroupIdTemplate template = new ResourceGroupIdTemplate("test.${USER}.${SOURCE}"); ResourceGroupId expected = new ResourceGroupId(new ResourceGroupId(new ResourceGroupId("test"), "u"), "s"); - assertEquals(template.expandTemplate(new SelectionContext(true, "u", Optional.of("s"), 1)), expected); + assertEquals(template.expandTemplate(new SelectionContext(true, "u", Optional.of("s"), 1, Optional.empty())), expected); template = new ResourceGroupIdTemplate("test.${USER}"); - assertEquals(template.expandTemplate(new SelectionContext(true, "alice.smith", Optional.empty(), 1)), new ResourceGroupId(new ResourceGroupId("test"), "alice.smith")); + assertEquals(template.expandTemplate(new SelectionContext(true, "alice.smith", Optional.empty(), 1, Optional.empty())), new ResourceGroupId(new ResourceGroupId("test"), "alice.smith")); } @Test(expectedExceptions = IllegalArgumentException.class) diff --git a/presto-resource-group-managers/src/test/java/com/facebook/presto/resourceGroups/db/TestDbResourceGroupConfigurationManager.java b/presto-resource-group-managers/src/test/java/com/facebook/presto/resourceGroups/db/TestDbResourceGroupConfigurationManager.java index 316c48d06760..45555e3d75c9 100644 --- a/presto-resource-group-managers/src/test/java/com/facebook/presto/resourceGroups/db/TestDbResourceGroupConfigurationManager.java +++ b/presto-resource-group-managers/src/test/java/com/facebook/presto/resourceGroups/db/TestDbResourceGroupConfigurationManager.java @@ -59,11 +59,11 @@ public void testConfiguration() daoProvider.get()); AtomicBoolean exported = new AtomicBoolean(); InternalResourceGroup global = new InternalResourceGroup.RootInternalResourceGroup("global", (group, export) -> exported.set(export), directExecutor()); - manager.configure(global, new SelectionContext(true, "user", Optional.empty(), 1)); + manager.configure(global, new SelectionContext(true, "user", Optional.empty(), 1, Optional.empty())); assertEqualsResourceGroup(global, "1MB", 1000, 100, WEIGHTED, DEFAULT_WEIGHT, true, new Duration(1, HOURS), new Duration(1, DAYS), new Duration(1, HOURS), new Duration(1, HOURS)); exported.set(false); InternalResourceGroup sub = global.getOrCreateSubGroup("sub"); - manager.configure(sub, new SelectionContext(true, "user", Optional.empty(), 1)); + manager.configure(sub, new SelectionContext(true, "user", Optional.empty(), 1, Optional.empty())); assertEqualsResourceGroup(sub, "2MB", 4, 3, FAIR, 5, false, new Duration(Long.MAX_VALUE, MILLISECONDS), new Duration(Long.MAX_VALUE, MILLISECONDS), new Duration(1, HOURS), new Duration(1, HOURS)); } @@ -121,7 +121,7 @@ public void testMissing() }, daoProvider.get()); InternalResourceGroup missing = new InternalResourceGroup.RootInternalResourceGroup("missing", (group, export) -> { }, directExecutor()); - manager.configure(missing, new SelectionContext(true, "user", Optional.empty(), 1)); + manager.configure(missing, new SelectionContext(true, "user", Optional.empty(), 1, Optional.empty())); } @Test(timeOut = 60_000) @@ -143,9 +143,9 @@ public void testReconfig() manager.start(); AtomicBoolean exported = new AtomicBoolean(); InternalResourceGroup global = new InternalResourceGroup.RootInternalResourceGroup("global", (group, export) -> exported.set(export), directExecutor()); - manager.configure(global, new SelectionContext(true, "user", Optional.empty(), 1)); + manager.configure(global, new SelectionContext(true, "user", Optional.empty(), 1, Optional.empty())); InternalResourceGroup globalSub = global.getOrCreateSubGroup("sub"); - manager.configure(globalSub, new SelectionContext(true, "user", Optional.empty(), 1)); + manager.configure(globalSub, new SelectionContext(true, "user", Optional.empty(), 1, Optional.empty())); // Verify record exists assertEqualsResourceGroup(globalSub, "2MB", 4, 3, FAIR, 5, false, new Duration(Long.MAX_VALUE, MILLISECONDS), new Duration(Long.MAX_VALUE, MILLISECONDS), new Duration(Long.MAX_VALUE, MILLISECONDS), new Duration(Long.MAX_VALUE, MILLISECONDS)); dao.updateResourceGroup(2, "sub", "3MB", 2, 1, "weighted", 6, true, "1h", "1d", null, null, 1L); diff --git a/presto-resource-group-managers/src/test/resources/resource_groups_config_bad_query_type.json b/presto-resource-group-managers/src/test/resources/resource_groups_config_bad_query_type.json new file mode 100644 index 000000000000..664e3d0cb2b6 --- /dev/null +++ b/presto-resource-group-managers/src/test/resources/resource_groups_config_bad_query_type.json @@ -0,0 +1,29 @@ +{ + "rootGroups": [ + { + "name": "global", + "softMemoryLimit": "1MB", + "maxRunning": 1, + "maxQueued": 1000, + "softCpuLimit": "1h", + "hardCpuLimit": "1d", + "subGroups": [ + { + "name": "select", + "softMemoryLimit": "2MB", + "maxRunning": 3, + "maxQueued": 4 + } + ] + } + ], + "selectors": [ + { + "user" : "test_user", + "queryType" : "invalid_query_type", + "group": "global.select" + } + ], + "cpuQuotaPeriod": "1h" +} + diff --git a/presto-resource-group-managers/src/test/resources/resource_groups_config_query_type.json b/presto-resource-group-managers/src/test/resources/resource_groups_config_query_type.json new file mode 100644 index 000000000000..a2cf40912c2b --- /dev/null +++ b/presto-resource-group-managers/src/test/resources/resource_groups_config_query_type.json @@ -0,0 +1,94 @@ +{ + "rootGroups": [ + { + "name": "global", + "softMemoryLimit": "1MB", + "maxRunning": 100, + "maxQueued": 1000, + "softCpuLimit": "1h", + "hardCpuLimit": "1d", + "subGroups": [ + { + "name": "select", + "softMemoryLimit": "2MB", + "maxRunning": 3, + "maxQueued": 4 + }, + { + "name": "explain", + "softMemoryLimit": "2MB", + "maxRunning": 3, + "maxQueued": 4 + }, + { + "name": "insert", + "softMemoryLimit": "2MB", + "maxRunning": 3, + "maxQueued": 4 + }, + { + "name": "delete", + "softMemoryLimit": "2MB", + "maxRunning": 3, + "maxQueued": 4 + }, + { + "name": "describe", + "softMemoryLimit": "2MB", + "maxRunning": 3, + "maxQueued": 4 + }, + { + "name": "data_definition", + "softMemoryLimit": "2MB", + "maxRunning": 3, + "maxQueued": 4 + }, + { + "name": "other", + "softMemoryLimit": "2MB", + "maxRunning": 3, + "maxQueued": 4 + } + ] + } + ], + "selectors": [ + { + "user" : "test_user", + "queryType" : "select", + "group": "global.select" + }, + { + "user" : "test_user", + "queryType" : "explain", + "group": "global.explain" + }, + { + "user" : "test_user", + "queryType" : "insert", + "group": "global.insert" + }, + { + "user" : "test_user", + "queryType" : "delete", + "group": "global.delete" + }, + { + "user" : "test_user", + "queryType" : "describe", + "group": "global.describe" + }, + { + "user" : "test_user", + "queryType" : "data_definition", + "group": "global.data_definition" + }, + { + "user": "test_user", + "group": "global.other" + } + ], + "cpuQuotaPeriod": "1h" +} + diff --git a/presto-server-rpm/pom.xml b/presto-server-rpm/pom.xml index c8b8be64cde8..4e1096e3cf57 100644 --- a/presto-server-rpm/pom.xml +++ b/presto-server-rpm/pom.xml @@ -5,7 +5,7 @@ com.facebook.presto presto-root - 0.179-tw-0.36 + 0.181-tw-0.37 presto-server-rpm diff --git a/presto-server/pom.xml b/presto-server/pom.xml index f15592f0a3f6..9df27545885e 100644 --- a/presto-server/pom.xml +++ b/presto-server/pom.xml @@ -5,7 +5,7 @@ com.facebook.presto presto-root - 0.179-tw-0.36 + 0.181-tw-0.37 presto-server diff --git a/presto-server/src/main/provisio/presto.xml b/presto-server/src/main/provisio/presto.xml index 83bc62f8a476..387f9c4c247f 100644 --- a/presto-server/src/main/provisio/presto.xml +++ b/presto-server/src/main/provisio/presto.xml @@ -151,4 +151,10 @@ + + + + + + diff --git a/presto-spi/pom.xml b/presto-spi/pom.xml index 3e032a884e8b..5367e6981c7f 100644 --- a/presto-spi/pom.xml +++ b/presto-spi/pom.xml @@ -5,7 +5,7 @@ com.facebook.presto presto-root - 0.179-tw-0.36 + 0.181-tw-0.37 presto-spi @@ -56,6 +56,12 @@ test + + it.unimi.dsi + fastutil + test + + com.google.guava guava diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/Page.java b/presto-spi/src/main/java/com/facebook/presto/spi/Page.java index 03267092c2b8..e3fa7903ef05 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/Page.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/Page.java @@ -252,6 +252,17 @@ private static int determinePositionCount(Block... blocks) return blocks[0].getPositionCount(); } + public static Page mask(Page page, int[] retainedPositions) + { + requireNonNull(page, "page is null"); + requireNonNull(retainedPositions, "retainedPositions is null"); + + Block[] blocks = Arrays.stream(page.getBlocks()) + .map(block -> new DictionaryBlock(block, retainedPositions)) + .toArray(Block[]::new); + return new Page(retainedPositions.length, blocks); + } + private static class DictionaryBlockIndexes { private final List blocks = new ArrayList<>(); diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/SchemaTableName.java b/presto-spi/src/main/java/com/facebook/presto/spi/SchemaTableName.java index 8e7a8027f9a8..7abad8558393 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/SchemaTableName.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/SchemaTableName.java @@ -13,6 +13,7 @@ */ package com.facebook.presto.spi; +import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; import java.util.Objects; @@ -25,6 +26,7 @@ public class SchemaTableName private final String schemaName; private final String tableName; + @JsonCreator public SchemaTableName(@JsonProperty("schema") String schemaName, @JsonProperty("table") String tableName) { this.schemaName = checkNotEmpty(schemaName, "schemaName").toLowerCase(ENGLISH); diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/block/AbstractArrayBlock.java b/presto-spi/src/main/java/com/facebook/presto/spi/block/AbstractArrayBlock.java index 7cad65c976f9..b6600b6b43b7 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/block/AbstractArrayBlock.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/block/AbstractArrayBlock.java @@ -90,7 +90,7 @@ public Block getRegion(int position, int length) } @Override - public int getRegionSizeInBytes(int position, int length) + public long getRegionSizeInBytes(int position, int length) { int positionCount = getPositionCount(); if (position < 0 || length < 0 || position + length > positionCount) { @@ -100,7 +100,7 @@ public int getRegionSizeInBytes(int position, int length) int valueStart = getOffsets()[getOffsetBase() + position]; int valueEnd = getOffsets()[getOffsetBase() + position + length]; - return getValues().getRegionSizeInBytes(valueStart, valueEnd - valueStart) + ((Integer.BYTES + Byte.BYTES) * length); + return getValues().getRegionSizeInBytes(valueStart, valueEnd - valueStart) + ((Integer.BYTES + Byte.BYTES) * (long) length); } @Override diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/block/AbstractFixedWidthBlock.java b/presto-spi/src/main/java/com/facebook/presto/spi/block/AbstractFixedWidthBlock.java index 03e825f1787f..171b796b53c4 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/block/AbstractFixedWidthBlock.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/block/AbstractFixedWidthBlock.java @@ -165,13 +165,13 @@ public boolean isNull(int position) } @Override - public int getRegionSizeInBytes(int positionOffset, int length) + public long getRegionSizeInBytes(int positionOffset, int length) { int positionCount = getPositionCount(); if (positionOffset < 0 || length < 0 || positionOffset + length > positionCount) { throw new IndexOutOfBoundsException("Invalid position " + positionOffset + " in block with " + positionCount + " positions"); } - return length * (fixedSize + Byte.BYTES); + return (fixedSize + Byte.BYTES) * (long) length; } private int valueOffset(int position) diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/block/AbstractInterleavedBlock.java b/presto-spi/src/main/java/com/facebook/presto/spi/block/AbstractInterleavedBlock.java index 0a3f465a50d7..aaf400fd8751 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/block/AbstractInterleavedBlock.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/block/AbstractInterleavedBlock.java @@ -253,7 +253,7 @@ public Block copyRegion(int position, int length) } @Override - public int getRegionSizeInBytes(int position, int length) + public long getRegionSizeInBytes(int position, int length) { if (position == 0 && length == getPositionCount()) { // Calculation of getRegionSizeInBytes is expensive in this class. @@ -261,7 +261,7 @@ public int getRegionSizeInBytes(int position, int length) return getSizeInBytes(); } validateRange(position, length); - int result = 0; + long result = 0; for (int blockIndex = 0; blockIndex < getBlockCount(); blockIndex++) { result += getBlock(blockIndex).getRegionSizeInBytes(position / columns, length / columns); } diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/block/AbstractMapBlock.java b/presto-spi/src/main/java/com/facebook/presto/spi/block/AbstractMapBlock.java index 44a9d3b9fc83..02c627e3f7db 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/block/AbstractMapBlock.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/block/AbstractMapBlock.java @@ -36,8 +36,10 @@ public abstract class AbstractMapBlock public AbstractMapBlock(Type keyType, MethodHandle keyNativeHashCode, MethodHandle keyBlockNativeEquals) { this.keyType = requireNonNull(keyType, "keyType is null"); - this.keyNativeHashCode = requireNonNull(keyNativeHashCode, "keyNativeHashCode is null"); - this.keyBlockNativeEquals = requireNonNull(keyBlockNativeEquals, "keyBlockNativeEquals is null"); + // keyNativeHashCode can only be null due to map block kill switch. deprecated.new-map-block + this.keyNativeHashCode = keyNativeHashCode; + // keyBlockNativeEquals can only be null due to map block kill switch. deprecated.new-map-block + this.keyBlockNativeEquals = keyBlockNativeEquals; } protected abstract Block getKeys(); @@ -141,7 +143,7 @@ public Block getRegion(int position, int length) } @Override - public int getRegionSizeInBytes(int position, int length) + public long getRegionSizeInBytes(int position, int length) { int positionCount = getPositionCount(); if (position < 0 || length < 0 || position + length > positionCount) { @@ -154,8 +156,8 @@ public int getRegionSizeInBytes(int position, int length) return getKeys().getRegionSizeInBytes(entriesStart, entryCount) + getValues().getRegionSizeInBytes(entriesStart, entryCount) + - (Integer.BYTES + Byte.BYTES) * length + - Integer.BYTES * HASH_MULTIPLIER * entryCount; + (Integer.BYTES + Byte.BYTES) * (long) length + + Integer.BYTES * HASH_MULTIPLIER * (long) entryCount; } @Override diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/block/AbstractSingleArrayBlock.java b/presto-spi/src/main/java/com/facebook/presto/spi/block/AbstractSingleArrayBlock.java index aefc794c40d4..9c5e9381782d 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/block/AbstractSingleArrayBlock.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/block/AbstractSingleArrayBlock.java @@ -168,7 +168,7 @@ public Block getRegion(int position, int length) } @Override - public int getRegionSizeInBytes(int position, int length) + public long getRegionSizeInBytes(int position, int length) { throw new UnsupportedOperationException(); } diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/block/AbstractSingleMapBlock.java b/presto-spi/src/main/java/com/facebook/presto/spi/block/AbstractSingleMapBlock.java index c6513237e4e1..0cdba8fb2415 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/block/AbstractSingleMapBlock.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/block/AbstractSingleMapBlock.java @@ -224,7 +224,7 @@ public Block getSingleValueBlock(int position) } @Override - public int getRegionSizeInBytes(int position, int length) + public long getRegionSizeInBytes(int position, int length) { throw new UnsupportedOperationException(); } diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/block/ArrayBlock.java b/presto-spi/src/main/java/com/facebook/presto/spi/block/ArrayBlock.java index 5f2872a96372..6ad627167991 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/block/ArrayBlock.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/block/ArrayBlock.java @@ -15,7 +15,8 @@ import org.openjdk.jol.info.ClassLayout; -import static com.facebook.presto.spi.block.BlockUtil.intSaturatedCast; +import java.util.function.BiConsumer; + import static io.airlift.slice.SizeOf.sizeOf; import static java.util.Objects.requireNonNull; @@ -30,8 +31,8 @@ public class ArrayBlock private final Block values; private final int[] offsets; - private int sizeInBytes; - private final int retainedSizeInBytes; + private long sizeInBytes; + private final long retainedSizeInBytes; public ArrayBlock(int positionCount, boolean[] valueIsNull, int[] offsets, Block values) { @@ -65,7 +66,7 @@ public ArrayBlock(int positionCount, boolean[] valueIsNull, int[] offsets, Block this.values = requireNonNull(values); sizeInBytes = -1; - retainedSizeInBytes = intSaturatedCast(INSTANCE_SIZE + values.getRetainedSizeInBytes() + sizeOf(offsets) + sizeOf(valueIsNull)); + retainedSizeInBytes = INSTANCE_SIZE + values.getRetainedSizeInBytes() + sizeOf(offsets) + sizeOf(valueIsNull); } @Override @@ -75,7 +76,7 @@ public int getPositionCount() } @Override - public int getSizeInBytes() + public long getSizeInBytes() { // this is racy but is safe because sizeInBytes is an int and the calculation is stable if (sizeInBytes < 0) { @@ -88,15 +89,24 @@ private void calculateSize() { int valueStart = offsets[arrayOffset]; int valueEnd = offsets[arrayOffset + positionCount]; - sizeInBytes = intSaturatedCast(values.getRegionSizeInBytes(valueStart, valueEnd - valueStart) + ((Integer.BYTES + Byte.BYTES) * (long) this.positionCount)); + sizeInBytes = values.getRegionSizeInBytes(valueStart, valueEnd - valueStart) + ((Integer.BYTES + Byte.BYTES) * (long) this.positionCount); } @Override - public int getRetainedSizeInBytes() + public long getRetainedSizeInBytes() { return retainedSizeInBytes; } + @Override + public void retainedBytesForEachPart(BiConsumer consumer) + { + consumer.accept(values, values.getRetainedSizeInBytes()); + consumer.accept(offsets, sizeOf(offsets)); + consumer.accept(valueIsNull, sizeOf(valueIsNull)); + consumer.accept(this, (long) INSTANCE_SIZE); + } + @Override protected Block getValues() { diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/block/ArrayBlockBuilder.java b/presto-spi/src/main/java/com/facebook/presto/spi/block/ArrayBlockBuilder.java index 5ba1d0795209..7d124f0c6140 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/block/ArrayBlockBuilder.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/block/ArrayBlockBuilder.java @@ -19,9 +19,9 @@ import javax.annotation.Nullable; import java.util.Arrays; +import java.util.function.BiConsumer; import static com.facebook.presto.spi.block.BlockUtil.calculateBlockResetSize; -import static com.facebook.presto.spi.block.BlockUtil.intSaturatedCast; import static io.airlift.slice.SizeOf.sizeOf; import static java.lang.Math.max; import static java.util.Objects.requireNonNull; @@ -45,7 +45,7 @@ public class ArrayBlockBuilder private final BlockBuilder values; private boolean currentEntryOpened; - private int retainedSizeInBytes; + private long retainedSizeInBytes; /** * Caller of this constructor is responsible for making sure `valuesBlock` is constructed with the same `blockBuilderStatus` as the one in the argument @@ -93,17 +93,26 @@ public int getPositionCount() } @Override - public int getSizeInBytes() + public long getSizeInBytes() { - return values.getSizeInBytes() + ((Integer.BYTES + Byte.BYTES) * positionCount); + return values.getSizeInBytes() + ((Integer.BYTES + Byte.BYTES) * (long) positionCount); } @Override - public int getRetainedSizeInBytes() + public long getRetainedSizeInBytes() { return retainedSizeInBytes + values.getRetainedSizeInBytes(); } + @Override + public void retainedBytesForEachPart(BiConsumer consumer) + { + consumer.accept(values, values.getRetainedSizeInBytes()); + consumer.accept(offsets, sizeOf(offsets)); + consumer.accept(valueIsNull, sizeOf(valueIsNull)); + consumer.accept(this, (long) INSTANCE_SIZE); + } + @Override protected Block getValues() { @@ -215,11 +224,10 @@ private void growCapacity() private void updateDataSize() { - long size = INSTANCE_SIZE + sizeOf(valueIsNull) + sizeOf(offsets); + retainedSizeInBytes = INSTANCE_SIZE + sizeOf(valueIsNull) + sizeOf(offsets); if (blockBuilderStatus != null) { - size += BlockBuilderStatus.INSTANCE_SIZE; + retainedSizeInBytes += BlockBuilderStatus.INSTANCE_SIZE; } - retainedSizeInBytes = intSaturatedCast(size); } @Override diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/block/Block.java b/presto-spi/src/main/java/com/facebook/presto/spi/block/Block.java index 4f08875259ce..ad955ec42fa3 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/block/Block.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/block/Block.java @@ -16,6 +16,7 @@ import io.airlift.slice.Slice; import java.util.List; +import java.util.function.BiConsumer; public interface Block { @@ -162,18 +163,28 @@ default int compareTo(int leftPosition, int leftOffset, int leftLength, Block ri /** * Returns the logical size of this block in memory. */ - int getSizeInBytes(); + long getSizeInBytes(); /** * Returns the logical size of {@code block.getRegion(position, length)} in memory. */ - int getRegionSizeInBytes(int position, int length); + long getRegionSizeInBytes(int position, int length); /** * Returns the retained size of this block in memory. * This method is called from the inner most execution loop and must be fast. */ - int getRetainedSizeInBytes(); + long getRetainedSizeInBytes(); + + /** + * {@code consumer} visits each of the internal data container and accepts the size for it. + * This method can be helpful in cases such as memory counting for internal data structure. + * Also, the method should be non-recursive, only visit the elements at the top level, + * and specifically should not call retainedBytesForEachPart on nested blocks + * {@code consumer} should be called at least once with the current block and + * must include the instance size of the current block + */ + void retainedBytesForEachPart(BiConsumer consumer); /** * Get the encoding for this block. diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/block/BlockUtil.java b/presto-spi/src/main/java/com/facebook/presto/spi/block/BlockUtil.java index 21487bfa55a9..d28fae9432b6 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/block/BlockUtil.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/block/BlockUtil.java @@ -76,12 +76,4 @@ else if (newSize > MAX_ARRAY_SIZE) { } return (int) newSize; } - - static int intSaturatedCast(long value) - { - if (value > Integer.MAX_VALUE) { - return Integer.MAX_VALUE; - } - return (int) value; - } } diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/block/ByteArrayBlock.java b/presto-spi/src/main/java/com/facebook/presto/spi/block/ByteArrayBlock.java index e85ee5d1b22c..af9940d3fcda 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/block/ByteArrayBlock.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/block/ByteArrayBlock.java @@ -17,9 +17,9 @@ import java.util.Arrays; import java.util.List; +import java.util.function.BiConsumer; import static com.facebook.presto.spi.block.BlockUtil.checkValidRegion; -import static com.facebook.presto.spi.block.BlockUtil.intSaturatedCast; import static io.airlift.slice.SizeOf.sizeOf; public class ByteArrayBlock @@ -32,8 +32,8 @@ public class ByteArrayBlock private final boolean[] valueIsNull; private final byte[] values; - private final int sizeInBytes; - private final int retainedSizeInBytes; + private final long sizeInBytes; + private final long retainedSizeInBytes; public ByteArrayBlock(int positionCount, boolean[] valueIsNull, byte[] values) { @@ -61,28 +61,36 @@ public ByteArrayBlock(int positionCount, boolean[] valueIsNull, byte[] values) } this.valueIsNull = valueIsNull; - sizeInBytes = intSaturatedCast((Byte.BYTES + Byte.BYTES) * (long) positionCount); - retainedSizeInBytes = intSaturatedCast((INSTANCE_SIZE + sizeOf(valueIsNull) + sizeOf(values))); + sizeInBytes = (Byte.BYTES + Byte.BYTES) * (long) positionCount; + retainedSizeInBytes = (INSTANCE_SIZE + sizeOf(valueIsNull) + sizeOf(values)); } @Override - public int getSizeInBytes() + public long getSizeInBytes() { return sizeInBytes; } @Override - public int getRegionSizeInBytes(int position, int length) + public long getRegionSizeInBytes(int position, int length) { - return intSaturatedCast((Byte.BYTES + Byte.BYTES) * (long) length); + return (Byte.BYTES + Byte.BYTES) * (long) length; } @Override - public int getRetainedSizeInBytes() + public long getRetainedSizeInBytes() { return retainedSizeInBytes; } + @Override + public void retainedBytesForEachPart(BiConsumer consumer) + { + consumer.accept(values, sizeOf(values)); + consumer.accept(valueIsNull, sizeOf(valueIsNull)); + consumer.accept(this, (long) INSTANCE_SIZE); + } + @Override public int getPositionCount() { diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/block/ByteArrayBlockBuilder.java b/presto-spi/src/main/java/com/facebook/presto/spi/block/ByteArrayBlockBuilder.java index 95a415633177..406dc6a907ee 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/block/ByteArrayBlockBuilder.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/block/ByteArrayBlockBuilder.java @@ -19,9 +19,9 @@ import java.util.Arrays; import java.util.List; +import java.util.function.BiConsumer; import static com.facebook.presto.spi.block.BlockUtil.checkValidRegion; -import static com.facebook.presto.spi.block.BlockUtil.intSaturatedCast; import static io.airlift.slice.SizeOf.sizeOf; import static java.lang.Math.max; @@ -41,7 +41,7 @@ public class ByteArrayBlockBuilder private boolean[] valueIsNull = new boolean[0]; private byte[] values = new byte[0]; - private int retainedSizeInBytes; + private long retainedSizeInBytes; public ByteArrayBlockBuilder(@Nullable BlockBuilderStatus blockBuilderStatus, int expectedEntries) { @@ -119,32 +119,38 @@ private void growCapacity() private void updateDataSize() { - long size = INSTANCE_SIZE + sizeOf(valueIsNull) + sizeOf(values); + retainedSizeInBytes = INSTANCE_SIZE + sizeOf(valueIsNull) + sizeOf(values); if (blockBuilderStatus != null) { - size += BlockBuilderStatus.INSTANCE_SIZE; + retainedSizeInBytes += BlockBuilderStatus.INSTANCE_SIZE; } - retainedSizeInBytes = intSaturatedCast(size); } - // Copied from ByteArrayBlock @Override - public int getSizeInBytes() + public long getSizeInBytes() { - return intSaturatedCast((Byte.BYTES + Byte.BYTES) * (long) positionCount); + return (Byte.BYTES + Byte.BYTES) * (long) positionCount; } @Override - public int getRegionSizeInBytes(int position, int length) + public long getRegionSizeInBytes(int position, int length) { - return intSaturatedCast((Byte.BYTES + Byte.BYTES) * (long) length); + return (Byte.BYTES + Byte.BYTES) * (long) length; } @Override - public int getRetainedSizeInBytes() + public long getRetainedSizeInBytes() { return retainedSizeInBytes; } + @Override + public void retainedBytesForEachPart(BiConsumer consumer) + { + consumer.accept(values, sizeOf(values)); + consumer.accept(valueIsNull, sizeOf(valueIsNull)); + consumer.accept(this, (long) INSTANCE_SIZE); + } + @Override public int getPositionCount() { diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/block/DictionaryBlock.java b/presto-spi/src/main/java/com/facebook/presto/spi/block/DictionaryBlock.java index 8b6b50b9c9b3..be60c7115824 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/block/DictionaryBlock.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/block/DictionaryBlock.java @@ -22,12 +22,12 @@ import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.function.BiConsumer; import static com.facebook.presto.spi.block.BlockUtil.checkValidPositions; import static com.facebook.presto.spi.block.DictionaryId.randomDictionaryId; import static io.airlift.slice.SizeOf.sizeOf; import static java.lang.Math.min; -import static java.lang.Math.toIntExact; import static java.util.Objects.requireNonNull; public class DictionaryBlock @@ -39,11 +39,16 @@ public class DictionaryBlock private final Block dictionary; private final int idsOffset; private final int[] ids; - private final int retainedSizeInBytes; - private volatile int sizeInBytes = -1; + private final long retainedSizeInBytes; + private volatile long sizeInBytes = -1; private volatile int uniqueIds = -1; private final DictionaryId dictionarySourceId; + public DictionaryBlock(Block dictionary, int[] ids) + { + this(requireNonNull(ids, "ids is null").length, dictionary, ids); + } + public DictionaryBlock(int positionCount, Block dictionary, int[] ids) { this(0, positionCount, dictionary, ids, false, randomDictionaryId()); @@ -82,7 +87,7 @@ private DictionaryBlock(int idsOffset, int positionCount, Block dictionary, int[ this.dictionary = dictionary; this.ids = ids; this.dictionarySourceId = requireNonNull(dictionarySourceId, "dictionarySourceId is null"); - this.retainedSizeInBytes = toIntExact(INSTANCE_SIZE + dictionary.getRetainedSizeInBytes() + sizeOf(ids)); + this.retainedSizeInBytes = INSTANCE_SIZE + dictionary.getRetainedSizeInBytes() + sizeOf(ids); if (dictionaryIsCompacted) { this.sizeInBytes = this.retainedSizeInBytes; @@ -187,7 +192,7 @@ public int getPositionCount() } @Override - public int getSizeInBytes() + public long getSizeInBytes() { // this is racy but is safe because sizeInBytes is an int and the calculation is stable if (sizeInBytes < 0) { @@ -198,7 +203,7 @@ public int getSizeInBytes() private void calculateCompactSize() { - int sizeInBytes = 0; + long sizeInBytes = 0; int uniqueIds = 0; boolean[] seen = new boolean[dictionary.getPositionCount()]; for (int i = 0; i < positionCount; i++) { @@ -211,12 +216,12 @@ private void calculateCompactSize() seen[position] = true; } } - this.sizeInBytes = sizeInBytes + (positionCount * Integer.BYTES); + this.sizeInBytes = sizeInBytes + (Integer.BYTES * (long) positionCount); this.uniqueIds = uniqueIds; } @Override - public int getRegionSizeInBytes(int positionOffset, int length) + public long getRegionSizeInBytes(int positionOffset, int length) { if (positionOffset == 0 && length == getPositionCount()) { // Calculation of getRegionSizeInBytes is expensive in this class. @@ -224,7 +229,7 @@ public int getRegionSizeInBytes(int positionOffset, int length) return getSizeInBytes(); } - int sizeInBytes = 0; + long sizeInBytes = 0; boolean[] seen = new boolean[dictionary.getPositionCount()]; for (int i = positionOffset; i < positionOffset + length; i++) { int position = getId(i); @@ -235,15 +240,24 @@ public int getRegionSizeInBytes(int positionOffset, int length) seen[position] = true; } } - return sizeInBytes + (length * Integer.BYTES); + sizeInBytes += Integer.BYTES * (long) length; + return sizeInBytes; } @Override - public int getRetainedSizeInBytes() + public long getRetainedSizeInBytes() { return retainedSizeInBytes; } + @Override + public void retainedBytesForEachPart(BiConsumer consumer) + { + consumer.accept(dictionary, dictionary.getRetainedSizeInBytes()); + consumer.accept(ids, sizeOf(ids)); + consumer.accept(this, (long) INSTANCE_SIZE); + } + @Override public BlockEncoding getEncoding() { @@ -267,7 +281,7 @@ public Block copyPositions(List positions) } newIds[i] = oldIndexToNewIndex.get(oldIndex); } - return new DictionaryBlock(positions.size(), dictionary.copyPositions(positionsToCopy), newIds); + return new DictionaryBlock(dictionary.copyPositions(positionsToCopy), newIds); } @Override @@ -286,7 +300,7 @@ public Block copyRegion(int position, int length) throw new IndexOutOfBoundsException("Invalid position " + position + " in block with " + positionCount + " positions"); } int[] newIds = Arrays.copyOfRange(ids, idsOffset + position, idsOffset + position + length); - DictionaryBlock dictionaryBlock = new DictionaryBlock(length, dictionary, newIds); + DictionaryBlock dictionaryBlock = new DictionaryBlock(dictionary, newIds); return dictionaryBlock.compact(); } diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/block/FixedWidthBlock.java b/presto-spi/src/main/java/com/facebook/presto/spi/block/FixedWidthBlock.java index 27725058b320..e28a8868239e 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/block/FixedWidthBlock.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/block/FixedWidthBlock.java @@ -19,9 +19,9 @@ import org.openjdk.jol.info.ClassLayout; import java.util.List; +import java.util.function.BiConsumer; import static com.facebook.presto.spi.block.BlockUtil.checkValidPositions; -import static com.facebook.presto.spi.block.BlockUtil.intSaturatedCast; import static java.util.Objects.requireNonNull; public class FixedWidthBlock @@ -72,15 +72,23 @@ public int getPositionCount() } @Override - public int getSizeInBytes() + public long getSizeInBytes() { - return intSaturatedCast(getRawSlice().length() + valueIsNull.length()); + return getRawSlice().length() + (long) valueIsNull.length(); } @Override - public int getRetainedSizeInBytes() + public long getRetainedSizeInBytes() { - return intSaturatedCast(INSTANCE_SIZE + getRawSlice().getRetainedSize() + valueIsNull.getRetainedSize()); + return INSTANCE_SIZE + getRawSlice().getRetainedSize() + valueIsNull.getRetainedSize(); + } + + @Override + public void retainedBytesForEachPart(BiConsumer consumer) + { + consumer.accept(slice, (long) slice.getRetainedSize()); + consumer.accept(valueIsNull, (long) valueIsNull.getRetainedSize()); + consumer.accept(this, (long) INSTANCE_SIZE); } @Override diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/block/FixedWidthBlockBuilder.java b/presto-spi/src/main/java/com/facebook/presto/spi/block/FixedWidthBlockBuilder.java index bacdd812fb41..f5a004d18c82 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/block/FixedWidthBlockBuilder.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/block/FixedWidthBlockBuilder.java @@ -22,11 +22,11 @@ import javax.annotation.Nullable; import java.util.List; +import java.util.function.BiConsumer; import static com.facebook.presto.spi.block.BlockUtil.MAX_ARRAY_SIZE; import static com.facebook.presto.spi.block.BlockUtil.calculateBlockResetSize; import static com.facebook.presto.spi.block.BlockUtil.checkValidPositions; -import static com.facebook.presto.spi.block.BlockUtil.intSaturatedCast; import static io.airlift.slice.SizeOf.SIZE_OF_BYTE; import static io.airlift.slice.SizeOf.SIZE_OF_INT; import static io.airlift.slice.SizeOf.SIZE_OF_LONG; @@ -85,19 +85,27 @@ public int getPositionCount() } @Override - public int getSizeInBytes() + public long getSizeInBytes() { - return intSaturatedCast(sliceOutput.size() + valueIsNull.size()); + return sliceOutput.size() + (long) valueIsNull.size(); } @Override - public int getRetainedSizeInBytes() + public long getRetainedSizeInBytes() { long size = INSTANCE_SIZE + sliceOutput.getRetainedSize() + valueIsNull.getRetainedSize(); if (blockBuilderStatus != null) { size += BlockBuilderStatus.INSTANCE_SIZE; } - return intSaturatedCast(size); + return size; + } + + @Override + public void retainedBytesForEachPart(BiConsumer consumer) + { + consumer.accept(sliceOutput, (long) sliceOutput.getRetainedSize()); + consumer.accept(valueIsNull, (long) valueIsNull.getRetainedSize()); + consumer.accept(this, (long) INSTANCE_SIZE); } @Override diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/block/IntArrayBlock.java b/presto-spi/src/main/java/com/facebook/presto/spi/block/IntArrayBlock.java index 2a2226fa3280..3456da04d680 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/block/IntArrayBlock.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/block/IntArrayBlock.java @@ -17,9 +17,9 @@ import java.util.Arrays; import java.util.List; +import java.util.function.BiConsumer; import static com.facebook.presto.spi.block.BlockUtil.checkValidRegion; -import static com.facebook.presto.spi.block.BlockUtil.intSaturatedCast; import static io.airlift.slice.SizeOf.sizeOf; public class IntArrayBlock @@ -32,8 +32,8 @@ public class IntArrayBlock private final boolean[] valueIsNull; private final int[] values; - private final int sizeInBytes; - private final int retainedSizeInBytes; + private final long sizeInBytes; + private final long retainedSizeInBytes; public IntArrayBlock(int positionCount, boolean[] valueIsNull, int[] values) { @@ -61,28 +61,36 @@ public IntArrayBlock(int positionCount, boolean[] valueIsNull, int[] values) } this.valueIsNull = valueIsNull; - sizeInBytes = intSaturatedCast((Integer.BYTES + Byte.BYTES) * (long) positionCount); - retainedSizeInBytes = intSaturatedCast(INSTANCE_SIZE + sizeOf(valueIsNull) + sizeOf(values)); + sizeInBytes = (Integer.BYTES + Byte.BYTES) * (long) positionCount; + retainedSizeInBytes = INSTANCE_SIZE + sizeOf(valueIsNull) + sizeOf(values); } @Override - public int getSizeInBytes() + public long getSizeInBytes() { return sizeInBytes; } @Override - public int getRegionSizeInBytes(int position, int length) + public long getRegionSizeInBytes(int position, int length) { - return intSaturatedCast((Integer.BYTES + Byte.BYTES) * (long) length); + return (Integer.BYTES + Byte.BYTES) * (long) length; } @Override - public int getRetainedSizeInBytes() + public long getRetainedSizeInBytes() { return retainedSizeInBytes; } + @Override + public void retainedBytesForEachPart(BiConsumer consumer) + { + consumer.accept(values, sizeOf(values)); + consumer.accept(valueIsNull, sizeOf(valueIsNull)); + consumer.accept(this, (long) INSTANCE_SIZE); + } + @Override public int getPositionCount() { diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/block/IntArrayBlockBuilder.java b/presto-spi/src/main/java/com/facebook/presto/spi/block/IntArrayBlockBuilder.java index 4908f6b08f70..4df15caecd18 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/block/IntArrayBlockBuilder.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/block/IntArrayBlockBuilder.java @@ -19,10 +19,10 @@ import java.util.Arrays; import java.util.List; +import java.util.function.BiConsumer; import static com.facebook.presto.spi.block.BlockUtil.calculateBlockResetSize; import static com.facebook.presto.spi.block.BlockUtil.checkValidRegion; -import static com.facebook.presto.spi.block.BlockUtil.intSaturatedCast; import static io.airlift.slice.SizeOf.sizeOf; import static java.lang.Math.max; @@ -42,7 +42,7 @@ public class IntArrayBlockBuilder private boolean[] valueIsNull = new boolean[0]; private int[] values = new int[0]; - private int retainedSizeInBytes; + private long retainedSizeInBytes; public IntArrayBlockBuilder(@Nullable BlockBuilderStatus blockBuilderStatus, int expectedEntries) { @@ -120,32 +120,38 @@ private void growCapacity() private void updateDataSize() { - long size = INSTANCE_SIZE + sizeOf(valueIsNull) + sizeOf(values); + retainedSizeInBytes = INSTANCE_SIZE + sizeOf(valueIsNull) + sizeOf(values); if (blockBuilderStatus != null) { - size += BlockBuilderStatus.INSTANCE_SIZE; + retainedSizeInBytes += BlockBuilderStatus.INSTANCE_SIZE; } - retainedSizeInBytes = intSaturatedCast(size); } - // Copied from IntArrayBlock @Override - public int getSizeInBytes() + public long getSizeInBytes() { - return intSaturatedCast((Integer.BYTES + Byte.BYTES) * (long) positionCount); + return (Integer.BYTES + Byte.BYTES) * (long) positionCount; } @Override - public int getRegionSizeInBytes(int position, int length) + public long getRegionSizeInBytes(int position, int length) { - return intSaturatedCast((Integer.BYTES + Byte.BYTES) * (long) length); + return (Integer.BYTES + Byte.BYTES) * (long) length; } @Override - public int getRetainedSizeInBytes() + public long getRetainedSizeInBytes() { return retainedSizeInBytes; } + @Override + public void retainedBytesForEachPart(BiConsumer consumer) + { + consumer.accept(values, sizeOf(values)); + consumer.accept(valueIsNull, sizeOf(valueIsNull)); + consumer.accept(this, (long) INSTANCE_SIZE); + } + @Override public int getPositionCount() { diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/block/InterleavedBlock.java b/presto-spi/src/main/java/com/facebook/presto/spi/block/InterleavedBlock.java index 1978c4633f38..c021733b4a1f 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/block/InterleavedBlock.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/block/InterleavedBlock.java @@ -15,7 +15,8 @@ import org.openjdk.jol.info.ClassLayout; -import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicLong; +import java.util.function.BiConsumer; public class InterleavedBlock extends AbstractInterleavedBlock @@ -26,17 +27,17 @@ public class InterleavedBlock private final InterleavedBlockEncoding blockEncoding; private final int start; private final int positionCount; - private final int retainedSizeInBytes; + private final long retainedSizeInBytes; - private final AtomicInteger sizeInBytes; + private final AtomicLong sizeInBytes; public InterleavedBlock(Block[] blocks) { super(blocks.length); this.blocks = blocks; - int sizeInBytes = 0; - int retainedSizeInBytes = INSTANCE_SIZE; + long sizeInBytes = 0; + long retainedSizeInBytes = INSTANCE_SIZE; int positionCount = 0; int firstSubBlockPositionCount = blocks[0].getPositionCount(); for (int i = 0; i < getBlockCount(); i++) { @@ -52,11 +53,11 @@ public InterleavedBlock(Block[] blocks) this.blockEncoding = computeBlockEncoding(); this.start = 0; this.positionCount = positionCount; - this.sizeInBytes = new AtomicInteger(sizeInBytes); + this.sizeInBytes = new AtomicLong(sizeInBytes); this.retainedSizeInBytes = retainedSizeInBytes; } - private InterleavedBlock(Block[] blocks, int start, int positionCount, int retainedSizeInBytes, InterleavedBlockEncoding blockEncoding) + private InterleavedBlock(Block[] blocks, int start, int positionCount, long retainedSizeInBytes, InterleavedBlockEncoding blockEncoding) { super(blocks.length); this.blocks = blocks; @@ -64,7 +65,7 @@ private InterleavedBlock(Block[] blocks, int start, int positionCount, int retai this.positionCount = positionCount; this.retainedSizeInBytes = retainedSizeInBytes; this.blockEncoding = blockEncoding; - this.sizeInBytes = new AtomicInteger(-1); + this.sizeInBytes = new AtomicLong(-1); } @Override @@ -103,9 +104,9 @@ public int getPositionCount() } @Override - public int getSizeInBytes() + public long getSizeInBytes() { - int sizeInBytes = this.sizeInBytes.get(); + long sizeInBytes = this.sizeInBytes.get(); if (sizeInBytes < 0) { sizeInBytes = 0; for (int i = 0; i < getBlockCount(); i++) { @@ -117,11 +118,18 @@ public int getSizeInBytes() } @Override - public int getRetainedSizeInBytes() + public long getRetainedSizeInBytes() { return retainedSizeInBytes; } + @Override + public void retainedBytesForEachPart(BiConsumer consumer) + { + consumer.accept(blocks, retainedSizeInBytes - INSTANCE_SIZE); + consumer.accept(this, (long) INSTANCE_SIZE); + } + @Override public String toString() { diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/block/InterleavedBlockBuilder.java b/presto-spi/src/main/java/com/facebook/presto/spi/block/InterleavedBlockBuilder.java index 6da6ffc70f0e..32ff7c95b9d8 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/block/InterleavedBlockBuilder.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/block/InterleavedBlockBuilder.java @@ -18,6 +18,7 @@ import org.openjdk.jol.info.ClassLayout; import java.util.List; +import java.util.function.BiConsumer; import static java.util.Objects.requireNonNull; @@ -33,10 +34,10 @@ public class InterleavedBlockBuilder private int positionCount; private int currentBlockIndex; - private int sizeInBytes; - private int startSize; - private int retainedSizeInBytes; - private int startRetainedSize; + private long sizeInBytes; + private long startSize; + private long retainedSizeInBytes; + private long startRetainedSize; public InterleavedBlockBuilder(List types, BlockBuilderStatus blockBuilderStatus, int expectedEntries) { @@ -109,17 +110,24 @@ public int getPositionCount() } @Override - public int getSizeInBytes() + public long getSizeInBytes() { return sizeInBytes; } @Override - public int getRetainedSizeInBytes() + public long getRetainedSizeInBytes() { return retainedSizeInBytes; } + @Override + public void retainedBytesForEachPart(BiConsumer consumer) + { + consumer.accept(blockBuilders, retainedSizeInBytes - INSTANCE_SIZE); + consumer.accept(this, (long) INSTANCE_SIZE); + } + private void recordStartSizesIfNecessary(BlockBuilder blockBuilder) { if (startSize < 0) { diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/block/LazyBlock.java b/presto-spi/src/main/java/com/facebook/presto/spi/block/LazyBlock.java index 10d3142e395c..7ac0b66ab43b 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/block/LazyBlock.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/block/LazyBlock.java @@ -17,6 +17,7 @@ import org.openjdk.jol.info.ClassLayout; import java.util.List; +import java.util.function.BiConsumer; import static java.util.Objects.requireNonNull; @@ -164,26 +165,36 @@ public Block getSingleValueBlock(int position) } @Override - public int getSizeInBytes() + public long getSizeInBytes() { assureLoaded(); return block.getSizeInBytes(); } @Override - public int getRegionSizeInBytes(int position, int length) + public long getRegionSizeInBytes(int position, int length) { assureLoaded(); return block.getRegionSizeInBytes(position, length); } @Override - public int getRetainedSizeInBytes() + public long getRetainedSizeInBytes() { assureLoaded(); return INSTANCE_SIZE + block.getRetainedSizeInBytes(); } + @Override + public void retainedBytesForEachPart(BiConsumer consumer) + { + // do not support LazyBlock (for now) for the following two reasons: + // (1) the method is mainly used for inspecting the identity and size of each element to prevent over counting + // (2) the method should be non-recursive and only inspects blocks at the top level; + // given LazyBlock is a wrapper for other blocks, it is not meaningful to only inspect the top-level elements + throw new UnsupportedOperationException(getClass().getName()); + } + @Override public BlockEncoding getEncoding() { diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/block/LongArrayBlock.java b/presto-spi/src/main/java/com/facebook/presto/spi/block/LongArrayBlock.java index ef29a44d0899..17e0ff44ffd2 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/block/LongArrayBlock.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/block/LongArrayBlock.java @@ -17,9 +17,9 @@ import java.util.Arrays; import java.util.List; +import java.util.function.BiConsumer; import static com.facebook.presto.spi.block.BlockUtil.checkValidRegion; -import static com.facebook.presto.spi.block.BlockUtil.intSaturatedCast; import static io.airlift.slice.SizeOf.sizeOf; import static java.lang.Math.toIntExact; @@ -33,8 +33,8 @@ public class LongArrayBlock private final boolean[] valueIsNull; private final long[] values; - private final int sizeInBytes; - private final int retainedSizeInBytes; + private final long sizeInBytes; + private final long retainedSizeInBytes; public LongArrayBlock(int positionCount, boolean[] valueIsNull, long[] values) { @@ -62,28 +62,36 @@ public LongArrayBlock(int positionCount, boolean[] valueIsNull, long[] values) } this.valueIsNull = valueIsNull; - sizeInBytes = intSaturatedCast((Long.BYTES + Byte.BYTES) * (long) positionCount); - retainedSizeInBytes = intSaturatedCast(INSTANCE_SIZE + sizeOf(valueIsNull) + sizeOf(values)); + sizeInBytes = (Long.BYTES + Byte.BYTES) * (long) positionCount; + retainedSizeInBytes = INSTANCE_SIZE + sizeOf(valueIsNull) + sizeOf(values); } @Override - public int getSizeInBytes() + public long getSizeInBytes() { return sizeInBytes; } @Override - public int getRegionSizeInBytes(int position, int length) + public long getRegionSizeInBytes(int position, int length) { - return intSaturatedCast((Long.BYTES + Byte.BYTES) * (long) length); + return (Long.BYTES + Byte.BYTES) * (long) length; } @Override - public int getRetainedSizeInBytes() + public long getRetainedSizeInBytes() { return retainedSizeInBytes; } + @Override + public void retainedBytesForEachPart(BiConsumer consumer) + { + consumer.accept(values, sizeOf(values)); + consumer.accept(valueIsNull, sizeOf(valueIsNull)); + consumer.accept(this, (long) INSTANCE_SIZE); + } + @Override public int getPositionCount() { diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/block/LongArrayBlockBuilder.java b/presto-spi/src/main/java/com/facebook/presto/spi/block/LongArrayBlockBuilder.java index 87b4e46d7a0e..806a30560509 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/block/LongArrayBlockBuilder.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/block/LongArrayBlockBuilder.java @@ -19,10 +19,10 @@ import java.util.Arrays; import java.util.List; +import java.util.function.BiConsumer; import static com.facebook.presto.spi.block.BlockUtil.calculateBlockResetSize; import static com.facebook.presto.spi.block.BlockUtil.checkValidRegion; -import static com.facebook.presto.spi.block.BlockUtil.intSaturatedCast; import static io.airlift.slice.SizeOf.sizeOf; import static java.lang.Math.max; import static java.lang.Math.toIntExact; @@ -43,7 +43,7 @@ public class LongArrayBlockBuilder private boolean[] valueIsNull = new boolean[0]; private long[] values = new long[0]; - private int retainedSizeInBytes; + private long retainedSizeInBytes; public LongArrayBlockBuilder(@Nullable BlockBuilderStatus blockBuilderStatus, int expectedEntries) { @@ -121,32 +121,38 @@ private void growCapacity() private void updateDataSize() { - long size = INSTANCE_SIZE + sizeOf(valueIsNull) + sizeOf(values); + retainedSizeInBytes = INSTANCE_SIZE + sizeOf(valueIsNull) + sizeOf(values); if (blockBuilderStatus != null) { - size += BlockBuilderStatus.INSTANCE_SIZE; + retainedSizeInBytes += BlockBuilderStatus.INSTANCE_SIZE; } - retainedSizeInBytes = intSaturatedCast(size); } - // Copied from LongArrayBlock @Override - public int getSizeInBytes() + public long getSizeInBytes() { - return intSaturatedCast((Long.BYTES + Byte.BYTES) * (long) positionCount); + return (Long.BYTES + Byte.BYTES) * (long) positionCount; } @Override - public int getRegionSizeInBytes(int position, int length) + public long getRegionSizeInBytes(int position, int length) { - return intSaturatedCast((Long.BYTES + Byte.BYTES) * (long) length); + return (Long.BYTES + Byte.BYTES) * (long) length; } @Override - public int getRetainedSizeInBytes() + public long getRetainedSizeInBytes() { return retainedSizeInBytes; } + @Override + public void retainedBytesForEachPart(BiConsumer consumer) + { + consumer.accept(values, sizeOf(values)); + consumer.accept(valueIsNull, sizeOf(valueIsNull)); + consumer.accept(this, (long) INSTANCE_SIZE); + } + @Override public int getPositionCount() { diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/block/MapBlock.java b/presto-spi/src/main/java/com/facebook/presto/spi/block/MapBlock.java index db44d3f8651a..f2d1fda189cc 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/block/MapBlock.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/block/MapBlock.java @@ -14,12 +14,15 @@ package com.facebook.presto.spi.block; +import com.facebook.presto.spi.type.MapType; import com.facebook.presto.spi.type.Type; import org.openjdk.jol.info.ClassLayout; import java.lang.invoke.MethodHandle; +import java.util.Arrays; +import java.util.function.BiConsumer; -import static com.facebook.presto.spi.block.BlockUtil.intSaturatedCast; +import static com.facebook.presto.spi.block.MapBlockBuilder.buildHashTable; import static io.airlift.slice.SizeOf.sizeOf; import static java.lang.String.format; import static java.util.Objects.requireNonNull; @@ -38,14 +41,14 @@ public class MapBlock private final Block valueBlock; private final int[] hashTables; // hash to location in map; - private int sizeInBytes; - private final int retainedSizeInBytes; + private long sizeInBytes; + private final long retainedSizeInBytes; /** * @param keyBlockNativeEquals (T, Block, int)boolean * @param keyNativeHashCode (T)long */ - public MapBlock( + MapBlock( int startOffset, int positionCount, boolean[] mapIsNull, @@ -74,8 +77,7 @@ public MapBlock( this.hashTables = hashTables; this.sizeInBytes = -1; - this.retainedSizeInBytes = intSaturatedCast( - INSTANCE_SIZE + keyBlock.getRetainedSizeInBytes() + valueBlock.getRetainedSizeInBytes() + sizeOf(offsets) + sizeOf(mapIsNull) + sizeOf(hashTables)); + this.retainedSizeInBytes = INSTANCE_SIZE + keyBlock.getRetainedSizeInBytes() + valueBlock.getRetainedSizeInBytes() + sizeOf(offsets) + sizeOf(mapIsNull) + sizeOf(hashTables); } @Override @@ -121,7 +123,7 @@ public int getPositionCount() } @Override - public int getSizeInBytes() + public long getSizeInBytes() { // this is racy but is safe because sizeInBytes is an int and the calculation is stable if (sizeInBytes < 0) { @@ -137,16 +139,27 @@ private void calculateSize() int entryCount = entriesEnd - entriesStart; sizeInBytes = keyBlock.getRegionSizeInBytes(entriesStart, entryCount) + valueBlock.getRegionSizeInBytes(entriesStart, entryCount) + - (Integer.BYTES + Byte.BYTES) * this.positionCount + - Integer.BYTES * HASH_MULTIPLIER * entryCount; + (Integer.BYTES + Byte.BYTES) * (long) this.positionCount + + Integer.BYTES * HASH_MULTIPLIER * (long) entryCount; } @Override - public int getRetainedSizeInBytes() + public long getRetainedSizeInBytes() { return retainedSizeInBytes; } + @Override + public void retainedBytesForEachPart(BiConsumer consumer) + { + consumer.accept(keyBlock, keyBlock.getRetainedSizeInBytes()); + consumer.accept(valueBlock, valueBlock.getRetainedSizeInBytes()); + consumer.accept(offsets, sizeOf(offsets)); + consumer.accept(mapIsNull, sizeOf(mapIsNull)); + consumer.accept(hashTables, sizeOf(hashTables)); + consumer.accept(this, (long) INSTANCE_SIZE); + } + @Override public String toString() { @@ -155,4 +168,50 @@ public String toString() sb.append('}'); return sb.toString(); } + + public static MapBlock fromKeyValueBlock( + boolean useNewMapBlock, + boolean[] mapIsNull, + int[] offsets, + Block keyBlock, + Block valueBlock, + MapType mapType, + MethodHandle keyBlockNativeEquals, + MethodHandle keyNativeHashCode, + MethodHandle keyBlockHashCode) + { + if (keyBlock.getPositionCount() != valueBlock.getPositionCount()) { + throw new IllegalArgumentException(format("keyBlock position count does not match valueBlock position count. %s %s", keyBlock.getPositionCount(), valueBlock.getPositionCount())); + } + int elementCount = keyBlock.getPositionCount(); + if (mapIsNull.length != offsets.length - 1) { + throw new IllegalArgumentException(format("mapIsNull.length-1 does not match offsets.length. %s %s", mapIsNull.length - 1, offsets.length)); + } + int mapCount = mapIsNull.length; + if (offsets[mapCount] != elementCount) { + throw new IllegalArgumentException(format("Last element of offsets does not match keyBlock position count. %s %s", offsets[mapCount], keyBlock.getPositionCount())); + } + int[] hashTables = new int[elementCount * HASH_MULTIPLIER]; + Arrays.fill(hashTables, -1); + for (int i = 0; i < mapCount; i++) { + int keyOffset = offsets[i]; + int keyCount = offsets[i + 1] - keyOffset; + if (keyCount < 0) { + throw new IllegalArgumentException(format("Offset is not monotonically ascending. offsets[%s]=%s, offsets[%s]=%s", i, offsets[i], i + 1, offsets[i + 1])); + } + buildHashTable(useNewMapBlock, keyBlock, keyOffset, keyCount, keyBlockHashCode, hashTables, keyOffset * HASH_MULTIPLIER, keyCount * HASH_MULTIPLIER); + } + + return new MapBlock( + 0, + mapCount, + mapIsNull, + offsets, + keyBlock, + valueBlock, + hashTables, + mapType.getKeyType(), + keyBlockNativeEquals, + keyNativeHashCode); + } } diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/block/MapBlockBuilder.java b/presto-spi/src/main/java/com/facebook/presto/spi/block/MapBlockBuilder.java index a2031641ff16..2467bbffdd92 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/block/MapBlockBuilder.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/block/MapBlockBuilder.java @@ -21,9 +21,9 @@ import java.lang.invoke.MethodHandle; import java.util.Arrays; +import java.util.function.BiConsumer; import static com.facebook.presto.spi.block.BlockUtil.calculateBlockResetSize; -import static com.facebook.presto.spi.block.BlockUtil.intSaturatedCast; import static io.airlift.slice.SizeOf.sizeOf; import static java.lang.String.format; import static java.util.Objects.requireNonNull; @@ -34,6 +34,7 @@ public class MapBlockBuilder { private static final int INSTANCE_SIZE = ClassLayout.parseClass(MapBlockBuilder.class).instanceSize(); + private final boolean useNewMapBlock; private final MethodHandle keyBlockHashCode; @Nullable @@ -49,16 +50,20 @@ public class MapBlockBuilder private boolean currentEntryOpened; public MapBlockBuilder( + boolean useNewMapBlock, Type keyType, Type valueType, - MethodHandle keyBlockNativeEquals, MethodHandle keyNativeHashCode, + MethodHandle keyBlockNativeEquals, + MethodHandle keyNativeHashCode, MethodHandle keyBlockHashCode, BlockBuilderStatus blockBuilderStatus, int expectedEntries) { this( + useNewMapBlock, keyType, - keyBlockNativeEquals, keyNativeHashCode, + keyBlockNativeEquals, + keyNativeHashCode, keyBlockHashCode, blockBuilderStatus, keyType.createBlockBuilder(blockBuilderStatus, expectedEntries), @@ -69,8 +74,10 @@ public MapBlockBuilder( } private MapBlockBuilder( + boolean useNewMapBlock, Type keyType, - MethodHandle keyBlockNativeEquals, MethodHandle keyNativeHashCode, + MethodHandle keyBlockNativeEquals, + MethodHandle keyNativeHashCode, MethodHandle keyBlockHashCode, @Nullable BlockBuilderStatus blockBuilderStatus, BlockBuilder keyBlockBuilder, @@ -81,7 +88,16 @@ private MapBlockBuilder( { super(keyType, keyNativeHashCode, keyBlockNativeEquals); - this.keyBlockHashCode = requireNonNull(keyBlockHashCode, "keyBlockHashCode is null"); + this.useNewMapBlock = useNewMapBlock; + if (useNewMapBlock) { + requireNonNull(keyBlockHashCode, "keyBlockHashCode is null"); + } + else { + if (keyBlockHashCode != null) { + throw new IllegalArgumentException("When useNewMapBlock is false, keyBlockHashCode should be null."); + } + } + this.keyBlockHashCode = keyBlockHashCode; this.blockBuilderStatus = blockBuilderStatus; this.positionCount = 0; @@ -135,21 +151,32 @@ public int getPositionCount() } @Override - public int getSizeInBytes() + public long getSizeInBytes() { return keyBlockBuilder.getSizeInBytes() + valueBlockBuilder.getSizeInBytes() + - (Integer.BYTES + Byte.BYTES) * positionCount + - Integer.BYTES * HASH_MULTIPLIER * keyBlockBuilder.getPositionCount(); + (Integer.BYTES + Byte.BYTES) * (long) positionCount + + Integer.BYTES * HASH_MULTIPLIER * (long) keyBlockBuilder.getPositionCount(); } @Override - public int getRetainedSizeInBytes() + public long getRetainedSizeInBytes() { long size = INSTANCE_SIZE + keyBlockBuilder.getRetainedSizeInBytes() + valueBlockBuilder.getRetainedSizeInBytes() + sizeOf(offsets) + sizeOf(mapIsNull) + sizeOf(hashTables); if (blockBuilderStatus != null) { size += BlockBuilderStatus.INSTANCE_SIZE; } - return intSaturatedCast(size); + return size; + } + + @Override + public void retainedBytesForEachPart(BiConsumer consumer) + { + consumer.accept(keyBlockBuilder, keyBlockBuilder.getRetainedSizeInBytes()); + consumer.accept(valueBlockBuilder, valueBlockBuilder.getRetainedSizeInBytes()); + consumer.accept(offsets, sizeOf(offsets)); + consumer.accept(mapIsNull, sizeOf(mapIsNull)); + consumer.accept(hashTables, sizeOf(hashTables)); + consumer.accept(this, (long) INSTANCE_SIZE); } @Override @@ -181,7 +208,7 @@ public BlockBuilder closeEntry() hashTables = Arrays.copyOf(hashTables, newSize); Arrays.fill(hashTables, oldSize, hashTables.length, -1); } - buildHashTable(keyBlockBuilder, previousAggregatedEntryCount, entryCount, keyBlockHashCode, hashTables, previousAggregatedEntryCount * HASH_MULTIPLIER, entryCount * HASH_MULTIPLIER); + buildHashTable(useNewMapBlock, keyBlockBuilder, previousAggregatedEntryCount, entryCount, keyBlockHashCode, hashTables, previousAggregatedEntryCount * HASH_MULTIPLIER, entryCount * HASH_MULTIPLIER); if (blockBuilderStatus != null) { blockBuilderStatus.addBytes(entryCount * HASH_MULTIPLIER * Integer.BYTES); } @@ -283,8 +310,10 @@ public BlockBuilder newBlockBuilderLike(BlockBuilderStatus blockBuilderStatus) { int newSize = calculateBlockResetSize(getPositionCount()); return new MapBlockBuilder( + useNewMapBlock, keyType, - keyBlockNativeEquals, keyNativeHashCode, + keyBlockNativeEquals, + keyNativeHashCode, keyBlockHashCode, blockBuilderStatus, keyBlockBuilder.newBlockBuilderLike(blockBuilderStatus), @@ -301,8 +330,12 @@ private static int[] newNegativeOneFilledArray(int size) return hashTable; } - private static void buildHashTable(Block keyBlock, int keyOffset, int keyCount, MethodHandle keyBlockHashCode, int[] outputHashTable, int hashTableOffset, int hashTableSize) + static void buildHashTable(boolean useNewMapBlock, Block keyBlock, int keyOffset, int keyCount, MethodHandle keyBlockHashCode, int[] outputHashTable, int hashTableOffset, int hashTableSize) { + if (!useNewMapBlock) { + return; + } + // This method assumes that keyBlock has no duplicated entries (in the specified range) for (int i = 0; i < keyCount; i++) { if (keyBlock.isNull(keyOffset + i)) { diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/block/MapBlockEncoding.java b/presto-spi/src/main/java/com/facebook/presto/spi/block/MapBlockEncoding.java index 743ed239fdba..4bb06745fc5f 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/block/MapBlockEncoding.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/block/MapBlockEncoding.java @@ -47,8 +47,10 @@ public class MapBlockEncoding public MapBlockEncoding(Type keyType, MethodHandle keyBlockNativeEquals, MethodHandle keyNativeHashCode, BlockEncoding keyBlockEncoding, BlockEncoding valueBlockEncoding) { this.keyType = requireNonNull(keyType, "keyType is null"); - this.keyNativeHashCode = requireNonNull(keyNativeHashCode, "keyNativeHashCode is null"); - this.keyBlockNativeEquals = requireNonNull(keyBlockNativeEquals, "keyBlockNativeEquals"); + // keyNativeHashCode can only be null due to map block kill switch. deprecated.new-map-block + this.keyNativeHashCode = keyNativeHashCode; + // keyBlockNativeEquals can only be null due to map block kill switch. deprecated.new-map-block + this.keyBlockNativeEquals = keyBlockNativeEquals; this.keyBlockEncoding = requireNonNull(keyBlockEncoding, "keyBlockEncoding is null"); this.valueBlockEncoding = requireNonNull(valueBlockEncoding, "valueBlockEncoding is null"); } diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/block/RunLengthEncodedBlock.java b/presto-spi/src/main/java/com/facebook/presto/spi/block/RunLengthEncodedBlock.java index eb8ac0b0e5b8..87b4d9a9e59c 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/block/RunLengthEncodedBlock.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/block/RunLengthEncodedBlock.java @@ -19,6 +19,7 @@ import org.openjdk.jol.info.ClassLayout; import java.util.List; +import java.util.function.BiConsumer; import static com.facebook.presto.spi.block.BlockUtil.checkValidPositions; import static java.lang.String.format; @@ -70,17 +71,24 @@ public int getPositionCount() } @Override - public int getSizeInBytes() + public long getSizeInBytes() { return value.getSizeInBytes(); } @Override - public int getRetainedSizeInBytes() + public long getRetainedSizeInBytes() { return INSTANCE_SIZE + value.getRetainedSizeInBytes(); } + @Override + public void retainedBytesForEachPart(BiConsumer consumer) + { + consumer.accept(value, value.getRetainedSizeInBytes()); + consumer.accept(this, (long) INSTANCE_SIZE); + } + @Override public RunLengthBlockEncoding getEncoding() { @@ -102,7 +110,7 @@ public Block getRegion(int positionOffset, int length) } @Override - public int getRegionSizeInBytes(int position, int length) + public long getRegionSizeInBytes(int position, int length) { return value.getSizeInBytes(); } diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/block/ShortArrayBlock.java b/presto-spi/src/main/java/com/facebook/presto/spi/block/ShortArrayBlock.java index 62c8b3676b1f..1a5aa29fe1d1 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/block/ShortArrayBlock.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/block/ShortArrayBlock.java @@ -17,9 +17,9 @@ import java.util.Arrays; import java.util.List; +import java.util.function.BiConsumer; import static com.facebook.presto.spi.block.BlockUtil.checkValidRegion; -import static com.facebook.presto.spi.block.BlockUtil.intSaturatedCast; import static io.airlift.slice.SizeOf.sizeOf; public class ShortArrayBlock @@ -32,8 +32,8 @@ public class ShortArrayBlock private final boolean[] valueIsNull; private final short[] values; - private final int sizeInBytes; - private final int retainedSizeInBytes; + private final long sizeInBytes; + private final long retainedSizeInBytes; public ShortArrayBlock(int positionCount, boolean[] valueIsNull, short[] values) { @@ -61,28 +61,36 @@ public ShortArrayBlock(int positionCount, boolean[] valueIsNull, short[] values) } this.valueIsNull = valueIsNull; - sizeInBytes = intSaturatedCast((Short.BYTES + Byte.BYTES) * (long) positionCount); - retainedSizeInBytes = intSaturatedCast(INSTANCE_SIZE + sizeOf(valueIsNull) + sizeOf(values)); + sizeInBytes = (Short.BYTES + Byte.BYTES) * (long) positionCount; + retainedSizeInBytes = INSTANCE_SIZE + sizeOf(valueIsNull) + sizeOf(values); } @Override - public int getSizeInBytes() + public long getSizeInBytes() { return sizeInBytes; } @Override - public int getRegionSizeInBytes(int position, int length) + public long getRegionSizeInBytes(int position, int length) { - return intSaturatedCast((Short.BYTES + Byte.BYTES) * (long) length); + return (Short.BYTES + Byte.BYTES) * (long) length; } @Override - public int getRetainedSizeInBytes() + public long getRetainedSizeInBytes() { return retainedSizeInBytes; } + @Override + public void retainedBytesForEachPart(BiConsumer consumer) + { + consumer.accept(values, sizeOf(values)); + consumer.accept(valueIsNull, sizeOf(valueIsNull)); + consumer.accept(this, (long) INSTANCE_SIZE); + } + @Override public int getPositionCount() { diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/block/ShortArrayBlockBuilder.java b/presto-spi/src/main/java/com/facebook/presto/spi/block/ShortArrayBlockBuilder.java index d2ddd318c909..f3a3b4798c6c 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/block/ShortArrayBlockBuilder.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/block/ShortArrayBlockBuilder.java @@ -19,10 +19,10 @@ import java.util.Arrays; import java.util.List; +import java.util.function.BiConsumer; import static com.facebook.presto.spi.block.BlockUtil.calculateBlockResetSize; import static com.facebook.presto.spi.block.BlockUtil.checkValidRegion; -import static com.facebook.presto.spi.block.BlockUtil.intSaturatedCast; import static io.airlift.slice.SizeOf.sizeOf; import static java.lang.Math.max; @@ -42,7 +42,7 @@ public class ShortArrayBlockBuilder private boolean[] valueIsNull = new boolean[0]; private short[] values = new short[0]; - private int retainedSizeInBytes; + private long retainedSizeInBytes; public ShortArrayBlockBuilder(@Nullable BlockBuilderStatus blockBuilderStatus, int expectedEntries) { @@ -120,32 +120,38 @@ private void growCapacity() private void updateDataSize() { - long size = INSTANCE_SIZE + sizeOf(valueIsNull) + sizeOf(values); + retainedSizeInBytes = INSTANCE_SIZE + sizeOf(valueIsNull) + sizeOf(values); if (blockBuilderStatus != null) { - size += BlockBuilderStatus.INSTANCE_SIZE; + retainedSizeInBytes += BlockBuilderStatus.INSTANCE_SIZE; } - retainedSizeInBytes = intSaturatedCast(size); } - // Copied from ShortArrayBlock @Override - public int getSizeInBytes() + public long getSizeInBytes() { - return intSaturatedCast((Short.BYTES + Byte.BYTES) * (long) positionCount); + return (Short.BYTES + Byte.BYTES) * (long) positionCount; } @Override - public int getRegionSizeInBytes(int position, int length) + public long getRegionSizeInBytes(int position, int length) { - return intSaturatedCast((Short.BYTES + Byte.BYTES) * (long) length); + return (Short.BYTES + Byte.BYTES) * (long) length; } @Override - public int getRetainedSizeInBytes() + public long getRetainedSizeInBytes() { return retainedSizeInBytes; } + @Override + public void retainedBytesForEachPart(BiConsumer consumer) + { + consumer.accept(values, sizeOf(values)); + consumer.accept(valueIsNull, sizeOf(valueIsNull)); + consumer.accept(this, (long) INSTANCE_SIZE); + } + @Override public int getPositionCount() { diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/block/SingleArrayBlockWriter.java b/presto-spi/src/main/java/com/facebook/presto/spi/block/SingleArrayBlockWriter.java index c02836bb52f8..e3a9fa6dcd37 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/block/SingleArrayBlockWriter.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/block/SingleArrayBlockWriter.java @@ -16,6 +16,8 @@ import io.airlift.slice.Slice; import org.openjdk.jol.info.ClassLayout; +import java.util.function.BiConsumer; + public class SingleArrayBlockWriter extends AbstractSingleArrayBlock implements BlockBuilder @@ -23,7 +25,7 @@ public class SingleArrayBlockWriter private static final int INSTANCE_SIZE = ClassLayout.parseClass(SingleArrayBlockWriter.class).instanceSize(); private final BlockBuilder blockBuilder; - private final int initialBlockBuilderSize; + private final long initialBlockBuilderSize; private int positionsWritten; public SingleArrayBlockWriter(BlockBuilder blockBuilder, int start) @@ -40,17 +42,24 @@ protected BlockBuilder getBlock() } @Override - public int getSizeInBytes() + public long getSizeInBytes() { return blockBuilder.getSizeInBytes() - initialBlockBuilderSize; } @Override - public int getRetainedSizeInBytes() + public long getRetainedSizeInBytes() { return INSTANCE_SIZE + blockBuilder.getRetainedSizeInBytes(); } + @Override + public void retainedBytesForEachPart(BiConsumer consumer) + { + consumer.accept(blockBuilder, blockBuilder.getRetainedSizeInBytes()); + consumer.accept(this, (long) INSTANCE_SIZE); + } + @Override public BlockBuilder writeByte(int value) { diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/block/SingleMapBlock.java b/presto-spi/src/main/java/com/facebook/presto/spi/block/SingleMapBlock.java index 5729bfa43202..d031eae95437 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/block/SingleMapBlock.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/block/SingleMapBlock.java @@ -20,17 +20,17 @@ import org.openjdk.jol.info.ClassLayout; import java.lang.invoke.MethodHandle; +import java.util.function.BiConsumer; import static com.facebook.presto.spi.StandardErrorCode.GENERIC_INTERNAL_ERROR; import static com.facebook.presto.spi.block.AbstractMapBlock.HASH_MULTIPLIER; -import static com.facebook.presto.spi.block.BlockUtil.intSaturatedCast; import static io.airlift.slice.SizeOf.sizeOf; import static io.airlift.slice.SizeOf.sizeOfIntArray; public class SingleMapBlock extends AbstractSingleMapBlock { - private static final int INSTANCE_SIZE = ClassLayout.parseClass(SingleMapBlockWriter.class).instanceSize(); + private static final int INSTANCE_SIZE = ClassLayout.parseClass(SingleMapBlock.class).instanceSize(); private final int offset; private final int positionCount; @@ -62,17 +62,26 @@ public int getPositionCount() } @Override - public int getSizeInBytes() + public long getSizeInBytes() { - return intSaturatedCast(keyBlock.getRegionSizeInBytes(offset / 2, positionCount / 2) + + return keyBlock.getRegionSizeInBytes(offset / 2, positionCount / 2) + valueBlock.getRegionSizeInBytes(offset / 2, positionCount / 2) + - sizeOfIntArray(positionCount / 2 * HASH_MULTIPLIER)); + sizeOfIntArray(positionCount / 2 * HASH_MULTIPLIER); } @Override - public int getRetainedSizeInBytes() + public long getRetainedSizeInBytes() { - return intSaturatedCast(INSTANCE_SIZE + keyBlock.getRetainedSizeInBytes() + valueBlock.getRetainedSizeInBytes() + sizeOf(hashTable)); + return INSTANCE_SIZE + keyBlock.getRetainedSizeInBytes() + valueBlock.getRetainedSizeInBytes() + sizeOf(hashTable); + } + + @Override + public void retainedBytesForEachPart(BiConsumer consumer) + { + consumer.accept(keyBlock, keyBlock.getRetainedSizeInBytes()); + consumer.accept(valueBlock, valueBlock.getRetainedSizeInBytes()); + consumer.accept(hashTable, sizeOf(hashTable)); + consumer.accept(this, (long) INSTANCE_SIZE); } @Override diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/block/SingleMapBlockEncoding.java b/presto-spi/src/main/java/com/facebook/presto/spi/block/SingleMapBlockEncoding.java index 7b4d97ced8d4..9c5ae989953c 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/block/SingleMapBlockEncoding.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/block/SingleMapBlockEncoding.java @@ -47,8 +47,10 @@ public class SingleMapBlockEncoding public SingleMapBlockEncoding(Type keyType, MethodHandle keyNativeHashCode, MethodHandle keyBlockNativeEquals, BlockEncoding keyBlockEncoding, BlockEncoding valueBlockEncoding) { this.keyType = requireNonNull(keyType, "keyType is null"); - this.keyNativeHashCode = requireNonNull(keyNativeHashCode, "keyNativeHashCode is null"); - this.keyBlockNativeEquals = requireNonNull(keyBlockNativeEquals, "keyBlockNativeEquals"); + // keyNativeHashCode can only be null due to map block kill switch. deprecated.new-map-block + this.keyNativeHashCode = keyNativeHashCode; + // keyBlockNativeEquals can only be null due to map block kill switch. deprecated.new-map-block + this.keyBlockNativeEquals = keyBlockNativeEquals; this.keyBlockEncoding = requireNonNull(keyBlockEncoding, "keyBlockEncoding is null"); this.valueBlockEncoding = requireNonNull(valueBlockEncoding, "valueBlockEncoding is null"); } diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/block/SingleMapBlockWriter.java b/presto-spi/src/main/java/com/facebook/presto/spi/block/SingleMapBlockWriter.java index a9704a6d7d00..840c068ca3e2 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/block/SingleMapBlockWriter.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/block/SingleMapBlockWriter.java @@ -16,6 +16,8 @@ import io.airlift.slice.Slice; import org.openjdk.jol.info.ClassLayout; +import java.util.function.BiConsumer; + public class SingleMapBlockWriter extends AbstractSingleMapBlock implements BlockBuilder @@ -24,7 +26,7 @@ public class SingleMapBlockWriter private final BlockBuilder keyBlockBuilder; private final BlockBuilder valueBlockBuilder; - private final int initialBlockBuilderSize; + private final long initialBlockBuilderSize; private int positionsWritten; private boolean writeToValueNext; @@ -38,17 +40,25 @@ public class SingleMapBlockWriter } @Override - public int getSizeInBytes() + public long getSizeInBytes() { return keyBlockBuilder.getSizeInBytes() + valueBlockBuilder.getSizeInBytes() - initialBlockBuilderSize; } @Override - public int getRetainedSizeInBytes() + public long getRetainedSizeInBytes() { return INSTANCE_SIZE + keyBlockBuilder.getRetainedSizeInBytes() + valueBlockBuilder.getRetainedSizeInBytes(); } + @Override + public void retainedBytesForEachPart(BiConsumer consumer) + { + consumer.accept(keyBlockBuilder, keyBlockBuilder.getRetainedSizeInBytes()); + consumer.accept(valueBlockBuilder, valueBlockBuilder.getRetainedSizeInBytes()); + consumer.accept(this, (long) INSTANCE_SIZE); + } + @Override public BlockBuilder writeByte(int value) { diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/block/SliceArrayBlock.java b/presto-spi/src/main/java/com/facebook/presto/spi/block/SliceArrayBlock.java index 40b8030b5ded..4b17fac3e141 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/block/SliceArrayBlock.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/block/SliceArrayBlock.java @@ -21,9 +21,9 @@ import java.util.IdentityHashMap; import java.util.List; import java.util.Map; +import java.util.function.BiConsumer; import static com.facebook.presto.spi.block.BlockUtil.checkValidPositions; -import static com.facebook.presto.spi.block.BlockUtil.intSaturatedCast; import static io.airlift.slice.SizeOf.sizeOf; public class SliceArrayBlock @@ -32,8 +32,8 @@ public class SliceArrayBlock private static final int INSTANCE_SIZE = ClassLayout.parseClass(SliceArrayBlock.class).instanceSize(); private final int positionCount; private final Slice[] values; - private final int sizeInBytes; - private final int retainedSizeInBytes; + private final long sizeInBytes; + private final long retainedSizeInBytes; public SliceArrayBlock(int positionCount, Slice[] values) { @@ -115,17 +115,24 @@ public int getSliceLength(int position) } @Override - public int getSizeInBytes() + public long getSizeInBytes() { return sizeInBytes; } @Override - public int getRetainedSizeInBytes() + public long getRetainedSizeInBytes() { return retainedSizeInBytes; } + @Override + public void retainedBytesForEachPart(BiConsumer consumer) + { + consumer.accept(values, retainedSizeInBytes - INSTANCE_SIZE); + consumer.accept(this, (long) INSTANCE_SIZE); + } + @Override public Block getRegion(int positionOffset, int length) { @@ -178,7 +185,7 @@ public String toString() return sb.toString(); } - private static int getSliceArraySizeInBytes(Slice[] values, int offset, int length) + private static long getSliceArraySizeInBytes(Slice[] values, int offset, int length) { long sizeInBytes = 0; for (int i = offset; i < offset + length; i++) { @@ -187,11 +194,11 @@ private static int getSliceArraySizeInBytes(Slice[] values, int offset, int leng sizeInBytes += value.length(); } } - return intSaturatedCast(sizeInBytes); + return sizeInBytes; } @Override - public int getRegionSizeInBytes(int positionOffset, int length) + public long getRegionSizeInBytes(int positionOffset, int length) { int positionCount = getPositionCount(); if (positionOffset == 0 && length == positionCount) { @@ -206,7 +213,7 @@ public int getRegionSizeInBytes(int positionOffset, int length) return getSliceArraySizeInBytes(values, positionOffset, length); } - private static int getSliceArrayRetainedSizeInBytes(Slice[] values, boolean valueSlicesAreDistinct) + private static long getSliceArrayRetainedSizeInBytes(Slice[] values, boolean valueSlicesAreDistinct) { if (valueSlicesAreDistinct) { return getDistinctSliceArrayRetainedSize(values); @@ -215,10 +222,10 @@ private static int getSliceArrayRetainedSizeInBytes(Slice[] values, boolean valu } // when the slices are not distinct we need to do reference counting to calculate the total retained size - private static int getSliceArrayRetainedSizeInBytes(Slice[] values) + private static long getSliceArrayRetainedSizeInBytes(Slice[] values) { long sizeInBytes = sizeOf(values); - Map uniqueRetained = new IdentityHashMap<>(values.length); + Map uniqueRetained = new IdentityHashMap<>(); for (Slice value : values) { if (value == null) { continue; @@ -227,10 +234,10 @@ private static int getSliceArrayRetainedSizeInBytes(Slice[] values) sizeInBytes += value.getRetainedSize(); } } - return intSaturatedCast(sizeInBytes); + return sizeInBytes; } - private static int getDistinctSliceArrayRetainedSize(Slice[] values) + private static long getDistinctSliceArrayRetainedSize(Slice[] values) { long sizeInBytes = sizeOf(values); for (Slice value : values) { @@ -240,6 +247,6 @@ private static int getDistinctSliceArrayRetainedSize(Slice[] values) } sizeInBytes += value.getRetainedSize(); } - return intSaturatedCast(sizeInBytes); + return sizeInBytes; } } diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/block/VariableWidthBlock.java b/presto-spi/src/main/java/com/facebook/presto/spi/block/VariableWidthBlock.java index 01ef6db99b6e..1b6bf8a4e11c 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/block/VariableWidthBlock.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/block/VariableWidthBlock.java @@ -20,10 +20,10 @@ import java.util.Arrays; import java.util.List; +import java.util.function.BiConsumer; import static com.facebook.presto.spi.block.BlockUtil.checkValidPositions; import static com.facebook.presto.spi.block.BlockUtil.checkValidRegion; -import static com.facebook.presto.spi.block.BlockUtil.intSaturatedCast; import static io.airlift.slice.SizeOf.sizeOf; public class VariableWidthBlock @@ -37,8 +37,8 @@ public class VariableWidthBlock private final int[] offsets; private final boolean[] valueIsNull; - private final int retainedSizeInBytes; - private final int sizeInBytes; + private final long retainedSizeInBytes; + private final long sizeInBytes; public VariableWidthBlock(int positionCount, Slice slice, int[] offsets, boolean[] valueIsNull) { @@ -71,8 +71,8 @@ public VariableWidthBlock(int positionCount, Slice slice, int[] offsets, boolean } this.valueIsNull = valueIsNull; - sizeInBytes = intSaturatedCast(offsets[arrayOffset + positionCount] - offsets[arrayOffset] + ((Integer.BYTES + Byte.BYTES) * (long) positionCount)); - retainedSizeInBytes = intSaturatedCast(INSTANCE_SIZE + slice.getRetainedSize() + sizeOf(valueIsNull) + sizeOf(offsets)); + sizeInBytes = offsets[arrayOffset + positionCount] - offsets[arrayOffset] + ((Integer.BYTES + Byte.BYTES) * (long) positionCount); + retainedSizeInBytes = INSTANCE_SIZE + slice.getRetainedSize() + sizeOf(valueIsNull) + sizeOf(offsets); } @Override @@ -101,23 +101,32 @@ public int getPositionCount() } @Override - public int getSizeInBytes() + public long getSizeInBytes() { return sizeInBytes; } @Override - public int getRegionSizeInBytes(int position, int length) + public long getRegionSizeInBytes(int position, int length) { - return intSaturatedCast(offsets[arrayOffset + position + length] - offsets[arrayOffset + position] + ((Integer.BYTES + Byte.BYTES) * (long) length)); + return offsets[arrayOffset + position + length] - offsets[arrayOffset + position] + ((Integer.BYTES + Byte.BYTES) * (long) length); } @Override - public int getRetainedSizeInBytes() + public long getRetainedSizeInBytes() { return retainedSizeInBytes; } + @Override + public void retainedBytesForEachPart(BiConsumer consumer) + { + consumer.accept(slice, (long) slice.getRetainedSize()); + consumer.accept(offsets, sizeOf(offsets)); + consumer.accept(valueIsNull, sizeOf(valueIsNull)); + consumer.accept(this, (long) INSTANCE_SIZE); + } + @Override public Block copyPositions(List positions) { diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/block/VariableWidthBlockBuilder.java b/presto-spi/src/main/java/com/facebook/presto/spi/block/VariableWidthBlockBuilder.java index c4b7fca3adc5..8de3da1b8609 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/block/VariableWidthBlockBuilder.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/block/VariableWidthBlockBuilder.java @@ -23,10 +23,10 @@ import java.util.Arrays; import java.util.List; +import java.util.function.BiConsumer; import static com.facebook.presto.spi.block.BlockUtil.MAX_ARRAY_SIZE; import static com.facebook.presto.spi.block.BlockUtil.calculateBlockResetSize; -import static com.facebook.presto.spi.block.BlockUtil.intSaturatedCast; import static io.airlift.slice.SizeOf.SIZE_OF_BYTE; import static io.airlift.slice.SizeOf.SIZE_OF_INT; import static io.airlift.slice.SizeOf.SIZE_OF_LONG; @@ -97,31 +97,40 @@ public int getPositionCount() } @Override - public int getSizeInBytes() + public long getSizeInBytes() { long arraysSizeInBytes = (Integer.BYTES + Byte.BYTES) * (long) positions; - return intSaturatedCast(sliceOutput.size() + arraysSizeInBytes); + return sliceOutput.size() + arraysSizeInBytes; } @Override - public int getRegionSizeInBytes(int positionOffset, int length) + public long getRegionSizeInBytes(int positionOffset, int length) { int positionCount = getPositionCount(); if (positionOffset < 0 || length < 0 || positionOffset + length > positionCount) { throw new IndexOutOfBoundsException("Invalid position " + positionOffset + " length " + length + " in block with " + positionCount + " positions"); } long arraysSizeInBytes = (Integer.BYTES + Byte.BYTES) * (long) length; - return intSaturatedCast(getOffset(positionOffset + length) - getOffset(positionOffset) + arraysSizeInBytes); + return getOffset(positionOffset + length) - getOffset(positionOffset) + arraysSizeInBytes; } @Override - public int getRetainedSizeInBytes() + public long getRetainedSizeInBytes() { long size = INSTANCE_SIZE + sliceOutput.getRetainedSize() + arraysRetainedSizeInBytes; if (blockBuilderStatus != null) { size += BlockBuilderStatus.INSTANCE_SIZE; } - return intSaturatedCast(size); + return size; + } + + @Override + public void retainedBytesForEachPart(BiConsumer consumer) + { + consumer.accept(sliceOutput, (long) sliceOutput.getRetainedSize()); + consumer.accept(offsets, sizeOf(offsets)); + consumer.accept(valueIsNull, sizeOf(valueIsNull)); + consumer.accept(this, (long) INSTANCE_SIZE); } @Override @@ -260,7 +269,7 @@ private void initializeCapacity() private void updateArraysDataSize() { - arraysRetainedSizeInBytes = intSaturatedCast(sizeOf(valueIsNull) + sizeOf(offsets)); + arraysRetainedSizeInBytes = sizeOf(valueIsNull) + sizeOf(offsets); } @Override diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/connector/ConnectorAccessControl.java b/presto-spi/src/main/java/com/facebook/presto/spi/connector/ConnectorAccessControl.java index 4f572f536b4a..67602098b3d9 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/connector/ConnectorAccessControl.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/connector/ConnectorAccessControl.java @@ -26,6 +26,7 @@ import static com.facebook.presto.spi.security.AccessDeniedException.denyCreateView; import static com.facebook.presto.spi.security.AccessDeniedException.denyCreateViewWithSelect; import static com.facebook.presto.spi.security.AccessDeniedException.denyDeleteTable; +import static com.facebook.presto.spi.security.AccessDeniedException.denyDropColumn; import static com.facebook.presto.spi.security.AccessDeniedException.denyDropSchema; import static com.facebook.presto.spi.security.AccessDeniedException.denyDropTable; import static com.facebook.presto.spi.security.AccessDeniedException.denyDropView; @@ -157,6 +158,16 @@ default void checkCanAddColumn(ConnectorTransactionHandle transactionHandle, Ide denyAddColumn(tableName.toString()); } + /** + * Check if identity is allowed to drop columns from the specified table in this catalog. + * + * @throws com.facebook.presto.spi.security.AccessDeniedException if not allowed + */ + default void checkCanDropColumn(ConnectorTransactionHandle transactionHandle, Identity identity, SchemaTableName tableName) + { + denyDropColumn(tableName.toString()); + } + /** * Check if identity is allowed to rename a column in the specified table in this catalog. * diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/connector/ConnectorMetadata.java b/presto-spi/src/main/java/com/facebook/presto/spi/connector/ConnectorMetadata.java index 56ce4b5a0ecd..3facfd4f00e0 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/connector/ConnectorMetadata.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/connector/ConnectorMetadata.java @@ -204,6 +204,14 @@ default void renameColumn(ConnectorSession session, ConnectorTableHandle tableHa throw new PrestoException(NOT_SUPPORTED, "This connector does not support renaming columns"); } + /** + * Drop the specified column + */ + default void dropColumn(ConnectorSession session, ConnectorTableHandle tableHandle, ColumnHandle column) + { + throw new PrestoException(NOT_SUPPORTED, "This connector does not support dropping columns"); + } + /** * Get the physical layout for a new table. */ diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/connector/classloader/ClassLoaderSafeConnectorMetadata.java b/presto-spi/src/main/java/com/facebook/presto/spi/connector/classloader/ClassLoaderSafeConnectorMetadata.java index 974f1c9c644b..4cbccefcfe31 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/connector/classloader/ClassLoaderSafeConnectorMetadata.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/connector/classloader/ClassLoaderSafeConnectorMetadata.java @@ -233,6 +233,14 @@ public void renameColumn(ConnectorSession session, ConnectorTableHandle tableHan } } + @Override + public void dropColumn(ConnectorSession session, ConnectorTableHandle tableHandle, ColumnHandle column) + { + try (ThreadContextClassLoader ignored = new ThreadContextClassLoader(classLoader)) { + delegate.dropColumn(session, tableHandle, column); + } + } + @Override public void renameTable(ConnectorSession session, ConnectorTableHandle tableHandle, SchemaTableName newTableName) { diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/resourceGroups/QueryType.java b/presto-spi/src/main/java/com/facebook/presto/spi/resourceGroups/QueryType.java new file mode 100644 index 000000000000..f037e70d6186 --- /dev/null +++ b/presto-spi/src/main/java/com/facebook/presto/spi/resourceGroups/QueryType.java @@ -0,0 +1,24 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.spi.resourceGroups; + +public enum QueryType +{ + DATA_DEFINITION, + DELETE, + DESCRIBE, + EXPLAIN, + INSERT, + SELECT +} diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/resourceGroups/ResourceGroupInfo.java b/presto-spi/src/main/java/com/facebook/presto/spi/resourceGroups/ResourceGroupInfo.java index abe8e6f245c5..e30d5393ada7 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/resourceGroups/ResourceGroupInfo.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/resourceGroups/ResourceGroupInfo.java @@ -16,6 +16,7 @@ import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; import io.airlift.units.DataSize; +import io.airlift.units.Duration; import java.util.List; import java.util.Optional; @@ -30,7 +31,9 @@ public class ResourceGroupInfo private final DataSize softMemoryLimit; private final int maxRunningQueries; + private final Duration runningTimeLimit; private final int maxQueuedQueries; + private final Duration queuedTimeLimit; private final ResourceGroupState state; private final int numEligibleSubGroups; @@ -45,21 +48,25 @@ public ResourceGroupInfo( @JsonProperty("id") ResourceGroupId id, @JsonProperty("softMemoryLimit") DataSize softMemoryLimit, @JsonProperty("maxRunningQueries") int maxRunningQueries, + @JsonProperty("runningTimeLimit") Duration runningTimeLimit, @JsonProperty("maxQueuedQueries") int maxQueuedQueries, + @JsonProperty("queuedTimeLimit") Duration queuedTimeLimit, @JsonProperty("state") ResourceGroupState state, @JsonProperty("numEligibleSubGroups") int numEligibleSubGroups, @JsonProperty("memoryUsage") DataSize memoryUsage, @JsonProperty("numAggregatedRunningQueries") int numAggregatedRunningQueries, @JsonProperty("numAggregatedQueuedQueries") int numAggregatedQueuedQueries) { - this(id, softMemoryLimit, maxRunningQueries, maxQueuedQueries, state, numEligibleSubGroups, memoryUsage, numAggregatedRunningQueries, numAggregatedQueuedQueries, emptyList()); + this(id, softMemoryLimit, maxRunningQueries, runningTimeLimit, maxQueuedQueries, queuedTimeLimit, state, numEligibleSubGroups, memoryUsage, numAggregatedRunningQueries, numAggregatedQueuedQueries, emptyList()); } public ResourceGroupInfo( ResourceGroupId id, DataSize softMemoryLimit, int maxRunningQueries, + Duration runningTimeLimit, int maxQueuedQueries, + Duration queuedTimeLimit, ResourceGroupState state, int numEligibleSubGroups, DataSize memoryUsage, @@ -70,7 +77,9 @@ public ResourceGroupInfo( this.id = requireNonNull(id, "id is null"); this.softMemoryLimit = requireNonNull(softMemoryLimit, "softMemoryLimit is null"); this.maxRunningQueries = maxRunningQueries; + this.runningTimeLimit = runningTimeLimit; this.maxQueuedQueries = maxQueuedQueries; + this.queuedTimeLimit = queuedTimeLimit; this.state = requireNonNull(state, "state is null"); this.numEligibleSubGroups = numEligibleSubGroups; this.memoryUsage = requireNonNull(memoryUsage, "memoryUsage is null"); @@ -97,12 +106,24 @@ public int getMaxRunningQueries() return maxRunningQueries; } + @JsonProperty + public Duration getRunningTimeLimit() + { + return runningTimeLimit; + } + @JsonProperty public int getMaxQueuedQueries() { return maxQueuedQueries; } + @JsonProperty + public Duration getQueuedTimeLimit() + { + return queuedTimeLimit; + } + public List getSubGroups() { return subGroups; @@ -151,7 +172,9 @@ public ResourceGroupInfo createSingleNodeInfo() getId(), getSoftMemoryLimit(), getMaxRunningQueries(), + getRunningTimeLimit(), getMaxQueuedQueries(), + getQueuedTimeLimit(), getState(), getNumEligibleSubGroups(), getMemoryUsage(), diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/resourceGroups/SelectionContext.java b/presto-spi/src/main/java/com/facebook/presto/spi/resourceGroups/SelectionContext.java index 9af417d819dc..6d0b961d51c9 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/resourceGroups/SelectionContext.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/resourceGroups/SelectionContext.java @@ -23,13 +23,15 @@ public final class SelectionContext private final String user; private final Optional source; private final int queryPriority; + private final Optional queryType; - public SelectionContext(boolean authenticated, String user, Optional source, int queryPriority) + public SelectionContext(boolean authenticated, String user, Optional source, int queryPriority, Optional queryType) { this.authenticated = authenticated; this.user = requireNonNull(user, "user is null"); this.source = requireNonNull(source, "source is null"); this.queryPriority = queryPriority; + this.queryType = requireNonNull(queryType, "queryType is null"); } public boolean isAuthenticated() @@ -51,4 +53,9 @@ public int getQueryPriority() { return queryPriority; } + + public Optional getQueryType() + { + return queryType; + } } diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/security/AccessDeniedException.java b/presto-spi/src/main/java/com/facebook/presto/spi/security/AccessDeniedException.java index 81cb34dd356b..a1054eda9ecf 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/security/AccessDeniedException.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/security/AccessDeniedException.java @@ -138,6 +138,16 @@ public static void denyAddColumn(String tableName, String extraInfo) throw new AccessDeniedException(format("Cannot add a column to table %s%s", tableName, formatExtraInfo(extraInfo))); } + public static void denyDropColumn(String tableName) + { + denyDropColumn(tableName, null); + } + + public static void denyDropColumn(String tableName, String extraInfo) + { + throw new AccessDeniedException(format("Cannot drop a column from table %s%s", tableName, formatExtraInfo(extraInfo))); + } + public static void denyRenameColumn(String tableName) { denyRenameColumn(tableName, null); diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/security/SystemAccessControl.java b/presto-spi/src/main/java/com/facebook/presto/spi/security/SystemAccessControl.java index 7226aa7c911d..de83d7213420 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/security/SystemAccessControl.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/security/SystemAccessControl.java @@ -28,6 +28,7 @@ import static com.facebook.presto.spi.security.AccessDeniedException.denyCreateView; import static com.facebook.presto.spi.security.AccessDeniedException.denyCreateViewWithSelect; import static com.facebook.presto.spi.security.AccessDeniedException.denyDeleteTable; +import static com.facebook.presto.spi.security.AccessDeniedException.denyDropColumn; import static com.facebook.presto.spi.security.AccessDeniedException.denyDropSchema; import static com.facebook.presto.spi.security.AccessDeniedException.denyDropTable; import static com.facebook.presto.spi.security.AccessDeniedException.denyDropView; @@ -191,6 +192,16 @@ default void checkCanAddColumn(Identity identity, CatalogSchemaTableName table) denyAddColumn(table.toString()); } + /** + * Check if identity is allowed to drop columns from the specified table in a catalog. + * + * @throws com.facebook.presto.spi.security.AccessDeniedException if not allowed + */ + default void checkCanDropColumn(Identity identity, CatalogSchemaTableName table) + { + denyDropColumn(table.toString()); + } + /** * Check if identity is allowed to rename a column in the specified table in a catalog. * diff --git a/presto-main/src/main/java/com/facebook/presto/type/ArrayType.java b/presto-spi/src/main/java/com/facebook/presto/spi/type/ArrayType.java similarity index 91% rename from presto-main/src/main/java/com/facebook/presto/type/ArrayType.java rename to presto-spi/src/main/java/com/facebook/presto/spi/type/ArrayType.java index 01c27ffea225..c4ad958fe353 100644 --- a/presto-main/src/main/java/com/facebook/presto/type/ArrayType.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/type/ArrayType.java @@ -11,20 +11,14 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.facebook.presto.type; +package com.facebook.presto.spi.type; -import com.facebook.presto.operator.scalar.CombineHashFunction; import com.facebook.presto.spi.ConnectorSession; import com.facebook.presto.spi.block.AbstractArrayBlock; import com.facebook.presto.spi.block.ArrayBlockBuilder; import com.facebook.presto.spi.block.Block; import com.facebook.presto.spi.block.BlockBuilder; import com.facebook.presto.spi.block.BlockBuilderStatus; -import com.facebook.presto.spi.type.AbstractType; -import com.facebook.presto.spi.type.Type; -import com.facebook.presto.spi.type.TypeSignature; -import com.facebook.presto.spi.type.TypeSignatureParameter; -import com.google.common.collect.ImmutableList; import io.airlift.slice.Slice; import java.util.ArrayList; @@ -32,8 +26,9 @@ import java.util.List; import static com.facebook.presto.spi.type.StandardTypes.ARRAY; -import static com.facebook.presto.type.TypeUtils.checkElementNotNull; -import static com.facebook.presto.type.TypeUtils.hashPosition; +import static com.facebook.presto.spi.type.TypeUtils.checkElementNotNull; +import static com.facebook.presto.spi.type.TypeUtils.hashPosition; +import static java.util.Collections.singletonList; import static java.util.Objects.requireNonNull; public class ArrayType @@ -92,7 +87,7 @@ public long hash(Block block, int position) Block array = getObject(block, position); long hash = 0; for (int i = 0; i < array.getPositionCount(); i++) { - hash = CombineHashFunction.getHash(hash, hashPosition(elementType, array, i)); + hash = 31 * hash + hashPosition(elementType, array, i); } return hash; } @@ -210,7 +205,7 @@ public BlockBuilder createBlockBuilder(BlockBuilderStatus blockBuilderStatus, in @Override public List getTypeParameters() { - return ImmutableList.of(getElementType()); + return singletonList(getElementType()); } @Override diff --git a/presto-main/src/main/java/com/facebook/presto/type/MapType.java b/presto-spi/src/main/java/com/facebook/presto/spi/type/MapType.java similarity index 68% rename from presto-main/src/main/java/com/facebook/presto/type/MapType.java rename to presto-spi/src/main/java/com/facebook/presto/spi/type/MapType.java index 201254aa884e..7557e7288f66 100644 --- a/presto-main/src/main/java/com/facebook/presto/type/MapType.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/type/MapType.java @@ -11,19 +11,15 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.facebook.presto.type; +package com.facebook.presto.spi.type; import com.facebook.presto.spi.ConnectorSession; import com.facebook.presto.spi.block.Block; import com.facebook.presto.spi.block.BlockBuilder; import com.facebook.presto.spi.block.BlockBuilderStatus; +import com.facebook.presto.spi.block.MapBlock; import com.facebook.presto.spi.block.MapBlockBuilder; -import com.facebook.presto.spi.type.AbstractType; -import com.facebook.presto.spi.type.StandardTypes; -import com.facebook.presto.spi.type.Type; -import com.facebook.presto.spi.type.TypeSignature; -import com.facebook.presto.spi.type.TypeSignatureParameter; -import com.google.common.collect.ImmutableList; +import com.facebook.presto.spi.block.SingleMapBlock; import java.lang.invoke.MethodHandle; import java.util.Collections; @@ -31,14 +27,17 @@ import java.util.List; import java.util.Map; -import static com.facebook.presto.type.TypeUtils.checkElementNotNull; -import static com.facebook.presto.type.TypeUtils.hashPosition; -import static com.google.common.base.Preconditions.checkArgument; +import static com.facebook.presto.spi.type.TypeUtils.checkElementNotNull; +import static com.facebook.presto.spi.type.TypeUtils.hashPosition; +import static java.lang.String.format; +import static java.util.Arrays.asList; import static java.util.Objects.requireNonNull; public class MapType extends AbstractType { + private final boolean useNewMapBlock; + private final Type keyType; private final Type valueType; private static final String MAP_NULL_ELEMENT_MSG = "MAP comparison not supported for null value elements"; @@ -48,24 +47,43 @@ public class MapType private final MethodHandle keyBlockHashCode; private final MethodHandle keyBlockNativeEquals; - MapType(Type keyType, Type valueType, MethodHandle keyBlockNativeEquals, MethodHandle keyNativeHashCode, MethodHandle keyBlockHashCode) + public MapType(boolean useNewMapBlock, Type keyType, Type valueType, MethodHandle keyBlockNativeEquals, MethodHandle keyNativeHashCode, MethodHandle keyBlockHashCode) { super(new TypeSignature(StandardTypes.MAP, TypeSignatureParameter.of(keyType.getTypeSignature()), TypeSignatureParameter.of(valueType.getTypeSignature())), Block.class); - checkArgument(keyType.isComparable(), "key type must be comparable"); + if (!keyType.isComparable()) { + throw new IllegalArgumentException(format("key type must be comparable, got %s", keyType)); + } + this.useNewMapBlock = useNewMapBlock; this.keyType = keyType; this.valueType = valueType; - this.keyBlockNativeEquals = requireNonNull(keyBlockNativeEquals, "keyBlockNativeEquals is null"); - this.keyNativeHashCode = requireNonNull(keyNativeHashCode, "keyNativeHashCode is null"); - this.keyBlockHashCode = requireNonNull(keyBlockHashCode, "keyBlockHashCode is null"); + if (useNewMapBlock) { + requireNonNull(keyBlockNativeEquals, "keyBlockNativeEquals is null"); + requireNonNull(keyNativeHashCode, "keyNativeHashCode is null"); + requireNonNull(keyBlockHashCode, "keyBlockHashCode is null"); + } + else { + if (keyBlockNativeEquals != null) { + throw new IllegalArgumentException("When useNewMapBlock is false, keyBlockNativeEquals should be null."); + } + if (keyNativeHashCode != null) { + throw new IllegalArgumentException("When useNewMapBlock is false, keyNativeHashCode should be null."); + } + if (keyBlockHashCode != null) { + throw new IllegalArgumentException("When useNewMapBlock is false, keyBlockHashCode should be null."); + } + } + this.keyBlockNativeEquals = keyBlockNativeEquals; + this.keyNativeHashCode = keyNativeHashCode; + this.keyBlockHashCode = keyBlockHashCode; } @Override public BlockBuilder createBlockBuilder(BlockBuilderStatus blockBuilderStatus, int expectedEntries, int expectedBytesPerEntry) { - return new MapBlockBuilder(keyType, valueType, keyBlockNativeEquals, keyNativeHashCode, keyBlockHashCode, blockBuilderStatus, expectedEntries); + return new MapBlockBuilder(useNewMapBlock, keyType, valueType, keyBlockNativeEquals, keyNativeHashCode, keyBlockHashCode, blockBuilderStatus, expectedEntries); } @Override @@ -182,10 +200,13 @@ public Object getObjectValue(ConnectorSession session, Block block, int position return null; } - Block mapBlock = block.getObject(position, Block.class); + Block singleMapBlock = block.getObject(position, Block.class); + if (!(singleMapBlock instanceof SingleMapBlock)) { + throw new UnsupportedOperationException("Map is encoded with legacy block representation"); + } Map map = new HashMap<>(); - for (int i = 0; i < mapBlock.getPositionCount(); i += 2) { - map.put(keyType.getObjectValue(session, mapBlock, i), valueType.getObjectValue(session, mapBlock, i + 1)); + for (int i = 0; i < singleMapBlock.getPositionCount(); i += 2) { + map.put(keyType.getObjectValue(session, singleMapBlock, i), valueType.getObjectValue(session, singleMapBlock, i + 1)); } return Collections.unmodifiableMap(map); @@ -218,7 +239,7 @@ public void writeObject(BlockBuilder blockBuilder, Object value) @Override public List getTypeParameters() { - return ImmutableList.of(getKeyType(), getValueType()); + return asList(getKeyType(), getValueType()); } @Override @@ -226,4 +247,18 @@ public String getDisplayName() { return "map(" + keyType.getDisplayName() + ", " + valueType.getDisplayName() + ")"; } + + public MapBlock createBlockFromKeyValue(boolean[] mapIsNull, int[] offsets, Block keyBlock, Block valueBlock) + { + return MapBlock.fromKeyValueBlock( + useNewMapBlock, + mapIsNull, + offsets, + keyBlock, + valueBlock, + this, + keyBlockNativeEquals, + keyNativeHashCode, + keyBlockHashCode); + } } diff --git a/presto-main/src/main/java/com/facebook/presto/type/RowType.java b/presto-spi/src/main/java/com/facebook/presto/spi/type/RowType.java similarity index 81% rename from presto-main/src/main/java/com/facebook/presto/type/RowType.java rename to presto-spi/src/main/java/com/facebook/presto/spi/type/RowType.java index 84da31a00a96..9ca20f3904e8 100644 --- a/presto-main/src/main/java/com/facebook/presto/type/RowType.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/type/RowType.java @@ -11,7 +11,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.facebook.presto.type; +package com.facebook.presto.spi.type; import com.facebook.presto.spi.ConnectorSession; import com.facebook.presto.spi.PrestoException; @@ -21,12 +21,6 @@ import com.facebook.presto.spi.block.BlockBuilder; import com.facebook.presto.spi.block.BlockBuilderStatus; import com.facebook.presto.spi.block.InterleavedBlockBuilder; -import com.facebook.presto.spi.type.AbstractType; -import com.facebook.presto.spi.type.Type; -import com.facebook.presto.spi.type.TypeSignature; -import com.google.common.base.Joiner; -import com.google.common.collect.ImmutableList; -import com.google.common.collect.Lists; import java.util.ArrayList; import java.util.Collections; @@ -34,8 +28,6 @@ import java.util.Optional; import static com.facebook.presto.spi.type.StandardTypes.ROW; -import static com.facebook.presto.type.TypeUtils.hashPosition; -import static com.google.common.collect.ImmutableList.toImmutableList; import static java.util.Objects.requireNonNull; /** @@ -49,20 +41,33 @@ public class RowType public RowType(List fieldTypes, Optional> fieldNames) { - super(new TypeSignature( - ROW, - Lists.transform(fieldTypes, Type::getTypeSignature), - fieldNames.orElse(ImmutableList.of()).stream() - .collect(toImmutableList())), - Block.class); + super(toTypeSignature(fieldTypes, fieldNames), Block.class); - ImmutableList.Builder builder = ImmutableList.builder(); + List fields = new ArrayList<>(); for (int i = 0; i < fieldTypes.size(); i++) { int index = i; - builder.add(new RowField(fieldTypes.get(i), fieldNames.map((names) -> names.get(index)))); + fields.add(new RowField(fieldTypes.get(i), fieldNames.map((names) -> names.get(index)))); } - fields = builder.build(); - this.fieldTypes = ImmutableList.copyOf(fieldTypes); + this.fields = fields; + this.fieldTypes = fieldTypes; + } + + private static TypeSignature toTypeSignature(List fieldTypes, Optional> fieldNames) + { + int size = fieldTypes.size(); + if (size == 0) { + throw new IllegalArgumentException("Row type must have at least 1 field"); + } + + List elementTypeSignatures = new ArrayList<>(); + List literalParameters = new ArrayList<>(); + for (int i = 0; i < size; i++) { + elementTypeSignatures.add(fieldTypes.get(i).getTypeSignature()); + if (fieldNames.isPresent()) { + literalParameters.add(fieldNames.get().get(i)); + } + } + return new TypeSignature(ROW, elementTypeSignatures, literalParameters); } @Override @@ -87,17 +92,21 @@ public BlockBuilder createBlockBuilder(BlockBuilderStatus blockBuilderStatus, in public String getDisplayName() { // Convert to standard sql name - List fieldDisplayNames = new ArrayList<>(); + StringBuilder result = new StringBuilder(); + result.append(ROW).append('('); for (RowField field : fields) { String typeDisplayName = field.getType().getDisplayName(); if (field.getName().isPresent()) { - fieldDisplayNames.add(field.getName().get() + " " + typeDisplayName); + result.append(field.getName().get()).append(' ').append(typeDisplayName); } else { - fieldDisplayNames.add(typeDisplayName); + result.append(typeDisplayName); } + result.append(", "); } - return ROW + "(" + Joiner.on(", ").join(fieldDisplayNames) + ")"; + result.setLength(result.length() - 2); + result.append(')'); + return result.toString(); } @Override @@ -233,7 +242,7 @@ public long hash(Block block, int position) long result = 1; for (int i = 0; i < arrayBlock.getPositionCount(); i++) { Type elementType = fields.get(i).getType(); - result = 31 * result + hashPosition(elementType, arrayBlock, i); + result = 31 * result + TypeUtils.hashPosition(elementType, arrayBlock, i); } return result; } diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/type/TypeUtils.java b/presto-spi/src/main/java/com/facebook/presto/spi/type/TypeUtils.java index efa34bcf0dd4..6648110ccf1b 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/type/TypeUtils.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/type/TypeUtils.java @@ -13,13 +13,18 @@ */ package com.facebook.presto.spi.type; +import com.facebook.presto.spi.PrestoException; import com.facebook.presto.spi.block.Block; import com.facebook.presto.spi.block.BlockBuilder; import io.airlift.slice.Slice; import io.airlift.slice.Slices; +import static com.facebook.presto.spi.StandardErrorCode.NOT_SUPPORTED; + public final class TypeUtils { + public static final int NULL_HASH_CODE = 0; + private TypeUtils() { } @@ -83,4 +88,19 @@ else if (value instanceof String) { type.writeObject(blockBuilder, value); } } + + static long hashPosition(Type type, Block block, int position) + { + if (block.isNull(position)) { + return NULL_HASH_CODE; + } + return type.hash(block, position); + } + + static void checkElementNotNull(boolean isNull, String errorMsg) + { + if (isNull) { + throw new PrestoException(NOT_SUPPORTED, errorMsg); + } + } } diff --git a/presto-spi/src/test/java/com/facebook/presto/spi/TestPage.java b/presto-spi/src/test/java/com/facebook/presto/spi/TestPage.java index 605f7076f12d..ea9da285e399 100644 --- a/presto-spi/src/test/java/com/facebook/presto/spi/TestPage.java +++ b/presto-spi/src/test/java/com/facebook/presto/spi/TestPage.java @@ -87,7 +87,7 @@ public void testCompactDictionaryBlocks() int otherDictionaryUsedPositions = 30; int[] otherDictionaryIds = getDictionaryIds(positionCount, otherDictionaryUsedPositions); SliceArrayBlock dictionary3 = new SliceArrayBlock(70, createExpectedValues(70)); - DictionaryBlock randomSourceIdBlock = new DictionaryBlock(positionCount, dictionary3, otherDictionaryIds); + DictionaryBlock randomSourceIdBlock = new DictionaryBlock(dictionary3, otherDictionaryIds); Page page = new Page(commonSourceIdBlock1, randomSourceIdBlock, commonSourceIdBlock2); page.compact(); diff --git a/presto-spi/src/test/java/com/facebook/presto/spi/block/TestBlockRetainedSizeBreakdown.java b/presto-spi/src/test/java/com/facebook/presto/spi/block/TestBlockRetainedSizeBreakdown.java new file mode 100644 index 000000000000..7928ecc79245 --- /dev/null +++ b/presto-spi/src/test/java/com/facebook/presto/spi/block/TestBlockRetainedSizeBreakdown.java @@ -0,0 +1,207 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.spi.block; + +import com.facebook.presto.spi.type.Type; +import com.google.common.collect.ImmutableList; +import io.airlift.slice.Slice; +import it.unimi.dsi.fastutil.Hash.Strategy; +import it.unimi.dsi.fastutil.objects.Object2LongOpenCustomHashMap; +import org.testng.annotations.Test; + +import java.util.concurrent.atomic.AtomicLong; +import java.util.function.BiConsumer; + +import static com.facebook.presto.spi.type.BigintType.BIGINT; +import static com.facebook.presto.spi.type.DoubleType.DOUBLE; +import static com.facebook.presto.spi.type.IntegerType.INTEGER; +import static com.facebook.presto.spi.type.TinyintType.TINYINT; +import static com.facebook.presto.spi.type.TypeUtils.writeNativeValue; +import static com.facebook.presto.spi.type.VarcharType.VARCHAR; +import static io.airlift.slice.Slices.utf8Slice; +import static org.testng.Assert.assertEquals; + +public class TestBlockRetainedSizeBreakdown +{ + private static final int EXPECTED_ENTRIES = 100; + + @Test + public void testArrayBlock() + { + BlockBuilder arrayBlockBuilder = new ArrayBlockBuilder(BIGINT, new BlockBuilderStatus(), EXPECTED_ENTRIES); + for (int i = 0; i < EXPECTED_ENTRIES; i++) { + BlockBuilder arrayElementBuilder = arrayBlockBuilder.beginBlockEntry(); + writeNativeValue(BIGINT, arrayElementBuilder, castIntegerToObject(i, BIGINT)); + arrayBlockBuilder.closeEntry(); + } + checkRetainedSize(arrayBlockBuilder.build(), false); + } + + @Test + public void testByteArrayBlock() + { + BlockBuilder blockBuilder = new ByteArrayBlockBuilder(new BlockBuilderStatus(), EXPECTED_ENTRIES); + for (int i = 0; i < EXPECTED_ENTRIES; i++) { + blockBuilder.writeByte(i); + } + checkRetainedSize(blockBuilder.build(), false); + } + + @Test + public void testDictionaryBlock() + { + Block keyDictionaryBlock = createSliceArrayBlock(EXPECTED_ENTRIES); + int[] keyIds = new int[EXPECTED_ENTRIES]; + for (int i = 0; i < keyIds.length; i++) { + keyIds[i] = i; + } + checkRetainedSize(new DictionaryBlock(EXPECTED_ENTRIES, keyDictionaryBlock, keyIds), false); + } + + @Test + public void testFixedWidthBlock() + { + BlockBuilder blockBuilder = new FixedWidthBlockBuilder(8, new BlockBuilderStatus(), EXPECTED_ENTRIES); + writeEntries(EXPECTED_ENTRIES, blockBuilder, DOUBLE); + checkRetainedSize(blockBuilder.build(), true); + } + + @Test + public void testIntArrayBlock() + { + BlockBuilder blockBuilder = new IntArrayBlockBuilder(new BlockBuilderStatus(), EXPECTED_ENTRIES); + writeEntries(EXPECTED_ENTRIES, blockBuilder, INTEGER); + checkRetainedSize(blockBuilder.build(), false); + } + + @Test + public void testInterleavedBlock() + { + BlockBuilder blockBuilder = new InterleavedBlockBuilder(ImmutableList.of(INTEGER, INTEGER), new BlockBuilderStatus(), EXPECTED_ENTRIES); + writeEntries(EXPECTED_ENTRIES, blockBuilder, INTEGER); + checkRetainedSize(blockBuilder.build(), false); + } + + @Test + public void testLongArrayBlock() + { + BlockBuilder blockBuilder = new LongArrayBlockBuilder(new BlockBuilderStatus(), EXPECTED_ENTRIES); + writeEntries(EXPECTED_ENTRIES, blockBuilder, BIGINT); + checkRetainedSize(blockBuilder.build(), false); + } + + @Test + public void testRunLengthEncodedBlock() + { + BlockBuilder blockBuilder = new LongArrayBlockBuilder(new BlockBuilderStatus(), 1); + writeEntries(1, blockBuilder, BIGINT); + checkRetainedSize(new RunLengthEncodedBlock(blockBuilder.build(), 1), false); + } + + @Test + public void testShortArrayBlock() + { + BlockBuilder blockBuilder = new ShortArrayBlockBuilder(new BlockBuilderStatus(), EXPECTED_ENTRIES); + for (int i = 0; i < EXPECTED_ENTRIES; i++) { + blockBuilder.writeShort(i); + } + checkRetainedSize(blockBuilder.build(), false); + } + + @Test + public void testSliceArrayBlock() + { + checkRetainedSize(createSliceArrayBlock(EXPECTED_ENTRIES), true); + } + + @Test + public void testVariableWidthBlock() + { + BlockBuilder blockBuilder = new VariableWidthBlockBuilder(new BlockBuilderStatus(), EXPECTED_ENTRIES, 4); + writeEntries(EXPECTED_ENTRIES, blockBuilder, VARCHAR); + checkRetainedSize(blockBuilder.build(), false); + } + + private static final class ObjectStrategy + implements Strategy + { + @Override + public int hashCode(Object object) + { + return System.identityHashCode(object); + } + + @Override + public boolean equals(Object left, Object right) + { + return left == right; + } + } + + private static void checkRetainedSize(Block block, boolean getRegionCreateNewObjects) + { + AtomicLong objectSize = new AtomicLong(); + Object2LongOpenCustomHashMap trackedObjects = new Object2LongOpenCustomHashMap<>(new ObjectStrategy()); + + BiConsumer consumer = (object, size) -> { + objectSize.addAndGet(size); + trackedObjects.addTo(object, 1); + }; + + block.retainedBytesForEachPart(consumer); + assertEquals(objectSize.get(), block.getRetainedSizeInBytes()); + + Block copyBlock = block.getRegion(0, block.getPositionCount() / 2); + copyBlock.retainedBytesForEachPart(consumer); + assertEquals(objectSize.get(), block.getRetainedSizeInBytes() + copyBlock.getRetainedSizeInBytes()); + + assertEquals(trackedObjects.getLong(block), 1); + assertEquals(trackedObjects.getLong(copyBlock), 1); + trackedObjects.remove(block); + trackedObjects.remove(copyBlock); + for (long value : trackedObjects.values()) { + assertEquals(value, getRegionCreateNewObjects ? 1 : 2); + } + } + + private static void writeEntries(int expectedEntries, BlockBuilder blockBuilder, Type type) + { + for (int i = 0; i < expectedEntries; i++) { + writeNativeValue(type, blockBuilder, castIntegerToObject(i, type)); + } + } + + private static Object castIntegerToObject(int value, Type type) + { + if (type == INTEGER || type == TINYINT || type == BIGINT) { + return (long) value; + } + if (type == VARCHAR) { + return String.valueOf(value); + } + if (type == DOUBLE) { + return (double) value; + } + throw new UnsupportedOperationException(); + } + + private static Block createSliceArrayBlock(int entries) + { + Slice[] sliceArray = new Slice[entries]; + for (int i = 0; i < entries; i++) { + sliceArray[i] = utf8Slice(i + ""); + } + return new SliceArrayBlock(sliceArray.length, sliceArray); + } +} diff --git a/presto-spi/src/test/java/com/facebook/presto/spi/block/TestDictionaryBlockEncoding.java b/presto-spi/src/test/java/com/facebook/presto/spi/block/TestDictionaryBlockEncoding.java index 843350ee0adb..6f1f6ec3e272 100644 --- a/presto-spi/src/test/java/com/facebook/presto/spi/block/TestDictionaryBlockEncoding.java +++ b/presto-spi/src/test/java/com/facebook/presto/spi/block/TestDictionaryBlockEncoding.java @@ -92,7 +92,7 @@ public void testRoundTrip() } BlockEncoding blockEncoding = new DictionaryBlockEncoding(new VariableWidthBlockEncoding()); - DictionaryBlock dictionaryBlock = new DictionaryBlock(positionCount, dictionary, ids); + DictionaryBlock dictionaryBlock = new DictionaryBlock(dictionary, ids); DynamicSliceOutput sliceOutput = new DynamicSliceOutput(1024); blockEncoding.writeBlock(sliceOutput, dictionaryBlock); diff --git a/presto-main/src/test/java/com/facebook/presto/type/TestArrayType.java b/presto-spi/src/test/java/com/facebook/presto/spi/type/TestArrayType.java similarity index 95% rename from presto-main/src/test/java/com/facebook/presto/type/TestArrayType.java rename to presto-spi/src/test/java/com/facebook/presto/spi/type/TestArrayType.java index a5281e225bfe..ba1b0dcc4c75 100644 --- a/presto-main/src/test/java/com/facebook/presto/type/TestArrayType.java +++ b/presto-spi/src/test/java/com/facebook/presto/spi/type/TestArrayType.java @@ -11,7 +11,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.facebook.presto.type; +package com.facebook.presto.spi.type; import org.testng.annotations.Test; diff --git a/presto-main/src/test/java/com/facebook/presto/type/TestMapType.java b/presto-spi/src/test/java/com/facebook/presto/spi/type/TestMapType.java similarity index 51% rename from presto-main/src/test/java/com/facebook/presto/type/TestMapType.java rename to presto-spi/src/test/java/com/facebook/presto/spi/type/TestMapType.java index 0903d5b1940e..59260f5fd01e 100644 --- a/presto-main/src/test/java/com/facebook/presto/type/TestMapType.java +++ b/presto-spi/src/test/java/com/facebook/presto/spi/type/TestMapType.java @@ -11,14 +11,14 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.facebook.presto.type; +package com.facebook.presto.spi.type; +import com.facebook.presto.spi.block.MethodHandleUtil; import org.testng.annotations.Test; import static com.facebook.presto.spi.type.BigintType.BIGINT; import static com.facebook.presto.spi.type.VarcharType.VARCHAR; import static com.facebook.presto.spi.type.VarcharType.createVarcharType; -import static com.facebook.presto.util.StructuralTestUtil.mapType; import static org.testng.Assert.assertEquals; public class TestMapType @@ -26,10 +26,27 @@ public class TestMapType @Test public void testMapDisplayName() { - MapType mapType = mapType(BIGINT, createVarcharType(42)); + MapType mapType = new MapType( + true, + BIGINT, + createVarcharType(42), + MethodHandleUtil.methodHandle(TestMapType.class, "throwUnsupportedOperation"), + MethodHandleUtil.methodHandle(TestMapType.class, "throwUnsupportedOperation"), + MethodHandleUtil.methodHandle(TestMapType.class, "throwUnsupportedOperation")); assertEquals(mapType.getDisplayName(), "map(bigint, varchar(42))"); - mapType = mapType(BIGINT, VARCHAR); + mapType = new MapType( + true, + BIGINT, + VARCHAR, + MethodHandleUtil.methodHandle(TestMapType.class, "throwUnsupportedOperation"), + MethodHandleUtil.methodHandle(TestMapType.class, "throwUnsupportedOperation"), + MethodHandleUtil.methodHandle(TestMapType.class, "throwUnsupportedOperation")); assertEquals(mapType.getDisplayName(), "map(bigint, varchar)"); } + + public static void throwUnsupportedOperation() + { + throw new UnsupportedOperationException(); + } } diff --git a/presto-main/src/test/java/com/facebook/presto/type/TestRowType.java b/presto-spi/src/test/java/com/facebook/presto/spi/type/TestRowType.java similarity index 56% rename from presto-main/src/test/java/com/facebook/presto/type/TestRowType.java rename to presto-spi/src/test/java/com/facebook/presto/spi/type/TestRowType.java index 1bed39e7b405..d86bf78267af 100644 --- a/presto-main/src/test/java/com/facebook/presto/type/TestRowType.java +++ b/presto-spi/src/test/java/com/facebook/presto/spi/type/TestRowType.java @@ -11,18 +11,17 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.facebook.presto.type; +package com.facebook.presto.spi.type; -import com.facebook.presto.spi.type.Type; import org.testng.annotations.Test; import java.util.List; import java.util.Optional; +import static com.facebook.presto.spi.block.MethodHandleUtil.methodHandle; import static com.facebook.presto.spi.type.BooleanType.BOOLEAN; import static com.facebook.presto.spi.type.DoubleType.DOUBLE; import static com.facebook.presto.spi.type.VarcharType.VARCHAR; -import static com.facebook.presto.util.StructuralTestUtil.mapType; import static java.util.Arrays.asList; import static org.testng.Assert.assertEquals; @@ -31,7 +30,17 @@ public class TestRowType @Test public void testRowDisplayName() { - List types = asList(BOOLEAN, DOUBLE, new ArrayType(VARCHAR), mapType(BOOLEAN, DOUBLE)); + List types = asList( + BOOLEAN, + DOUBLE, + new ArrayType(VARCHAR), + new MapType( + true, + BOOLEAN, + DOUBLE, + methodHandle(TestMapType.class, "throwUnsupportedOperation"), + methodHandle(TestMapType.class, "throwUnsupportedOperation"), + methodHandle(TestMapType.class, "throwUnsupportedOperation"))); Optional> names = Optional.of(asList("bool_col", "double_col", "array_col", "map_col")); RowType row = new RowType(types, names); assertEquals( @@ -42,10 +51,25 @@ public void testRowDisplayName() @Test public void testRowDisplayNoColumnNames() { - List types = asList(BOOLEAN, DOUBLE, new ArrayType(VARCHAR), mapType(BOOLEAN, DOUBLE)); + List types = asList( + BOOLEAN, + DOUBLE, + new ArrayType(VARCHAR), + new MapType( + true, + BOOLEAN, + DOUBLE, + methodHandle(TestMapType.class, "throwUnsupportedOperation"), + methodHandle(TestMapType.class, "throwUnsupportedOperation"), + methodHandle(TestMapType.class, "throwUnsupportedOperation"))); RowType row = new RowType(types, Optional.empty()); assertEquals( row.getDisplayName(), "row(boolean, double, array(varchar), map(boolean, double))"); } + + public static void throwUnsupportedOperation() + { + throw new UnsupportedOperationException(); + } } diff --git a/presto-sqlserver/pom.xml b/presto-sqlserver/pom.xml index 09028af692db..dd44e45bfc93 100644 --- a/presto-sqlserver/pom.xml +++ b/presto-sqlserver/pom.xml @@ -3,7 +3,7 @@ presto-root com.facebook.presto - 0.179-tw-0.36 + 0.181-tw-0.37 4.0.0 diff --git a/presto-teradata-functions/pom.xml b/presto-teradata-functions/pom.xml index f089ca39b26c..77c6f701c613 100644 --- a/presto-teradata-functions/pom.xml +++ b/presto-teradata-functions/pom.xml @@ -4,7 +4,7 @@ com.facebook.presto presto-root - 0.179-tw-0.36 + 0.181-tw-0.37 presto-teradata-functions diff --git a/presto-teradata-functions/src/main/java/com/facebook/presto/teradata/functions/TeradataDateFunctions.java b/presto-teradata-functions/src/main/java/com/facebook/presto/teradata/functions/TeradataDateFunctions.java index 3ffb07422a0b..29f12b884fbc 100644 --- a/presto-teradata-functions/src/main/java/com/facebook/presto/teradata/functions/TeradataDateFunctions.java +++ b/presto-teradata-functions/src/main/java/com/facebook/presto/teradata/functions/TeradataDateFunctions.java @@ -37,7 +37,7 @@ import static com.facebook.presto.spi.type.TimeZoneKey.MAX_TIME_ZONE_KEY; import static com.facebook.presto.spi.type.TimeZoneKey.getTimeZoneKeys; import static com.facebook.presto.teradata.functions.dateformat.DateFormatParser.createDateTimeFormatter; -import static com.google.common.base.Throwables.propagateIfInstanceOf; +import static com.google.common.base.Throwables.throwIfInstanceOf; import static io.airlift.slice.Slices.utf8Slice; import static java.nio.charset.StandardCharsets.UTF_8; @@ -92,8 +92,8 @@ public static long toDate( return (long) castToDate.invokeExact(session, millis); } catch (Throwable t) { - propagateIfInstanceOf(t, Error.class); - propagateIfInstanceOf(t, PrestoException.class); + throwIfInstanceOf(t, Error.class); + throwIfInstanceOf(t, PrestoException.class); throw new PrestoException(GENERIC_INTERNAL_ERROR, t); } } diff --git a/presto-teradata-functions/src/main/java/com/facebook/presto/teradata/functions/TeradataStringFunctions.java b/presto-teradata-functions/src/main/java/com/facebook/presto/teradata/functions/TeradataStringFunctions.java index f8785f650c8f..3fa3aa89890c 100644 --- a/presto-teradata-functions/src/main/java/com/facebook/presto/teradata/functions/TeradataStringFunctions.java +++ b/presto-teradata-functions/src/main/java/com/facebook/presto/teradata/functions/TeradataStringFunctions.java @@ -27,7 +27,7 @@ import java.lang.invoke.MethodHandle; import static com.facebook.presto.spi.StandardErrorCode.GENERIC_INTERNAL_ERROR; -import static com.google.common.base.Throwables.propagateIfInstanceOf; +import static com.google.common.base.Throwables.throwIfInstanceOf; import static java.nio.charset.StandardCharsets.UTF_16BE; public final class TeradataStringFunctions @@ -52,8 +52,8 @@ public static long index( return (long) method.invokeExact(string, substring); } catch (Throwable t) { - propagateIfInstanceOf(t, Error.class); - propagateIfInstanceOf(t, PrestoException.class); + throwIfInstanceOf(t, Error.class); + throwIfInstanceOf(t, PrestoException.class); throw new PrestoException(GENERIC_INTERNAL_ERROR, t); } } @@ -75,8 +75,8 @@ public static Slice substring( return (Slice) method.invokeExact(utf8, start); } catch (Throwable t) { - propagateIfInstanceOf(t, Error.class); - propagateIfInstanceOf(t, PrestoException.class); + throwIfInstanceOf(t, Error.class); + throwIfInstanceOf(t, PrestoException.class); throw new PrestoException(GENERIC_INTERNAL_ERROR, t); } } @@ -99,8 +99,8 @@ public static Slice substring( return (Slice) method.invokeExact(utf8, start, length); } catch (Throwable t) { - propagateIfInstanceOf(t, Error.class); - propagateIfInstanceOf(t, PrestoException.class); + throwIfInstanceOf(t, Error.class); + throwIfInstanceOf(t, PrestoException.class); throw new PrestoException(GENERIC_INTERNAL_ERROR, t); } } diff --git a/presto-testing-server-launcher/pom.xml b/presto-testing-server-launcher/pom.xml index 996b904a9e00..7f5b625a7f1d 100644 --- a/presto-testing-server-launcher/pom.xml +++ b/presto-testing-server-launcher/pom.xml @@ -5,7 +5,7 @@ com.facebook.presto presto-root - 0.179-tw-0.36 + 0.181-tw-0.37 presto-testing-server-launcher diff --git a/presto-tests/pom.xml b/presto-tests/pom.xml index 2286f2c8deb1..bc41260dfdc3 100644 --- a/presto-tests/pom.xml +++ b/presto-tests/pom.xml @@ -5,7 +5,7 @@ presto-root com.facebook.presto - 0.179-tw-0.36 + 0.181-tw-0.37 presto-tests @@ -67,11 +67,6 @@ discovery-server - - io.airlift - http-client - - io.airlift http-server @@ -127,6 +122,11 @@ testing + + com.squareup.okhttp3 + okhttp + + com.fasterxml.jackson.core jackson-annotations diff --git a/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestDistributedQueries.java b/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestDistributedQueries.java index 569ef2ed4169..99fa8ac4f681 100644 --- a/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestDistributedQueries.java +++ b/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestDistributedQueries.java @@ -39,6 +39,7 @@ import static com.facebook.presto.testing.TestingAccessControlManager.TestingPrivilegeType.CREATE_VIEW; import static com.facebook.presto.testing.TestingAccessControlManager.TestingPrivilegeType.CREATE_VIEW_WITH_SELECT_TABLE; import static com.facebook.presto.testing.TestingAccessControlManager.TestingPrivilegeType.CREATE_VIEW_WITH_SELECT_VIEW; +import static com.facebook.presto.testing.TestingAccessControlManager.TestingPrivilegeType.DROP_COLUMN; import static com.facebook.presto.testing.TestingAccessControlManager.TestingPrivilegeType.DROP_TABLE; import static com.facebook.presto.testing.TestingAccessControlManager.TestingPrivilegeType.RENAME_COLUMN; import static com.facebook.presto.testing.TestingAccessControlManager.TestingPrivilegeType.RENAME_TABLE; @@ -291,8 +292,11 @@ public void testExplainAnalyzeDDL() private void assertExplainAnalyze(@Language("SQL") String query) { String value = getOnlyElement(computeActual(query).getOnlyColumnAsSet()); + + assertTrue(value.matches("(?s:.*)CPU:.*, Input:.*, Output(?s:.*)"), format("Expected output to contain \"CPU:.*, Input:.*, Output\", but it is %s", value)); + // TODO: check that rendered plan is as expected, once stats are collected in a consistent way - assertTrue(value.contains("Cost: "), format("Expected output to contain \"Cost: \", but it is %s", value)); + // assertTrue(value.contains("Cost: "), format("Expected output to contain \"Cost: \", but it is %s", value)); } protected void assertCreateTableAsSelect(String table, @Language("SQL") String query, @Language("SQL") String rowCountQuery) @@ -351,6 +355,17 @@ public void testRenameColumn() assertFalse(getQueryRunner().tableExists(getSession(), "test_rename_column")); } + @Test + public void testDropColumn() + { + assertUpdate("CREATE TABLE test_drop_column AS SELECT 123 x, 111 a", 1); + + assertUpdate("ALTER TABLE test_drop_column DROP COLUMN x"); + assertQueryFails("SELECT x FROM test_drop_column", ".* Column 'x' cannot be resolved"); + + assertQueryFails("ALTER TABLE test_drop_column DROP COLUMN a", "Cannot drop the only column in a table"); + } + @Test public void testAddColumn() { @@ -808,6 +823,7 @@ public void testNonQueryAccessControl() assertAccessDenied("DROP TABLE orders", "Cannot drop table .*.orders.*", privilege("orders", DROP_TABLE)); assertAccessDenied("ALTER TABLE orders RENAME TO foo", "Cannot rename table .*.orders.* to .*.foo.*", privilege("orders", RENAME_TABLE)); assertAccessDenied("ALTER TABLE orders ADD COLUMN foo bigint", "Cannot add a column to table .*.orders.*", privilege("orders", ADD_COLUMN)); + assertAccessDenied("ALTER TABLE orders DROP COLUMN foo", "Cannot drop a column from table .*.orders.*", privilege("orders", DROP_COLUMN)); assertAccessDenied("ALTER TABLE orders RENAME COLUMN orderkey TO foo", "Cannot rename a column in table .*.orders.*", privilege("orders", RENAME_COLUMN)); assertAccessDenied("CREATE VIEW foo as SELECT * FROM orders", "Cannot create view .*.foo.*", privilege("foo", CREATE_VIEW)); // todo add DROP VIEW test... not all connectors have view support diff --git a/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestQueries.java b/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestQueries.java index 0d0b75f6f61b..cd548a6c29f7 100644 --- a/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestQueries.java +++ b/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestQueries.java @@ -92,6 +92,7 @@ import static com.google.common.collect.Iterables.transform; import static io.airlift.tpch.TpchTable.ORDERS; import static java.lang.String.format; +import static java.util.Arrays.asList; import static java.util.stream.Collectors.joining; import static java.util.stream.Collectors.toList; import static java.util.stream.IntStream.range; @@ -587,6 +588,16 @@ public void testUnnest() "SELECT * FROM (SELECT custkey FROM orders ORDER BY orderkey LIMIT 1) CROSS JOIN (VALUES (10, 1), (20, 2), (30, 3))"); assertQuery("SELECT * FROM orders, UNNEST(ARRAY[1])", "SELECT orders.*, 1 FROM orders"); + + assertQueryFails( + "SELECT * FROM (VALUES array[2, 2]) a(x) LEFT OUTER JOIN UNNEST(x) ON true", + "line .*: UNNEST on other than the right side of CROSS JOIN is not supported"); + assertQueryFails( + "SELECT * FROM (VALUES array[2, 2]) a(x) RIGHT OUTER JOIN UNNEST(x) ON true", + "line .*: UNNEST on other than the right side of CROSS JOIN is not supported"); + assertQueryFails( + "SELECT * FROM (VALUES array[2, 2]) a(x) FULL OUTER JOIN UNNEST(x) ON true", + "line .*: UNNEST on other than the right side of CROSS JOIN is not supported"); } @Test @@ -621,6 +632,9 @@ public void testRows() public void testMaps() { assertQuery("SELECT m[max_key] FROM (SELECT map_agg(orderkey, orderkey) m, max(orderkey) max_key FROM orders)", "SELECT max(orderkey) FROM orders"); + // Make sure that even if the map constructor throws with the NULL key the block builders are left in a consistent state + // and the TRY() call eventually succeeds and return NULL values. + assertQuery("SELECT JSON_FORMAT(CAST(TRY(MAP(ARRAY[NULL], ARRAY[x])) AS JSON)) FROM (VALUES 1, 2) t(x)", "SELECT * FROM (VALUES NULL, NULL)"); } @Test @@ -1105,6 +1119,11 @@ public void testOrderByWithOutputColumnReference() assertQueryOrdered("select a as foo FROM (values (1,2),(3,2)) t(a,b) GROUP BY GROUPING SETS ((a), (a, b)) HAVING b IS NOT NULL ORDER BY -a", "VALUES 3, 1"); assertQueryOrdered("select max(a) FROM (values (1,2),(3,2)) t(a,b) ORDER BY max(-a)", "VALUES 3"); assertQueryFails("SELECT max(a) AS a FROM (values (1,2)) t(a,b) GROUP BY b ORDER BY max(a+b)", ".*Invalid reference to output projection attribute from ORDER BY aggregation"); + assertQueryOrdered("SELECT -a as a, a as b FROM (VALUES 1, 2) t(a) GROUP BY t.a ORDER BY a", "VALUES (-2, 2), (-1, 1)"); + assertQueryOrdered("SELECT -a as a, a as b FROM (VALUES 1, 2) t(a) GROUP BY t.a ORDER BY t.a", "VALUES (-1, 1), (-2, 2)"); + assertQueryOrdered("SELECT -a as a, a as b FROM (VALUES 1, 2) t(a) GROUP BY a ORDER BY t.a", "VALUES (-1, 1), (-2, 2)"); + assertQueryOrdered("SELECT -a as a, a as b FROM (VALUES 1, 2) t(a) GROUP BY a ORDER BY t.a+2*a", "VALUES (-2, 2), (-1, 1)"); + assertQueryOrdered("SELECT -a as a, a as b FROM (VALUES 1, 2) t(a) GROUP BY t.a ORDER BY t.a+2*a", "VALUES (-2, 2), (-1, 1)"); // lambdas assertQueryOrdered("SELECT x as y FROM (values (1,2), (2,3)) t(x, y) GROUP BY x ORDER BY apply(x, x -> -x) + 2*x", "VALUES 1, 2"); @@ -1172,6 +1191,20 @@ public void testOrderByWithAggregation() "GROUP BY x\n" + "ORDER BY sum(cast(t.x AS double))", "VALUES ('1.0', 1.0)"); + + Session legacyOrderBy = Session.builder(getSession()) + .setSystemProperty(LEGACY_ORDER_BY, "true") + .build(); + + for (Session session : asList(getSession(), legacyOrderBy)) { + for (String groupBy : asList("x.letter", "letter")) { + for (String orderBy : asList("x.letter", "letter")) { + for (String output : asList("", ", letter", ", letter as y")) { + assertQueryOrdered(session, format("select count(*) %s from (select substr(name,1,1) letter from nation) x group by %s order by %s", output, groupBy, orderBy)); + } + } + } + } } @Test @@ -1848,19 +1881,19 @@ public void testGrouping() throws Exception { assertQuery( - "SELECT a, b as t, sum(c), grouping(a, b) + grouping(a) " + - "FROM (VALUES ('h', 'j', 11), ('k', 'l', 7)) AS t (a, b, c) " + - "GROUP BY GROUPING SETS ( (a), (b)) " + - "ORDER BY grouping(b) ASC", - "VALUES (NULL, 'j', 11, 3), (NULL, 'l', 7, 3), ('h', NULL, 11, 1), ('k', NULL, 7, 1)"); + "SELECT a, b as t, sum(c), grouping(a, b) + grouping(a) " + + "FROM (VALUES ('h', 'j', 11), ('k', 'l', 7)) AS t (a, b, c) " + + "GROUP BY GROUPING SETS ( (a), (b)) " + + "ORDER BY grouping(b) ASC", + "VALUES (NULL, 'j', 11, 3), (NULL, 'l', 7, 3), ('h', NULL, 11, 1), ('k', NULL, 7, 1)"); assertQuery( - "SELECT a, sum(b), grouping(a) FROM (VALUES ('h', 11, 0), ('k', 7, 0)) AS t (a, b, c) GROUP BY GROUPING SETS (a)", - "VALUES ('h', 11, 0), ('k', 7, 0)"); + "SELECT a, sum(b), grouping(a) FROM (VALUES ('h', 11, 0), ('k', 7, 0)) AS t (a, b, c) GROUP BY GROUPING SETS (a)", + "VALUES ('h', 11, 0), ('k', 7, 0)"); assertQuery( - "SELECT a, b, sum(c), grouping(a, b) FROM (VALUES ('h', 'j', 11), ('k', 'l', 7) ) AS t (a, b, c) GROUP BY GROUPING SETS ( (a), (b)) HAVING grouping(a, b) > 1 ", - "VALUES (NULL, 'j', 11, 2), (NULL, 'l', 7, 2)"); + "SELECT a, b, sum(c), grouping(a, b) FROM (VALUES ('h', 'j', 11), ('k', 'l', 7) ) AS t (a, b, c) GROUP BY GROUPING SETS ( (a), (b)) HAVING grouping(a, b) > 1 ", + "VALUES (NULL, 'j', 11, 2), (NULL, 'l', 7, 2)"); assertQuery("SELECT a, grouping(a) * 1.0 FROM (VALUES (1) ) AS t (a) GROUP BY a", "VALUES (1, 0.0)"); @@ -1869,7 +1902,7 @@ public void testGrouping() "VALUES (1, 0, 0)"); assertQuery("SELECT grouping(a) FROM (VALUES ('h', 'j', 11), ('k', 'l', 7)) AS t (a, b, c) GROUP BY GROUPING SETS (a,c), c*2", - "VALUES (0), (1), (0), (1)"); + "VALUES (0), (1), (0), (1)"); } @Test @@ -1904,23 +1937,23 @@ public void testGroupingInWindowFunction() throws Exception { assertQuery( - "SELECT orderkey, custkey, sum(totalprice), grouping(orderkey)+grouping(custkey) as g, " + - " rank() OVER (PARTITION BY grouping(orderkey)+grouping(custkey), " + - " CASE WHEN grouping(orderkey) = 0 THEN custkey END ORDER BY orderkey ASC) as r " + - "FROM orders " + - "GROUP BY ROLLUP (orderkey, custkey) " + - "ORDER BY orderkey, custkey " + - "LIMIT 10", - "VALUES (1, 370, 172799.49, 0, 1), " + - " (1, NULL, 172799.49, 1, 1), " + - " (2, 781, 38426.09, 0, 1), " + - " (2, NULL, 38426.09, 1, 2), " + - " (3, 1234, 205654.30, 0, 1), " + - " (3, NULL, 205654.30, 1, 3), " + - " (4, 1369, 56000.91, 0, 1), " + - " (4, NULL, 56000.91, 1, 4), " + - " (5, 445, 105367.67, 0, 1), " + - " (5, NULL, 105367.67, 1, 5)"); + "SELECT orderkey, custkey, sum(totalprice), grouping(orderkey)+grouping(custkey) as g, " + + " rank() OVER (PARTITION BY grouping(orderkey)+grouping(custkey), " + + " CASE WHEN grouping(orderkey) = 0 THEN custkey END ORDER BY orderkey ASC) as r " + + "FROM orders " + + "GROUP BY ROLLUP (orderkey, custkey) " + + "ORDER BY orderkey, custkey " + + "LIMIT 10", + "VALUES (1, 370, 172799.49, 0, 1), " + + " (1, NULL, 172799.49, 1, 1), " + + " (2, 781, 38426.09, 0, 1), " + + " (2, NULL, 38426.09, 1, 2), " + + " (3, 1234, 205654.30, 0, 1), " + + " (3, NULL, 205654.30, 1, 3), " + + " (4, 1369, 56000.91, 0, 1), " + + " (4, NULL, 56000.91, 1, 4), " + + " (5, 445, 105367.67, 0, 1), " + + " (5, NULL, 105367.67, 1, 5)"); } @Test @@ -1935,54 +1968,54 @@ public void testGroupingInTableSubquery() // Inner query has a single GROUP BY and outer query has GROUPING SETS assertQuery( - "SELECT orderkey, custkey, sum(agg_price) as outer_sum, grouping(orderkey, custkey), g " + - "FROM " + - " (SELECT orderkey, custkey, sum(totalprice) as agg_price, grouping(custkey, orderkey) as g " + - " FROM orders " + - " GROUP BY orderkey, custkey " + - " ORDER BY agg_price ASC " + - " LIMIT 5) as t " + - "GROUP BY GROUPING SETS ((orderkey, custkey), g) " + - "ORDER BY outer_sum", - "VALUES (35271, 334, 874.89, 0, NULL), " + - " (28647, 1351, 924.33, 0, NULL), " + - " (58145, 862, 929.03, 0, NULL), " + - " (8354, 634, 974.04, 0, NULL), " + - " (37415, 301, 986.63, 0, NULL), " + - " (NULL, NULL, 4688.92, 3, 0)"); + "SELECT orderkey, custkey, sum(agg_price) as outer_sum, grouping(orderkey, custkey), g " + + "FROM " + + " (SELECT orderkey, custkey, sum(totalprice) as agg_price, grouping(custkey, orderkey) as g " + + " FROM orders " + + " GROUP BY orderkey, custkey " + + " ORDER BY agg_price ASC " + + " LIMIT 5) as t " + + "GROUP BY GROUPING SETS ((orderkey, custkey), g) " + + "ORDER BY outer_sum", + "VALUES (35271, 334, 874.89, 0, NULL), " + + " (28647, 1351, 924.33, 0, NULL), " + + " (58145, 862, 929.03, 0, NULL), " + + " (8354, 634, 974.04, 0, NULL), " + + " (37415, 301, 986.63, 0, NULL), " + + " (NULL, NULL, 4688.92, 3, 0)"); // Inner query has GROUPING SETS and outer query has GROUP BY assertQuery( - "SELECT orderkey, custkey, g, sum(agg_price) as outer_sum, grouping(orderkey, custkey) " + - "FROM " + - " (SELECT orderkey, custkey, sum(totalprice) as agg_price, grouping(custkey, orderkey) as g " + - " FROM orders " + - " GROUP BY GROUPING SETS ((custkey), (orderkey)) " + - " ORDER BY agg_price ASC " + - " LIMIT 5) as t " + - "GROUP BY orderkey, custkey, g", - "VALUES (28647, NULL, 2, 924.33, 0), " + - " (8354, NULL, 2, 974.04, 0), " + - " (37415, NULL, 2, 986.63, 0), " + - " (58145, NULL, 2, 929.03, 0), " + - " (35271, NULL, 2, 874.89, 0)"); + "SELECT orderkey, custkey, g, sum(agg_price) as outer_sum, grouping(orderkey, custkey) " + + "FROM " + + " (SELECT orderkey, custkey, sum(totalprice) as agg_price, grouping(custkey, orderkey) as g " + + " FROM orders " + + " GROUP BY GROUPING SETS ((custkey), (orderkey)) " + + " ORDER BY agg_price ASC " + + " LIMIT 5) as t " + + "GROUP BY orderkey, custkey, g", + "VALUES (28647, NULL, 2, 924.33, 0), " + + " (8354, NULL, 2, 974.04, 0), " + + " (37415, NULL, 2, 986.63, 0), " + + " (58145, NULL, 2, 929.03, 0), " + + " (35271, NULL, 2, 874.89, 0)"); // Inner query has GROUPING SETS but no grouping and outer query has a simple GROUP BY assertQuery( - "SELECT orderkey, custkey, sum(agg_price) as outer_sum, grouping(orderkey, custkey) " + - "FROM " + - " (SELECT orderkey, custkey, sum(totalprice) as agg_price " + - " FROM orders " + - " GROUP BY GROUPING SETS ((custkey), (orderkey)) " + - " ORDER BY agg_price ASC NULLS FIRST) as t " + - "GROUP BY orderkey, custkey " + - "ORDER BY outer_sum ASC NULLS FIRST " + - "LIMIT 5", - "VALUES (35271, NULL, 874.89, 0), " + - " (28647, NULL, 924.33, 0), " + - " (58145, NULL, 929.03, 0), " + - " (8354, NULL, 974.04, 0), " + - " (37415, NULL, 986.63, 0)"); + "SELECT orderkey, custkey, sum(agg_price) as outer_sum, grouping(orderkey, custkey) " + + "FROM " + + " (SELECT orderkey, custkey, sum(totalprice) as agg_price " + + " FROM orders " + + " GROUP BY GROUPING SETS ((custkey), (orderkey)) " + + " ORDER BY agg_price ASC NULLS FIRST) as t " + + "GROUP BY orderkey, custkey " + + "ORDER BY outer_sum ASC NULLS FIRST " + + "LIMIT 5", + "VALUES (35271, NULL, 874.89, 0), " + + " (28647, NULL, 924.33, 0), " + + " (58145, NULL, 929.03, 0), " + + " (8354, NULL, 974.04, 0), " + + " (37415, NULL, 986.63, 0)"); } @Test @@ -6043,11 +6076,11 @@ public void testChainedUnionsWithOrder() public void testUnionWithTopN() { assertQuery("SELECT * FROM (" + - " SELECT regionkey FROM nation " + - " UNION ALL " + - " SELECT nationkey FROM nation" + - ") t(a) " + - "ORDER BY a LIMIT 1", + " SELECT regionkey FROM nation " + + " UNION ALL " + + " SELECT nationkey FROM nation" + + ") t(a) " + + "ORDER BY a LIMIT 1", "SELECT 0"); } @@ -7047,11 +7080,11 @@ public void testCorrelatedScalarSubqueries() assertQueryFails("SELECT * FROM lineitem l WHERE 1 = (SELECT (SELECT 2 * l.orderkey))", errorMsg); // explicit limit in subquery - assertQueryFails("SELECT (SELECT count(*) FROM (SELECT * FROM (values (7,1)) t(orderkey, value) WHERE orderkey = corr_key LIMIT 1)) FROM (values 7) t(corr_key)", errorMsg); + assertQueryFails("SELECT (SELECT count(*) FROM (VALUES (7,1)) t(orderkey, value) WHERE orderkey = corr_key LIMIT 1) FROM (values 7) t(corr_key)", errorMsg); } @Test - public void testCorrelatedScalarSubqueriesWithCountScalarAggregationAndEqualityPredicatesInWhere() + public void testCorrelatedScalarSubqueriesWithScalarAggregationAndEqualityPredicatesInWhere() { assertQuery("SELECT (SELECT count(*) WHERE o.orderkey = 1) FROM orders o"); assertQuery("SELECT count(*) FROM orders o WHERE 1 = (SELECT count(*) WHERE o.orderkey = 0)"); @@ -7061,7 +7094,7 @@ public void testCorrelatedScalarSubqueriesWithCountScalarAggregationAndEqualityP "(SELECT count(*) FROM region r WHERE n.regionkey = r.regionkey) > 1"); assertQueryFails( "SELECT count(*) FROM nation n WHERE " + - "(SELECT count(*) FROM (SELECT count(*) FROM region r WHERE n.regionkey = r.regionkey)) > 1", + "(SELECT avg(a) FROM (SELECT count(*) FROM region r WHERE n.regionkey = r.regionkey) t(a)) > 1", "Unexpected node: com.facebook.presto.sql.planner.plan.LateralJoinNode"); // with duplicated rows @@ -7179,8 +7212,8 @@ public void testCorrelatedScalarSubqueriesWithScalarAggregation() //count in subquery assertQuery("SELECT * " + - "FROM (VALUES (0),( 1), (2), (7)) as v1(c1) " + - "WHERE v1.c1 > (SELECT count(c1) from (VALUES (0),( 1), (2)) as v2(c1) WHERE v1.c1 = v2.c1)", + "FROM (VALUES (0),( 1), (2), (7)) as v1(c1) " + + "WHERE v1.c1 > (SELECT count(c1) from (VALUES (0),( 1), (2)) as v2(c1) WHERE v1.c1 = v2.c1)", "VALUES (2), (7)"); } @@ -8856,4 +8889,80 @@ public void testAggregationPushedBelowOuterJoin() "GROUP BY v1.col1", "VALUES 24"); } + + @Test + public void testLateralJoin() + { + assertQuery( + "SELECT name FROM nation, LATERAL (SELECT 1 WHERE false)", + "SELECT 1 WHERE false"); + + assertQuery( + "SELECT name FROM nation, LATERAL (SELECT 1)", + "SELECT name FROM nation"); + + assertQuery( + "SELECT name FROM nation, LATERAL (SELECT 1 WHERE name = 'ola')", + "SELECT 1 WHERE false"); + + assertQuery( + "SELECT nationkey, a FROM nation, LATERAL (SELECT max(region.name) FROM region WHERE region.regionkey <= nation.regionkey) t(a) ORDER BY nationkey LIMIT 1", + "VALUES (0, 'AFRICA')"); + + assertQuery( + "SELECT nationkey, a FROM nation, LATERAL (SELECT region.name || '_' FROM region WHERE region.regionkey = nation.regionkey) t(a) ORDER BY nationkey LIMIT 1", + "VALUES (0, 'AFRICA_')"); + + assertQuery( + "SELECT nationkey, a, b, name FROM nation, LATERAL (SELECT nationkey + 2 AS a), LATERAL (SELECT a * -1 AS b) ORDER BY b LIMIT 1", + "VALUES (24, 26, -26, 'UNITED STATES')"); + + assertQuery( + "SELECT * FROM region r, LATERAL (SELECT * FROM nation) n WHERE n.regionkey = r.regionkey", + "SELECT * FROM region, nation WHERE nation.regionkey = region.regionkey"); + assertQuery( + "SELECT * FROM region, LATERAL (SELECT * FROM nation WHERE nation.regionkey = region.regionkey)", + "SELECT * FROM region, nation WHERE nation.regionkey = region.regionkey"); + + assertQuery( + "SELECT quantity, extendedprice, avg_price, low, high " + + "FROM lineitem, " + + "LATERAL (SELECT extendedprice / quantity AS avg_price) average_price, " + + "LATERAL (SELECT avg_price * 0.9 AS low) lower_bound, " + + "LATERAL (SELECT avg_price * 1.1 AS high) upper_bound " + + "ORDER BY extendedprice, quantity LIMIT 1", + "VALUES (1.0, 904.0, 904.0, 813.6, 994.400)"); + + assertQuery( + "SELECT y FROM (VALUES array[2, 3]) a(x) CROSS JOIN LATERAL(SELECT x[1]) b(y)", + "SELECT 2"); + assertQuery( + "SELECT * FROM (VALUES 2) a(x) CROSS JOIN LATERAL(SELECT x + 1)", + "SELECT 2, 3"); + assertQuery( + "SELECT * FROM (VALUES 2) a(x) CROSS JOIN LATERAL(SELECT x)", + "SELECT 2, 2"); + assertQuery( + "SELECT * FROM (VALUES 2) a(x) CROSS JOIN LATERAL(SELECT x, x + 1)", + "SELECT 2, 2, 3"); + + assertQueryFails( + "SELECT * FROM (VALUES array[2, 2]) a(x) LEFT OUTER JOIN LATERAL(VALUES x) ON true", + "line .*: LATERAL on other than the right side of CROSS JOIN is not supported"); + assertQueryFails( + "SELECT * FROM (VALUES array[2, 2]) a(x) RIGHT OUTER JOIN LATERAL(VALUES x) ON true", + "line .*: LATERAL on other than the right side of CROSS JOIN is not supported"); + assertQueryFails( + "SELECT * FROM (VALUES array[2, 2]) a(x) FULL OUTER JOIN LATERAL(VALUES x) ON true", + "line .*: LATERAL on other than the right side of CROSS JOIN is not supported"); + } + + @Test + public void testPruningCountAggregationOverScalar() + { + assertQuery("SELECT COUNT(*) FROM (SELECT SUM(orderkey) FROM orders)"); + assertQuery( + "SELECT COUNT(*) FROM (SELECT SUM(orderkey) FROM orders group by custkey)", + "VALUES 1000"); + } } diff --git a/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestQueryFramework.java b/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestQueryFramework.java index 9a94c0c9a770..66f7f4a13115 100644 --- a/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestQueryFramework.java +++ b/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestQueryFramework.java @@ -30,6 +30,7 @@ import com.facebook.presto.testing.TestingAccessControlManager.TestingPrivilege; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; +import io.airlift.units.Duration; import org.intellij.lang.annotations.Language; import org.testng.SkipException; import org.testng.annotations.AfterClass; @@ -117,11 +118,6 @@ protected void assertQuery(Session session, @Language("SQL") String sql) QueryAssertions.assertQuery(queryRunner, session, sql, h2QueryRunner, sql, false, false); } - public void assertQueryOrdered(@Language("SQL") String sql) - { - QueryAssertions.assertQuery(queryRunner, getSession(), sql, h2QueryRunner, sql, true, false); - } - protected void assertQuery(@Language("SQL") String actual, @Language("SQL") String expected) { QueryAssertions.assertQuery(queryRunner, getSession(), actual, h2QueryRunner, expected, false, false); @@ -132,6 +128,16 @@ protected void assertQuery(Session session, @Language("SQL") String actual, @Lan QueryAssertions.assertQuery(queryRunner, session, actual, h2QueryRunner, expected, false, false); } + public void assertQueryOrdered(@Language("SQL") String sql) + { + assertQueryOrdered(getSession(), sql); + } + + public void assertQueryOrdered(Session session, @Language("SQL") String sql) + { + assertQueryOrdered(session, sql, sql); + } + protected void assertQueryOrdered(@Language("SQL") String actual, @Language("SQL") String expected) { assertQueryOrdered(getSession(), actual, expected); @@ -172,20 +178,19 @@ protected void assertUpdate(Session session, @Language("SQL") String sql, long c QueryAssertions.assertUpdate(queryRunner, session, sql, OptionalLong.of(count)); } + protected void assertQueryFailsEventually(@Language("SQL") String sql, @Language("RegExp") String expectedMessageRegExp, Duration timeout) + { + QueryAssertions.assertQueryFailsEventually(queryRunner, getSession(), sql, expectedMessageRegExp, timeout); + } + protected void assertQueryFails(@Language("SQL") String sql, @Language("RegExp") String expectedMessageRegExp) { - assertQueryFails(getSession(), sql, expectedMessageRegExp); + QueryAssertions.assertQueryFails(queryRunner, getSession(), sql, expectedMessageRegExp); } protected void assertQueryFails(Session session, @Language("SQL") String sql, @Language("RegExp") String expectedMessageRegExp) { - try { - queryRunner.execute(session, sql); - fail(format("Expected query to fail: %s", sql)); - } - catch (RuntimeException ex) { - assertExceptionMessage(sql, ex, expectedMessageRegExp); - } + QueryAssertions.assertQueryFails(queryRunner, session, sql, expectedMessageRegExp); } protected void assertAccessAllowed(@Language("SQL") String sql, TestingPrivilege... deniedPrivileges) diff --git a/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestingPrestoClient.java b/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestingPrestoClient.java index 469aacb74cec..6f83ed817618 100644 --- a/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestingPrestoClient.java +++ b/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestingPrestoClient.java @@ -27,11 +27,8 @@ import com.google.common.base.Function; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; -import io.airlift.http.client.HttpClient; -import io.airlift.http.client.HttpClientConfig; -import io.airlift.http.client.jetty.JettyHttpClient; -import io.airlift.json.JsonCodec; import io.airlift.units.Duration; +import okhttp3.OkHttpClient; import org.intellij.lang.annotations.Language; import java.io.Closeable; @@ -45,35 +42,28 @@ import static com.facebook.presto.transaction.TransactionBuilder.transaction; import static com.google.common.base.Verify.verify; import static com.google.common.collect.Iterables.transform; -import static io.airlift.json.JsonCodec.jsonCodec; import static java.util.Objects.requireNonNull; public abstract class AbstractTestingPrestoClient implements Closeable { - private static final JsonCodec QUERY_RESULTS_CODEC = jsonCodec(QueryResults.class); - private final TestingPrestoServer prestoServer; private final Session defaultSession; - private final HttpClient httpClient; + private final OkHttpClient httpClient = new OkHttpClient(); protected AbstractTestingPrestoClient(TestingPrestoServer prestoServer, Session defaultSession) { this.prestoServer = requireNonNull(prestoServer, "prestoServer is null"); this.defaultSession = requireNonNull(defaultSession, "defaultSession is null"); - - this.httpClient = new JettyHttpClient( - new HttpClientConfig() - .setConnectTimeout(new Duration(1, TimeUnit.DAYS)) - .setIdleTimeout(new Duration(10, TimeUnit.DAYS))); } @Override public void close() { - this.httpClient.close(); + httpClient.dispatcher().executorService().shutdown(); + httpClient.connectionPool().evictAll(); } protected abstract ResultsSession getResultSession(Session session); @@ -89,7 +79,7 @@ public ResultWithQueryId execute(Session session, @Language("SQL") String sql ClientSession clientSession = toClientSession(session, prestoServer.getBaseUrl(), true, new Duration(2, TimeUnit.MINUTES)); - try (StatementClient client = new StatementClient(httpClient, QUERY_RESULTS_CODEC, clientSession, sql)) { + try (StatementClient client = new StatementClient(httpClient, clientSession, sql)) { while (client.isValid()) { QueryResults results = client.current(); diff --git a/presto-tests/src/main/java/com/facebook/presto/tests/QueryAssertions.java b/presto-tests/src/main/java/com/facebook/presto/tests/QueryAssertions.java index 589783f08225..51bb93f13594 100644 --- a/presto-tests/src/main/java/com/facebook/presto/tests/QueryAssertions.java +++ b/presto-tests/src/main/java/com/facebook/presto/tests/QueryAssertions.java @@ -30,6 +30,7 @@ import java.util.OptionalLong; import java.util.function.Supplier; +import static com.google.common.base.Strings.nullToEmpty; import static com.google.common.util.concurrent.Uninterruptibles.sleepUninterruptibly; import static io.airlift.units.Duration.nanosSince; import static java.lang.String.format; @@ -186,6 +187,41 @@ public static void assertContains(MaterializedResult all, MaterializedResult exp } } + protected static void assertQueryFailsEventually(QueryRunner queryRunner, Session session, @Language("SQL") String sql, @Language("RegExp") String expectedMessageRegExp, Duration timeout) + { + long start = System.nanoTime(); + while (!Thread.currentThread().isInterrupted()) { + try { + assertQueryFails(queryRunner, session, sql, expectedMessageRegExp); + return; + } + catch (AssertionError e) { + if (nanosSince(start).compareTo(timeout) > 0) { + throw e; + } + } + sleepUninterruptibly(50, MILLISECONDS); + } + } + + protected static void assertQueryFails(QueryRunner queryRunner, Session session, @Language("SQL") String sql, @Language("RegExp") String expectedMessageRegExp) + { + try { + queryRunner.execute(session, sql); + fail(format("Expected query to fail: %s", sql)); + } + catch (RuntimeException ex) { + assertExceptionMessage(sql, ex, expectedMessageRegExp); + } + } + + private static void assertExceptionMessage(String sql, Exception exception, @Language("RegExp") String regex) + { + if (!nullToEmpty(exception.getMessage()).matches(regex)) { + fail(format("Expected exception message '%s' to match '%s' for query: %s", exception.getMessage(), regex, sql), exception); + } + } + public static void copyTpchTables( QueryRunner queryRunner, String sourceCatalog, diff --git a/presto-tests/src/main/java/com/facebook/presto/tests/StructuralTestUtil.java b/presto-tests/src/main/java/com/facebook/presto/tests/StructuralTestUtil.java index cfd85eed7e18..1522c74193f7 100644 --- a/presto-tests/src/main/java/com/facebook/presto/tests/StructuralTestUtil.java +++ b/presto-tests/src/main/java/com/facebook/presto/tests/StructuralTestUtil.java @@ -21,12 +21,12 @@ import com.facebook.presto.spi.block.InterleavedBlockBuilder; import com.facebook.presto.spi.type.DecimalType; import com.facebook.presto.spi.type.Decimals; +import com.facebook.presto.spi.type.MapType; import com.facebook.presto.spi.type.StandardTypes; import com.facebook.presto.spi.type.Type; import com.facebook.presto.spi.type.TypeManager; import com.facebook.presto.spi.type.TypeSignatureParameter; import com.facebook.presto.sql.analyzer.FeaturesConfig; -import com.facebook.presto.type.MapType; import com.facebook.presto.type.TypeRegistry; import com.google.common.collect.ImmutableList; import io.airlift.slice.Slice; diff --git a/presto-tests/src/main/java/com/facebook/presto/tests/TestingPrestoClient.java b/presto-tests/src/main/java/com/facebook/presto/tests/TestingPrestoClient.java index 88b9b9e90c21..ca2cb4f60c76 100644 --- a/presto-tests/src/main/java/com/facebook/presto/tests/TestingPrestoClient.java +++ b/presto-tests/src/main/java/com/facebook/presto/tests/TestingPrestoClient.java @@ -18,14 +18,14 @@ import com.facebook.presto.client.IntervalYearMonth; import com.facebook.presto.client.QueryResults; import com.facebook.presto.server.testing.TestingPrestoServer; +import com.facebook.presto.spi.type.ArrayType; import com.facebook.presto.spi.type.DecimalType; +import com.facebook.presto.spi.type.MapType; import com.facebook.presto.spi.type.TimeZoneKey; import com.facebook.presto.spi.type.Type; import com.facebook.presto.spi.type.VarcharType; import com.facebook.presto.testing.MaterializedResult; import com.facebook.presto.testing.MaterializedRow; -import com.facebook.presto.type.ArrayType; -import com.facebook.presto.type.MapType; import com.facebook.presto.type.SqlIntervalDayTime; import com.facebook.presto.type.SqlIntervalYearMonth; import com.google.common.base.Function; diff --git a/presto-tests/src/test/java/com/facebook/presto/execution/TestQueues.java b/presto-tests/src/test/java/com/facebook/presto/execution/TestQueues.java index 760dbf0a6df4..2e576f22522b 100644 --- a/presto-tests/src/test/java/com/facebook/presto/execution/TestQueues.java +++ b/presto-tests/src/test/java/com/facebook/presto/execution/TestQueues.java @@ -17,13 +17,17 @@ import com.facebook.presto.resourceGroups.ResourceGroupManagerPlugin; import com.facebook.presto.spi.QueryId; import com.facebook.presto.tests.DistributedQueryRunner; +import com.facebook.presto.tests.tpch.TpchQueryRunner; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; import org.testng.annotations.Test; import java.util.Map; +import java.util.Optional; +import static com.facebook.presto.SystemSessionProperties.HASH_PARTITION_COUNT; import static com.facebook.presto.execution.QueryState.FAILED; +import static com.facebook.presto.execution.QueryState.FINISHED; import static com.facebook.presto.execution.QueryState.QUEUED; import static com.facebook.presto.execution.QueryState.RUNNING; import static com.facebook.presto.execution.TestQueryRunnerUtil.cancelQuery; @@ -32,7 +36,9 @@ import static com.facebook.presto.execution.TestQueryRunnerUtil.waitForQueryState; import static com.facebook.presto.spi.StandardErrorCode.QUERY_REJECTED; import static com.facebook.presto.testing.TestingSession.testSessionBuilder; +import static java.lang.String.format; import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertTrue; // run single threaded to avoid creating multiple query runners at once @Test(singleThreaded = true) @@ -195,6 +201,32 @@ public void testResourceGroupManagerRejection() testRejection(true); } + @Test(timeOut = 240_000) + public void testQueryTypeBasedSelection() + throws Exception + { + try (DistributedQueryRunner queryRunner = TpchQueryRunner.createQueryRunner(ImmutableMap.of(), ImmutableMap.of("experimental.resource-groups-enabled", "true"))) { + queryRunner.installPlugin(new ResourceGroupManagerPlugin()); + queryRunner.getCoordinator().getResourceGroupManager().get() + .setConfigurationManager("file", ImmutableMap.of("resource-groups.config-file", getResourceFilePath("resource_groups_query_type_based_config.json"))); + assertResourceGroup(queryRunner, LONG_LASTING_QUERY, "global.select"); + assertResourceGroup(queryRunner, "SHOW TABLES", "global.describe"); + assertResourceGroup(queryRunner, "EXPLAIN " + LONG_LASTING_QUERY, "global.explain"); + assertResourceGroup(queryRunner, "DESCRIBE lineitem", "global.describe"); + assertResourceGroup(queryRunner, "RESET SESSION " + HASH_PARTITION_COUNT, "global.data_definition"); + } + } + + private void assertResourceGroup(DistributedQueryRunner queryRunner, String query, String expectedResourceGroup) + throws InterruptedException + { + QueryId queryId = createQuery(queryRunner, newSession(), query); + waitForQueryState(queryRunner, queryId, ImmutableSet.of(RUNNING, FINISHED)); + Optional resourceGroupName = queryRunner.getCoordinator().getQueryManager().getQueryInfo(queryId).getResourceGroupName(); + assertTrue(resourceGroupName.isPresent(), "Query should have a resource group"); + assertEquals(resourceGroupName.get().toString(), expectedResourceGroup, format("Expected: '%s' resource group, found: %s", expectedResourceGroup, resourceGroupName.get())); + } + private void testRejection(boolean resourceGroups) throws Exception { diff --git a/presto-tests/src/test/java/com/facebook/presto/execution/resourceGroups/db/H2TestUtil.java b/presto-tests/src/test/java/com/facebook/presto/execution/resourceGroups/db/H2TestUtil.java new file mode 100644 index 000000000000..e6f4550b647e --- /dev/null +++ b/presto-tests/src/test/java/com/facebook/presto/execution/resourceGroups/db/H2TestUtil.java @@ -0,0 +1,165 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.execution.resourceGroups.db; + +import com.facebook.presto.Session; +import com.facebook.presto.execution.QueryManager; +import com.facebook.presto.execution.QueryState; +import com.facebook.presto.resourceGroups.db.DbResourceGroupConfig; +import com.facebook.presto.resourceGroups.db.H2DaoProvider; +import com.facebook.presto.resourceGroups.db.H2ResourceGroupsDao; +import com.facebook.presto.spi.Plugin; +import com.facebook.presto.spi.resourceGroups.ResourceGroupSelector; +import com.facebook.presto.sql.parser.SqlParserOptions; +import com.facebook.presto.tests.DistributedQueryRunner; +import com.facebook.presto.tpch.TpchPlugin; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; + +import java.util.List; +import java.util.Random; +import java.util.Set; + +import static com.facebook.presto.execution.QueryState.RUNNING; +import static com.facebook.presto.execution.QueryState.TERMINAL_QUERY_STATES; +import static com.facebook.presto.testing.TestingSession.testSessionBuilder; +import static java.util.concurrent.TimeUnit.MILLISECONDS; + +class H2TestUtil +{ + private static final String CONFIGURATION_MANAGER_TYPE = "h2"; + + private H2TestUtil() {} + + public static Session adhocSession() + { + return testSessionBuilder() + .setCatalog("tpch") + .setSchema("sf100000") + .setSource("adhoc") + .build(); + } + + public static Session dashboardSession() + { + return testSessionBuilder() + .setCatalog("tpch") + .setSchema("sf100000") + .setSource("dashboard") + .build(); + } + + public static Session rejectingSession() + { + return testSessionBuilder() + .setCatalog("tpch") + .setSchema("sf100000") + .setSource("reject") + .build(); + } + + public static void waitForCompleteQueryCount(DistributedQueryRunner queryRunner, int expectedCount) + throws InterruptedException + { + waitForQueryCount(queryRunner, TERMINAL_QUERY_STATES, expectedCount); + } + + public static void waitForRunningQueryCount(DistributedQueryRunner queryRunner, int expectedCount) + throws InterruptedException + { + waitForQueryCount(queryRunner, ImmutableSet.of(RUNNING), expectedCount); + } + + public static void waitForQueryCount(DistributedQueryRunner queryRunner, Set countingStates, int expectedCount) + throws InterruptedException + { + QueryManager queryManager = queryRunner.getCoordinator().getQueryManager(); + while (queryManager.getAllQueryInfo().stream() + .filter(q -> countingStates.contains(q.getState())).count() != expectedCount) { + MILLISECONDS.sleep(500); + } + } + + public static String getDbConfigUrl() + { + return "jdbc:h2:mem:test_" + Math.abs(new Random().nextLong()); + } + + public static H2ResourceGroupsDao getDao(String url) + { + DbResourceGroupConfig dbResourceGroupConfig = new DbResourceGroupConfig() + .setConfigDbUrl(url); + H2ResourceGroupsDao dao = new H2DaoProvider(dbResourceGroupConfig).get(); + dao.createResourceGroupsTable(); + dao.createSelectorsTable(); + dao.createResourceGroupsGlobalPropertiesTable(); + return dao; + } + + public static DistributedQueryRunner createQueryRunner(String dbConfigUrl, H2ResourceGroupsDao dao) + throws Exception + { + DistributedQueryRunner queryRunner = new DistributedQueryRunner( + testSessionBuilder().setCatalog("tpch").setSchema("tiny").build(), + 2, + ImmutableMap.of("experimental.resource-groups-enabled", "true"), + ImmutableMap.of(), + new SqlParserOptions()); + try { + Plugin h2ResourceGroupManagerPlugin = new H2ResourceGroupManagerPlugin(); + queryRunner.installPlugin(h2ResourceGroupManagerPlugin); + queryRunner.getCoordinator().getResourceGroupManager().get() + .setConfigurationManager(CONFIGURATION_MANAGER_TYPE, ImmutableMap.of("resource-groups.config-db-url", dbConfigUrl)); + queryRunner.installPlugin(new TpchPlugin()); + queryRunner.createCatalog("tpch", "tpch"); + setup(queryRunner, dao); + return queryRunner; + } + catch (Exception e) { + queryRunner.close(); + throw e; + } + } + + public static DistributedQueryRunner getSimpleQueryRunner() + throws Exception + { + String dbConfigUrl = getDbConfigUrl(); + H2ResourceGroupsDao dao = getDao(dbConfigUrl); + return createQueryRunner(dbConfigUrl, dao); + } + + private static void setup(DistributedQueryRunner queryRunner, H2ResourceGroupsDao dao) + throws InterruptedException + { + dao.insertResourceGroupsGlobalProperties("cpu_quota_period", "1h"); + dao.insertResourceGroup(1, "global", "1MB", 100, 1000, null, null, null, null, null, null, null, null); + dao.insertResourceGroup(2, "bi-${USER}", "1MB", 3, 2, null, null, null, null, null, null, null, 1L); + dao.insertResourceGroup(3, "user-${USER}", "1MB", 3, 3, null, null, null, null, null, null, null, 1L); + dao.insertResourceGroup(4, "adhoc-${USER}", "1MB", 3, 3, null, null, null, null, null, null, null, 3L); + dao.insertResourceGroup(5, "dashboard-${USER}", "1MB", 1, 1, null, null, null, null, null, null, null, 3L); + dao.insertSelector(2, "user.*", "test"); + dao.insertSelector(4, "user.*", "(?i).*adhoc.*"); + dao.insertSelector(5, "user.*", "(?i).*dashboard.*"); + // Selectors are loaded last + while (getSelectors(queryRunner).size() != 3) { + MILLISECONDS.sleep(500); + } + } + + public static List getSelectors(DistributedQueryRunner queryRunner) + { + return queryRunner.getCoordinator().getResourceGroupManager().get().getConfigurationManager().getSelectors(); + } +} diff --git a/presto-tests/src/test/java/com/facebook/presto/execution/resourceGroups/db/TestQueues.java b/presto-tests/src/test/java/com/facebook/presto/execution/resourceGroups/db/TestQueues.java index 3c06b88dffe1..4efdc0c7d6da 100644 --- a/presto-tests/src/test/java/com/facebook/presto/execution/resourceGroups/db/TestQueues.java +++ b/presto-tests/src/test/java/com/facebook/presto/execution/resourceGroups/db/TestQueues.java @@ -13,39 +13,35 @@ */ package com.facebook.presto.execution.resourceGroups.db; -import com.facebook.presto.Session; import com.facebook.presto.execution.QueryManager; -import com.facebook.presto.execution.QueryState; -import com.facebook.presto.execution.TestingSessionFactory; -import com.facebook.presto.execution.resourceGroups.ResourceGroupManager; -import com.facebook.presto.resourceGroups.db.DbResourceGroupConfig; -import com.facebook.presto.resourceGroups.db.H2DaoProvider; +import com.facebook.presto.execution.resourceGroups.InternalResourceGroupManager; +import com.facebook.presto.resourceGroups.db.DbResourceGroupConfigurationManager; import com.facebook.presto.resourceGroups.db.H2ResourceGroupsDao; -import com.facebook.presto.spi.Plugin; import com.facebook.presto.spi.QueryId; import com.facebook.presto.spi.resourceGroups.ResourceGroupId; import com.facebook.presto.spi.resourceGroups.ResourceGroupInfo; -import com.facebook.presto.spi.resourceGroups.ResourceGroupSelector; -import com.facebook.presto.sql.parser.SqlParserOptions; import com.facebook.presto.tests.DistributedQueryRunner; -import com.facebook.presto.tests.tpch.TpchQueryRunner; -import com.facebook.presto.tpch.TpchPlugin; -import com.google.common.collect.ImmutableMap; -import com.google.common.collect.ImmutableSet; import org.testng.annotations.Test; -import java.util.List; -import java.util.Map; -import java.util.Random; -import java.util.Set; import java.util.concurrent.TimeUnit; import static com.facebook.presto.execution.QueryState.FAILED; import static com.facebook.presto.execution.QueryState.QUEUED; import static com.facebook.presto.execution.QueryState.RUNNING; -import static com.facebook.presto.execution.QueryState.TERMINAL_QUERY_STATES; +import static com.facebook.presto.execution.TestQueryRunnerUtil.cancelQuery; +import static com.facebook.presto.execution.TestQueryRunnerUtil.createQuery; +import static com.facebook.presto.execution.TestQueryRunnerUtil.waitForQueryState; +import static com.facebook.presto.execution.resourceGroups.db.H2TestUtil.adhocSession; +import static com.facebook.presto.execution.resourceGroups.db.H2TestUtil.createQueryRunner; +import static com.facebook.presto.execution.resourceGroups.db.H2TestUtil.dashboardSession; +import static com.facebook.presto.execution.resourceGroups.db.H2TestUtil.getDao; +import static com.facebook.presto.execution.resourceGroups.db.H2TestUtil.getDbConfigUrl; +import static com.facebook.presto.execution.resourceGroups.db.H2TestUtil.getSelectors; +import static com.facebook.presto.execution.resourceGroups.db.H2TestUtil.getSimpleQueryRunner; +import static com.facebook.presto.execution.resourceGroups.db.H2TestUtil.rejectingSession; +import static com.facebook.presto.execution.resourceGroups.db.H2TestUtil.waitForCompleteQueryCount; +import static com.facebook.presto.execution.resourceGroups.db.H2TestUtil.waitForRunningQueryCount; import static com.facebook.presto.spi.StandardErrorCode.QUERY_REJECTED; -import static com.facebook.presto.testing.TestingSession.testSessionBuilder; import static java.util.concurrent.TimeUnit.MILLISECONDS; import static org.testng.Assert.assertEquals; @@ -54,7 +50,6 @@ public class TestQueues { // Copy of TestQueues with tests for db reconfiguration of resource groups - private static final String NAME = "h2"; private static final String LONG_LASTING_QUERY = "SELECT COUNT(*) FROM lineitem"; @Test(timeOut = 60_000) @@ -80,15 +75,14 @@ public void testBasic() String dbConfigUrl = getDbConfigUrl(); H2ResourceGroupsDao dao = getDao(dbConfigUrl); try (DistributedQueryRunner queryRunner = createQueryRunner(dbConfigUrl, dao)) { - QueryManager queryManager = queryRunner.getCoordinator().getQueryManager(); // submit first "dashboard" query - QueryId firstDashboardQuery = createQuery(queryRunner, newDashboardSession(), LONG_LASTING_QUERY); + QueryId firstDashboardQuery = createQuery(queryRunner, dashboardSession(), LONG_LASTING_QUERY); // wait for the first "dashboard" query to start waitForQueryState(queryRunner, firstDashboardQuery, RUNNING); waitForRunningQueryCount(queryRunner, 1); // submit second "dashboard" query - QueryId secondDashboardQuery = createQuery(queryRunner, newDashboardSession(), LONG_LASTING_QUERY); + QueryId secondDashboardQuery = createQuery(queryRunner, dashboardSession(), LONG_LASTING_QUERY); MILLISECONDS.sleep(2000); // wait for the second "dashboard" query to be queued ("dashboard.${USER}" queue strategy only allows one "dashboard" query to be accepted for execution) waitForQueryState(queryRunner, secondDashboardQuery, QUEUED); @@ -97,16 +91,16 @@ public void testBasic() dao.updateResourceGroup(3, "user-${USER}", "1MB", 3, 4, null, null, null, null, null, null, null, 1L); dao.updateResourceGroup(5, "dashboard-${USER}", "1MB", 1, 2, null, null, null, null, null, null, null, 3L); waitForQueryState(queryRunner, secondDashboardQuery, RUNNING); - QueryId thirdDashboardQuery = createQuery(queryRunner, newDashboardSession(), LONG_LASTING_QUERY); + QueryId thirdDashboardQuery = createQuery(queryRunner, dashboardSession(), LONG_LASTING_QUERY); waitForQueryState(queryRunner, thirdDashboardQuery, QUEUED); waitForRunningQueryCount(queryRunner, 2); // submit first non "dashboard" query - QueryId firstNonDashboardQuery = createQuery(queryRunner, newSession(), LONG_LASTING_QUERY); + QueryId firstNonDashboardQuery = createQuery(queryRunner, adhocSession(), LONG_LASTING_QUERY); // wait for the first non "dashboard" query to start waitForQueryState(queryRunner, firstNonDashboardQuery, RUNNING); waitForRunningQueryCount(queryRunner, 3); // submit second non "dashboard" query - QueryId secondNonDashboardQuery = createQuery(queryRunner, newSession(), LONG_LASTING_QUERY); + QueryId secondNonDashboardQuery = createQuery(queryRunner, adhocSession(), LONG_LASTING_QUERY); // wait for the second non "dashboard" query to start waitForQueryState(queryRunner, secondNonDashboardQuery, RUNNING); waitForRunningQueryCount(queryRunner, 4); @@ -126,45 +120,46 @@ public void testTwoQueriesAtSameTime() String dbConfigUrl = getDbConfigUrl(); H2ResourceGroupsDao dao = getDao(dbConfigUrl); try (DistributedQueryRunner queryRunner = createQueryRunner(dbConfigUrl, dao)) { - QueryId firstDashboardQuery = createQuery(queryRunner, newDashboardSession(), LONG_LASTING_QUERY); - QueryId secondDashboardQuery = createQuery(queryRunner, newDashboardSession(), LONG_LASTING_QUERY); - - ImmutableSet queuedOrRunning = ImmutableSet.of(QUEUED, RUNNING); + QueryId firstDashboardQuery = createQuery(queryRunner, dashboardSession(), LONG_LASTING_QUERY); + QueryId secondDashboardQuery = createQuery(queryRunner, dashboardSession(), LONG_LASTING_QUERY); waitForQueryState(queryRunner, firstDashboardQuery, RUNNING); waitForQueryState(queryRunner, secondDashboardQuery, QUEUED); } } - @Test(timeOut = 60_000) + @Test(timeOut = 90_000) public void testTooManyQueries() throws Exception { String dbConfigUrl = getDbConfigUrl(); H2ResourceGroupsDao dao = getDao(dbConfigUrl); try (DistributedQueryRunner queryRunner = createQueryRunner(dbConfigUrl, dao)) { - QueryId firstDashboardQuery = createQuery(queryRunner, newDashboardSession(), LONG_LASTING_QUERY); + QueryId firstDashboardQuery = createQuery(queryRunner, dashboardSession(), LONG_LASTING_QUERY); waitForQueryState(queryRunner, firstDashboardQuery, RUNNING); - QueryId secondDashboardQuery = createQuery(queryRunner, newDashboardSession(), LONG_LASTING_QUERY); + QueryId secondDashboardQuery = createQuery(queryRunner, dashboardSession(), LONG_LASTING_QUERY); waitForQueryState(queryRunner, secondDashboardQuery, QUEUED); - QueryId thirdDashboardQuery = createQuery(queryRunner, newDashboardSession(), LONG_LASTING_QUERY); + QueryId thirdDashboardQuery = createQuery(queryRunner, dashboardSession(), LONG_LASTING_QUERY); waitForQueryState(queryRunner, thirdDashboardQuery, FAILED); // Allow one more query to run and resubmit third query dao.updateResourceGroup(3, "user-${USER}", "1MB", 3, 4, null, null, null, null, null, null, null, 1L); dao.updateResourceGroup(5, "dashboard-${USER}", "1MB", 1, 2, null, null, null, null, null, null, null, 3L); + + InternalResourceGroupManager manager = queryRunner.getCoordinator().getResourceGroupManager().get(); + DbResourceGroupConfigurationManager dbConfigurationManager = (DbResourceGroupConfigurationManager) manager.getConfigurationManager(); + + // Trigger reload to make the test more deterministic + dbConfigurationManager.load(); waitForQueryState(queryRunner, secondDashboardQuery, RUNNING); - thirdDashboardQuery = createQuery(queryRunner, newDashboardSession(), LONG_LASTING_QUERY); + thirdDashboardQuery = createQuery(queryRunner, dashboardSession(), LONG_LASTING_QUERY); waitForQueryState(queryRunner, thirdDashboardQuery, QUEUED); - // Lower running queries in dashboard resource groups and wait until groups are reconfigured + // Lower running queries in dashboard resource groups and reload the config dao.updateResourceGroup(5, "dashboard-${USER}", "1MB", 1, 1, null, null, null, null, null, null, null, 3L); - ResourceGroupManager manager = queryRunner.getCoordinator().getResourceGroupManager().get(); - while (manager.getResourceGroupInfo( - new ResourceGroupId(new ResourceGroupId(new ResourceGroupId("global"), "user-user"), "dashboard-user")).getMaxRunningQueries() != 1) { - MILLISECONDS.sleep(500); - } + dbConfigurationManager.load(); + // Cancel query and verify that third query is still queued cancelQuery(queryRunner, firstDashboardQuery); waitForQueryState(queryRunner, firstDashboardQuery, FAILED); @@ -181,7 +176,7 @@ public void testRejection() H2ResourceGroupsDao dao = getDao(dbConfigUrl); try (DistributedQueryRunner queryRunner = createQueryRunner(dbConfigUrl, dao)) { // Verify the query cannot be submitted - QueryId queryId = createQuery(queryRunner, newRejectionSession(), LONG_LASTING_QUERY); + QueryId queryId = createQuery(queryRunner, rejectingSession(), LONG_LASTING_QUERY); waitForQueryState(queryRunner, queryId, FAILED); QueryManager queryManager = queryRunner.getCoordinator().getQueryManager(); assertEquals(queryManager.getQueryInfo(queryId).getErrorCode(), QUERY_REJECTED.toErrorCode()); @@ -192,14 +187,14 @@ public void testRejection() MILLISECONDS.sleep(500); } // Verify the query can be submitted - queryId = createQuery(queryRunner, newRejectionSession(), LONG_LASTING_QUERY); + queryId = createQuery(queryRunner, rejectingSession(), LONG_LASTING_QUERY); waitForQueryState(queryRunner, queryId, RUNNING); dao.deleteSelector(4, "user.*", "(?i).*reject.*"); while (getSelectors(queryRunner).size() != selectorCount) { MILLISECONDS.sleep(500); } // Verify the query cannot be submitted - queryId = createQuery(queryRunner, newRejectionSession(), LONG_LASTING_QUERY); + queryId = createQuery(queryRunner, rejectingSession(), LONG_LASTING_QUERY); waitForQueryState(queryRunner, queryId, FAILED); } } @@ -212,7 +207,7 @@ public void testRunningTimeLimit() H2ResourceGroupsDao dao = getDao(dbConfigUrl); try (DistributedQueryRunner queryRunner = createQueryRunner(dbConfigUrl, dao)) { dao.updateResourceGroup(5, "dashboard-${USER}", "1MB", 1, 1, null, null, null, null, null, null, "3s", 3L); - QueryId firstDashboardQuery = createQuery(queryRunner, newDashboardSession(), LONG_LASTING_QUERY); + QueryId firstDashboardQuery = createQuery(queryRunner, dashboardSession(), LONG_LASTING_QUERY); waitForQueryState(queryRunner, firstDashboardQuery, FAILED); } } @@ -225,164 +220,11 @@ public void testQueuedTimeLimit() H2ResourceGroupsDao dao = getDao(dbConfigUrl); try (DistributedQueryRunner queryRunner = createQueryRunner(dbConfigUrl, dao)) { dao.updateResourceGroup(5, "dashboard-${USER}", "1MB", 1, 1, null, null, null, null, null, "5s", null, 3L); - QueryId firstDashboardQuery = createQuery(queryRunner, newDashboardSession(), LONG_LASTING_QUERY); + QueryId firstDashboardQuery = createQuery(queryRunner, dashboardSession(), LONG_LASTING_QUERY); waitForQueryState(queryRunner, firstDashboardQuery, RUNNING); - QueryId secondDashboardQuery = createQuery(queryRunner, newDashboardSession(), LONG_LASTING_QUERY); + QueryId secondDashboardQuery = createQuery(queryRunner, dashboardSession(), LONG_LASTING_QUERY); waitForQueryState(queryRunner, secondDashboardQuery, QUEUED); waitForQueryState(queryRunner, secondDashboardQuery, FAILED); } } - - private static Session newSession() - { - return testSessionBuilder() - .setCatalog("tpch") - .setSchema("sf100000") - .setSource("adhoc") - .build(); - } - - private static Session newDashboardSession() - { - return testSessionBuilder() - .setCatalog("tpch") - .setSchema("sf100000") - .setSource("dashboard") - .build(); - } - - private static Session newRejectionSession() - { - return testSessionBuilder() - .setCatalog("tpch") - .setSchema("sf100000") - .setSource("reject") - .build(); - } - - private static QueryId createQuery(DistributedQueryRunner queryRunner, Session session, String sql) - { - return queryRunner.getCoordinator().getQueryManager().createQuery(new TestingSessionFactory(session), sql).getQueryId(); - } - - private static void cancelQuery(DistributedQueryRunner queryRunner, QueryId queryId) - { - queryRunner.getCoordinator().getQueryManager().cancelQuery(queryId); - } - - private static void waitForCompleteQueryCount(DistributedQueryRunner queryRunner, int expectedCount) - throws InterruptedException - { - waitForQueryCount(queryRunner, TERMINAL_QUERY_STATES, expectedCount); - } - - private static void waitForRunningQueryCount(DistributedQueryRunner queryRunner, int expectedCount) - throws InterruptedException - { - waitForQueryCount(queryRunner, ImmutableSet.of(RUNNING), expectedCount); - } - - private static void waitForQueryCount(DistributedQueryRunner queryRunner, Set countingStates, int expectedCount) - throws InterruptedException - { - QueryManager queryManager = queryRunner.getCoordinator().getQueryManager(); - while (queryManager.getAllQueryInfo().stream().filter(q -> countingStates.contains(q.getState())).count() != expectedCount) { - MILLISECONDS.sleep(500); - } - } - - private static void waitForQueryState(DistributedQueryRunner queryRunner, QueryId queryId, QueryState expectedQueryState) - throws InterruptedException - { - waitForQueryState(queryRunner, queryId, ImmutableSet.of(expectedQueryState)); - } - - private static void waitForQueryState(DistributedQueryRunner queryRunner, QueryId queryId, Set expectedQueryStates) - throws InterruptedException - { - QueryManager queryManager = queryRunner.getCoordinator().getQueryManager(); - while (!expectedQueryStates.contains(queryManager.getQueryInfo(queryId).getState())) { - MILLISECONDS.sleep(500); - } - } - - private static String getDbConfigUrl() - { - Random rnd = new Random(); - return "jdbc:h2:mem:test_" + Math.abs(rnd.nextLong()); - } - - private static H2ResourceGroupsDao getDao(String url) - { - DbResourceGroupConfig dbResourceGroupConfig = new DbResourceGroupConfig() - .setConfigDbUrl(url); - H2ResourceGroupsDao dao = new H2DaoProvider(dbResourceGroupConfig).get(); - dao.createResourceGroupsTable(); - dao.createSelectorsTable(); - dao.createResourceGroupsGlobalPropertiesTable(); - return dao; - } - - private static DistributedQueryRunner createQueryRunner(String dbConfigUrl, H2ResourceGroupsDao dao) - throws Exception - { - ImmutableMap.Builder builder = ImmutableMap.builder(); - builder.put("experimental.resource-groups-enabled", "true"); - Map properties = builder.build(); - DistributedQueryRunner queryRunner = new DistributedQueryRunner(testSessionBuilder().build(), 2, ImmutableMap.of(), properties, new SqlParserOptions()); - try { - Plugin h2ResourceGroupManagerPlugin = new H2ResourceGroupManagerPlugin(); - queryRunner.installPlugin(h2ResourceGroupManagerPlugin); - queryRunner.getCoordinator().getResourceGroupManager().get() - .setConfigurationManager(NAME, ImmutableMap.of("resource-groups.config-db-url", dbConfigUrl)); - queryRunner.installPlugin(new TpchPlugin()); - queryRunner.createCatalog("tpch", "tpch"); - setup(queryRunner, dao); - return queryRunner; - } - catch (Exception e) { - queryRunner.close(); - throw e; - } - } - - static DistributedQueryRunner getSimpleQueryRunner() - throws Exception - { - String dbConfigUrl = getDbConfigUrl(); - H2ResourceGroupsDao dao = getDao(dbConfigUrl); - ImmutableMap.Builder builder = ImmutableMap.builder(); - builder.put("experimental.resource-groups-enabled", "true"); - Map properties = builder.build(); - DistributedQueryRunner queryRunner = TpchQueryRunner.createQueryRunner(properties); - Plugin h2ResourceGroupManagerPlugin = new H2ResourceGroupManagerPlugin(); - queryRunner.installPlugin(h2ResourceGroupManagerPlugin); - queryRunner.getCoordinator().getResourceGroupManager().get() - .setConfigurationManager(NAME, ImmutableMap.of("resource-groups.config-db-url", dbConfigUrl)); - setup(queryRunner, dao); - return queryRunner; - } - - private static void setup(DistributedQueryRunner queryRunner, H2ResourceGroupsDao dao) - throws InterruptedException - { - dao.insertResourceGroupsGlobalProperties("cpu_quota_period", "1h"); - dao.insertResourceGroup(1, "global", "1MB", 100, 1000, null, null, null, null, null, null, null, null); - dao.insertResourceGroup(2, "bi-${USER}", "1MB", 3, 2, null, null, null, null, null, null, null, 1L); - dao.insertResourceGroup(3, "user-${USER}", "1MB", 3, 3, null, null, null, null, null, null, null, 1L); - dao.insertResourceGroup(4, "adhoc-${USER}", "1MB", 3, 3, null, null, null, null, null, null, null, 3L); - dao.insertResourceGroup(5, "dashboard-${USER}", "1MB", 1, 1, null, null, null, null, null, null, null, 3L); - dao.insertSelector(2, "user.*", "test"); - dao.insertSelector(4, "user.*", "(?i).*adhoc.*"); - dao.insertSelector(5, "user.*", "(?i).*dashboard.*"); - // Selectors are loaded last - while (getSelectors(queryRunner).size() != 3) { - MILLISECONDS.sleep(500); - } - } - - private static List getSelectors(DistributedQueryRunner queryRunner) - { - return queryRunner.getCoordinator().getResourceGroupManager().get().getConfigurationManager().getSelectors(); - } } diff --git a/presto-tests/src/test/java/com/facebook/presto/execution/resourceGroups/db/TestResourceGroupIntegration.java b/presto-tests/src/test/java/com/facebook/presto/execution/resourceGroups/db/TestResourceGroupIntegration.java index babc76d83757..be44580e8912 100644 --- a/presto-tests/src/test/java/com/facebook/presto/execution/resourceGroups/db/TestResourceGroupIntegration.java +++ b/presto-tests/src/test/java/com/facebook/presto/execution/resourceGroups/db/TestResourceGroupIntegration.java @@ -17,7 +17,7 @@ import org.testng.annotations.Test; import static com.facebook.presto.execution.resourceGroups.TestResourceGroupIntegration.waitForGlobalResourceGroup; -import static com.facebook.presto.execution.resourceGroups.db.TestQueues.getSimpleQueryRunner; +import static com.facebook.presto.execution.resourceGroups.db.H2TestUtil.getSimpleQueryRunner; public class TestResourceGroupIntegration { diff --git a/presto-tests/src/test/java/com/facebook/presto/memory/TestMemoryManager.java b/presto-tests/src/test/java/com/facebook/presto/memory/TestMemoryManager.java index 7318b8a5e67f..b63dfce3044a 100644 --- a/presto-tests/src/test/java/com/facebook/presto/memory/TestMemoryManager.java +++ b/presto-tests/src/test/java/com/facebook/presto/memory/TestMemoryManager.java @@ -124,6 +124,7 @@ public void testOutOfMemoryKiller() while (!queryDone) { for (QueryInfo info : queryRunner.getCoordinator().getQueryManager().getAllQueryInfo()) { if (info.getState().isDone()) { + assertNotNull(info.getErrorCode()); assertEquals(info.getErrorCode().getCode(), CLUSTER_OUT_OF_MEMORY.toErrorCode().getCode()); queryDone = true; break; diff --git a/presto-tests/src/test/resources/resource_groups_query_type_based_config.json b/presto-tests/src/test/resources/resource_groups_query_type_based_config.json new file mode 100644 index 000000000000..9df7de9ab957 --- /dev/null +++ b/presto-tests/src/test/resources/resource_groups_query_type_based_config.json @@ -0,0 +1,87 @@ +{ + "rootGroups": [ + { + "name": "global", + "softMemoryLimit": "1MB", + "maxRunning": 100, + "maxQueued": 1000, + "softCpuLimit": "1h", + "hardCpuLimit": "1d", + "subGroups": [ + { + "name": "select", + "softMemoryLimit": "2MB", + "maxRunning": 3, + "maxQueued": 4 + }, + { + "name": "explain", + "softMemoryLimit": "2MB", + "maxRunning": 3, + "maxQueued": 4 + }, + { + "name": "insert", + "softMemoryLimit": "2MB", + "maxRunning": 3, + "maxQueued": 4 + }, + { + "name": "delete", + "softMemoryLimit": "2MB", + "maxRunning": 3, + "maxQueued": 4 + }, + { + "name": "describe", + "softMemoryLimit": "2MB", + "maxRunning": 3, + "maxQueued": 4 + }, + { + "name": "data_definition", + "softMemoryLimit": "2MB", + "maxRunning": 3, + "maxQueued": 4 + }, + { + "name": "other", + "softMemoryLimit": "2MB", + "maxRunning": 3, + "maxQueued": 4 + } + ] + } + ], + "selectors": [ + { + "queryType" : "select", + "group": "global.select" + }, + { + "queryType" : "explain", + "group": "global.explain" + }, + { + "queryType" : "insert", + "group": "global.insert" + }, + { + "queryType" : "delete", + "group": "global.delete" + }, + { + "queryType" : "describe", + "group": "global.describe" + }, + { + "queryType" : "data_definition", + "group": "global.data_definition" + }, + { + "group": "global.other" + } + ], + "cpuQuotaPeriod": "1h" +} + diff --git a/presto-thrift-connector-api/pom.xml b/presto-thrift-connector-api/pom.xml new file mode 100644 index 000000000000..c8b7b537ee95 --- /dev/null +++ b/presto-thrift-connector-api/pom.xml @@ -0,0 +1,84 @@ + + + 4.0.0 + + + com.facebook.presto + presto-root + 0.181-tw-0.37 + + + presto-thrift-connector-api + Presto - Thrift Connector API + jar + + + ${project.parent.basedir} + + + + + com.google.guava + guava + + + + com.google.code.findbugs + annotations + + + + com.facebook.swift + swift-annotations + + + + com.facebook.presto + presto-spi + + + + io.airlift + slice + + + + com.fasterxml.jackson.core + jackson-annotations + + + + + + + com.facebook.swift + swift-javadoc + provided + + + + com.facebook.swift + swift-codec + provided + + + + + org.testng + testng + test + + + + com.facebook.presto + presto-main + test + + + + io.airlift + stats + test + + + diff --git a/presto-thrift-connector-api/src/main/java/com/facebook/presto/connector/thrift/api/NameValidationUtils.java b/presto-thrift-connector-api/src/main/java/com/facebook/presto/connector/thrift/api/NameValidationUtils.java new file mode 100644 index 000000000000..92c1abd3c763 --- /dev/null +++ b/presto-thrift-connector-api/src/main/java/com/facebook/presto/connector/thrift/api/NameValidationUtils.java @@ -0,0 +1,34 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.connector.thrift.api; + +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Strings.isNullOrEmpty; + +final class NameValidationUtils +{ + private NameValidationUtils() {} + + public static String checkValidName(String name) + { + checkArgument(!isNullOrEmpty(name), "name is null or empty"); + checkArgument('a' <= name.charAt(0) && name.charAt(0) <= 'z', "name must start with a lowercase latin letter: '%s'", name); + for (int i = 1; i < name.length(); i++) { + char ch = name.charAt(i); + checkArgument('a' <= ch && ch <= 'z' || '0' <= ch && ch <= '9' || ch == '_', + "name must contain only lowercase latin letters, digits or underscores: '%s'", name); + } + return name; + } +} diff --git a/presto-thrift-connector-api/src/main/java/com/facebook/presto/connector/thrift/api/PrestoThriftBlock.java b/presto-thrift-connector-api/src/main/java/com/facebook/presto/connector/thrift/api/PrestoThriftBlock.java new file mode 100644 index 000000000000..33838b84fbfe --- /dev/null +++ b/presto-thrift-connector-api/src/main/java/com/facebook/presto/connector/thrift/api/PrestoThriftBlock.java @@ -0,0 +1,311 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.connector.thrift.api; + +import com.facebook.presto.connector.thrift.api.datatypes.PrestoThriftBigint; +import com.facebook.presto.connector.thrift.api.datatypes.PrestoThriftBigintArray; +import com.facebook.presto.connector.thrift.api.datatypes.PrestoThriftBoolean; +import com.facebook.presto.connector.thrift.api.datatypes.PrestoThriftColumnData; +import com.facebook.presto.connector.thrift.api.datatypes.PrestoThriftDate; +import com.facebook.presto.connector.thrift.api.datatypes.PrestoThriftDouble; +import com.facebook.presto.connector.thrift.api.datatypes.PrestoThriftHyperLogLog; +import com.facebook.presto.connector.thrift.api.datatypes.PrestoThriftInteger; +import com.facebook.presto.connector.thrift.api.datatypes.PrestoThriftJson; +import com.facebook.presto.connector.thrift.api.datatypes.PrestoThriftTimestamp; +import com.facebook.presto.connector.thrift.api.datatypes.PrestoThriftVarchar; +import com.facebook.presto.spi.block.Block; +import com.facebook.presto.spi.type.BigintType; +import com.facebook.presto.spi.type.Type; +import com.facebook.swift.codec.ThriftConstructor; +import com.facebook.swift.codec.ThriftField; +import com.facebook.swift.codec.ThriftStruct; + +import javax.annotation.Nullable; + +import java.util.Objects; + +import static com.facebook.presto.spi.type.StandardTypes.ARRAY; +import static com.facebook.presto.spi.type.StandardTypes.BIGINT; +import static com.facebook.presto.spi.type.StandardTypes.BOOLEAN; +import static com.facebook.presto.spi.type.StandardTypes.DATE; +import static com.facebook.presto.spi.type.StandardTypes.DOUBLE; +import static com.facebook.presto.spi.type.StandardTypes.HYPER_LOG_LOG; +import static com.facebook.presto.spi.type.StandardTypes.INTEGER; +import static com.facebook.presto.spi.type.StandardTypes.JSON; +import static com.facebook.presto.spi.type.StandardTypes.TIMESTAMP; +import static com.facebook.presto.spi.type.StandardTypes.VARCHAR; +import static com.facebook.swift.codec.ThriftField.Requiredness.OPTIONAL; +import static com.google.common.base.MoreObjects.toStringHelper; +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.collect.Iterables.getOnlyElement; + +@ThriftStruct +public final class PrestoThriftBlock +{ + // number + private final PrestoThriftInteger integerData; + private final PrestoThriftBigint bigintData; + private final PrestoThriftDouble doubleData; + + // variable width + private final PrestoThriftVarchar varcharData; + + // boolean + private final PrestoThriftBoolean booleanData; + + // temporal + private final PrestoThriftDate dateData; + private final PrestoThriftTimestamp timestampData; + + // special + private final PrestoThriftJson jsonData; + private final PrestoThriftHyperLogLog hyperLogLogData; + + // array + private final PrestoThriftBigintArray bigintArrayData; + + // non-thrift field which points to non-null data item + private final PrestoThriftColumnData dataReference; + + @ThriftConstructor + public PrestoThriftBlock( + @Nullable PrestoThriftInteger integerData, + @Nullable PrestoThriftBigint bigintData, + @Nullable PrestoThriftDouble doubleData, + @Nullable PrestoThriftVarchar varcharData, + @Nullable PrestoThriftBoolean booleanData, + @Nullable PrestoThriftDate dateData, + @Nullable PrestoThriftTimestamp timestampData, + @Nullable PrestoThriftJson jsonData, + @Nullable PrestoThriftHyperLogLog hyperLogLogData, + @Nullable PrestoThriftBigintArray bigintArrayData) + { + this.integerData = integerData; + this.bigintData = bigintData; + this.doubleData = doubleData; + this.varcharData = varcharData; + this.booleanData = booleanData; + this.dateData = dateData; + this.timestampData = timestampData; + this.jsonData = jsonData; + this.hyperLogLogData = hyperLogLogData; + this.bigintArrayData = bigintArrayData; + this.dataReference = theOnlyNonNull(integerData, bigintData, doubleData, varcharData, booleanData, dateData, timestampData, jsonData, hyperLogLogData, bigintArrayData); + } + + @Nullable + @ThriftField(value = 1, requiredness = OPTIONAL) + public PrestoThriftInteger getIntegerData() + { + return integerData; + } + + @Nullable + @ThriftField(value = 2, requiredness = OPTIONAL) + public PrestoThriftBigint getBigintData() + { + return bigintData; + } + + @Nullable + @ThriftField(value = 3, requiredness = OPTIONAL) + public PrestoThriftDouble getDoubleData() + { + return doubleData; + } + + @Nullable + @ThriftField(value = 4, requiredness = OPTIONAL) + public PrestoThriftVarchar getVarcharData() + { + return varcharData; + } + + @Nullable + @ThriftField(value = 5, requiredness = OPTIONAL) + public PrestoThriftBoolean getBooleanData() + { + return booleanData; + } + + @Nullable + @ThriftField(value = 6, requiredness = OPTIONAL) + public PrestoThriftDate getDateData() + { + return dateData; + } + + @Nullable + @ThriftField(value = 7, requiredness = OPTIONAL) + public PrestoThriftTimestamp getTimestampData() + { + return timestampData; + } + + @Nullable + @ThriftField(value = 8, requiredness = OPTIONAL) + public PrestoThriftJson getJsonData() + { + return jsonData; + } + + @Nullable + @ThriftField(value = 9, requiredness = OPTIONAL) + public PrestoThriftHyperLogLog getHyperLogLogData() + { + return hyperLogLogData; + } + + @Nullable + @ThriftField(value = 10, requiredness = OPTIONAL) + public PrestoThriftBigintArray getBigintArrayData() + { + return bigintArrayData; + } + + public Block toBlock(Type desiredType) + { + return dataReference.toBlock(desiredType); + } + + public int numberOfRecords() + { + return dataReference.numberOfRecords(); + } + + @Override + public boolean equals(Object obj) + { + if (this == obj) { + return true; + } + if (obj == null || getClass() != obj.getClass()) { + return false; + } + PrestoThriftBlock other = (PrestoThriftBlock) obj; + // remaining fields are guaranteed to be null by the constructor + return Objects.equals(this.dataReference, other.dataReference); + } + + @Override + public int hashCode() + { + return Objects.hash(integerData, bigintData, doubleData, varcharData, booleanData, dateData, timestampData, jsonData, hyperLogLogData, bigintArrayData); + } + + @Override + public String toString() + { + return toStringHelper(this) + .add("data", dataReference) + .toString(); + } + + public static PrestoThriftBlock integerData(PrestoThriftInteger integerData) + { + return new PrestoThriftBlock(integerData, null, null, null, null, null, null, null, null, null); + } + + public static PrestoThriftBlock bigintData(PrestoThriftBigint bigintData) + { + return new PrestoThriftBlock(null, bigintData, null, null, null, null, null, null, null, null); + } + + public static PrestoThriftBlock doubleData(PrestoThriftDouble doubleData) + { + return new PrestoThriftBlock(null, null, doubleData, null, null, null, null, null, null, null); + } + + public static PrestoThriftBlock varcharData(PrestoThriftVarchar varcharData) + { + return new PrestoThriftBlock(null, null, null, varcharData, null, null, null, null, null, null); + } + + public static PrestoThriftBlock booleanData(PrestoThriftBoolean booleanData) + { + return new PrestoThriftBlock(null, null, null, null, booleanData, null, null, null, null, null); + } + + public static PrestoThriftBlock dateData(PrestoThriftDate dateData) + { + return new PrestoThriftBlock(null, null, null, null, null, dateData, null, null, null, null); + } + + public static PrestoThriftBlock timestampData(PrestoThriftTimestamp timestampData) + { + return new PrestoThriftBlock(null, null, null, null, null, null, timestampData, null, null, null); + } + + public static PrestoThriftBlock jsonData(PrestoThriftJson jsonData) + { + return new PrestoThriftBlock(null, null, null, null, null, null, null, jsonData, null, null); + } + + public static PrestoThriftBlock hyperLogLogData(PrestoThriftHyperLogLog hyperLogLogData) + { + return new PrestoThriftBlock(null, null, null, null, null, null, null, null, hyperLogLogData, null); + } + + public static PrestoThriftBlock bigintArrayData(PrestoThriftBigintArray bigintArrayData) + { + return new PrestoThriftBlock(null, null, null, null, null, null, null, null, null, bigintArrayData); + } + + public static PrestoThriftBlock fromBlock(Block block, Type type) + { + switch (type.getTypeSignature().getBase()) { + case INTEGER: + return PrestoThriftInteger.fromBlock(block); + case BIGINT: + return PrestoThriftBigint.fromBlock(block); + case DOUBLE: + return PrestoThriftDouble.fromBlock(block); + case VARCHAR: + return PrestoThriftVarchar.fromBlock(block, type); + case BOOLEAN: + return PrestoThriftBoolean.fromBlock(block); + case DATE: + return PrestoThriftDate.fromBlock(block); + case TIMESTAMP: + return PrestoThriftTimestamp.fromBlock(block); + case JSON: + return PrestoThriftJson.fromBlock(block, type); + case HYPER_LOG_LOG: + return PrestoThriftHyperLogLog.fromBlock(block); + case ARRAY: + Type elementType = getOnlyElement(type.getTypeParameters()); + if (BigintType.BIGINT.equals(elementType)) { + return PrestoThriftBigintArray.fromBlock(block); + } + else { + throw new IllegalArgumentException("Unsupported array block type: " + type); + } + default: + throw new IllegalArgumentException("Unsupported block type: " + type); + } + } + + private static PrestoThriftColumnData theOnlyNonNull(PrestoThriftColumnData... columnsData) + { + PrestoThriftColumnData result = null; + for (PrestoThriftColumnData data : columnsData) { + if (data != null) { + checkArgument(result == null, "more than one type is present"); + result = data; + } + } + checkArgument(result != null, "no types are present"); + return result; + } +} diff --git a/presto-thrift-connector-api/src/main/java/com/facebook/presto/connector/thrift/api/PrestoThriftColumnMetadata.java b/presto-thrift-connector-api/src/main/java/com/facebook/presto/connector/thrift/api/PrestoThriftColumnMetadata.java new file mode 100644 index 000000000000..16bba53e4616 --- /dev/null +++ b/presto-thrift-connector-api/src/main/java/com/facebook/presto/connector/thrift/api/PrestoThriftColumnMetadata.java @@ -0,0 +1,115 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.connector.thrift.api; + +import com.facebook.presto.spi.ColumnMetadata; +import com.facebook.presto.spi.type.TypeManager; +import com.facebook.swift.codec.ThriftConstructor; +import com.facebook.swift.codec.ThriftField; +import com.facebook.swift.codec.ThriftStruct; + +import javax.annotation.Nullable; + +import java.util.Objects; + +import static com.facebook.presto.connector.thrift.api.NameValidationUtils.checkValidName; +import static com.facebook.presto.spi.type.TypeSignature.parseTypeSignature; +import static com.facebook.swift.codec.ThriftField.Requiredness.OPTIONAL; +import static com.google.common.base.MoreObjects.toStringHelper; +import static java.util.Objects.requireNonNull; + +@ThriftStruct +public final class PrestoThriftColumnMetadata +{ + private final String name; + private final String type; + private final String comment; + private final boolean hidden; + + @ThriftConstructor + public PrestoThriftColumnMetadata(String name, String type, @Nullable String comment, boolean hidden) + { + this.name = checkValidName(name); + this.type = requireNonNull(type, "type is null"); + this.comment = comment; + this.hidden = hidden; + } + + @ThriftField(1) + public String getName() + { + return name; + } + + @ThriftField(2) + public String getType() + { + return type; + } + + @Nullable + @ThriftField(value = 3, requiredness = OPTIONAL) + public String getComment() + { + return comment; + } + + @ThriftField(4) + public boolean isHidden() + { + return hidden; + } + + public ColumnMetadata toColumnMetadata(TypeManager typeManager) + { + return new ColumnMetadata( + name, + typeManager.getType(parseTypeSignature(type)), + comment, + hidden); + } + + @Override + public boolean equals(Object obj) + { + if (this == obj) { + return true; + } + if (obj == null || getClass() != obj.getClass()) { + return false; + } + PrestoThriftColumnMetadata other = (PrestoThriftColumnMetadata) obj; + return Objects.equals(this.name, other.name) && + Objects.equals(this.type, other.type) && + Objects.equals(this.comment, other.comment) && + this.hidden == other.hidden; + } + + @Override + public int hashCode() + { + return Objects.hash(name, type, comment, hidden); + } + + @Override + public String toString() + { + return toStringHelper(this) + .add("name", name) + .add("type", type) + .add("comment", comment) + .add("hidden", hidden) + .toString(); + } +} diff --git a/presto-thrift-connector-api/src/main/java/com/facebook/presto/connector/thrift/api/PrestoThriftDomain.java b/presto-thrift-connector-api/src/main/java/com/facebook/presto/connector/thrift/api/PrestoThriftDomain.java new file mode 100644 index 000000000000..a4eb6eb6fc5f --- /dev/null +++ b/presto-thrift-connector-api/src/main/java/com/facebook/presto/connector/thrift/api/PrestoThriftDomain.java @@ -0,0 +1,86 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.connector.thrift.api; + +import com.facebook.presto.connector.thrift.api.valuesets.PrestoThriftValueSet; +import com.facebook.presto.spi.predicate.Domain; +import com.facebook.swift.codec.ThriftConstructor; +import com.facebook.swift.codec.ThriftField; +import com.facebook.swift.codec.ThriftStruct; + +import java.util.Objects; + +import static com.facebook.presto.connector.thrift.api.valuesets.PrestoThriftValueSet.fromValueSet; +import static com.google.common.base.MoreObjects.toStringHelper; +import static java.util.Objects.requireNonNull; + +@ThriftStruct +public final class PrestoThriftDomain +{ + private final PrestoThriftValueSet valueSet; + private final boolean nullAllowed; + + @ThriftConstructor + public PrestoThriftDomain(PrestoThriftValueSet valueSet, boolean nullAllowed) + { + this.valueSet = requireNonNull(valueSet, "valueSet is null"); + this.nullAllowed = nullAllowed; + } + + @ThriftField(1) + public PrestoThriftValueSet getValueSet() + { + return valueSet; + } + + @ThriftField(2) + public boolean isNullAllowed() + { + return nullAllowed; + } + + @Override + public boolean equals(Object obj) + { + if (this == obj) { + return true; + } + if (obj == null || getClass() != obj.getClass()) { + return false; + } + PrestoThriftDomain other = (PrestoThriftDomain) obj; + return Objects.equals(this.valueSet, other.valueSet) && + this.nullAllowed == other.nullAllowed; + } + + @Override + public int hashCode() + { + return Objects.hash(valueSet, nullAllowed); + } + + @Override + public String toString() + { + return toStringHelper(this) + .add("valueSet", valueSet) + .add("nullAllowed", nullAllowed) + .toString(); + } + + public static PrestoThriftDomain fromDomain(Domain domain) + { + return new PrestoThriftDomain(fromValueSet(domain.getValues()), domain.isNullAllowed()); + } +} diff --git a/presto-thrift-connector-api/src/main/java/com/facebook/presto/connector/thrift/api/PrestoThriftHostAddress.java b/presto-thrift-connector-api/src/main/java/com/facebook/presto/connector/thrift/api/PrestoThriftHostAddress.java new file mode 100644 index 000000000000..d3317fdf4718 --- /dev/null +++ b/presto-thrift-connector-api/src/main/java/com/facebook/presto/connector/thrift/api/PrestoThriftHostAddress.java @@ -0,0 +1,84 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.connector.thrift.api; + +import com.facebook.presto.spi.HostAddress; +import com.facebook.swift.codec.ThriftConstructor; +import com.facebook.swift.codec.ThriftField; +import com.facebook.swift.codec.ThriftStruct; + +import java.util.Objects; + +import static com.google.common.base.MoreObjects.toStringHelper; +import static java.util.Objects.requireNonNull; + +@ThriftStruct +public final class PrestoThriftHostAddress +{ + private final String host; + private final int port; + + @ThriftConstructor + public PrestoThriftHostAddress(String host, int port) + { + this.host = requireNonNull(host, "host is null"); + this.port = port; + } + + @ThriftField(1) + public String getHost() + { + return host; + } + + @ThriftField(2) + public int getPort() + { + return port; + } + + public HostAddress toHostAddress() + { + return HostAddress.fromParts(getHost(), getPort()); + } + + @Override + public boolean equals(Object obj) + { + if (this == obj) { + return true; + } + if (obj == null || getClass() != obj.getClass()) { + return false; + } + PrestoThriftHostAddress other = (PrestoThriftHostAddress) obj; + return Objects.equals(this.host, other.host) && + this.port == other.port; + } + + @Override + public int hashCode() + { + return Objects.hash(host, port); + } + + @Override + public String toString() + { + return toStringHelper(this) + .add("host", host) + .add("port", port) + .toString(); + } +} diff --git a/presto-thrift-connector-api/src/main/java/com/facebook/presto/connector/thrift/api/PrestoThriftId.java b/presto-thrift-connector-api/src/main/java/com/facebook/presto/connector/thrift/api/PrestoThriftId.java new file mode 100644 index 000000000000..082c960fb8f8 --- /dev/null +++ b/presto-thrift-connector-api/src/main/java/com/facebook/presto/connector/thrift/api/PrestoThriftId.java @@ -0,0 +1,89 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.connector.thrift.api; + +import com.facebook.swift.codec.ThriftConstructor; +import com.facebook.swift.codec.ThriftField; +import com.facebook.swift.codec.ThriftStruct; +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.google.common.annotations.VisibleForTesting; +import com.google.common.io.BaseEncoding; + +import java.util.Arrays; + +import static com.google.common.base.MoreObjects.toStringHelper; +import static java.util.Objects.requireNonNull; + +@ThriftStruct +public final class PrestoThriftId +{ + private static final int PREFIX_SUFFIX_BYTES = 8; + private static final String FILLER = ".."; + private static final int MAX_DISPLAY_CHARACTERS = PREFIX_SUFFIX_BYTES * 4 + FILLER.length(); + + private final byte[] id; + + @JsonCreator + @ThriftConstructor + public PrestoThriftId(@JsonProperty("id") byte[] id) + { + this.id = requireNonNull(id, "id is null"); + } + + @JsonProperty + @ThriftField(1) + public byte[] getId() + { + return id; + } + + @Override + public boolean equals(Object obj) + { + if (this == obj) { + return true; + } + if (obj == null || getClass() != obj.getClass()) { + return false; + } + PrestoThriftId other = (PrestoThriftId) obj; + return Arrays.equals(this.id, other.id); + } + + @Override + public int hashCode() + { + return Arrays.hashCode(id); + } + + @Override + public String toString() + { + return toStringHelper(this) + .add("id", summarize(id)) + .toString(); + } + + @VisibleForTesting + static String summarize(byte[] value) + { + if (value.length * 2 <= MAX_DISPLAY_CHARACTERS) { + return BaseEncoding.base16().encode(value); + } + return BaseEncoding.base16().encode(value, 0, PREFIX_SUFFIX_BYTES) + + FILLER + + BaseEncoding.base16().encode(value, value.length - PREFIX_SUFFIX_BYTES, PREFIX_SUFFIX_BYTES); + } +} diff --git a/presto-thrift-connector-api/src/main/java/com/facebook/presto/connector/thrift/api/PrestoThriftNullableColumnSet.java b/presto-thrift-connector-api/src/main/java/com/facebook/presto/connector/thrift/api/PrestoThriftNullableColumnSet.java new file mode 100644 index 000000000000..41b1e7f8b6d9 --- /dev/null +++ b/presto-thrift-connector-api/src/main/java/com/facebook/presto/connector/thrift/api/PrestoThriftNullableColumnSet.java @@ -0,0 +1,72 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.connector.thrift.api; + +import com.facebook.swift.codec.ThriftConstructor; +import com.facebook.swift.codec.ThriftField; +import com.facebook.swift.codec.ThriftStruct; + +import javax.annotation.Nullable; + +import java.util.Objects; +import java.util.Set; + +import static com.facebook.swift.codec.ThriftField.Requiredness.OPTIONAL; +import static com.google.common.base.MoreObjects.toStringHelper; + +@ThriftStruct +public final class PrestoThriftNullableColumnSet +{ + private final Set columns; + + @ThriftConstructor + public PrestoThriftNullableColumnSet(@Nullable Set columns) + { + this.columns = columns; + } + + @Nullable + @ThriftField(value = 1, requiredness = OPTIONAL) + public Set getColumns() + { + return columns; + } + + @Override + public boolean equals(Object obj) + { + if (this == obj) { + return true; + } + if (obj == null || getClass() != obj.getClass()) { + return false; + } + PrestoThriftNullableColumnSet other = (PrestoThriftNullableColumnSet) obj; + return Objects.equals(this.columns, other.columns); + } + + @Override + public int hashCode() + { + return Objects.hashCode(columns); + } + + @Override + public String toString() + { + return toStringHelper(this) + .add("columns", columns) + .toString(); + } +} diff --git a/presto-thrift-connector-api/src/main/java/com/facebook/presto/connector/thrift/api/PrestoThriftNullableSchemaName.java b/presto-thrift-connector-api/src/main/java/com/facebook/presto/connector/thrift/api/PrestoThriftNullableSchemaName.java new file mode 100644 index 000000000000..f048a9e92bea --- /dev/null +++ b/presto-thrift-connector-api/src/main/java/com/facebook/presto/connector/thrift/api/PrestoThriftNullableSchemaName.java @@ -0,0 +1,71 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.connector.thrift.api; + +import com.facebook.swift.codec.ThriftConstructor; +import com.facebook.swift.codec.ThriftField; +import com.facebook.swift.codec.ThriftStruct; + +import javax.annotation.Nullable; + +import java.util.Objects; + +import static com.facebook.swift.codec.ThriftField.Requiredness.OPTIONAL; +import static com.google.common.base.MoreObjects.toStringHelper; + +@ThriftStruct +public final class PrestoThriftNullableSchemaName +{ + private final String schemaName; + + @ThriftConstructor + public PrestoThriftNullableSchemaName(@Nullable String schemaName) + { + this.schemaName = schemaName; + } + + @Nullable + @ThriftField(value = 1, requiredness = OPTIONAL) + public String getSchemaName() + { + return schemaName; + } + + @Override + public boolean equals(Object obj) + { + if (this == obj) { + return true; + } + if (obj == null || getClass() != obj.getClass()) { + return false; + } + PrestoThriftNullableSchemaName other = (PrestoThriftNullableSchemaName) obj; + return Objects.equals(this.schemaName, other.schemaName); + } + + @Override + public int hashCode() + { + return Objects.hashCode(schemaName); + } + + @Override + public String toString() + { + return toStringHelper(this) + .add("schemaName", schemaName) + .toString(); + } +} diff --git a/presto-thrift-connector-api/src/main/java/com/facebook/presto/connector/thrift/api/PrestoThriftNullableTableMetadata.java b/presto-thrift-connector-api/src/main/java/com/facebook/presto/connector/thrift/api/PrestoThriftNullableTableMetadata.java new file mode 100644 index 000000000000..a94cfcf401b5 --- /dev/null +++ b/presto-thrift-connector-api/src/main/java/com/facebook/presto/connector/thrift/api/PrestoThriftNullableTableMetadata.java @@ -0,0 +1,71 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.connector.thrift.api; + +import com.facebook.swift.codec.ThriftConstructor; +import com.facebook.swift.codec.ThriftField; +import com.facebook.swift.codec.ThriftStruct; + +import javax.annotation.Nullable; + +import java.util.Objects; + +import static com.facebook.swift.codec.ThriftField.Requiredness.OPTIONAL; +import static com.google.common.base.MoreObjects.toStringHelper; + +@ThriftStruct +public final class PrestoThriftNullableTableMetadata +{ + private final PrestoThriftTableMetadata tableMetadata; + + @ThriftConstructor + public PrestoThriftNullableTableMetadata(@Nullable PrestoThriftTableMetadata tableMetadata) + { + this.tableMetadata = tableMetadata; + } + + @Nullable + @ThriftField(value = 1, requiredness = OPTIONAL) + public PrestoThriftTableMetadata getTableMetadata() + { + return tableMetadata; + } + + @Override + public boolean equals(Object obj) + { + if (this == obj) { + return true; + } + if (obj == null || getClass() != obj.getClass()) { + return false; + } + PrestoThriftNullableTableMetadata other = (PrestoThriftNullableTableMetadata) obj; + return Objects.equals(this.tableMetadata, other.tableMetadata); + } + + @Override + public int hashCode() + { + return Objects.hashCode(tableMetadata); + } + + @Override + public String toString() + { + return toStringHelper(this) + .add("tableMetadata", tableMetadata) + .toString(); + } +} diff --git a/presto-thrift-connector-api/src/main/java/com/facebook/presto/connector/thrift/api/PrestoThriftNullableToken.java b/presto-thrift-connector-api/src/main/java/com/facebook/presto/connector/thrift/api/PrestoThriftNullableToken.java new file mode 100644 index 000000000000..b8670e5cb152 --- /dev/null +++ b/presto-thrift-connector-api/src/main/java/com/facebook/presto/connector/thrift/api/PrestoThriftNullableToken.java @@ -0,0 +1,71 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.connector.thrift.api; + +import com.facebook.swift.codec.ThriftConstructor; +import com.facebook.swift.codec.ThriftField; +import com.facebook.swift.codec.ThriftStruct; + +import javax.annotation.Nullable; + +import java.util.Objects; + +import static com.facebook.swift.codec.ThriftField.Requiredness.OPTIONAL; +import static com.google.common.base.MoreObjects.toStringHelper; + +@ThriftStruct +public final class PrestoThriftNullableToken +{ + private final PrestoThriftId token; + + @ThriftConstructor + public PrestoThriftNullableToken(@Nullable PrestoThriftId token) + { + this.token = token; + } + + @Nullable + @ThriftField(value = 1, requiredness = OPTIONAL) + public PrestoThriftId getToken() + { + return token; + } + + @Override + public boolean equals(Object obj) + { + if (this == obj) { + return true; + } + if (obj == null || getClass() != obj.getClass()) { + return false; + } + PrestoThriftNullableToken other = (PrestoThriftNullableToken) obj; + return Objects.equals(this.token, other.token); + } + + @Override + public int hashCode() + { + return Objects.hashCode(token); + } + + @Override + public String toString() + { + return toStringHelper(this) + .add("token", token) + .toString(); + } +} diff --git a/presto-thrift-connector-api/src/main/java/com/facebook/presto/connector/thrift/api/PrestoThriftPageResult.java b/presto-thrift-connector-api/src/main/java/com/facebook/presto/connector/thrift/api/PrestoThriftPageResult.java new file mode 100644 index 000000000000..71d046d3c37a --- /dev/null +++ b/presto-thrift-connector-api/src/main/java/com/facebook/presto/connector/thrift/api/PrestoThriftPageResult.java @@ -0,0 +1,131 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.connector.thrift.api; + +import com.facebook.presto.spi.Page; +import com.facebook.presto.spi.block.Block; +import com.facebook.presto.spi.type.Type; +import com.facebook.swift.codec.ThriftConstructor; +import com.facebook.swift.codec.ThriftField; +import com.facebook.swift.codec.ThriftStruct; + +import javax.annotation.Nullable; + +import java.util.List; +import java.util.Objects; + +import static com.facebook.swift.codec.ThriftField.Requiredness.OPTIONAL; +import static com.google.common.base.MoreObjects.toStringHelper; +import static com.google.common.base.Preconditions.checkArgument; +import static java.util.Objects.requireNonNull; + +@ThriftStruct +public final class PrestoThriftPageResult +{ + private final List columnBlocks; + private final int rowCount; + private final PrestoThriftId nextToken; + + @ThriftConstructor + public PrestoThriftPageResult(List columnBlocks, int rowCount, @Nullable PrestoThriftId nextToken) + { + this.columnBlocks = requireNonNull(columnBlocks, "columnBlocks is null"); + checkArgument(rowCount >= 0, "rowCount is negative"); + checkAllColumnsAreOfExpectedSize(columnBlocks, rowCount); + this.rowCount = rowCount; + this.nextToken = nextToken; + } + + /** + * Returns data in a columnar format. + * Columns in this list must be in the order they were requested by the engine. + */ + @ThriftField(1) + public List getColumnBlocks() + { + return columnBlocks; + } + + @ThriftField(2) + public int getRowCount() + { + return rowCount; + } + + @Nullable + @ThriftField(value = 3, requiredness = OPTIONAL) + public PrestoThriftId getNextToken() + { + return nextToken; + } + + @Nullable + public Page toPage(List columnTypes) + { + if (rowCount == 0) { + return null; + } + checkArgument(columnBlocks.size() == columnTypes.size(), "columns and types have different sizes"); + int numberOfColumns = columnBlocks.size(); + if (numberOfColumns == 0) { + // request/response with no columns, used for queries like "select count star" + return new Page(rowCount); + } + Block[] blocks = new Block[numberOfColumns]; + for (int i = 0; i < numberOfColumns; i++) { + blocks[i] = columnBlocks.get(i).toBlock(columnTypes.get(i)); + } + return new Page(blocks); + } + + @Override + public boolean equals(Object obj) + { + if (this == obj) { + return true; + } + if (obj == null || getClass() != obj.getClass()) { + return false; + } + PrestoThriftPageResult other = (PrestoThriftPageResult) obj; + return Objects.equals(this.columnBlocks, other.columnBlocks) && + this.rowCount == other.rowCount && + Objects.equals(this.nextToken, other.nextToken); + } + + @Override + public int hashCode() + { + return Objects.hash(columnBlocks, rowCount, nextToken); + } + + @Override + public String toString() + { + return toStringHelper(this) + .add("columnBlocks", columnBlocks) + .add("rowCount", rowCount) + .add("nextToken", nextToken) + .toString(); + } + + private static void checkAllColumnsAreOfExpectedSize(List columnBlocks, int expectedNumberOfRows) + { + for (int i = 0; i < columnBlocks.size(); i++) { + checkArgument(columnBlocks.get(i).numberOfRecords() == expectedNumberOfRows, + "Incorrect number of records for column with index %s: expected %s, got %s", + i, expectedNumberOfRows, columnBlocks.get(i).numberOfRecords()); + } + } +} diff --git a/presto-thrift-connector-api/src/main/java/com/facebook/presto/connector/thrift/api/PrestoThriftSchemaTableName.java b/presto-thrift-connector-api/src/main/java/com/facebook/presto/connector/thrift/api/PrestoThriftSchemaTableName.java new file mode 100644 index 000000000000..0efce302de41 --- /dev/null +++ b/presto-thrift-connector-api/src/main/java/com/facebook/presto/connector/thrift/api/PrestoThriftSchemaTableName.java @@ -0,0 +1,89 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.connector.thrift.api; + +import com.facebook.presto.spi.SchemaTableName; +import com.facebook.swift.codec.ThriftConstructor; +import com.facebook.swift.codec.ThriftField; +import com.facebook.swift.codec.ThriftStruct; + +import java.util.Objects; + +import static com.facebook.presto.connector.thrift.api.NameValidationUtils.checkValidName; +import static com.google.common.base.MoreObjects.toStringHelper; + +@ThriftStruct +public final class PrestoThriftSchemaTableName +{ + private final String schemaName; + private final String tableName; + + @ThriftConstructor + public PrestoThriftSchemaTableName(String schemaName, String tableName) + { + this.schemaName = checkValidName(schemaName); + this.tableName = checkValidName(tableName); + } + + @ThriftField(1) + public String getSchemaName() + { + return schemaName; + } + + @ThriftField(2) + public String getTableName() + { + return tableName; + } + + @Override + public boolean equals(Object obj) + { + if (this == obj) { + return true; + } + if (obj == null || getClass() != obj.getClass()) { + return false; + } + PrestoThriftSchemaTableName other = (PrestoThriftSchemaTableName) obj; + return Objects.equals(this.schemaName, other.schemaName) && + Objects.equals(this.tableName, other.tableName); + } + + @Override + public int hashCode() + { + return Objects.hash(schemaName, tableName); + } + + @Override + public String toString() + { + return toStringHelper(this) + .add("schemaName", schemaName) + .add("tableName", tableName) + .toString(); + } + + public SchemaTableName toSchemaTableName() + { + return new SchemaTableName(getSchemaName(), getTableName()); + } + + public static PrestoThriftSchemaTableName fromSchemaTableName(SchemaTableName schemaTableName) + { + return new PrestoThriftSchemaTableName(schemaTableName.getSchemaName(), schemaTableName.getTableName()); + } +} diff --git a/presto-thrift-connector-api/src/main/java/com/facebook/presto/connector/thrift/api/PrestoThriftService.java b/presto-thrift-connector-api/src/main/java/com/facebook/presto/connector/thrift/api/PrestoThriftService.java new file mode 100644 index 000000000000..f40babdc7f6a --- /dev/null +++ b/presto-thrift-connector-api/src/main/java/com/facebook/presto/connector/thrift/api/PrestoThriftService.java @@ -0,0 +1,100 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.connector.thrift.api; + +import com.facebook.swift.codec.ThriftField; +import com.facebook.swift.service.ThriftMethod; +import com.facebook.swift.service.ThriftService; +import com.google.common.util.concurrent.ListenableFuture; + +import java.io.Closeable; +import java.util.List; + +/** + * Presto Thrift service definition. + * This thrift service needs to be implemented in order to be used with Thrift Connector. + */ +@ThriftService +public interface PrestoThriftService + extends Closeable +{ + /** + * Returns available schema names. + */ + @ThriftMethod("prestoListSchemaNames") + List listSchemaNames() + throws PrestoThriftServiceException; + + /** + * Returns tables for the given schema name. + * + * @param schemaNameOrNull a structure containing schema name or {@literal null} + * @return a list of table names with corresponding schemas. If schema name is null then returns + * a list of tables for all schemas. Returns an empty list if a schema does not exist + */ + @ThriftMethod("prestoListTables") + List listTables( + @ThriftField(name = "schemaNameOrNull") PrestoThriftNullableSchemaName schemaNameOrNull) + throws PrestoThriftServiceException; + + /** + * Returns metadata for a given table. + * + * @param schemaTableName schema and table name + * @return metadata for a given table, or a {@literal null} value inside if it does not exist + */ + @ThriftMethod("prestoGetTableMetadata") + PrestoThriftNullableTableMetadata getTableMetadata( + @ThriftField(name = "schemaTableName") PrestoThriftSchemaTableName schemaTableName) + throws PrestoThriftServiceException; + + /** + * Returns a batch of splits. + * + * @param schemaTableName schema and table name + * @param desiredColumns a superset of columns to return; empty set means "no columns", {@literal null} set means "all columns" + * @param outputConstraint constraint on the returned data + * @param maxSplitCount maximum number of splits to return + * @param nextToken token from a previous split batch or {@literal null} if it is the first call + * @return a batch of splits + */ + @ThriftMethod("prestoGetSplits") + ListenableFuture getSplits( + @ThriftField(name = "schemaTableName") PrestoThriftSchemaTableName schemaTableName, + @ThriftField(name = "desiredColumns") PrestoThriftNullableColumnSet desiredColumns, + @ThriftField(name = "outputConstraint") PrestoThriftTupleDomain outputConstraint, + @ThriftField(name = "maxSplitCount") int maxSplitCount, + @ThriftField(name = "nextToken") PrestoThriftNullableToken nextToken) + throws PrestoThriftServiceException; + + /** + * Returns a batch of rows for the given split. + * + * @param splitId split id as returned in split batch + * @param columns a list of column names to return + * @param maxBytes maximum size of returned data in bytes + * @param nextToken token from a previous batch or {@literal null} if it is the first call + * @return a batch of table data + */ + @ThriftMethod("prestoGetRows") + ListenableFuture getRows( + @ThriftField(name = "splitId") PrestoThriftId splitId, + @ThriftField(name = "columns") List columns, + @ThriftField(name = "maxBytes") long maxBytes, + @ThriftField(name = "nextToken") PrestoThriftNullableToken nextToken) + throws PrestoThriftServiceException; + + @Override + void close(); +} diff --git a/presto-thrift-connector-api/src/main/java/com/facebook/presto/connector/thrift/api/PrestoThriftServiceException.java b/presto-thrift-connector-api/src/main/java/com/facebook/presto/connector/thrift/api/PrestoThriftServiceException.java new file mode 100644 index 000000000000..fb51007111cf --- /dev/null +++ b/presto-thrift-connector-api/src/main/java/com/facebook/presto/connector/thrift/api/PrestoThriftServiceException.java @@ -0,0 +1,45 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.connector.thrift.api; + +import com.facebook.swift.codec.ThriftConstructor; +import com.facebook.swift.codec.ThriftField; +import com.facebook.swift.codec.ThriftStruct; + +@ThriftStruct +public final class PrestoThriftServiceException + extends RuntimeException +{ + private final boolean retryable; + + @ThriftConstructor + public PrestoThriftServiceException(String message, boolean retryable) + { + super(message); + this.retryable = retryable; + } + + @Override + @ThriftField(1) + public String getMessage() + { + return super.getMessage(); + } + + @ThriftField(2) + public boolean isRetryable() + { + return retryable; + } +} diff --git a/presto-thrift-connector-api/src/main/java/com/facebook/presto/connector/thrift/api/PrestoThriftSplit.java b/presto-thrift-connector-api/src/main/java/com/facebook/presto/connector/thrift/api/PrestoThriftSplit.java new file mode 100644 index 000000000000..a146d9bc727d --- /dev/null +++ b/presto-thrift-connector-api/src/main/java/com/facebook/presto/connector/thrift/api/PrestoThriftSplit.java @@ -0,0 +1,79 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.connector.thrift.api; + +import com.facebook.swift.codec.ThriftConstructor; +import com.facebook.swift.codec.ThriftField; +import com.facebook.swift.codec.ThriftStruct; + +import java.util.List; +import java.util.Objects; + +import static com.google.common.base.MoreObjects.toStringHelper; +import static java.util.Objects.requireNonNull; + +@ThriftStruct +public final class PrestoThriftSplit +{ + private final PrestoThriftId splitId; + private final List hosts; + + @ThriftConstructor + public PrestoThriftSplit(PrestoThriftId splitId, List hosts) + { + this.splitId = requireNonNull(splitId, "splitId is null"); + this.hosts = requireNonNull(hosts, "hosts is null"); + } + + @ThriftField(1) + public PrestoThriftId getSplitId() + { + return splitId; + } + + @ThriftField(2) + public List getHosts() + { + return hosts; + } + + @Override + public boolean equals(Object obj) + { + if (this == obj) { + return true; + } + if (obj == null || getClass() != obj.getClass()) { + return false; + } + PrestoThriftSplit other = (PrestoThriftSplit) obj; + return Objects.equals(this.splitId, other.splitId) && + Objects.equals(this.hosts, other.hosts); + } + + @Override + public int hashCode() + { + return Objects.hash(splitId, hosts); + } + + @Override + public String toString() + { + return toStringHelper(this) + .add("splitId", splitId) + .add("hosts", hosts) + .toString(); + } +} diff --git a/presto-thrift-connector-api/src/main/java/com/facebook/presto/connector/thrift/api/PrestoThriftSplitBatch.java b/presto-thrift-connector-api/src/main/java/com/facebook/presto/connector/thrift/api/PrestoThriftSplitBatch.java new file mode 100644 index 000000000000..95f265207ea9 --- /dev/null +++ b/presto-thrift-connector-api/src/main/java/com/facebook/presto/connector/thrift/api/PrestoThriftSplitBatch.java @@ -0,0 +1,83 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.connector.thrift.api; + +import com.facebook.swift.codec.ThriftConstructor; +import com.facebook.swift.codec.ThriftField; +import com.facebook.swift.codec.ThriftStruct; + +import javax.annotation.Nullable; + +import java.util.List; +import java.util.Objects; + +import static com.facebook.swift.codec.ThriftField.Requiredness.OPTIONAL; +import static com.google.common.base.MoreObjects.toStringHelper; +import static java.util.Objects.requireNonNull; + +@ThriftStruct +public final class PrestoThriftSplitBatch +{ + private final List splits; + private final PrestoThriftId nextToken; + + @ThriftConstructor + public PrestoThriftSplitBatch(List splits, @Nullable PrestoThriftId nextToken) + { + this.splits = requireNonNull(splits, "splits is null"); + this.nextToken = nextToken; + } + + @ThriftField(1) + public List getSplits() + { + return splits; + } + + @Nullable + @ThriftField(value = 2, requiredness = OPTIONAL) + public PrestoThriftId getNextToken() + { + return nextToken; + } + + @Override + public boolean equals(Object obj) + { + if (this == obj) { + return true; + } + if (obj == null || getClass() != obj.getClass()) { + return false; + } + PrestoThriftSplitBatch other = (PrestoThriftSplitBatch) obj; + return Objects.equals(this.splits, other.splits) && + Objects.equals(this.nextToken, other.nextToken); + } + + @Override + public int hashCode() + { + return Objects.hash(splits, nextToken); + } + + @Override + public String toString() + { + return toStringHelper(this) + .add("numberOfSplits", splits.size()) + .add("nextToken", nextToken) + .toString(); + } +} diff --git a/presto-thrift-connector-api/src/main/java/com/facebook/presto/connector/thrift/api/PrestoThriftTableMetadata.java b/presto-thrift-connector-api/src/main/java/com/facebook/presto/connector/thrift/api/PrestoThriftTableMetadata.java new file mode 100644 index 000000000000..0e9fceba52da --- /dev/null +++ b/presto-thrift-connector-api/src/main/java/com/facebook/presto/connector/thrift/api/PrestoThriftTableMetadata.java @@ -0,0 +1,118 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.connector.thrift.api; + +import com.facebook.presto.spi.ColumnMetadata; +import com.facebook.presto.spi.ConnectorTableMetadata; +import com.facebook.presto.spi.type.TypeManager; +import com.facebook.swift.codec.ThriftConstructor; +import com.facebook.swift.codec.ThriftField; +import com.facebook.swift.codec.ThriftStruct; +import com.google.common.collect.ImmutableMap; + +import javax.annotation.Nullable; + +import java.util.List; +import java.util.Objects; +import java.util.Optional; + +import static com.facebook.swift.codec.ThriftField.Requiredness.OPTIONAL; +import static com.google.common.base.MoreObjects.toStringHelper; +import static com.google.common.collect.ImmutableList.toImmutableList; +import static java.util.Objects.requireNonNull; + +@ThriftStruct +public final class PrestoThriftTableMetadata +{ + private final PrestoThriftSchemaTableName schemaTableName; + private final List columns; + private final String comment; + + @ThriftConstructor + public PrestoThriftTableMetadata( + @ThriftField(name = "schemaTableName") PrestoThriftSchemaTableName schemaTableName, + @ThriftField(name = "columns") List columns, + @ThriftField(name = "comment") @Nullable String comment) + { + this.schemaTableName = requireNonNull(schemaTableName, "schemaTableName is null"); + this.columns = requireNonNull(columns, "columns is null"); + this.comment = comment; + } + + @ThriftField(1) + public PrestoThriftSchemaTableName getSchemaTableName() + { + return schemaTableName; + } + + @ThriftField(2) + public List getColumns() + { + return columns; + } + + @Nullable + @ThriftField(value = 3, requiredness = OPTIONAL) + public String getComment() + { + return comment; + } + + public ConnectorTableMetadata toConnectorTableMetadata(TypeManager typeManager) + { + return new ConnectorTableMetadata( + schemaTableName.toSchemaTableName(), + columnMetadata(typeManager), + ImmutableMap.of(), + Optional.ofNullable(comment)); + } + + private List columnMetadata(TypeManager typeManager) + { + return columns.stream() + .map(column -> column.toColumnMetadata(typeManager)) + .collect(toImmutableList()); + } + + @Override + public boolean equals(Object obj) + { + if (this == obj) { + return true; + } + if (obj == null || getClass() != obj.getClass()) { + return false; + } + PrestoThriftTableMetadata other = (PrestoThriftTableMetadata) obj; + return Objects.equals(this.schemaTableName, other.schemaTableName) && + Objects.equals(this.columns, other.columns) && + Objects.equals(this.comment, other.comment); + } + + @Override + public int hashCode() + { + return Objects.hash(schemaTableName, columns, comment); + } + + @Override + public String toString() + { + return toStringHelper(this) + .add("schemaTableName", schemaTableName) + .add("numberOfColumns", columns.size()) + .add("comment", comment) + .toString(); + } +} diff --git a/presto-thrift-connector-api/src/main/java/com/facebook/presto/connector/thrift/api/PrestoThriftTupleDomain.java b/presto-thrift-connector-api/src/main/java/com/facebook/presto/connector/thrift/api/PrestoThriftTupleDomain.java new file mode 100644 index 000000000000..a1631d26e278 --- /dev/null +++ b/presto-thrift-connector-api/src/main/java/com/facebook/presto/connector/thrift/api/PrestoThriftTupleDomain.java @@ -0,0 +1,82 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.connector.thrift.api; + +import com.facebook.swift.codec.ThriftConstructor; +import com.facebook.swift.codec.ThriftField; +import com.facebook.swift.codec.ThriftStruct; +import com.google.common.collect.ImmutableSet; + +import javax.annotation.Nullable; + +import java.util.Map; +import java.util.Objects; + +import static com.facebook.presto.connector.thrift.api.NameValidationUtils.checkValidName; +import static com.facebook.swift.codec.ThriftField.Requiredness.OPTIONAL; +import static com.google.common.base.MoreObjects.toStringHelper; + +@ThriftStruct +public final class PrestoThriftTupleDomain +{ + private final Map domains; + + @ThriftConstructor + public PrestoThriftTupleDomain(@Nullable Map domains) + { + if (domains != null) { + for (String name : domains.keySet()) { + checkValidName(name); + } + } + this.domains = domains; + } + + /** + * Return a map of column names to constraints. + */ + @Nullable + @ThriftField(value = 1, requiredness = OPTIONAL) + public Map getDomains() + { + return domains; + } + + @Override + public boolean equals(Object obj) + { + if (this == obj) { + return true; + } + if (obj == null || getClass() != obj.getClass()) { + return false; + } + PrestoThriftTupleDomain other = (PrestoThriftTupleDomain) obj; + return Objects.equals(this.domains, other.domains); + } + + @Override + public int hashCode() + { + return Objects.hashCode(domains); + } + + @Override + public String toString() + { + return toStringHelper(this) + .add("columnsWithConstraints", domains != null ? domains.keySet() : ImmutableSet.of()) + .toString(); + } +} diff --git a/presto-thrift-connector-api/src/main/java/com/facebook/presto/connector/thrift/api/datatypes/PrestoThriftBigint.java b/presto-thrift-connector-api/src/main/java/com/facebook/presto/connector/thrift/api/datatypes/PrestoThriftBigint.java new file mode 100644 index 000000000000..8ebe4c6c33f5 --- /dev/null +++ b/presto-thrift-connector-api/src/main/java/com/facebook/presto/connector/thrift/api/datatypes/PrestoThriftBigint.java @@ -0,0 +1,131 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.connector.thrift.api.datatypes; + +import com.facebook.presto.connector.thrift.api.PrestoThriftBlock; +import com.facebook.presto.spi.block.Block; +import com.facebook.presto.spi.block.LongArrayBlock; +import com.facebook.presto.spi.type.Type; +import com.facebook.swift.codec.ThriftConstructor; +import com.facebook.swift.codec.ThriftField; +import com.facebook.swift.codec.ThriftStruct; + +import javax.annotation.Nullable; + +import java.util.Arrays; +import java.util.Objects; + +import static com.facebook.presto.connector.thrift.api.PrestoThriftBlock.bigintData; +import static com.facebook.presto.connector.thrift.api.datatypes.PrestoThriftTypeUtils.fromLongBasedBlock; +import static com.facebook.presto.spi.type.BigintType.BIGINT; +import static com.facebook.swift.codec.ThriftField.Requiredness.OPTIONAL; +import static com.google.common.base.MoreObjects.toStringHelper; +import static com.google.common.base.Preconditions.checkArgument; + +/** + * Elements of {@code nulls} array determine if a value for a corresponding row is null. + * Elements of {@code longs} array are values for each row. If row is null then value is ignored. + */ +@ThriftStruct +public final class PrestoThriftBigint + implements PrestoThriftColumnData +{ + private final boolean[] nulls; + private final long[] longs; + + @ThriftConstructor + public PrestoThriftBigint( + @ThriftField(name = "nulls") @Nullable boolean[] nulls, + @ThriftField(name = "longs") @Nullable long[] longs) + { + checkArgument(sameSizeIfPresent(nulls, longs), "nulls and values must be of the same size"); + this.nulls = nulls; + this.longs = longs; + } + + @Nullable + @ThriftField(value = 1, requiredness = OPTIONAL) + public boolean[] getNulls() + { + return nulls; + } + + @Nullable + @ThriftField(value = 2, requiredness = OPTIONAL) + public long[] getLongs() + { + return longs; + } + + @Override + public Block toBlock(Type desiredType) + { + checkArgument(BIGINT.equals(desiredType), "type doesn't match: %s", desiredType); + int numberOfRecords = numberOfRecords(); + return new LongArrayBlock( + numberOfRecords, + nulls == null ? new boolean[numberOfRecords] : nulls, + longs == null ? new long[numberOfRecords] : longs); + } + + @Override + public int numberOfRecords() + { + if (nulls != null) { + return nulls.length; + } + if (longs != null) { + return longs.length; + } + return 0; + } + + @Override + public boolean equals(Object obj) + { + if (this == obj) { + return true; + } + if (obj == null || getClass() != obj.getClass()) { + return false; + } + PrestoThriftBigint other = (PrestoThriftBigint) obj; + return Arrays.equals(this.nulls, other.nulls) && + Arrays.equals(this.longs, other.longs); + } + + @Override + public int hashCode() + { + return Objects.hash(Arrays.hashCode(nulls), Arrays.hashCode(longs)); + } + + @Override + public String toString() + { + return toStringHelper(this) + .add("numberOfRecords", numberOfRecords()) + .toString(); + } + + public static PrestoThriftBlock fromBlock(Block block) + { + return fromLongBasedBlock(block, BIGINT, (nulls, longs) -> bigintData(new PrestoThriftBigint(nulls, longs))); + } + + private static boolean sameSizeIfPresent(boolean[] nulls, long[] longs) + { + return nulls == null || longs == null || nulls.length == longs.length; + } +} diff --git a/presto-thrift-connector-api/src/main/java/com/facebook/presto/connector/thrift/api/datatypes/PrestoThriftBigintArray.java b/presto-thrift-connector-api/src/main/java/com/facebook/presto/connector/thrift/api/datatypes/PrestoThriftBigintArray.java new file mode 100644 index 000000000000..6232ba23f8f4 --- /dev/null +++ b/presto-thrift-connector-api/src/main/java/com/facebook/presto/connector/thrift/api/datatypes/PrestoThriftBigintArray.java @@ -0,0 +1,180 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.connector.thrift.api.datatypes; + +import com.facebook.presto.connector.thrift.api.PrestoThriftBlock; +import com.facebook.presto.spi.block.AbstractArrayBlock; +import com.facebook.presto.spi.block.ArrayBlock; +import com.facebook.presto.spi.block.Block; +import com.facebook.presto.spi.block.LongArrayBlock; +import com.facebook.presto.spi.type.Type; +import com.facebook.swift.codec.ThriftConstructor; +import com.facebook.swift.codec.ThriftField; +import com.facebook.swift.codec.ThriftStruct; + +import javax.annotation.Nullable; + +import java.util.Arrays; +import java.util.Objects; + +import static com.facebook.presto.connector.thrift.api.PrestoThriftBlock.bigintArrayData; +import static com.facebook.presto.connector.thrift.api.datatypes.PrestoThriftTypeUtils.calculateOffsets; +import static com.facebook.presto.connector.thrift.api.datatypes.PrestoThriftTypeUtils.sameSizeIfPresent; +import static com.facebook.presto.connector.thrift.api.datatypes.PrestoThriftTypeUtils.totalSize; +import static com.facebook.presto.spi.type.BigintType.BIGINT; +import static com.facebook.swift.codec.ThriftField.Requiredness.OPTIONAL; +import static com.google.common.base.MoreObjects.toStringHelper; +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Preconditions.checkState; + +/** + * Elements of {@code nulls} array determine if a value for a corresponding row is null. + * Each elements of {@code sizes} array contains the number of elements in the corresponding values array. + * If row is null then the corresponding element in {@code sizes} is ignored. + * {@code values} is a bigint block containing array elements one after another for all rows. + * The total number of elements in bigint block must be equal to the sum of all sizes. + */ +@ThriftStruct +public final class PrestoThriftBigintArray + implements PrestoThriftColumnData +{ + private final boolean[] nulls; + private final int[] sizes; + private final PrestoThriftBigint values; + + @ThriftConstructor + public PrestoThriftBigintArray( + @ThriftField(name = "nulls") @Nullable boolean[] nulls, + @ThriftField(name = "sizes") @Nullable int[] sizes, + @ThriftField(name = "values") @Nullable PrestoThriftBigint values) + { + checkArgument(sameSizeIfPresent(nulls, sizes), "nulls and values must be of the same size"); + checkArgument(totalSize(nulls, sizes) == numberOfValues(values), "total number of values doesn't match expected size"); + this.nulls = nulls; + this.sizes = sizes; + this.values = values; + } + + @Nullable + @ThriftField(value = 1, requiredness = OPTIONAL) + public boolean[] getNulls() + { + return nulls; + } + + @Nullable + @ThriftField(value = 2, requiredness = OPTIONAL) + public int[] getSizes() + { + return sizes; + } + + @Nullable + @ThriftField(value = 3, requiredness = OPTIONAL) + public PrestoThriftBigint getValues() + { + return values; + } + + @Override + public Block toBlock(Type desiredType) + { + checkArgument(desiredType.getTypeParameters().size() == 1 && BIGINT.equals(desiredType.getTypeParameters().get(0)), + "type doesn't match: %s", desiredType); + int numberOfRecords = numberOfRecords(); + return new ArrayBlock( + numberOfRecords, + nulls == null ? new boolean[numberOfRecords] : nulls, + calculateOffsets(sizes, nulls, numberOfRecords), + values != null ? values.toBlock(BIGINT) : new LongArrayBlock(0, new boolean[] {}, new long[] {})); + } + + @Override + public int numberOfRecords() + { + if (nulls != null) { + return nulls.length; + } + if (sizes != null) { + return sizes.length; + } + return 0; + } + + @Override + public boolean equals(Object obj) + { + if (this == obj) { + return true; + } + if (obj == null || getClass() != obj.getClass()) { + return false; + } + PrestoThriftBigintArray other = (PrestoThriftBigintArray) obj; + return Arrays.equals(this.nulls, other.nulls) && + Arrays.equals(this.sizes, other.sizes) && + Objects.equals(this.values, other.values); + } + + @Override + public int hashCode() + { + return Objects.hash(Arrays.hashCode(nulls), Arrays.hashCode(sizes), values); + } + + @Override + public String toString() + { + return toStringHelper(this) + .add("numberOfRecords", numberOfRecords()) + .toString(); + } + + public static PrestoThriftBlock fromBlock(Block block) + { + checkArgument(block instanceof AbstractArrayBlock, "block is not of an array type"); + AbstractArrayBlock arrayBlock = (AbstractArrayBlock) block; + int positions = arrayBlock.getPositionCount(); + if (positions == 0) { + return bigintArrayData(new PrestoThriftBigintArray(null, null, null)); + } + boolean[] nulls = null; + int[] sizes = null; + for (int position = 0; position < positions; position++) { + if (arrayBlock.isNull(position)) { + if (nulls == null) { + nulls = new boolean[positions]; + } + nulls[position] = true; + } + else { + if (sizes == null) { + sizes = new int[positions]; + } + sizes[position] = arrayBlock.apply((valuesBlock, startPosition, length) -> length, position); + } + } + PrestoThriftBigint values = arrayBlock + .apply((valuesBlock, startPosition, length) -> PrestoThriftBigint.fromBlock(valuesBlock), 0) + .getBigintData(); + checkState(values != null, "values must be present"); + checkState(totalSize(nulls, sizes) == values.numberOfRecords(), "unexpected number of values"); + return bigintArrayData(new PrestoThriftBigintArray(nulls, sizes, values)); + } + + private static int numberOfValues(PrestoThriftBigint values) + { + return values != null ? values.numberOfRecords() : 0; + } +} diff --git a/presto-thrift-connector-api/src/main/java/com/facebook/presto/connector/thrift/api/datatypes/PrestoThriftBoolean.java b/presto-thrift-connector-api/src/main/java/com/facebook/presto/connector/thrift/api/datatypes/PrestoThriftBoolean.java new file mode 100644 index 000000000000..1f7067406e76 --- /dev/null +++ b/presto-thrift-connector-api/src/main/java/com/facebook/presto/connector/thrift/api/datatypes/PrestoThriftBoolean.java @@ -0,0 +1,157 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.connector.thrift.api.datatypes; + +import com.facebook.presto.connector.thrift.api.PrestoThriftBlock; +import com.facebook.presto.spi.block.Block; +import com.facebook.presto.spi.block.ByteArrayBlock; +import com.facebook.presto.spi.type.Type; +import com.facebook.swift.codec.ThriftConstructor; +import com.facebook.swift.codec.ThriftField; +import com.facebook.swift.codec.ThriftStruct; + +import javax.annotation.Nullable; + +import java.util.Arrays; +import java.util.Objects; + +import static com.facebook.presto.connector.thrift.api.PrestoThriftBlock.booleanData; +import static com.facebook.presto.spi.type.BooleanType.BOOLEAN; +import static com.facebook.swift.codec.ThriftField.Requiredness.OPTIONAL; +import static com.google.common.base.MoreObjects.toStringHelper; +import static com.google.common.base.Preconditions.checkArgument; + +/** + * Elements of {@code nulls} array determine if a value for a corresponding row is null. + * Elements of {@code booleans} array are values for each row. If row is null then value is ignored. + */ +@ThriftStruct +public final class PrestoThriftBoolean + implements PrestoThriftColumnData +{ + private final boolean[] nulls; + private final boolean[] booleans; + + @ThriftConstructor + public PrestoThriftBoolean(@Nullable boolean[] nulls, @Nullable boolean[] booleans) + { + checkArgument(sameSizeIfPresent(nulls, booleans), "nulls and values must be of the same size"); + this.nulls = nulls; + this.booleans = booleans; + } + + @Nullable + @ThriftField(value = 1, requiredness = OPTIONAL) + public boolean[] getNulls() + { + return nulls; + } + + @Nullable + @ThriftField(value = 2, requiredness = OPTIONAL) + public boolean[] getBooleans() + { + return booleans; + } + + @Override + public Block toBlock(Type desiredType) + { + checkArgument(BOOLEAN.equals(desiredType), "type doesn't match: %s", desiredType); + int numberOfRecords = numberOfRecords(); + return new ByteArrayBlock( + numberOfRecords, + nulls == null ? new boolean[numberOfRecords] : nulls, + booleans == null ? new byte[numberOfRecords] : toByteArray(booleans)); + } + + @Override + public int numberOfRecords() + { + if (nulls != null) { + return nulls.length; + } + if (booleans != null) { + return booleans.length; + } + return 0; + } + + @Override + public boolean equals(Object obj) + { + if (this == obj) { + return true; + } + if (obj == null || getClass() != obj.getClass()) { + return false; + } + PrestoThriftBoolean other = (PrestoThriftBoolean) obj; + return Arrays.equals(this.nulls, other.nulls) && + Arrays.equals(this.booleans, other.booleans); + } + + @Override + public int hashCode() + { + return Objects.hash(Arrays.hashCode(nulls), Arrays.hashCode(booleans)); + } + + @Override + public String toString() + { + return toStringHelper(this) + .add("numberOfRecords", numberOfRecords()) + .toString(); + } + + public static PrestoThriftBlock fromBlock(Block block) + { + int positions = block.getPositionCount(); + if (positions == 0) { + return booleanData(new PrestoThriftBoolean(null, null)); + } + boolean[] nulls = null; + boolean[] booleans = null; + for (int position = 0; position < positions; position++) { + if (block.isNull(position)) { + if (nulls == null) { + nulls = new boolean[positions]; + } + nulls[position] = true; + } + else { + if (booleans == null) { + booleans = new boolean[positions]; + } + booleans[position] = BOOLEAN.getBoolean(block, position); + } + } + return booleanData(new PrestoThriftBoolean(nulls, booleans)); + } + + private static boolean sameSizeIfPresent(boolean[] nulls, boolean[] booleans) + { + return nulls == null || booleans == null || nulls.length == booleans.length; + } + + private static byte[] toByteArray(boolean[] booleans) + { + byte[] bytes = new byte[booleans.length]; + for (int i = 0; i < booleans.length; i++) { + bytes[i] = booleans[i] ? (byte) 1 : (byte) 0; + } + return bytes; + } +} diff --git a/presto-thrift-connector-api/src/main/java/com/facebook/presto/connector/thrift/api/datatypes/PrestoThriftColumnData.java b/presto-thrift-connector-api/src/main/java/com/facebook/presto/connector/thrift/api/datatypes/PrestoThriftColumnData.java new file mode 100644 index 000000000000..eb687625e002 --- /dev/null +++ b/presto-thrift-connector-api/src/main/java/com/facebook/presto/connector/thrift/api/datatypes/PrestoThriftColumnData.java @@ -0,0 +1,24 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.connector.thrift.api.datatypes; + +import com.facebook.presto.spi.block.Block; +import com.facebook.presto.spi.type.Type; + +public interface PrestoThriftColumnData +{ + Block toBlock(Type desiredType); + + int numberOfRecords(); +} diff --git a/presto-thrift-connector-api/src/main/java/com/facebook/presto/connector/thrift/api/datatypes/PrestoThriftDate.java b/presto-thrift-connector-api/src/main/java/com/facebook/presto/connector/thrift/api/datatypes/PrestoThriftDate.java new file mode 100644 index 000000000000..cb93a885f036 --- /dev/null +++ b/presto-thrift-connector-api/src/main/java/com/facebook/presto/connector/thrift/api/datatypes/PrestoThriftDate.java @@ -0,0 +1,133 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.connector.thrift.api.datatypes; + +import com.facebook.presto.connector.thrift.api.PrestoThriftBlock; +import com.facebook.presto.spi.block.Block; +import com.facebook.presto.spi.block.IntArrayBlock; +import com.facebook.presto.spi.type.Type; +import com.facebook.swift.codec.ThriftConstructor; +import com.facebook.swift.codec.ThriftField; +import com.facebook.swift.codec.ThriftStruct; + +import javax.annotation.Nullable; + +import java.util.Arrays; +import java.util.Objects; + +import static com.facebook.presto.connector.thrift.api.PrestoThriftBlock.dateData; +import static com.facebook.presto.connector.thrift.api.datatypes.PrestoThriftTypeUtils.fromIntBasedBlock; +import static com.facebook.presto.spi.type.DateType.DATE; +import static com.facebook.swift.codec.ThriftField.Requiredness.OPTIONAL; +import static com.google.common.base.MoreObjects.toStringHelper; +import static com.google.common.base.Preconditions.checkArgument; + +/** + * Elements of {@code nulls} array determine if a value for a corresponding row is null. + * Elements of {@code dates} array are date values for each row represented as the number + * of days passed since 1970-01-01. + * If row is null then value is ignored. + */ +@ThriftStruct +public final class PrestoThriftDate + implements PrestoThriftColumnData +{ + private final boolean[] nulls; + private final int[] dates; + + @ThriftConstructor + public PrestoThriftDate( + @ThriftField(name = "nulls") @Nullable boolean[] nulls, + @ThriftField(name = "dates") @Nullable int[] dates) + { + checkArgument(sameSizeIfPresent(nulls, dates), "nulls and values must be of the same size"); + this.nulls = nulls; + this.dates = dates; + } + + @Nullable + @ThriftField(value = 1, requiredness = OPTIONAL) + public boolean[] getNulls() + { + return nulls; + } + + @Nullable + @ThriftField(value = 2, requiredness = OPTIONAL) + public int[] getDates() + { + return dates; + } + + @Override + public Block toBlock(Type desiredType) + { + checkArgument(DATE.equals(desiredType), "type doesn't match: %s", desiredType); + int numberOfRecords = numberOfRecords(); + return new IntArrayBlock( + numberOfRecords, + nulls == null ? new boolean[numberOfRecords] : nulls, + dates == null ? new int[numberOfRecords] : dates); + } + + @Override + public int numberOfRecords() + { + if (nulls != null) { + return nulls.length; + } + if (dates != null) { + return dates.length; + } + return 0; + } + + @Override + public boolean equals(Object obj) + { + if (this == obj) { + return true; + } + if (obj == null || getClass() != obj.getClass()) { + return false; + } + PrestoThriftDate other = (PrestoThriftDate) obj; + return Arrays.equals(this.nulls, other.nulls) && + Arrays.equals(this.dates, other.dates); + } + + @Override + public int hashCode() + { + return Objects.hash(Arrays.hashCode(nulls), Arrays.hashCode(dates)); + } + + @Override + public String toString() + { + return toStringHelper(this) + .add("numberOfRecords", numberOfRecords()) + .toString(); + } + + public static PrestoThriftBlock fromBlock(Block block) + { + return fromIntBasedBlock(block, DATE, (nulls, ints) -> dateData(new PrestoThriftDate(nulls, ints))); + } + + private static boolean sameSizeIfPresent(boolean[] nulls, int[] dates) + { + return nulls == null || dates == null || nulls.length == dates.length; + } +} diff --git a/presto-thrift-connector-api/src/main/java/com/facebook/presto/connector/thrift/api/datatypes/PrestoThriftDouble.java b/presto-thrift-connector-api/src/main/java/com/facebook/presto/connector/thrift/api/datatypes/PrestoThriftDouble.java new file mode 100644 index 000000000000..489515d8a470 --- /dev/null +++ b/presto-thrift-connector-api/src/main/java/com/facebook/presto/connector/thrift/api/datatypes/PrestoThriftDouble.java @@ -0,0 +1,156 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.connector.thrift.api.datatypes; + +import com.facebook.presto.connector.thrift.api.PrestoThriftBlock; +import com.facebook.presto.spi.block.Block; +import com.facebook.presto.spi.block.LongArrayBlock; +import com.facebook.presto.spi.type.Type; +import com.facebook.swift.codec.ThriftConstructor; +import com.facebook.swift.codec.ThriftField; +import com.facebook.swift.codec.ThriftStruct; + +import javax.annotation.Nullable; + +import java.util.Arrays; +import java.util.Objects; + +import static com.facebook.presto.connector.thrift.api.PrestoThriftBlock.booleanData; +import static com.facebook.presto.connector.thrift.api.PrestoThriftBlock.doubleData; +import static com.facebook.presto.spi.type.DoubleType.DOUBLE; +import static com.facebook.swift.codec.ThriftField.Requiredness.OPTIONAL; +import static com.google.common.base.MoreObjects.toStringHelper; +import static com.google.common.base.Preconditions.checkArgument; +import static java.lang.Double.doubleToLongBits; + +/** + * Elements of {@code nulls} array determine if a value for a corresponding row is null. + * Elements of {@code doubles} array are values for each row. If row is null then value is ignored. + */ +@ThriftStruct +public final class PrestoThriftDouble + implements PrestoThriftColumnData +{ + private final boolean[] nulls; + private final double[] doubles; + + @ThriftConstructor + public PrestoThriftDouble(@Nullable boolean[] nulls, @Nullable double[] doubles) + { + checkArgument(sameSizeIfPresent(nulls, doubles), "nulls and values must be of the same size"); + this.nulls = nulls; + this.doubles = doubles; + } + + @Nullable + @ThriftField(value = 1, requiredness = OPTIONAL) + public boolean[] getNulls() + { + return nulls; + } + + @Nullable + @ThriftField(value = 2, requiredness = OPTIONAL) + public double[] getDoubles() + { + return doubles; + } + + @Override + public Block toBlock(Type desiredType) + { + checkArgument(DOUBLE.equals(desiredType), "type doesn't match: %s", desiredType); + int numberOfRecords = numberOfRecords(); + long[] longs = new long[numberOfRecords]; + if (doubles != null) { + for (int i = 0; i < numberOfRecords; i++) { + longs[i] = doubleToLongBits(doubles[i]); + } + } + return new LongArrayBlock( + numberOfRecords, + nulls == null ? new boolean[numberOfRecords] : nulls, + longs); + } + + @Override + public int numberOfRecords() + { + if (nulls != null) { + return nulls.length; + } + if (doubles != null) { + return doubles.length; + } + return 0; + } + + @Override + public boolean equals(Object obj) + { + if (this == obj) { + return true; + } + if (obj == null || getClass() != obj.getClass()) { + return false; + } + PrestoThriftDouble other = (PrestoThriftDouble) obj; + return Arrays.equals(this.nulls, other.nulls) && + Arrays.equals(this.doubles, other.doubles); + } + + @Override + public int hashCode() + { + return Objects.hash(Arrays.hashCode(nulls), Arrays.hashCode(doubles)); + } + + @Override + public String toString() + { + return toStringHelper(this) + .add("numberOfRecords", numberOfRecords()) + .toString(); + } + + public static PrestoThriftBlock fromBlock(Block block) + { + int positions = block.getPositionCount(); + if (positions == 0) { + return booleanData(new PrestoThriftBoolean(null, null)); + } + boolean[] nulls = null; + double[] doubles = null; + for (int position = 0; position < positions; position++) { + if (block.isNull(position)) { + if (nulls == null) { + nulls = new boolean[positions]; + } + nulls[position] = true; + } + else { + if (doubles == null) { + doubles = new double[positions]; + } + doubles[position] = DOUBLE.getDouble(block, position); + } + } + return doubleData(new PrestoThriftDouble(nulls, doubles)); + } + + private static boolean sameSizeIfPresent(boolean[] nulls, double[] doubles) + { + return nulls == null || doubles == null || nulls.length == doubles.length; + } +} diff --git a/presto-thrift-connector-api/src/main/java/com/facebook/presto/connector/thrift/api/datatypes/PrestoThriftHyperLogLog.java b/presto-thrift-connector-api/src/main/java/com/facebook/presto/connector/thrift/api/datatypes/PrestoThriftHyperLogLog.java new file mode 100644 index 000000000000..308b285ebe1c --- /dev/null +++ b/presto-thrift-connector-api/src/main/java/com/facebook/presto/connector/thrift/api/datatypes/PrestoThriftHyperLogLog.java @@ -0,0 +1,123 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.connector.thrift.api.datatypes; + +import com.facebook.presto.connector.thrift.api.PrestoThriftBlock; +import com.facebook.presto.spi.block.Block; +import com.facebook.presto.spi.type.Type; +import com.facebook.swift.codec.ThriftConstructor; +import com.facebook.swift.codec.ThriftField; +import com.facebook.swift.codec.ThriftStruct; + +import javax.annotation.Nullable; + +import java.util.Objects; + +import static com.facebook.presto.connector.thrift.api.PrestoThriftBlock.hyperLogLogData; +import static com.facebook.presto.connector.thrift.api.datatypes.SliceData.fromSliceBasedBlock; +import static com.facebook.presto.spi.type.HyperLogLogType.HYPER_LOG_LOG; +import static com.facebook.swift.codec.ThriftField.Requiredness.OPTIONAL; +import static com.google.common.base.MoreObjects.toStringHelper; +import static com.google.common.base.Preconditions.checkArgument; + +/** + * Elements of {@code nulls} array determine if a value for a corresponding row is null. + * Each elements of {@code sizes} array contains the length in bytes for the corresponding element. + * If row is null then the corresponding element in {@code sizes} is ignored. + * {@code bytes} array contains encoded byte values for HyperLogLog representation as defined in + * Airlift specification: href="https://github.com/airlift/airlift/blob/master/stats/docs/hll.md + * Values for all rows are written to {@code bytes} array one after another. + * The total number of bytes must be equal to the sum of all sizes. + */ +@ThriftStruct +public final class PrestoThriftHyperLogLog + implements PrestoThriftColumnData +{ + private final SliceData sliceType; + + @ThriftConstructor + public PrestoThriftHyperLogLog( + @ThriftField(name = "nulls") @Nullable boolean[] nulls, + @ThriftField(name = "sizes") @Nullable int[] sizes, + @ThriftField(name = "bytes") @Nullable byte[] bytes) + { + this.sliceType = new SliceData(nulls, sizes, bytes); + } + + @Nullable + @ThriftField(value = 1, requiredness = OPTIONAL) + public boolean[] getNulls() + { + return sliceType.getNulls(); + } + + @Nullable + @ThriftField(value = 2, requiredness = OPTIONAL) + public int[] getSizes() + { + return sliceType.getSizes(); + } + + @Nullable + @ThriftField(value = 3, requiredness = OPTIONAL) + public byte[] getBytes() + { + return sliceType.getBytes(); + } + + @Override + public Block toBlock(Type desiredType) + { + checkArgument(HYPER_LOG_LOG.equals(desiredType), "type doesn't match: %s", desiredType); + return sliceType.toBlock(desiredType); + } + + @Override + public int numberOfRecords() + { + return sliceType.numberOfRecords(); + } + + @Override + public boolean equals(Object obj) + { + if (this == obj) { + return true; + } + if (obj == null || getClass() != obj.getClass()) { + return false; + } + PrestoThriftHyperLogLog other = (PrestoThriftHyperLogLog) obj; + return Objects.equals(this.sliceType, other.sliceType); + } + + @Override + public int hashCode() + { + return sliceType.hashCode(); + } + + @Override + public String toString() + { + return toStringHelper(this) + .add("numberOfRecords", numberOfRecords()) + .toString(); + } + + public static PrestoThriftBlock fromBlock(Block block) + { + return fromSliceBasedBlock(block, HYPER_LOG_LOG, (nulls, sizes, bytes) -> hyperLogLogData(new PrestoThriftHyperLogLog(nulls, sizes, bytes))); + } +} diff --git a/presto-thrift-connector-api/src/main/java/com/facebook/presto/connector/thrift/api/datatypes/PrestoThriftInteger.java b/presto-thrift-connector-api/src/main/java/com/facebook/presto/connector/thrift/api/datatypes/PrestoThriftInteger.java new file mode 100644 index 000000000000..8be9477330fa --- /dev/null +++ b/presto-thrift-connector-api/src/main/java/com/facebook/presto/connector/thrift/api/datatypes/PrestoThriftInteger.java @@ -0,0 +1,131 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.connector.thrift.api.datatypes; + +import com.facebook.presto.connector.thrift.api.PrestoThriftBlock; +import com.facebook.presto.spi.block.Block; +import com.facebook.presto.spi.block.IntArrayBlock; +import com.facebook.presto.spi.type.Type; +import com.facebook.swift.codec.ThriftConstructor; +import com.facebook.swift.codec.ThriftField; +import com.facebook.swift.codec.ThriftStruct; + +import javax.annotation.Nullable; + +import java.util.Arrays; +import java.util.Objects; + +import static com.facebook.presto.connector.thrift.api.PrestoThriftBlock.integerData; +import static com.facebook.presto.connector.thrift.api.datatypes.PrestoThriftTypeUtils.fromIntBasedBlock; +import static com.facebook.presto.spi.type.IntegerType.INTEGER; +import static com.facebook.swift.codec.ThriftField.Requiredness.OPTIONAL; +import static com.google.common.base.MoreObjects.toStringHelper; +import static com.google.common.base.Preconditions.checkArgument; + +/** + * Elements of {@code nulls} array determine if a value for a corresponding row is null. + * Elements of {@code ints} array are values for each row. If row is null then value is ignored. + */ +@ThriftStruct +public final class PrestoThriftInteger + implements PrestoThriftColumnData +{ + private final boolean[] nulls; + private final int[] ints; + + @ThriftConstructor + public PrestoThriftInteger( + @ThriftField(name = "nulls") @Nullable boolean[] nulls, + @ThriftField(name = "ints") @Nullable int[] ints) + { + checkArgument(sameSizeIfPresent(nulls, ints), "nulls and values must be of the same size"); + this.nulls = nulls; + this.ints = ints; + } + + @Nullable + @ThriftField(value = 1, requiredness = OPTIONAL) + public boolean[] getNulls() + { + return nulls; + } + + @Nullable + @ThriftField(value = 2, requiredness = OPTIONAL) + public int[] getInts() + { + return ints; + } + + @Override + public Block toBlock(Type desiredType) + { + checkArgument(INTEGER.equals(desiredType), "type doesn't match: %s", desiredType); + int numberOfRecords = numberOfRecords(); + return new IntArrayBlock( + numberOfRecords, + nulls == null ? new boolean[numberOfRecords] : nulls, + ints == null ? new int[numberOfRecords] : ints); + } + + @Override + public int numberOfRecords() + { + if (nulls != null) { + return nulls.length; + } + if (ints != null) { + return ints.length; + } + return 0; + } + + @Override + public boolean equals(Object obj) + { + if (this == obj) { + return true; + } + if (obj == null || getClass() != obj.getClass()) { + return false; + } + PrestoThriftInteger other = (PrestoThriftInteger) obj; + return Arrays.equals(this.nulls, other.nulls) && + Arrays.equals(this.ints, other.ints); + } + + @Override + public int hashCode() + { + return Objects.hash(Arrays.hashCode(nulls), Arrays.hashCode(ints)); + } + + @Override + public String toString() + { + return toStringHelper(this) + .add("numberOfRecords", numberOfRecords()) + .toString(); + } + + public static PrestoThriftBlock fromBlock(Block block) + { + return fromIntBasedBlock(block, INTEGER, (nulls, ints) -> integerData(new PrestoThriftInteger(nulls, ints))); + } + + private static boolean sameSizeIfPresent(boolean[] nulls, int[] ints) + { + return nulls == null || ints == null || nulls.length == ints.length; + } +} diff --git a/presto-thrift-connector-api/src/main/java/com/facebook/presto/connector/thrift/api/datatypes/PrestoThriftJson.java b/presto-thrift-connector-api/src/main/java/com/facebook/presto/connector/thrift/api/datatypes/PrestoThriftJson.java new file mode 100644 index 000000000000..1e173fc8e18e --- /dev/null +++ b/presto-thrift-connector-api/src/main/java/com/facebook/presto/connector/thrift/api/datatypes/PrestoThriftJson.java @@ -0,0 +1,119 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.connector.thrift.api.datatypes; + +import com.facebook.presto.connector.thrift.api.PrestoThriftBlock; +import com.facebook.presto.spi.block.Block; +import com.facebook.presto.spi.type.Type; +import com.facebook.swift.codec.ThriftConstructor; +import com.facebook.swift.codec.ThriftField; +import com.facebook.swift.codec.ThriftStruct; + +import javax.annotation.Nullable; + +import java.util.Objects; + +import static com.facebook.presto.connector.thrift.api.PrestoThriftBlock.jsonData; +import static com.facebook.presto.connector.thrift.api.datatypes.SliceData.fromSliceBasedBlock; +import static com.facebook.swift.codec.ThriftField.Requiredness.OPTIONAL; +import static com.google.common.base.MoreObjects.toStringHelper; + +/** + * Elements of {@code nulls} array determine if a value for a corresponding row is null. + * Each elements of {@code sizes} array contains the length in bytes for the corresponding element. + * If row is null then the corresponding element in {@code sizes} is ignored. + * {@code bytes} array contains uft8 encoded byte values for string representation of json. + * Values for all rows are written to {@code bytes} array one after another. + * The total number of bytes must be equal to the sum of all sizes. + */ +@ThriftStruct +public final class PrestoThriftJson + implements PrestoThriftColumnData +{ + private final SliceData sliceType; + + @ThriftConstructor + public PrestoThriftJson( + @ThriftField(name = "nulls") @Nullable boolean[] nulls, + @ThriftField(name = "sizes") @Nullable int[] sizes, + @ThriftField(name = "bytes") @Nullable byte[] bytes) + { + this.sliceType = new SliceData(nulls, sizes, bytes); + } + + @Nullable + @ThriftField(value = 1, requiredness = OPTIONAL) + public boolean[] getNulls() + { + return sliceType.getNulls(); + } + + @Nullable + @ThriftField(value = 2, requiredness = OPTIONAL) + public int[] getSizes() + { + return sliceType.getSizes(); + } + + @Nullable + @ThriftField(value = 3, requiredness = OPTIONAL) + public byte[] getBytes() + { + return sliceType.getBytes(); + } + + @Override + public Block toBlock(Type desiredType) + { + return sliceType.toBlock(desiredType); + } + + @Override + public int numberOfRecords() + { + return sliceType.numberOfRecords(); + } + + @Override + public boolean equals(Object obj) + { + if (this == obj) { + return true; + } + if (obj == null || getClass() != obj.getClass()) { + return false; + } + PrestoThriftJson other = (PrestoThriftJson) obj; + return Objects.equals(this.sliceType, other.sliceType); + } + + @Override + public int hashCode() + { + return sliceType.hashCode(); + } + + @Override + public String toString() + { + return toStringHelper(this) + .add("numberOfRecords", numberOfRecords()) + .toString(); + } + + public static PrestoThriftBlock fromBlock(Block block, Type type) + { + return fromSliceBasedBlock(block, type, (nulls, sizes, bytes) -> jsonData(new PrestoThriftJson(nulls, sizes, bytes))); + } +} diff --git a/presto-thrift-connector-api/src/main/java/com/facebook/presto/connector/thrift/api/datatypes/PrestoThriftTimestamp.java b/presto-thrift-connector-api/src/main/java/com/facebook/presto/connector/thrift/api/datatypes/PrestoThriftTimestamp.java new file mode 100644 index 000000000000..d39c46e0f137 --- /dev/null +++ b/presto-thrift-connector-api/src/main/java/com/facebook/presto/connector/thrift/api/datatypes/PrestoThriftTimestamp.java @@ -0,0 +1,133 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.connector.thrift.api.datatypes; + +import com.facebook.presto.connector.thrift.api.PrestoThriftBlock; +import com.facebook.presto.spi.block.Block; +import com.facebook.presto.spi.block.LongArrayBlock; +import com.facebook.presto.spi.type.Type; +import com.facebook.swift.codec.ThriftConstructor; +import com.facebook.swift.codec.ThriftField; +import com.facebook.swift.codec.ThriftStruct; + +import javax.annotation.Nullable; + +import java.util.Arrays; +import java.util.Objects; + +import static com.facebook.presto.connector.thrift.api.PrestoThriftBlock.timestampData; +import static com.facebook.presto.connector.thrift.api.datatypes.PrestoThriftTypeUtils.fromLongBasedBlock; +import static com.facebook.presto.spi.type.TimestampType.TIMESTAMP; +import static com.facebook.swift.codec.ThriftField.Requiredness.OPTIONAL; +import static com.google.common.base.MoreObjects.toStringHelper; +import static com.google.common.base.Preconditions.checkArgument; + +/** + * Elements of {@code nulls} array determine if a value for a corresponding row is null. + * Elements of {@code timestamps} array are values for each row represented as the number + * of milliseconds passed since 1970-01-01T00:00:00 UTC. + * If row is null then value is ignored. + */ +@ThriftStruct +public final class PrestoThriftTimestamp + implements PrestoThriftColumnData +{ + private final boolean[] nulls; + private final long[] timestamps; + + @ThriftConstructor + public PrestoThriftTimestamp( + @ThriftField(name = "nulls") @Nullable boolean[] nulls, + @ThriftField(name = "timestamps") @Nullable long[] timestamps) + { + checkArgument(sameSizeIfPresent(nulls, timestamps), "nulls and values must be of the same size"); + this.nulls = nulls; + this.timestamps = timestamps; + } + + @Nullable + @ThriftField(value = 1, requiredness = OPTIONAL) + public boolean[] getNulls() + { + return nulls; + } + + @Nullable + @ThriftField(value = 2, requiredness = OPTIONAL) + public long[] getTimestamps() + { + return timestamps; + } + + @Override + public Block toBlock(Type desiredType) + { + checkArgument(TIMESTAMP.equals(desiredType), "type doesn't match: %s", desiredType); + int numberOfRecords = numberOfRecords(); + return new LongArrayBlock( + numberOfRecords, + nulls == null ? new boolean[numberOfRecords] : nulls, + timestamps == null ? new long[numberOfRecords] : timestamps); + } + + @Override + public int numberOfRecords() + { + if (nulls != null) { + return nulls.length; + } + if (timestamps != null) { + return timestamps.length; + } + return 0; + } + + @Override + public boolean equals(Object obj) + { + if (this == obj) { + return true; + } + if (obj == null || getClass() != obj.getClass()) { + return false; + } + PrestoThriftTimestamp other = (PrestoThriftTimestamp) obj; + return Arrays.equals(this.nulls, other.nulls) && + Arrays.equals(this.timestamps, other.timestamps); + } + + @Override + public int hashCode() + { + return Objects.hash(Arrays.hashCode(nulls), Arrays.hashCode(timestamps)); + } + + @Override + public String toString() + { + return toStringHelper(this) + .add("numberOfRecords", numberOfRecords()) + .toString(); + } + + public static PrestoThriftBlock fromBlock(Block block) + { + return fromLongBasedBlock(block, TIMESTAMP, (nulls, longs) -> timestampData(new PrestoThriftTimestamp(nulls, longs))); + } + + private static boolean sameSizeIfPresent(boolean[] nulls, long[] timestamps) + { + return nulls == null || timestamps == null || nulls.length == timestamps.length; + } +} diff --git a/presto-thrift-connector-api/src/main/java/com/facebook/presto/connector/thrift/api/datatypes/PrestoThriftTypeUtils.java b/presto-thrift-connector-api/src/main/java/com/facebook/presto/connector/thrift/api/datatypes/PrestoThriftTypeUtils.java new file mode 100644 index 000000000000..5825334e0863 --- /dev/null +++ b/presto-thrift-connector-api/src/main/java/com/facebook/presto/connector/thrift/api/datatypes/PrestoThriftTypeUtils.java @@ -0,0 +1,117 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.connector.thrift.api.datatypes; + +import com.facebook.presto.connector.thrift.api.PrestoThriftBlock; +import com.facebook.presto.spi.block.Block; +import com.facebook.presto.spi.type.Type; + +import java.util.function.BiFunction; + +final class PrestoThriftTypeUtils +{ + private PrestoThriftTypeUtils() + { + } + + public static PrestoThriftBlock fromLongBasedBlock(Block block, Type type, BiFunction result) + { + int positions = block.getPositionCount(); + if (positions == 0) { + return result.apply(null, null); + } + boolean[] nulls = null; + long[] longs = null; + for (int position = 0; position < positions; position++) { + if (block.isNull(position)) { + if (nulls == null) { + nulls = new boolean[positions]; + } + nulls[position] = true; + } + else { + if (longs == null) { + longs = new long[positions]; + } + longs[position] = type.getLong(block, position); + } + } + return result.apply(nulls, longs); + } + + public static PrestoThriftBlock fromIntBasedBlock(Block block, Type type, BiFunction result) + { + int positions = block.getPositionCount(); + if (positions == 0) { + return result.apply(null, null); + } + boolean[] nulls = null; + int[] ints = null; + for (int position = 0; position < positions; position++) { + if (block.isNull(position)) { + if (nulls == null) { + nulls = new boolean[positions]; + } + nulls[position] = true; + } + else { + if (ints == null) { + ints = new int[positions]; + } + ints[position] = (int) type.getLong(block, position); + } + } + return result.apply(nulls, ints); + } + + public static int totalSize(boolean[] nulls, int[] sizes) + { + int numberOfRecords; + if (nulls != null) { + numberOfRecords = nulls.length; + } + else if (sizes != null) { + numberOfRecords = sizes.length; + } + else { + numberOfRecords = 0; + } + int total = 0; + for (int i = 0; i < numberOfRecords; i++) { + if (nulls == null || !nulls[i]) { + total += sizes[i]; + } + } + return total; + } + + public static int[] calculateOffsets(int[] sizes, boolean[] nulls, int totalRecords) + { + if (sizes == null) { + return new int[totalRecords + 1]; + } + int[] offsets = new int[totalRecords + 1]; + offsets[0] = 0; + for (int i = 0; i < totalRecords; i++) { + int size = nulls != null && nulls[i] ? 0 : sizes[i]; + offsets[i + 1] = offsets[i] + size; + } + return offsets; + } + + public static boolean sameSizeIfPresent(boolean[] nulls, int[] sizes) + { + return nulls == null || sizes == null || nulls.length == sizes.length; + } +} diff --git a/presto-thrift-connector-api/src/main/java/com/facebook/presto/connector/thrift/api/datatypes/PrestoThriftVarchar.java b/presto-thrift-connector-api/src/main/java/com/facebook/presto/connector/thrift/api/datatypes/PrestoThriftVarchar.java new file mode 100644 index 000000000000..063d9360bc53 --- /dev/null +++ b/presto-thrift-connector-api/src/main/java/com/facebook/presto/connector/thrift/api/datatypes/PrestoThriftVarchar.java @@ -0,0 +1,122 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.connector.thrift.api.datatypes; + +import com.facebook.presto.connector.thrift.api.PrestoThriftBlock; +import com.facebook.presto.spi.block.Block; +import com.facebook.presto.spi.type.Type; +import com.facebook.presto.spi.type.VarcharType; +import com.facebook.swift.codec.ThriftConstructor; +import com.facebook.swift.codec.ThriftField; +import com.facebook.swift.codec.ThriftStruct; + +import javax.annotation.Nullable; + +import java.util.Objects; + +import static com.facebook.presto.connector.thrift.api.PrestoThriftBlock.varcharData; +import static com.facebook.presto.connector.thrift.api.datatypes.SliceData.fromSliceBasedBlock; +import static com.facebook.swift.codec.ThriftField.Requiredness.OPTIONAL; +import static com.google.common.base.MoreObjects.toStringHelper; +import static com.google.common.base.Preconditions.checkArgument; + +/** + * Elements of {@code nulls} array determine if a value for a corresponding row is null. + * Each elements of {@code sizes} array contains the length in bytes for the corresponding element. + * If row is null then the corresponding element in {@code sizes} is ignored. + * {@code bytes} array contains uft8 encoded byte values. + * Values for all rows are written to {@code bytes} array one after another. + * The total number of bytes must be equal to the sum of all sizes. + */ +@ThriftStruct +public final class PrestoThriftVarchar + implements PrestoThriftColumnData +{ + private final SliceData sliceType; + + @ThriftConstructor + public PrestoThriftVarchar( + @ThriftField(name = "nulls") @Nullable boolean[] nulls, + @ThriftField(name = "sizes") @Nullable int[] sizes, + @ThriftField(name = "bytes") @Nullable byte[] bytes) + { + this.sliceType = new SliceData(nulls, sizes, bytes); + } + + @Nullable + @ThriftField(value = 1, requiredness = OPTIONAL) + public boolean[] getNulls() + { + return sliceType.getNulls(); + } + + @Nullable + @ThriftField(value = 2, requiredness = OPTIONAL) + public int[] getSizes() + { + return sliceType.getSizes(); + } + + @Nullable + @ThriftField(value = 3, requiredness = OPTIONAL) + public byte[] getBytes() + { + return sliceType.getBytes(); + } + + @Override + public Block toBlock(Type desiredType) + { + checkArgument(desiredType.getClass() == VarcharType.class, "type doesn't match: %s", desiredType); + return sliceType.toBlock(desiredType); + } + + @Override + public int numberOfRecords() + { + return sliceType.numberOfRecords(); + } + + @Override + public boolean equals(Object obj) + { + if (this == obj) { + return true; + } + if (obj == null || getClass() != obj.getClass()) { + return false; + } + PrestoThriftVarchar other = (PrestoThriftVarchar) obj; + return Objects.equals(this.sliceType, other.sliceType); + } + + @Override + public int hashCode() + { + return sliceType.hashCode(); + } + + @Override + public String toString() + { + return toStringHelper(this) + .add("numberOfRecords", numberOfRecords()) + .toString(); + } + + public static PrestoThriftBlock fromBlock(Block block, Type type) + { + return fromSliceBasedBlock(block, type, (nulls, sizes, bytes) -> varcharData(new PrestoThriftVarchar(nulls, sizes, bytes))); + } +} diff --git a/presto-thrift-connector-api/src/main/java/com/facebook/presto/connector/thrift/api/datatypes/SliceData.java b/presto-thrift-connector-api/src/main/java/com/facebook/presto/connector/thrift/api/datatypes/SliceData.java new file mode 100644 index 000000000000..20aedd9f4701 --- /dev/null +++ b/presto-thrift-connector-api/src/main/java/com/facebook/presto/connector/thrift/api/datatypes/SliceData.java @@ -0,0 +1,176 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.connector.thrift.api.datatypes; + +import com.facebook.presto.connector.thrift.api.PrestoThriftBlock; +import com.facebook.presto.spi.block.Block; +import com.facebook.presto.spi.block.VariableWidthBlock; +import com.facebook.presto.spi.type.Type; +import io.airlift.slice.Slice; +import io.airlift.slice.Slices; + +import javax.annotation.Nullable; + +import java.util.Arrays; +import java.util.Objects; + +import static com.facebook.presto.connector.thrift.api.datatypes.PrestoThriftTypeUtils.calculateOffsets; +import static com.facebook.presto.connector.thrift.api.datatypes.PrestoThriftTypeUtils.sameSizeIfPresent; +import static com.facebook.presto.connector.thrift.api.datatypes.PrestoThriftTypeUtils.totalSize; +import static com.google.common.base.MoreObjects.toStringHelper; +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Preconditions.checkState; + +final class SliceData + implements PrestoThriftColumnData +{ + private final boolean[] nulls; + private final int[] sizes; + private final byte[] bytes; + + public SliceData(@Nullable boolean[] nulls, @Nullable int[] sizes, @Nullable byte[] bytes) + { + checkArgument(sameSizeIfPresent(nulls, sizes), "nulls and values must be of the same size"); + checkArgument(totalSize(nulls, sizes) == (bytes != null ? bytes.length : 0), "total bytes size doesn't match expected size"); + this.nulls = nulls; + this.sizes = sizes; + this.bytes = bytes; + } + + @Nullable + public boolean[] getNulls() + { + return nulls; + } + + @Nullable + public int[] getSizes() + { + return sizes; + } + + @Nullable + public byte[] getBytes() + { + return bytes; + } + + @Override + public Block toBlock(Type desiredType) + { + checkArgument(desiredType.getJavaType() == Slice.class, "type doesn't match: %s", desiredType); + Slice values = bytes == null ? Slices.EMPTY_SLICE : Slices.wrappedBuffer(bytes); + int numberOfRecords = numberOfRecords(); + return new VariableWidthBlock( + numberOfRecords, + values, + calculateOffsets(sizes, nulls, numberOfRecords), + nulls == null ? new boolean[numberOfRecords] : nulls); + } + + @Override + public int numberOfRecords() + { + if (nulls != null) { + return nulls.length; + } + if (sizes != null) { + return sizes.length; + } + return 0; + } + + @Override + public boolean equals(Object obj) + { + if (this == obj) { + return true; + } + if (obj == null || getClass() != obj.getClass()) { + return false; + } + SliceData other = (SliceData) obj; + return Arrays.equals(this.nulls, other.nulls) && + Arrays.equals(this.sizes, other.sizes) && + Arrays.equals(this.bytes, other.bytes); + } + + @Override + public int hashCode() + { + return Objects.hash(Arrays.hashCode(nulls), Arrays.hashCode(sizes), Arrays.hashCode(bytes)); + } + + @Override + public String toString() + { + return toStringHelper(this) + .add("numberOfRecords", numberOfRecords()) + .toString(); + } + + public static PrestoThriftBlock fromSliceBasedBlock(Block block, Type type, CreateSliceThriftBlockFunction create) + { + int positions = block.getPositionCount(); + if (positions == 0) { + return create.apply(null, null, null); + } + boolean[] nulls = null; + int[] sizes = null; + byte[] bytes = null; + int bytesIndex = 0; + for (int position = 0; position < positions; position++) { + if (block.isNull(position)) { + if (nulls == null) { + nulls = new boolean[positions]; + } + nulls[position] = true; + } + else { + Slice value = type.getSlice(block, position); + if (sizes == null) { + sizes = new int[positions]; + int totalBytes = totalSliceBytes(block); + if (totalBytes > 0) { + bytes = new byte[totalBytes]; + } + } + int length = value.length(); + sizes[position] = length; + if (length > 0) { + checkState(bytes != null); + value.getBytes(0, bytes, bytesIndex, length); + bytesIndex += length; + } + } + } + checkState(bytes == null || bytesIndex == bytes.length); + return create.apply(nulls, sizes, bytes); + } + + private static int totalSliceBytes(Block block) + { + int totalBytes = 0; + int positions = block.getPositionCount(); + for (int position = 0; position < positions; position++) { + totalBytes += block.getSliceLength(position); + } + return totalBytes; + } + + public interface CreateSliceThriftBlockFunction + { + PrestoThriftBlock apply(boolean[] nulls, int[] sizes, byte[] bytes); + } +} diff --git a/presto-thrift-connector-api/src/main/java/com/facebook/presto/connector/thrift/api/valuesets/PrestoThriftAllOrNoneValueSet.java b/presto-thrift-connector-api/src/main/java/com/facebook/presto/connector/thrift/api/valuesets/PrestoThriftAllOrNoneValueSet.java new file mode 100644 index 000000000000..ce2c92768152 --- /dev/null +++ b/presto-thrift-connector-api/src/main/java/com/facebook/presto/connector/thrift/api/valuesets/PrestoThriftAllOrNoneValueSet.java @@ -0,0 +1,74 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.connector.thrift.api.valuesets; + +import com.facebook.presto.spi.predicate.AllOrNoneValueSet; +import com.facebook.swift.codec.ThriftConstructor; +import com.facebook.swift.codec.ThriftField; +import com.facebook.swift.codec.ThriftStruct; + +import static com.google.common.base.MoreObjects.toStringHelper; + +/** + * Set that either includes all values, or excludes all values. + */ +@ThriftStruct +public final class PrestoThriftAllOrNoneValueSet +{ + private final boolean all; + + @ThriftConstructor + public PrestoThriftAllOrNoneValueSet(boolean all) + { + this.all = all; + } + + @ThriftField(1) + public boolean isAll() + { + return all; + } + + @Override + public boolean equals(Object obj) + { + if (this == obj) { + return true; + } + if (obj == null || getClass() != obj.getClass()) { + return false; + } + PrestoThriftAllOrNoneValueSet other = (PrestoThriftAllOrNoneValueSet) obj; + return this.all == other.all; + } + + @Override + public int hashCode() + { + return Boolean.hashCode(all); + } + + @Override + public String toString() + { + return toStringHelper(this) + .add("all", all) + .toString(); + } + + public static PrestoThriftAllOrNoneValueSet fromAllOrNoneValueSet(AllOrNoneValueSet valueSet) + { + return new PrestoThriftAllOrNoneValueSet(valueSet.isAll()); + } +} diff --git a/presto-thrift-connector-api/src/main/java/com/facebook/presto/connector/thrift/api/valuesets/PrestoThriftEquatableValueSet.java b/presto-thrift-connector-api/src/main/java/com/facebook/presto/connector/thrift/api/valuesets/PrestoThriftEquatableValueSet.java new file mode 100644 index 000000000000..883af7996c14 --- /dev/null +++ b/presto-thrift-connector-api/src/main/java/com/facebook/presto/connector/thrift/api/valuesets/PrestoThriftEquatableValueSet.java @@ -0,0 +1,105 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.connector.thrift.api.valuesets; + +import com.facebook.presto.connector.thrift.api.PrestoThriftBlock; +import com.facebook.presto.spi.predicate.EquatableValueSet; +import com.facebook.presto.spi.predicate.EquatableValueSet.ValueEntry; +import com.facebook.presto.spi.type.Type; +import com.facebook.swift.codec.ThriftConstructor; +import com.facebook.swift.codec.ThriftField; +import com.facebook.swift.codec.ThriftStruct; + +import java.util.ArrayList; +import java.util.List; +import java.util.Objects; +import java.util.Set; + +import static com.facebook.presto.connector.thrift.api.PrestoThriftBlock.fromBlock; +import static com.google.common.base.MoreObjects.toStringHelper; +import static com.google.common.base.Preconditions.checkState; +import static java.util.Objects.requireNonNull; + +/** + * A set containing values that are uniquely identifiable. + * Assumes an infinite number of possible values. The values may be collectively included (aka whitelist) + * or collectively excluded (aka !whitelist). + * This structure is used with comparable, but not orderable types like "json", "map". + */ +@ThriftStruct +public final class PrestoThriftEquatableValueSet +{ + private final boolean whiteList; + private final List values; + + @ThriftConstructor + public PrestoThriftEquatableValueSet(boolean whiteList, List values) + { + this.whiteList = whiteList; + this.values = requireNonNull(values, "values are null"); + } + + @ThriftField(1) + public boolean isWhiteList() + { + return whiteList; + } + + @ThriftField(2) + public List getValues() + { + return values; + } + + @Override + public boolean equals(Object obj) + { + if (this == obj) { + return true; + } + if (obj == null || getClass() != obj.getClass()) { + return false; + } + PrestoThriftEquatableValueSet other = (PrestoThriftEquatableValueSet) obj; + return this.whiteList == other.whiteList && + Objects.equals(this.values, other.values); + } + + @Override + public int hashCode() + { + return Objects.hash(whiteList, values); + } + + @Override + public String toString() + { + return toStringHelper(this) + .add("whiteList", whiteList) + .add("values", values) + .toString(); + } + + public static PrestoThriftEquatableValueSet fromEquatableValueSet(EquatableValueSet valueSet) + { + Type type = valueSet.getType(); + Set values = valueSet.getEntries(); + List thriftValues = new ArrayList<>(values.size()); + for (ValueEntry value : values) { + checkState(type.equals(value.getType()), "ValueEntrySet has elements of different types: %s vs %s", type, value.getType()); + thriftValues.add(fromBlock(value.getBlock(), type)); + } + return new PrestoThriftEquatableValueSet(valueSet.isWhiteList(), thriftValues); + } +} diff --git a/presto-thrift-connector-api/src/main/java/com/facebook/presto/connector/thrift/api/valuesets/PrestoThriftRangeValueSet.java b/presto-thrift-connector-api/src/main/java/com/facebook/presto/connector/thrift/api/valuesets/PrestoThriftRangeValueSet.java new file mode 100644 index 000000000000..1c1475f85a5a --- /dev/null +++ b/presto-thrift-connector-api/src/main/java/com/facebook/presto/connector/thrift/api/valuesets/PrestoThriftRangeValueSet.java @@ -0,0 +1,259 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.connector.thrift.api.valuesets; + +import com.facebook.presto.connector.thrift.api.PrestoThriftBlock; +import com.facebook.presto.spi.predicate.Marker; +import com.facebook.presto.spi.predicate.Marker.Bound; +import com.facebook.presto.spi.predicate.Range; +import com.facebook.presto.spi.predicate.SortedRangeSet; +import com.facebook.swift.codec.ThriftConstructor; +import com.facebook.swift.codec.ThriftEnum; +import com.facebook.swift.codec.ThriftEnumValue; +import com.facebook.swift.codec.ThriftField; +import com.facebook.swift.codec.ThriftStruct; + +import javax.annotation.Nullable; + +import java.util.List; +import java.util.Objects; + +import static com.facebook.presto.connector.thrift.api.PrestoThriftBlock.fromBlock; +import static com.facebook.presto.connector.thrift.api.valuesets.PrestoThriftRangeValueSet.PrestoThriftBound.fromBound; +import static com.facebook.presto.connector.thrift.api.valuesets.PrestoThriftRangeValueSet.PrestoThriftMarker.fromMarker; +import static com.facebook.swift.codec.ThriftField.Requiredness.OPTIONAL; +import static com.google.common.base.MoreObjects.toStringHelper; +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.collect.ImmutableList.toImmutableList; +import static java.util.Objects.requireNonNull; + +/** + * A set containing zero or more Ranges of the same type over a continuous space of possible values. + * Ranges are coalesced into the most compact representation of non-overlapping Ranges. + * This structure is used with comparable and orderable types like bigint, integer, double, varchar, etc. + */ +@ThriftStruct +public final class PrestoThriftRangeValueSet +{ + private final List ranges; + + @ThriftConstructor + public PrestoThriftRangeValueSet(@ThriftField(name = "ranges") List ranges) + { + this.ranges = requireNonNull(ranges, "ranges is null"); + } + + @ThriftField(1) + public List getRanges() + { + return ranges; + } + + @Override + public boolean equals(Object obj) + { + if (this == obj) { + return true; + } + if (obj == null || getClass() != obj.getClass()) { + return false; + } + PrestoThriftRangeValueSet other = (PrestoThriftRangeValueSet) obj; + return Objects.equals(this.ranges, other.ranges); + } + + @Override + public int hashCode() + { + return Objects.hashCode(ranges); + } + + @Override + public String toString() + { + return toStringHelper(this) + .add("numberOfRanges", ranges.size()) + .toString(); + } + + public static PrestoThriftRangeValueSet fromSortedRangeSet(SortedRangeSet valueSet) + { + List ranges = valueSet.getOrderedRanges().stream() + .map(PrestoThriftRange::fromRange) + .collect(toImmutableList()); + return new PrestoThriftRangeValueSet(ranges); + } + + @ThriftEnum + public enum PrestoThriftBound + { + BELOW(1), // lower than the value, but infinitesimally close to the value + EXACTLY(2), // exactly the value + ABOVE(3); // higher than the value, but infinitesimally close to the value + + private final int value; + + PrestoThriftBound(int value) + { + this.value = value; + } + + @ThriftEnumValue + public int getValue() + { + return value; + } + + public static PrestoThriftBound fromBound(Bound bound) + { + switch (bound) { + case BELOW: + return BELOW; + case EXACTLY: + return EXACTLY; + case ABOVE: + return ABOVE; + default: + throw new IllegalArgumentException("Unknown bound: " + bound); + } + } + } + + /** + * LOWER UNBOUNDED is specified with an empty value and an ABOVE bound + * UPPER UNBOUNDED is specified with an empty value and a BELOW bound + */ + @ThriftStruct + public static final class PrestoThriftMarker + { + private final PrestoThriftBlock value; + private final PrestoThriftBound bound; + + @ThriftConstructor + public PrestoThriftMarker(@Nullable PrestoThriftBlock value, PrestoThriftBound bound) + { + checkArgument(value == null || value.numberOfRecords() == 1, "value must contain exactly one record when present"); + this.value = value; + this.bound = requireNonNull(bound, "bound is null"); + } + + @Nullable + @ThriftField(value = 1, requiredness = OPTIONAL) + public PrestoThriftBlock getValue() + { + return value; + } + + @ThriftField(2) + public PrestoThriftBound getBound() + { + return bound; + } + + @Override + public boolean equals(Object obj) + { + if (this == obj) { + return true; + } + if (obj == null || getClass() != obj.getClass()) { + return false; + } + PrestoThriftMarker other = (PrestoThriftMarker) obj; + return Objects.equals(this.value, other.value) && + Objects.equals(this.bound, other.bound); + } + + @Override + public int hashCode() + { + return Objects.hash(value, bound); + } + + @Override + public String toString() + { + return toStringHelper(this) + .add("value", value) + .add("bound", bound) + .toString(); + } + + public static PrestoThriftMarker fromMarker(Marker marker) + { + PrestoThriftBlock value = marker.getValueBlock().isPresent() ? fromBlock(marker.getValueBlock().get(), marker.getType()) : null; + return new PrestoThriftMarker(value, fromBound(marker.getBound())); + } + } + + @ThriftStruct + public static final class PrestoThriftRange + { + private final PrestoThriftMarker low; + private final PrestoThriftMarker high; + + @ThriftConstructor + public PrestoThriftRange(PrestoThriftMarker low, PrestoThriftMarker high) + { + this.low = requireNonNull(low, "low is null"); + this.high = requireNonNull(high, "high is null"); + } + + @ThriftField(1) + public PrestoThriftMarker getLow() + { + return low; + } + + @ThriftField(2) + public PrestoThriftMarker getHigh() + { + return high; + } + + @Override + public boolean equals(Object obj) + { + if (this == obj) { + return true; + } + if (obj == null || getClass() != obj.getClass()) { + return false; + } + PrestoThriftRange other = (PrestoThriftRange) obj; + return Objects.equals(this.low, other.low) && + Objects.equals(this.high, other.high); + } + + @Override + public int hashCode() + { + return Objects.hash(low, high); + } + + @Override + public String toString() + { + return toStringHelper(this) + .add("low", low) + .add("high", high) + .toString(); + } + + public static PrestoThriftRange fromRange(Range range) + { + return new PrestoThriftRange(fromMarker(range.getLow()), fromMarker(range.getHigh())); + } + } +} diff --git a/presto-thrift-connector-api/src/main/java/com/facebook/presto/connector/thrift/api/valuesets/PrestoThriftValueSet.java b/presto-thrift-connector-api/src/main/java/com/facebook/presto/connector/thrift/api/valuesets/PrestoThriftValueSet.java new file mode 100644 index 000000000000..32dce545f669 --- /dev/null +++ b/presto-thrift-connector-api/src/main/java/com/facebook/presto/connector/thrift/api/valuesets/PrestoThriftValueSet.java @@ -0,0 +1,149 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.connector.thrift.api.valuesets; + +import com.facebook.presto.spi.predicate.AllOrNoneValueSet; +import com.facebook.presto.spi.predicate.EquatableValueSet; +import com.facebook.presto.spi.predicate.SortedRangeSet; +import com.facebook.presto.spi.predicate.ValueSet; +import com.facebook.swift.codec.ThriftConstructor; +import com.facebook.swift.codec.ThriftField; +import com.facebook.swift.codec.ThriftStruct; + +import javax.annotation.Nullable; + +import java.util.Objects; + +import static com.facebook.presto.connector.thrift.api.valuesets.PrestoThriftAllOrNoneValueSet.fromAllOrNoneValueSet; +import static com.facebook.presto.connector.thrift.api.valuesets.PrestoThriftEquatableValueSet.fromEquatableValueSet; +import static com.facebook.presto.connector.thrift.api.valuesets.PrestoThriftRangeValueSet.fromSortedRangeSet; +import static com.facebook.swift.codec.ThriftField.Requiredness.OPTIONAL; +import static com.google.common.base.MoreObjects.toStringHelper; +import static com.google.common.base.Preconditions.checkArgument; + +@ThriftStruct +public final class PrestoThriftValueSet +{ + private final PrestoThriftAllOrNoneValueSet allOrNoneValueSet; + private final PrestoThriftEquatableValueSet equatableValueSet; + private final PrestoThriftRangeValueSet rangeValueSet; + + @ThriftConstructor + public PrestoThriftValueSet( + @Nullable PrestoThriftAllOrNoneValueSet allOrNoneValueSet, + @Nullable PrestoThriftEquatableValueSet equatableValueSet, + @Nullable PrestoThriftRangeValueSet rangeValueSet) + { + checkArgument(isExactlyOneNonNull(allOrNoneValueSet, equatableValueSet, rangeValueSet), "exactly one value set must be present"); + this.allOrNoneValueSet = allOrNoneValueSet; + this.equatableValueSet = equatableValueSet; + this.rangeValueSet = rangeValueSet; + } + + @Nullable + @ThriftField(value = 1, requiredness = OPTIONAL) + public PrestoThriftAllOrNoneValueSet getAllOrNoneValueSet() + { + return allOrNoneValueSet; + } + + @Nullable + @ThriftField(value = 2, requiredness = OPTIONAL) + public PrestoThriftEquatableValueSet getEquatableValueSet() + { + return equatableValueSet; + } + + @Nullable + @ThriftField(value = 3, requiredness = OPTIONAL) + public PrestoThriftRangeValueSet getRangeValueSet() + { + return rangeValueSet; + } + + @Override + public boolean equals(Object obj) + { + if (this == obj) { + return true; + } + if (obj == null || getClass() != obj.getClass()) { + return false; + } + PrestoThriftValueSet other = (PrestoThriftValueSet) obj; + return Objects.equals(this.allOrNoneValueSet, other.allOrNoneValueSet) && + Objects.equals(this.equatableValueSet, other.equatableValueSet) && + Objects.equals(this.rangeValueSet, other.rangeValueSet); + } + + @Override + public int hashCode() + { + return Objects.hash(allOrNoneValueSet, equatableValueSet, rangeValueSet); + } + + @Override + public String toString() + { + return toStringHelper(this) + .add("valueSet", firstNonNull(allOrNoneValueSet, equatableValueSet, rangeValueSet)) + .toString(); + } + + public static PrestoThriftValueSet fromValueSet(ValueSet valueSet) + { + if (valueSet.getClass() == AllOrNoneValueSet.class) { + return new PrestoThriftValueSet( + fromAllOrNoneValueSet((AllOrNoneValueSet) valueSet), + null, + null); + } + else if (valueSet.getClass() == EquatableValueSet.class) { + return new PrestoThriftValueSet( + null, + fromEquatableValueSet((EquatableValueSet) valueSet), + null); + } + else if (valueSet.getClass() == SortedRangeSet.class) { + return new PrestoThriftValueSet( + null, + null, + fromSortedRangeSet((SortedRangeSet) valueSet)); + } + else { + throw new IllegalArgumentException("Unknown implementation of a value set: " + valueSet.getClass()); + } + } + + private static boolean isExactlyOneNonNull(Object a, Object b, Object c) + { + return a != null && b == null && c == null || + a == null && b != null && c == null || + a == null && b == null && c != null; + } + + private static Object firstNonNull(Object a, Object b, Object c) + { + if (a != null) { + return a; + } + if (b != null) { + return b; + } + if (c != null) { + return c; + } + throw new IllegalArgumentException("All arguments are null"); + } +} diff --git a/presto-thrift-connector-api/src/test/java/com/facebook/presto/connector/thrift/api/TestNameValidationUtils.java b/presto-thrift-connector-api/src/test/java/com/facebook/presto/connector/thrift/api/TestNameValidationUtils.java new file mode 100644 index 000000000000..31b922425cb6 --- /dev/null +++ b/presto-thrift-connector-api/src/test/java/com/facebook/presto/connector/thrift/api/TestNameValidationUtils.java @@ -0,0 +1,36 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.connector.thrift.api; + +import org.testng.annotations.Test; + +import static com.facebook.presto.connector.thrift.api.NameValidationUtils.checkValidName; +import static org.testng.Assert.assertThrows; + +public class TestNameValidationUtils +{ + @Test + public void testCheckValidColumnName() + throws Exception + { + checkValidName("abc01_def2"); + assertThrows(() -> checkValidName(null)); + assertThrows(() -> checkValidName("")); + assertThrows(() -> checkValidName("Abc")); + assertThrows(() -> checkValidName("0abc")); + assertThrows(() -> checkValidName("_abc")); + assertThrows(() -> checkValidName("aBc")); + assertThrows(() -> checkValidName("ab-c")); + } +} diff --git a/presto-thrift-connector-api/src/test/java/com/facebook/presto/connector/thrift/api/TestPrestoThriftId.java b/presto-thrift-connector-api/src/test/java/com/facebook/presto/connector/thrift/api/TestPrestoThriftId.java new file mode 100644 index 000000000000..695a4e41d25e --- /dev/null +++ b/presto-thrift-connector-api/src/test/java/com/facebook/presto/connector/thrift/api/TestPrestoThriftId.java @@ -0,0 +1,45 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.connector.thrift.api; + +import org.testng.annotations.Test; + +import static com.facebook.presto.connector.thrift.api.PrestoThriftId.summarize; +import static org.testng.Assert.assertEquals; + +public class TestPrestoThriftId +{ + @Test + public void testSummarize() + throws Exception + { + assertEquals(summarize(bytes()), ""); + assertEquals(summarize(bytes(1)), "01"); + assertEquals(summarize(bytes(255, 254, 253, 252, 251, 250, 249)), "FFFEFDFCFBFAF9"); + assertEquals(summarize(bytes(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 249, 250, 251, 252, 253, 254, 255)), + "00010203040506070809F9FAFBFCFDFEFF"); + assertEquals(summarize(bytes(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 247, 248, 249, 250, 251, 252, 253, 254, 255)), + "0001020304050607..F8F9FAFBFCFDFEFF"); + } + + private static byte[] bytes(int... values) + { + int length = values.length; + byte[] result = new byte[length]; + for (int i = 0; i < length; i++) { + result[i] = (byte) values[i]; + } + return result; + } +} diff --git a/presto-thrift-connector-api/src/test/java/com/facebook/presto/connector/thrift/api/TestReadWrite.java b/presto-thrift-connector-api/src/test/java/com/facebook/presto/connector/thrift/api/TestReadWrite.java new file mode 100644 index 000000000000..b1153a56c394 --- /dev/null +++ b/presto-thrift-connector-api/src/test/java/com/facebook/presto/connector/thrift/api/TestReadWrite.java @@ -0,0 +1,469 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.connector.thrift.api; + +import com.facebook.presto.spi.Page; +import com.facebook.presto.spi.block.Block; +import com.facebook.presto.spi.block.BlockBuilder; +import com.facebook.presto.spi.block.BlockBuilderStatus; +import com.facebook.presto.spi.type.ArrayType; +import com.facebook.presto.spi.type.Type; +import com.facebook.presto.spi.type.VarcharType; +import com.google.common.collect.ImmutableList; +import io.airlift.slice.Slice; +import io.airlift.stats.cardinality.HyperLogLog; +import org.testng.annotations.Test; + +import java.util.ArrayList; +import java.util.Calendar; +import java.util.List; +import java.util.Random; +import java.util.concurrent.atomic.AtomicLong; + +import static com.facebook.presto.connector.thrift.api.PrestoThriftBlock.fromBlock; +import static com.facebook.presto.spi.type.BigintType.BIGINT; +import static com.facebook.presto.spi.type.BooleanType.BOOLEAN; +import static com.facebook.presto.spi.type.DateType.DATE; +import static com.facebook.presto.spi.type.DoubleType.DOUBLE; +import static com.facebook.presto.spi.type.HyperLogLogType.HYPER_LOG_LOG; +import static com.facebook.presto.spi.type.IntegerType.INTEGER; +import static com.facebook.presto.spi.type.TimestampType.TIMESTAMP; +import static com.facebook.presto.spi.type.VarcharType.createUnboundedVarcharType; +import static com.facebook.presto.spi.type.VarcharType.createVarcharType; +import static com.facebook.presto.type.JsonType.JSON; +import static com.google.common.collect.ImmutableList.toImmutableList; +import static java.lang.Math.toIntExact; +import static java.util.Objects.requireNonNull; +import static java.util.concurrent.TimeUnit.MILLISECONDS; +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertNotNull; + +public class TestReadWrite +{ + private static final double NULL_FRACTION = 0.1; + private static final int MAX_VARCHAR_GENERATED_LENGTH = 64; + private static final char[] SYMBOLS; + private static final long MIN_GENERATED_TIMESTAMP; + private static final long MAX_GENERATED_TIMESTAMP; + private static final int MIN_GENERATED_DATE; + private static final int MAX_GENERATED_DATE; + private static final int MAX_GENERATED_JSON_KEY_LENGTH = 8; + private static final int HYPER_LOG_LOG_BUCKETS = 128; + private static final int MAX_HYPER_LOG_LOG_ELEMENTS = 32; + private static final int MAX_ARRAY_GENERATED_LENGTH = 64; + private final AtomicLong singleValueSeedGenerator = new AtomicLong(762103512L); + private final AtomicLong columnDataSeedGenerator = new AtomicLong(762103512L); + private final List columns = ImmutableList.of( + new IntegerColumn(), + new BigintColumn(), + new DoubleColumn(), + new VarcharColumn(createUnboundedVarcharType()), + new VarcharColumn(createVarcharType(MAX_VARCHAR_GENERATED_LENGTH / 2)), + new BooleanColumn(), + new DateColumn(), + new TimestampColumn(), + new JsonColumn(), + new HyperLogLogColumn(), + new BigintArrayColumn()); + + static { + char[] symbols = new char[2 * 26 + 10]; + int next = 0; + for (char ch = 'A'; ch <= 'Z'; ch++) { + symbols[next++] = ch; + } + for (char ch = 'a'; ch <= 'z'; ch++) { + symbols[next++] = ch; + } + for (char ch = '0'; ch <= '9'; ch++) { + symbols[next++] = ch; + } + SYMBOLS = symbols; + + Calendar calendar = Calendar.getInstance(); + + calendar.set(2000, Calendar.JANUARY, 1); + MIN_GENERATED_TIMESTAMP = calendar.getTimeInMillis(); + MIN_GENERATED_DATE = toIntExact(MILLISECONDS.toDays(MIN_GENERATED_TIMESTAMP)); + + calendar.set(2020, Calendar.DECEMBER, 31); + MAX_GENERATED_TIMESTAMP = calendar.getTimeInMillis(); + MAX_GENERATED_DATE = toIntExact(MILLISECONDS.toDays(MAX_GENERATED_TIMESTAMP)); + } + + @Test(invocationCount = 20) + public void testReadWriteSingleValue() + throws Exception + { + testReadWrite(new Random(singleValueSeedGenerator.incrementAndGet()), 1); + } + + @Test(invocationCount = 20) + public void testReadWriteColumnData() + throws Exception + { + Random random = new Random(columnDataSeedGenerator.incrementAndGet()); + int records = random.nextInt(10000) + 10000; + testReadWrite(random, records); + } + + private void testReadWrite(Random random, int records) + throws Exception + { + // generate columns data + List inputBlocks = new ArrayList<>(columns.size()); + for (ColumnDefinition column : columns) { + inputBlocks.add(generateColumn(column, random, records)); + } + + // convert column data to thrift ("write step") + List columnBlocks = new ArrayList<>(columns.size()); + for (int i = 0; i < columns.size(); i++) { + columnBlocks.add(fromBlock(inputBlocks.get(i), columns.get(i).getType())); + } + PrestoThriftPageResult batch = new PrestoThriftPageResult(columnBlocks, records, null); + + // convert thrift data to page/blocks ("read step") + Page page = batch.toPage(columns.stream().map(ColumnDefinition::getType).collect(toImmutableList())); + + // compare the result with original input + assertNotNull(page); + assertEquals(page.getChannelCount(), columns.size()); + for (int i = 0; i < columns.size(); i++) { + Block actual = page.getBlock(i); + Block expected = inputBlocks.get(i); + assertBlock(actual, expected, columns.get(i)); + } + } + + private static Block generateColumn(ColumnDefinition column, Random random, int records) + { + BlockBuilder builder = column.getType().createBlockBuilder(new BlockBuilderStatus(), records); + for (int i = 0; i < records; i++) { + if (random.nextDouble() < NULL_FRACTION) { + builder.appendNull(); + } + else { + column.writeNextRandomValue(random, builder); + } + } + return builder.build(); + } + + private static void assertBlock(Block actual, Block expected, ColumnDefinition columnDefinition) + { + assertEquals(actual.getPositionCount(), expected.getPositionCount()); + int positions = actual.getPositionCount(); + for (int i = 0; i < positions; i++) { + Object actualValue = columnDefinition.extractValue(actual, i); + Object expectedValue = columnDefinition.extractValue(expected, i); + assertEquals(actualValue, expectedValue); + } + } + + private static String nextString(Random random) + { + return nextString(random, MAX_VARCHAR_GENERATED_LENGTH); + } + + private static String nextString(Random random, int maxLength) + { + int size = random.nextInt(maxLength); + char[] result = new char[size]; + for (int i = 0; i < size; i++) { + result[i] = SYMBOLS[random.nextInt(SYMBOLS.length)]; + } + return new String(result); + } + + private static long nextTimestamp(Random random) + { + return MIN_GENERATED_TIMESTAMP + (long) (random.nextDouble() * (MAX_GENERATED_TIMESTAMP - MIN_GENERATED_TIMESTAMP)); + } + + private static int nextDate(Random random) + { + return MIN_GENERATED_DATE + random.nextInt(MAX_GENERATED_DATE - MIN_GENERATED_DATE); + } + + private static Slice nextHyperLogLog(Random random) + { + HyperLogLog hll = HyperLogLog.newInstance(HYPER_LOG_LOG_BUCKETS); + int size = random.nextInt(MAX_HYPER_LOG_LOG_ELEMENTS); + for (int i = 0; i < size; i++) { + hll.add(random.nextLong()); + } + return hll.serialize(); + } + + private static void generateBigintArray(Random random, BlockBuilder parentBuilder) + { + int numberOfElements = random.nextInt(MAX_ARRAY_GENERATED_LENGTH); + BlockBuilder builder = parentBuilder.beginBlockEntry(); + for (int i = 0; i < numberOfElements; i++) { + if (random.nextDouble() < NULL_FRACTION) { + builder.appendNull(); + } + else { + builder.writeLong(random.nextLong()); + } + } + parentBuilder.closeEntry(); + } + + private abstract static class ColumnDefinition + { + private final Type type; + + public ColumnDefinition(Type type) + { + this.type = requireNonNull(type, "type is null"); + } + + public Type getType() + { + return type; + } + + abstract Object extractValue(Block block, int position); + + abstract void writeNextRandomValue(Random random, BlockBuilder builder); + } + + private static final class IntegerColumn + extends ColumnDefinition + { + public IntegerColumn() + { + super(INTEGER); + } + + @Override + Object extractValue(Block block, int position) + { + return INTEGER.getLong(block, position); + } + + @Override + void writeNextRandomValue(Random random, BlockBuilder builder) + { + INTEGER.writeLong(builder, random.nextInt()); + } + } + + private static final class BigintColumn + extends ColumnDefinition + { + public BigintColumn() + { + super(BIGINT); + } + + @Override + Object extractValue(Block block, int position) + { + return BIGINT.getLong(block, position); + } + + @Override + void writeNextRandomValue(Random random, BlockBuilder builder) + { + BIGINT.writeLong(builder, random.nextLong()); + } + } + + private static final class DoubleColumn + extends ColumnDefinition + { + public DoubleColumn() + { + super(DOUBLE); + } + + @Override + Object extractValue(Block block, int position) + { + return DOUBLE.getDouble(block, position); + } + + @Override + void writeNextRandomValue(Random random, BlockBuilder builder) + { + DOUBLE.writeDouble(builder, random.nextDouble()); + } + } + + private static final class VarcharColumn + extends ColumnDefinition + { + private final VarcharType varcharType; + + public VarcharColumn(VarcharType varcharType) + { + super(varcharType); + this.varcharType = requireNonNull(varcharType, "varcharType is null"); + } + + @Override + Object extractValue(Block block, int position) + { + return varcharType.getSlice(block, position); + } + + @Override + void writeNextRandomValue(Random random, BlockBuilder builder) + { + varcharType.writeString(builder, nextString(random)); + } + } + + private static final class BooleanColumn + extends ColumnDefinition + { + public BooleanColumn() + { + super(BOOLEAN); + } + + @Override + Object extractValue(Block block, int position) + { + return BOOLEAN.getBoolean(block, position); + } + + @Override + void writeNextRandomValue(Random random, BlockBuilder builder) + { + BOOLEAN.writeBoolean(builder, random.nextBoolean()); + } + } + + private static final class DateColumn + extends ColumnDefinition + { + public DateColumn() + { + super(DATE); + } + + @Override + Object extractValue(Block block, int position) + { + return DATE.getLong(block, position); + } + + @Override + void writeNextRandomValue(Random random, BlockBuilder builder) + { + DATE.writeLong(builder, nextDate(random)); + } + } + + private static final class TimestampColumn + extends ColumnDefinition + { + public TimestampColumn() + { + super(TIMESTAMP); + } + + @Override + Object extractValue(Block block, int position) + { + return TIMESTAMP.getLong(block, position); + } + + @Override + void writeNextRandomValue(Random random, BlockBuilder builder) + { + TIMESTAMP.writeLong(builder, nextTimestamp(random)); + } + } + + private static final class JsonColumn + extends ColumnDefinition + { + public JsonColumn() + { + super(JSON); + } + + @Override + Object extractValue(Block block, int position) + { + return JSON.getSlice(block, position); + } + + @Override + void writeNextRandomValue(Random random, BlockBuilder builder) + { + String json = String.format("{\"%s\": %d, \"%s\": \"%s\"}", + nextString(random, MAX_GENERATED_JSON_KEY_LENGTH), + random.nextInt(), + nextString(random, MAX_GENERATED_JSON_KEY_LENGTH), + random.nextInt()); + JSON.writeString(builder, json); + } + } + + private static final class HyperLogLogColumn + extends ColumnDefinition + { + public HyperLogLogColumn() + { + super(HYPER_LOG_LOG); + } + + @Override + Object extractValue(Block block, int position) + { + return HYPER_LOG_LOG.getSlice(block, position); + } + + @Override + void writeNextRandomValue(Random random, BlockBuilder builder) + { + HYPER_LOG_LOG.writeSlice(builder, nextHyperLogLog(random)); + } + } + + private static final class BigintArrayColumn + extends ColumnDefinition + { + private final ArrayType arrayType; + + public BigintArrayColumn() + { + this(new ArrayType(BIGINT)); + } + + private BigintArrayColumn(ArrayType arrayType) + { + super(arrayType); + this.arrayType = requireNonNull(arrayType, "arrayType is null"); + } + + @Override + Object extractValue(Block block, int position) + { + return arrayType.getObjectValue(null, block, position); + } + + @Override + void writeNextRandomValue(Random random, BlockBuilder builder) + { + generateBigintArray(random, builder); + } + } +} diff --git a/presto-thrift-connector-api/src/test/java/com/facebook/presto/connector/thrift/api/datatypes/TestPrestoThriftBigint.java b/presto-thrift-connector-api/src/test/java/com/facebook/presto/connector/thrift/api/datatypes/TestPrestoThriftBigint.java new file mode 100644 index 000000000000..625e1a04a613 --- /dev/null +++ b/presto-thrift-connector-api/src/test/java/com/facebook/presto/connector/thrift/api/datatypes/TestPrestoThriftBigint.java @@ -0,0 +1,203 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.connector.thrift.api.datatypes; + +import com.facebook.presto.connector.thrift.api.PrestoThriftBlock; +import com.facebook.presto.spi.block.Block; +import com.facebook.presto.spi.block.BlockBuilder; +import com.facebook.presto.spi.block.BlockBuilderStatus; +import org.testng.annotations.Test; + +import java.util.Arrays; +import java.util.List; + +import static com.facebook.presto.connector.thrift.api.PrestoThriftBlock.bigintData; +import static com.facebook.presto.connector.thrift.api.PrestoThriftBlock.integerData; +import static com.facebook.presto.connector.thrift.api.datatypes.PrestoThriftBigint.fromBlock; +import static com.facebook.presto.spi.type.BigintType.BIGINT; +import static com.facebook.presto.spi.type.IntegerType.INTEGER; +import static java.util.Collections.unmodifiableList; +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertNotNull; +import static org.testng.Assert.assertNull; +import static org.testng.Assert.assertTrue; + +public class TestPrestoThriftBigint +{ + @Test + public void testReadBlock() + throws Exception + { + PrestoThriftBlock columnsData = longColumn( + new boolean[] {false, true, false, false, false, false, true}, + new long[] {2, 0, 1, 3, 8, 4, 0} + ); + Block actual = columnsData.toBlock(BIGINT); + assertBlockEquals(actual, list(2L, null, 1L, 3L, 8L, 4L, null)); + } + + @Test + public void testReadBlockAllNullsOption1() + { + PrestoThriftBlock columnsData = longColumn( + new boolean[] {true, true, true, true, true, true, true}, + null + ); + Block actual = columnsData.toBlock(BIGINT); + assertBlockEquals(actual, list(null, null, null, null, null, null, null)); + } + + @Test + public void testReadBlockAllNullsOption2() + { + PrestoThriftBlock columnsData = longColumn( + new boolean[] {true, true, true, true, true, true, true}, + new long[] {0, 0, 0, 0, 0, 0, 0} + ); + Block actual = columnsData.toBlock(BIGINT); + assertBlockEquals(actual, list(null, null, null, null, null, null, null)); + } + + @Test + public void testReadBlockAllNonNullOption1() + throws Exception + { + PrestoThriftBlock columnsData = longColumn( + null, + new long[] {2, 7, 1, 3, 8, 4, 5} + ); + Block actual = columnsData.toBlock(BIGINT); + assertBlockEquals(actual, list(2L, 7L, 1L, 3L, 8L, 4L, 5L)); + } + + @Test + public void testReadBlockAllNonNullOption2() + throws Exception + { + PrestoThriftBlock columnsData = longColumn( + new boolean[] {false, false, false, false, false, false, false}, + new long[] {2, 7, 1, 3, 8, 4, 5} + ); + Block actual = columnsData.toBlock(BIGINT); + assertBlockEquals(actual, list(2L, 7L, 1L, 3L, 8L, 4L, 5L)); + } + + @Test(expectedExceptions = IllegalArgumentException.class) + public void testReadBlockWrongActualType() + throws Exception + { + PrestoThriftBlock columnsData = integerData(new PrestoThriftInteger(null, null)); + columnsData.toBlock(BIGINT); + } + + @Test(expectedExceptions = IllegalArgumentException.class) + public void testReadBlockWrongDesiredType() + throws Exception + { + PrestoThriftBlock columnsData = longColumn(null, null); + columnsData.toBlock(INTEGER); + } + + @Test + public void testWriteBlockAlternating() + throws Exception + { + Block source = longBlock(1, null, 2, null, 3, null, 4, null, 5, null, 6, null, 7, null); + PrestoThriftBlock column = fromBlock(source); + assertNotNull(column.getBigintData()); + assertEquals(column.getBigintData().getNulls(), + new boolean[] {false, true, false, true, false, true, false, true, false, true, false, true, false, true}); + assertEquals(column.getBigintData().getLongs(), + new long[] {1, 0, 2, 0, 3, 0, 4, 0, 5, 0, 6, 0, 7, 0}); + } + + @Test + public void testWriteBlockAllNulls() + throws Exception + { + Block source = longBlock(null, null, null, null, null); + PrestoThriftBlock column = fromBlock(source); + assertNotNull(column.getBigintData()); + assertEquals(column.getBigintData().getNulls(), new boolean[] {true, true, true, true, true}); + assertNull(column.getBigintData().getLongs()); + } + + @Test + public void testWriteBlockAllNonNull() + throws Exception + { + Block source = longBlock(1, 2, 3, 4, 5); + PrestoThriftBlock column = fromBlock(source); + assertNotNull(column.getBigintData()); + assertNull(column.getBigintData().getNulls()); + assertEquals(column.getBigintData().getLongs(), new long[] {1, 2, 3, 4, 5}); + } + + @Test + public void testWriteBlockEmpty() + throws Exception + { + PrestoThriftBlock column = fromBlock(longBlock()); + assertNotNull(column.getBigintData()); + assertNull(column.getBigintData().getNulls()); + assertNull(column.getBigintData().getLongs()); + } + + @Test + public void testWriteBlockSingleValue() + throws Exception + { + PrestoThriftBlock column = fromBlock(longBlock(1)); + assertNotNull(column.getBigintData()); + assertNull(column.getBigintData().getNulls()); + assertEquals(column.getBigintData().getLongs(), new long[] {1}); + } + + private void assertBlockEquals(Block block, List expected) + { + assertEquals(block.getPositionCount(), expected.size()); + for (int i = 0; i < expected.size(); i++) { + if (expected.get(i) == null) { + assertTrue(block.isNull(i)); + } + else { + assertEquals(block.getLong(i, 0), expected.get(i).longValue()); + } + } + } + + private static Block longBlock(Integer... values) + { + BlockBuilder blockBuilder = BIGINT.createBlockBuilder(new BlockBuilderStatus(), values.length); + for (Integer value : values) { + if (value == null) { + blockBuilder.appendNull(); + } + else { + blockBuilder.writeLong(value).closeEntry(); + } + } + return blockBuilder.build(); + } + + private static PrestoThriftBlock longColumn(boolean[] nulls, long[] longs) + { + return bigintData(new PrestoThriftBigint(nulls, longs)); + } + + private static List list(Long... values) + { + return unmodifiableList(Arrays.asList(values)); + } +} diff --git a/presto-thrift-connector-api/src/test/java/com/facebook/presto/connector/thrift/api/valuesets/TestPrestoThriftAllOrNoneValueSet.java b/presto-thrift-connector-api/src/test/java/com/facebook/presto/connector/thrift/api/valuesets/TestPrestoThriftAllOrNoneValueSet.java new file mode 100644 index 000000000000..d1ee43697ad7 --- /dev/null +++ b/presto-thrift-connector-api/src/test/java/com/facebook/presto/connector/thrift/api/valuesets/TestPrestoThriftAllOrNoneValueSet.java @@ -0,0 +1,44 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.connector.thrift.api.valuesets; + +import com.facebook.presto.spi.predicate.ValueSet; +import org.testng.annotations.Test; + +import static com.facebook.presto.connector.thrift.api.valuesets.PrestoThriftValueSet.fromValueSet; +import static com.facebook.presto.spi.type.HyperLogLogType.HYPER_LOG_LOG; +import static org.testng.Assert.assertFalse; +import static org.testng.Assert.assertNotNull; +import static org.testng.Assert.assertTrue; + +public class TestPrestoThriftAllOrNoneValueSet +{ + @Test + public void testFromValueSetAll() + throws Exception + { + PrestoThriftValueSet thriftValueSet = fromValueSet(ValueSet.all(HYPER_LOG_LOG)); + assertNotNull(thriftValueSet.getAllOrNoneValueSet()); + assertTrue(thriftValueSet.getAllOrNoneValueSet().isAll()); + } + + @Test + public void testFromValueSetNone() + throws Exception + { + PrestoThriftValueSet thriftValueSet = fromValueSet(ValueSet.none(HYPER_LOG_LOG)); + assertNotNull(thriftValueSet.getAllOrNoneValueSet()); + assertFalse(thriftValueSet.getAllOrNoneValueSet().isAll()); + } +} diff --git a/presto-thrift-connector-api/src/test/java/com/facebook/presto/connector/thrift/api/valuesets/TestPrestoThriftEquatableValueSet.java b/presto-thrift-connector-api/src/test/java/com/facebook/presto/connector/thrift/api/valuesets/TestPrestoThriftEquatableValueSet.java new file mode 100644 index 000000000000..068efe3ced81 --- /dev/null +++ b/presto-thrift-connector-api/src/test/java/com/facebook/presto/connector/thrift/api/valuesets/TestPrestoThriftEquatableValueSet.java @@ -0,0 +1,67 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.connector.thrift.api.valuesets; + +import com.facebook.presto.connector.thrift.api.datatypes.PrestoThriftJson; +import com.facebook.presto.spi.predicate.ValueSet; +import com.google.common.collect.ImmutableList; +import org.testng.annotations.Test; + +import static com.facebook.presto.connector.thrift.api.PrestoThriftBlock.jsonData; +import static com.facebook.presto.connector.thrift.api.valuesets.PrestoThriftValueSet.fromValueSet; +import static com.facebook.presto.type.JsonType.JSON; +import static io.airlift.slice.Slices.utf8Slice; +import static java.nio.charset.StandardCharsets.UTF_8; +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertFalse; +import static org.testng.Assert.assertNotNull; +import static org.testng.Assert.assertTrue; + +public class TestPrestoThriftEquatableValueSet +{ + private static final String JSON1 = "\"key1\":\"value1\""; + private static final String JSON2 = "\"key2\":\"value2\""; + + @Test + public void testFromValueSetAll() + throws Exception + { + PrestoThriftValueSet thriftValueSet = fromValueSet(ValueSet.all(JSON)); + assertNotNull(thriftValueSet.getEquatableValueSet()); + assertFalse(thriftValueSet.getEquatableValueSet().isWhiteList()); + assertTrue(thriftValueSet.getEquatableValueSet().getValues().isEmpty()); + } + + @Test + public void testFromValueSetNone() + throws Exception + { + PrestoThriftValueSet thriftValueSet = fromValueSet(ValueSet.none(JSON)); + assertNotNull(thriftValueSet.getEquatableValueSet()); + assertTrue(thriftValueSet.getEquatableValueSet().isWhiteList()); + assertTrue(thriftValueSet.getEquatableValueSet().getValues().isEmpty()); + } + + @Test + public void testFromValueSetOf() + throws Exception + { + PrestoThriftValueSet thriftValueSet = fromValueSet(ValueSet.of(JSON, utf8Slice(JSON1), utf8Slice(JSON2))); + assertNotNull(thriftValueSet.getEquatableValueSet()); + assertTrue(thriftValueSet.getEquatableValueSet().isWhiteList()); + assertEquals(thriftValueSet.getEquatableValueSet().getValues(), ImmutableList.of( + jsonData(new PrestoThriftJson(null, new int[] {JSON1.length()}, JSON1.getBytes(UTF_8))), + jsonData(new PrestoThriftJson(null, new int[] {JSON2.length()}, JSON2.getBytes(UTF_8))))); + } +} diff --git a/presto-thrift-connector-api/src/test/java/com/facebook/presto/connector/thrift/api/valuesets/TestPrestoThriftRangeValueSet.java b/presto-thrift-connector-api/src/test/java/com/facebook/presto/connector/thrift/api/valuesets/TestPrestoThriftRangeValueSet.java new file mode 100644 index 000000000000..16c82af88e92 --- /dev/null +++ b/presto-thrift-connector-api/src/test/java/com/facebook/presto/connector/thrift/api/valuesets/TestPrestoThriftRangeValueSet.java @@ -0,0 +1,95 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.connector.thrift.api.valuesets; + +import com.facebook.presto.connector.thrift.api.PrestoThriftBlock; +import com.facebook.presto.connector.thrift.api.datatypes.PrestoThriftBigint; +import com.facebook.presto.connector.thrift.api.valuesets.PrestoThriftRangeValueSet.PrestoThriftMarker; +import com.facebook.presto.connector.thrift.api.valuesets.PrestoThriftRangeValueSet.PrestoThriftRange; +import com.facebook.presto.spi.predicate.Range; +import com.facebook.presto.spi.predicate.ValueSet; +import com.google.common.collect.ImmutableList; +import org.testng.annotations.Test; + +import static com.facebook.presto.connector.thrift.api.PrestoThriftBlock.bigintData; +import static com.facebook.presto.connector.thrift.api.valuesets.PrestoThriftRangeValueSet.PrestoThriftBound.ABOVE; +import static com.facebook.presto.connector.thrift.api.valuesets.PrestoThriftRangeValueSet.PrestoThriftBound.BELOW; +import static com.facebook.presto.connector.thrift.api.valuesets.PrestoThriftRangeValueSet.PrestoThriftBound.EXACTLY; +import static com.facebook.presto.connector.thrift.api.valuesets.PrestoThriftValueSet.fromValueSet; +import static com.facebook.presto.spi.predicate.Range.range; +import static com.facebook.presto.spi.type.BigintType.BIGINT; +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertNotNull; + +public class TestPrestoThriftRangeValueSet +{ + @Test + public void testFromValueSetAll() + throws Exception + { + PrestoThriftValueSet thriftValueSet = fromValueSet(ValueSet.all(BIGINT)); + assertNotNull(thriftValueSet.getRangeValueSet()); + assertEquals(thriftValueSet.getRangeValueSet().getRanges(), ImmutableList.of( + new PrestoThriftRange(new PrestoThriftMarker(null, ABOVE), new PrestoThriftMarker(null, BELOW)))); + } + + @Test + public void testFromValueSetNone() + throws Exception + { + PrestoThriftValueSet thriftValueSet = fromValueSet(ValueSet.none(BIGINT)); + assertNotNull(thriftValueSet.getRangeValueSet()); + assertEquals(thriftValueSet.getRangeValueSet().getRanges(), ImmutableList.of()); + } + + @Test + public void testFromValueSetOf() + throws Exception + { + PrestoThriftValueSet thriftValueSet = fromValueSet(ValueSet.of(BIGINT, 1L, 2L, 3L)); + assertNotNull(thriftValueSet.getRangeValueSet()); + assertEquals(thriftValueSet.getRangeValueSet().getRanges(), ImmutableList.of( + new PrestoThriftRange(new PrestoThriftMarker(longValue(1), EXACTLY), new PrestoThriftMarker(longValue(1), EXACTLY)), + new PrestoThriftRange(new PrestoThriftMarker(longValue(2), EXACTLY), new PrestoThriftMarker(longValue(2), EXACTLY)), + new PrestoThriftRange(new PrestoThriftMarker(longValue(3), EXACTLY), new PrestoThriftMarker(longValue(3), EXACTLY)))); + } + + @Test + public void testFromValueSetOfRangesUnbounded() + throws Exception + { + PrestoThriftValueSet thriftValueSet = fromValueSet(ValueSet.ofRanges(Range.greaterThanOrEqual(BIGINT, 0L))); + assertNotNull(thriftValueSet.getRangeValueSet()); + assertEquals(thriftValueSet.getRangeValueSet().getRanges(), ImmutableList.of( + new PrestoThriftRange(new PrestoThriftMarker(longValue(0), EXACTLY), new PrestoThriftMarker(null, BELOW)))); + } + + @Test + public void testFromValueSetOfRangesBounded() + throws Exception + { + PrestoThriftValueSet thriftValueSet = fromValueSet(ValueSet.ofRanges( + range(BIGINT, -10L, true, -1L, false), + range(BIGINT, -1L, false, 100L, true))); + assertNotNull(thriftValueSet.getRangeValueSet()); + assertEquals(thriftValueSet.getRangeValueSet().getRanges(), ImmutableList.of( + new PrestoThriftRange(new PrestoThriftMarker(longValue(-10), EXACTLY), new PrestoThriftMarker(longValue(-1), BELOW)), + new PrestoThriftRange(new PrestoThriftMarker(longValue(-1), ABOVE), new PrestoThriftMarker(longValue(100), EXACTLY)))); + } + + private static PrestoThriftBlock longValue(long value) + { + return bigintData(new PrestoThriftBigint(null, new long[] {value})); + } +} diff --git a/presto-thrift-connector/README.md b/presto-thrift-connector/README.md new file mode 100644 index 000000000000..6a564e2b32a4 --- /dev/null +++ b/presto-thrift-connector/README.md @@ -0,0 +1,23 @@ +Thrift Connector +================ + +Thrift Connector makes it possible to integrate with external storage systems without a custom Presto connector implementation. + +In order to use Thrift Connector with external system you need to implement `PrestoThriftService` interface defined in `presto-thrift-connector-api` project. +Next, you configure Thrift Connector to point to a set of machines, called thrift servers, implementing it. +As part of the interface implementation thrift servers will provide metadata, splits and data. +Thrift server instances are assumed to be stateless and independent from each other. + +Using Thrift Connector over a custom Presto connector can be especially useful in the following cases. + +* Java client for a storage system is not available. +By using Thrift as transport and service definition Thrift Connector can integrate with systems written in non-Java languages. + +* Storage system's model doesn't easily map to metadata/table/row concept or there are multiple ways to do it. +For example, there are multiple ways how to map data from a key/value storage to relational representation. +Instead of supporting all of the variations in the connector this task can be moved to the external system itself. + +* You cannot or don't want to modify Presto code to add a custom connector to support your storage system. + +You can find thrift service interface that needs to be implemented together with related thrift structures in `presto-thrift-connector-api` project. +Documentation of [`PrestoThriftService`](../presto-thrift-connector-api/src/main/java/com/facebook/presto/connector/thrift/api/PrestoThriftService.java) is a good starting point. diff --git a/presto-thrift-connector/pom.xml b/presto-thrift-connector/pom.xml new file mode 100644 index 000000000000..a7d885bfaea2 --- /dev/null +++ b/presto-thrift-connector/pom.xml @@ -0,0 +1,156 @@ + + + 4.0.0 + + + com.facebook.presto + presto-root + 0.181-tw-0.37 + + + presto-thrift-connector + Presto - Thrift Connector + presto-plugin + + + ${project.parent.basedir} + + + + + com.facebook.presto + presto-thrift-connector-api + + + + com.google.guava + guava + + + + com.google.code.findbugs + annotations + + + + com.facebook.swift + swift-codec + + + + com.facebook.swift + swift-service + + + + io.airlift + bootstrap + + + + io.airlift + json + + + + io.airlift + log + + + + org.weakref + jmxutils + + + + com.google.inject + guice + + + + javax.inject + javax.inject + + + + io.airlift + configuration + + + + com.facebook.nifty + nifty-client + + + + javax.validation + validation-api + + + + io.airlift + concurrent + + + + com.facebook.presto + presto-spi + provided + + + + io.airlift + slice + provided + + + + io.airlift + units + provided + + + + org.openjdk.jol + jol-core + provided + + + + com.fasterxml.jackson.core + jackson-annotations + provided + + + + + org.testng + testng + test + + + + com.facebook.presto + presto-thrift-testing-server + test + + + + io.airlift + testing + test + + + + com.facebook.presto + presto-tests + test + + + + com.facebook.presto + presto-main + test + + + diff --git a/presto-thrift-connector/src/main/java/com/facebook/presto/connector/thrift/ThriftColumnHandle.java b/presto-thrift-connector/src/main/java/com/facebook/presto/connector/thrift/ThriftColumnHandle.java new file mode 100644 index 000000000000..21b3acc6b43d --- /dev/null +++ b/presto-thrift-connector/src/main/java/com/facebook/presto/connector/thrift/ThriftColumnHandle.java @@ -0,0 +1,117 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.connector.thrift; + +import com.facebook.presto.spi.ColumnHandle; +import com.facebook.presto.spi.ColumnMetadata; +import com.facebook.presto.spi.type.Type; +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; + +import javax.annotation.Nullable; + +import java.util.Objects; + +import static com.google.common.base.MoreObjects.toStringHelper; +import static java.util.Objects.requireNonNull; + +public final class ThriftColumnHandle + implements ColumnHandle +{ + private final String columnName; + private final Type columnType; + private final String comment; + private final boolean hidden; + + @JsonCreator + public ThriftColumnHandle( + @JsonProperty("columnName") String columnName, + @JsonProperty("columnType") Type columnType, + @JsonProperty("comment") @Nullable String comment, + @JsonProperty("hidden") boolean hidden) + { + this.columnName = requireNonNull(columnName, "columnName is null"); + this.columnType = requireNonNull(columnType, "columnType is null"); + this.comment = comment; + this.hidden = hidden; + } + + public ThriftColumnHandle(ColumnMetadata columnMetadata) + { + this(columnMetadata.getName(), columnMetadata.getType(), columnMetadata.getComment(), columnMetadata.isHidden()); + } + + @JsonProperty + public String getColumnName() + { + return columnName; + } + + @JsonProperty + public Type getColumnType() + { + return columnType; + } + + @Nullable + @JsonProperty + public String getComment() + { + return comment; + } + + @JsonProperty + public boolean isHidden() + { + return hidden; + } + + public ColumnMetadata toColumnMetadata() + { + return new ColumnMetadata(columnName, columnType, comment, hidden); + } + + @Override + public boolean equals(Object obj) + { + if (this == obj) { + return true; + } + if (obj == null || getClass() != obj.getClass()) { + return false; + } + ThriftColumnHandle other = (ThriftColumnHandle) obj; + return Objects.equals(this.columnName, other.columnName) && + Objects.equals(this.columnType, other.columnType) && + Objects.equals(this.comment, other.comment) && + this.hidden == other.hidden; + } + + @Override + public int hashCode() + { + return Objects.hash(columnName, columnType, comment, hidden); + } + + @Override + public String toString() + { + return toStringHelper(this) + .add("columnName", columnName) + .add("columnType", columnType) + .add("comment", comment) + .add("hidden", hidden) + .toString(); + } +} diff --git a/presto-thrift-connector/src/main/java/com/facebook/presto/connector/thrift/ThriftConnector.java b/presto-thrift-connector/src/main/java/com/facebook/presto/connector/thrift/ThriftConnector.java new file mode 100644 index 000000000000..01d96f40026a --- /dev/null +++ b/presto-thrift-connector/src/main/java/com/facebook/presto/connector/thrift/ThriftConnector.java @@ -0,0 +1,102 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.connector.thrift; + +import com.facebook.presto.spi.connector.Connector; +import com.facebook.presto.spi.connector.ConnectorMetadata; +import com.facebook.presto.spi.connector.ConnectorPageSourceProvider; +import com.facebook.presto.spi.connector.ConnectorSplitManager; +import com.facebook.presto.spi.connector.ConnectorTransactionHandle; +import com.facebook.presto.spi.session.PropertyMetadata; +import com.facebook.presto.spi.transaction.IsolationLevel; +import io.airlift.bootstrap.LifeCycleManager; +import io.airlift.log.Logger; + +import javax.inject.Inject; + +import java.util.List; + +import static java.util.Objects.requireNonNull; + +public class ThriftConnector + implements Connector +{ + private static final Logger log = Logger.get(ThriftConnector.class); + + private final LifeCycleManager lifeCycleManager; + private final ThriftMetadata metadata; + private final ThriftSplitManager splitManager; + private final ThriftPageSourceProvider pageSourceProvider; + private final ThriftSessionProperties sessionProperties; + + @Inject + public ThriftConnector( + LifeCycleManager lifeCycleManager, + ThriftMetadata metadata, + ThriftSplitManager splitManager, + ThriftPageSourceProvider pageSourceProvider, + ThriftSessionProperties sessionProperties) + { + this.lifeCycleManager = requireNonNull(lifeCycleManager, "lifeCycleManager is null"); + this.metadata = requireNonNull(metadata, "metadata is null"); + this.splitManager = requireNonNull(splitManager, "splitManager is null"); + this.pageSourceProvider = requireNonNull(pageSourceProvider, "pageSourceProvider is null"); + this.sessionProperties = requireNonNull(sessionProperties, "sessionProperties is null"); + } + + @Override + public ConnectorTransactionHandle beginTransaction(IsolationLevel isolationLevel, boolean readOnly) + { + return ThriftTransactionHandle.INSTANCE; + } + + @Override + public ConnectorMetadata getMetadata(ConnectorTransactionHandle transactionHandle) + { + return metadata; + } + + @Override + public ConnectorSplitManager getSplitManager() + { + return splitManager; + } + + @Override + public ConnectorPageSourceProvider getPageSourceProvider() + { + return pageSourceProvider; + } + + @Override + public List> getSessionProperties() + { + return sessionProperties.getSessionProperties(); + } + + @Override + public final void shutdown() + { + try { + lifeCycleManager.stop(); + } + catch (InterruptedException ie) { + Thread.currentThread().interrupt(); + log.error(ie, "Interrupted while shutting down connector"); + } + catch (Exception e) { + log.error(e, "Error shutting down connector"); + } + } +} diff --git a/presto-thrift-connector/src/main/java/com/facebook/presto/connector/thrift/ThriftConnectorConfig.java b/presto-thrift-connector/src/main/java/com/facebook/presto/connector/thrift/ThriftConnectorConfig.java new file mode 100644 index 000000000000..4eb43e17e557 --- /dev/null +++ b/presto-thrift-connector/src/main/java/com/facebook/presto/connector/thrift/ThriftConnectorConfig.java @@ -0,0 +1,58 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.connector.thrift; + +import io.airlift.configuration.Config; +import io.airlift.units.DataSize; +import io.airlift.units.MaxDataSize; +import io.airlift.units.MinDataSize; + +import javax.validation.constraints.Min; +import javax.validation.constraints.NotNull; + +import static io.airlift.units.DataSize.Unit.MEGABYTE; + +public class ThriftConnectorConfig +{ + private DataSize maxResponseSize = new DataSize(16, MEGABYTE); + private int metadataRefreshThreads = 1; + + @NotNull + @MinDataSize("1MB") + @MaxDataSize("32MB") + public DataSize getMaxResponseSize() + { + return maxResponseSize; + } + + @Config("presto-thrift.max-response-size") + public ThriftConnectorConfig setMaxResponseSize(DataSize maxResponseSize) + { + this.maxResponseSize = maxResponseSize; + return this; + } + + @Min(1) + public int getMetadataRefreshThreads() + { + return metadataRefreshThreads; + } + + @Config("presto-thrift.metadata-refresh-threads") + public ThriftConnectorConfig setMetadataRefreshThreads(int metadataRefreshThreads) + { + this.metadataRefreshThreads = metadataRefreshThreads; + return this; + } +} diff --git a/presto-thrift-connector/src/main/java/com/facebook/presto/connector/thrift/ThriftConnectorFactory.java b/presto-thrift-connector/src/main/java/com/facebook/presto/connector/thrift/ThriftConnectorFactory.java new file mode 100644 index 000000000000..d488d8fe3136 --- /dev/null +++ b/presto-thrift-connector/src/main/java/com/facebook/presto/connector/thrift/ThriftConnectorFactory.java @@ -0,0 +1,97 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.connector.thrift; + +import com.facebook.presto.connector.thrift.util.RebindSafeMBeanServer; +import com.facebook.presto.spi.ConnectorHandleResolver; +import com.facebook.presto.spi.connector.Connector; +import com.facebook.presto.spi.connector.ConnectorContext; +import com.facebook.presto.spi.connector.ConnectorFactory; +import com.facebook.presto.spi.type.TypeManager; +import com.facebook.swift.codec.guice.ThriftCodecModule; +import com.facebook.swift.service.guice.ThriftClientModule; +import com.facebook.swift.service.guice.ThriftClientStatsModule; +import com.google.inject.Injector; +import com.google.inject.Module; +import io.airlift.bootstrap.Bootstrap; +import io.airlift.json.JsonModule; +import org.weakref.jmx.guice.MBeanModule; + +import javax.management.MBeanServer; + +import java.util.Map; + +import static com.google.common.base.Throwables.throwIfUnchecked; +import static java.lang.management.ManagementFactory.getPlatformMBeanServer; +import static java.util.Objects.requireNonNull; + +public class ThriftConnectorFactory + implements ConnectorFactory +{ + private final String name; + private final Module locationModule; + + public ThriftConnectorFactory(String name, Module locationModule) + { + this.name = requireNonNull(name, "name is null"); + this.locationModule = requireNonNull(locationModule, "locationModule is null"); + } + + @Override + public String getName() + { + return name; + } + + @Override + public ConnectorHandleResolver getHandleResolver() + { + return new ThriftHandleResolver(); + } + + @Override + public Connector create(String connectorId, Map config, ConnectorContext context) + { + try { + Bootstrap app = new Bootstrap( + new JsonModule(), + new MBeanModule(), + new ThriftCodecModule(), + new ThriftClientModule(), + new ThriftClientStatsModule(), + binder -> { + binder.bind(MBeanServer.class).toInstance(new RebindSafeMBeanServer(getPlatformMBeanServer())); + binder.bind(TypeManager.class).toInstance(context.getTypeManager()); + }, + locationModule, + new ThriftModule()); + + Injector injector = app + .strictConfig() + .doNotInitializeLogging() + .setRequiredConfigurationProperties(config) + .initialize(); + + return injector.getInstance(ThriftConnector.class); + } + catch (InterruptedException ie) { + Thread.currentThread().interrupt(); + throw new RuntimeException("Interrupted while creating connector", ie); + } + catch (Exception e) { + throwIfUnchecked(e); + throw new RuntimeException(e); + } + } +} diff --git a/presto-thrift-connector/src/main/java/com/facebook/presto/connector/thrift/ThriftConnectorSplit.java b/presto-thrift-connector/src/main/java/com/facebook/presto/connector/thrift/ThriftConnectorSplit.java new file mode 100644 index 000000000000..8234a88932b8 --- /dev/null +++ b/presto-thrift-connector/src/main/java/com/facebook/presto/connector/thrift/ThriftConnectorSplit.java @@ -0,0 +1,97 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.connector.thrift; + +import com.facebook.presto.connector.thrift.api.PrestoThriftId; +import com.facebook.presto.spi.ConnectorSplit; +import com.facebook.presto.spi.HostAddress; +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.google.common.collect.ImmutableList; + +import java.util.List; +import java.util.Objects; + +import static com.google.common.base.MoreObjects.toStringHelper; +import static java.util.Objects.requireNonNull; + +public class ThriftConnectorSplit + implements ConnectorSplit +{ + private final PrestoThriftId splitId; + private final List addresses; + + @JsonCreator + public ThriftConnectorSplit( + @JsonProperty("splitId") PrestoThriftId splitId, + @JsonProperty("addresses") List addresses) + { + this.splitId = requireNonNull(splitId, "splitId is null"); + this.addresses = ImmutableList.copyOf(requireNonNull(addresses, "addresses is null")); + } + + @JsonProperty + public PrestoThriftId getSplitId() + { + return splitId; + } + + @Override + @JsonProperty + public List getAddresses() + { + return addresses; + } + + @Override + public Object getInfo() + { + return ""; + } + + @Override + public boolean isRemotelyAccessible() + { + return true; + } + + @Override + public boolean equals(Object obj) + { + if (this == obj) { + return true; + } + if (obj == null || getClass() != obj.getClass()) { + return false; + } + ThriftConnectorSplit other = (ThriftConnectorSplit) obj; + return Objects.equals(this.splitId, other.splitId) && + Objects.equals(this.addresses, other.addresses); + } + + @Override + public int hashCode() + { + return Objects.hash(splitId, addresses); + } + + @Override + public String toString() + { + return toStringHelper(this) + .add("splitId", splitId) + .add("addresses", addresses) + .toString(); + } +} diff --git a/presto-thrift-connector/src/main/java/com/facebook/presto/connector/thrift/ThriftErrorCode.java b/presto-thrift-connector/src/main/java/com/facebook/presto/connector/thrift/ThriftErrorCode.java new file mode 100644 index 000000000000..e4c4149c1e5f --- /dev/null +++ b/presto-thrift-connector/src/main/java/com/facebook/presto/connector/thrift/ThriftErrorCode.java @@ -0,0 +1,40 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.connector.thrift; + +import com.facebook.presto.spi.ErrorCode; +import com.facebook.presto.spi.ErrorCodeSupplier; +import com.facebook.presto.spi.ErrorType; + +import static com.facebook.presto.spi.ErrorType.EXTERNAL; + +public enum ThriftErrorCode + implements ErrorCodeSupplier +{ + THRIFT_SERVICE_CONNECTION_ERROR(1, EXTERNAL), + THRIFT_SERVICE_INVALID_RESPONSE(2, EXTERNAL); + + private final ErrorCode errorCode; + + ThriftErrorCode(int code, ErrorType type) + { + errorCode = new ErrorCode(code + 0x0105, name(), type); + } + + @Override + public ErrorCode toErrorCode() + { + return errorCode; + } +} diff --git a/presto-thrift-connector/src/main/java/com/facebook/presto/connector/thrift/ThriftHandleResolver.java b/presto-thrift-connector/src/main/java/com/facebook/presto/connector/thrift/ThriftHandleResolver.java new file mode 100644 index 000000000000..413b42020118 --- /dev/null +++ b/presto-thrift-connector/src/main/java/com/facebook/presto/connector/thrift/ThriftHandleResolver.java @@ -0,0 +1,55 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.connector.thrift; + +import com.facebook.presto.spi.ColumnHandle; +import com.facebook.presto.spi.ConnectorHandleResolver; +import com.facebook.presto.spi.ConnectorSplit; +import com.facebook.presto.spi.ConnectorTableHandle; +import com.facebook.presto.spi.ConnectorTableLayoutHandle; +import com.facebook.presto.spi.connector.ConnectorTransactionHandle; + +public class ThriftHandleResolver + implements ConnectorHandleResolver +{ + @Override + public Class getTableHandleClass() + { + return ThriftTableHandle.class; + } + + @Override + public Class getTableLayoutHandleClass() + { + return ThriftTableLayoutHandle.class; + } + + @Override + public Class getColumnHandleClass() + { + return ThriftColumnHandle.class; + } + + @Override + public Class getSplitClass() + { + return ThriftConnectorSplit.class; + } + + @Override + public Class getTransactionHandleClass() + { + return ThriftTransactionHandle.class; + } +} diff --git a/presto-thrift-connector/src/main/java/com/facebook/presto/connector/thrift/ThriftMetadata.java b/presto-thrift-connector/src/main/java/com/facebook/presto/connector/thrift/ThriftMetadata.java new file mode 100644 index 000000000000..57ba6ca398c1 --- /dev/null +++ b/presto-thrift-connector/src/main/java/com/facebook/presto/connector/thrift/ThriftMetadata.java @@ -0,0 +1,194 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.connector.thrift; + +import com.facebook.presto.connector.thrift.annotations.ForMetadataRefresh; +import com.facebook.presto.connector.thrift.api.PrestoThriftNullableSchemaName; +import com.facebook.presto.connector.thrift.api.PrestoThriftNullableTableMetadata; +import com.facebook.presto.connector.thrift.api.PrestoThriftSchemaTableName; +import com.facebook.presto.connector.thrift.api.PrestoThriftService; +import com.facebook.presto.connector.thrift.clientproviders.PrestoThriftServiceProvider; +import com.facebook.presto.spi.ColumnHandle; +import com.facebook.presto.spi.ColumnMetadata; +import com.facebook.presto.spi.ConnectorSession; +import com.facebook.presto.spi.ConnectorTableHandle; +import com.facebook.presto.spi.ConnectorTableLayout; +import com.facebook.presto.spi.ConnectorTableLayoutHandle; +import com.facebook.presto.spi.ConnectorTableLayoutResult; +import com.facebook.presto.spi.ConnectorTableMetadata; +import com.facebook.presto.spi.Constraint; +import com.facebook.presto.spi.PrestoException; +import com.facebook.presto.spi.SchemaTableName; +import com.facebook.presto.spi.SchemaTablePrefix; +import com.facebook.presto.spi.TableNotFoundException; +import com.facebook.presto.spi.connector.ConnectorMetadata; +import com.facebook.presto.spi.type.TypeManager; +import com.google.common.cache.CacheBuilder; +import com.google.common.cache.CacheLoader; +import com.google.common.cache.LoadingCache; +import com.google.common.collect.ImmutableList; +import io.airlift.units.Duration; + +import javax.annotation.Nonnull; +import javax.inject.Inject; + +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Optional; +import java.util.Set; +import java.util.concurrent.Executor; + +import static com.facebook.presto.connector.thrift.ThriftErrorCode.THRIFT_SERVICE_INVALID_RESPONSE; +import static com.facebook.presto.connector.thrift.api.PrestoThriftSchemaTableName.fromSchemaTableName; +import static com.google.common.cache.CacheLoader.asyncReloading; +import static com.google.common.collect.ImmutableList.toImmutableList; +import static com.google.common.collect.ImmutableMap.toImmutableMap; +import static java.util.Objects.requireNonNull; +import static java.util.concurrent.TimeUnit.MILLISECONDS; +import static java.util.concurrent.TimeUnit.MINUTES; +import static java.util.function.Function.identity; + +public class ThriftMetadata + implements ConnectorMetadata +{ + private static final Duration EXPIRE_AFTER_WRITE = new Duration(10, MINUTES); + private static final Duration REFRESH_AFTER_WRITE = new Duration(2, MINUTES); + + private final PrestoThriftServiceProvider clientProvider; + private final TypeManager typeManager; + private final LoadingCache> tableCache; + + @Inject + public ThriftMetadata( + PrestoThriftServiceProvider clientProvider, + TypeManager typeManager, + @ForMetadataRefresh Executor metadataRefreshExecutor) + { + this.clientProvider = requireNonNull(clientProvider, "clientProvider is null"); + this.typeManager = requireNonNull(typeManager, "typeManager is null"); + this.tableCache = CacheBuilder.newBuilder() + .expireAfterWrite(EXPIRE_AFTER_WRITE.toMillis(), MILLISECONDS) + .refreshAfterWrite(REFRESH_AFTER_WRITE.toMillis(), MILLISECONDS) + .build(asyncReloading(new CacheLoader>() + { + @Override + public Optional load(@Nonnull SchemaTableName schemaTableName) + throws Exception + { + return getTableMetadataInternal(schemaTableName); + } + }, metadataRefreshExecutor)); + } + + @Override + public List listSchemaNames(ConnectorSession session) + { + return clientProvider.runOnAnyHost(PrestoThriftService::listSchemaNames); + } + + @Override + public ConnectorTableHandle getTableHandle(ConnectorSession session, SchemaTableName tableName) + { + return tableCache.getUnchecked(tableName) + .map(ConnectorTableMetadata::getTable) + .map(ThriftTableHandle::new) + .orElse(null); + } + + @Override + public List getTableLayouts( + ConnectorSession session, + ConnectorTableHandle table, + Constraint constraint, + Optional> desiredColumns) + { + ThriftTableHandle tableHandle = (ThriftTableHandle) table; + ThriftTableLayoutHandle layoutHandle = new ThriftTableLayoutHandle( + tableHandle.getSchemaName(), + tableHandle.getTableName(), + desiredColumns, + constraint.getSummary()); + return ImmutableList.of(new ConnectorTableLayoutResult(new ConnectorTableLayout(layoutHandle), constraint.getSummary())); + } + + @Override + public ConnectorTableLayout getTableLayout(ConnectorSession session, ConnectorTableLayoutHandle handle) + { + return new ConnectorTableLayout(handle); + } + + @Override + public ConnectorTableMetadata getTableMetadata(ConnectorSession session, ConnectorTableHandle tableHandle) + { + ThriftTableHandle handle = ((ThriftTableHandle) tableHandle); + return getTableMetadata(new SchemaTableName(handle.getSchemaName(), handle.getTableName())); + } + + @Override + public List listTables(ConnectorSession session, String schemaNameOrNull) + { + return clientProvider.runOnAnyHost(client -> client.listTables(new PrestoThriftNullableSchemaName(schemaNameOrNull))).stream() + .map(PrestoThriftSchemaTableName::toSchemaTableName) + .collect(toImmutableList()); + } + + @Override + public Map getColumnHandles(ConnectorSession session, ConnectorTableHandle tableHandle) + { + return getTableMetadata(session, tableHandle).getColumns().stream().collect(toImmutableMap(ColumnMetadata::getName, ThriftColumnHandle::new)); + } + + @Override + public ColumnMetadata getColumnMetadata(ConnectorSession session, ConnectorTableHandle tableHandle, ColumnHandle columnHandle) + { + return ((ThriftColumnHandle) columnHandle).toColumnMetadata(); + } + + @Override + public Map> listTableColumns(ConnectorSession session, SchemaTablePrefix prefix) + { + return listTables(session, prefix.getSchemaName()).stream().collect(toImmutableMap(identity(), schemaTableName -> getTableMetadata(schemaTableName).getColumns())); + } + + private ConnectorTableMetadata getTableMetadata(SchemaTableName schemaTableName) + { + Optional table = tableCache.getUnchecked(schemaTableName); + if (!table.isPresent()) { + throw new TableNotFoundException(schemaTableName); + } + else { + return table.get(); + } + } + + // this method makes actual thrift request and should be called only by cache load method + private Optional getTableMetadataInternal(SchemaTableName schemaTableName) + { + requireNonNull(schemaTableName, "schemaTableName is null"); + return clientProvider.runOnAnyHost(client -> { + PrestoThriftNullableTableMetadata thriftTableMetadata = client.getTableMetadata(fromSchemaTableName(schemaTableName)); + if (thriftTableMetadata.getTableMetadata() == null) { + return Optional.empty(); + } + else { + ConnectorTableMetadata tableMetadata = thriftTableMetadata.getTableMetadata().toConnectorTableMetadata(typeManager); + if (!Objects.equals(schemaTableName, tableMetadata.getTable())) { + throw new PrestoException(THRIFT_SERVICE_INVALID_RESPONSE, "Requested and actual table names are different"); + } + return Optional.of(tableMetadata); + } + }); + } +} diff --git a/presto-thrift-connector/src/main/java/com/facebook/presto/connector/thrift/ThriftModule.java b/presto-thrift-connector/src/main/java/com/facebook/presto/connector/thrift/ThriftModule.java new file mode 100644 index 000000000000..c87dc8a7c54d --- /dev/null +++ b/presto-thrift-connector/src/main/java/com/facebook/presto/connector/thrift/ThriftModule.java @@ -0,0 +1,60 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.connector.thrift; + +import com.facebook.presto.connector.thrift.annotations.ForMetadataRefresh; +import com.facebook.presto.connector.thrift.annotations.NonRetrying; +import com.facebook.presto.connector.thrift.api.PrestoThriftService; +import com.facebook.presto.connector.thrift.clientproviders.DefaultPrestoThriftServiceProvider; +import com.facebook.presto.connector.thrift.clientproviders.PrestoThriftServiceProvider; +import com.facebook.presto.connector.thrift.clientproviders.RetryingPrestoThriftServiceProvider; +import com.google.inject.Binder; +import com.google.inject.Module; +import com.google.inject.Provides; +import com.google.inject.Scopes; + +import javax.inject.Singleton; + +import java.util.concurrent.Executor; + +import static com.facebook.swift.service.guice.ThriftClientBinder.thriftClientBinder; +import static io.airlift.concurrent.Threads.daemonThreadsNamed; +import static io.airlift.configuration.ConfigBinder.configBinder; +import static java.util.concurrent.Executors.newFixedThreadPool; + +public class ThriftModule + implements Module +{ + @Override + public void configure(Binder binder) + { + binder.bind(ThriftConnector.class).in(Scopes.SINGLETON); + thriftClientBinder(binder).bindThriftClient(PrestoThriftService.class); + binder.bind(ThriftMetadata.class).in(Scopes.SINGLETON); + binder.bind(ThriftSplitManager.class).in(Scopes.SINGLETON); + binder.bind(ThriftPageSourceProvider.class).in(Scopes.SINGLETON); + binder.bind(PrestoThriftServiceProvider.class).to(RetryingPrestoThriftServiceProvider.class).in(Scopes.SINGLETON); + binder.bind(PrestoThriftServiceProvider.class).annotatedWith(NonRetrying.class).to(DefaultPrestoThriftServiceProvider.class).in(Scopes.SINGLETON); + configBinder(binder).bindConfig(ThriftConnectorConfig.class); + binder.bind(ThriftSessionProperties.class).in(Scopes.SINGLETON); + } + + @Provides + @Singleton + @ForMetadataRefresh + public Executor createMetadataRefreshExecutor(ThriftConnectorConfig config) + { + return newFixedThreadPool(config.getMetadataRefreshThreads(), daemonThreadsNamed("metadata-refresh-%s")); + } +} diff --git a/presto-thrift-connector/src/main/java/com/facebook/presto/connector/thrift/ThriftPageSource.java b/presto-thrift-connector/src/main/java/com/facebook/presto/connector/thrift/ThriftPageSource.java new file mode 100644 index 000000000000..4638644f2a69 --- /dev/null +++ b/presto-thrift-connector/src/main/java/com/facebook/presto/connector/thrift/ThriftPageSource.java @@ -0,0 +1,196 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.connector.thrift; + +import com.facebook.presto.connector.thrift.api.PrestoThriftId; +import com.facebook.presto.connector.thrift.api.PrestoThriftNullableToken; +import com.facebook.presto.connector.thrift.api.PrestoThriftPageResult; +import com.facebook.presto.connector.thrift.api.PrestoThriftService; +import com.facebook.presto.connector.thrift.clientproviders.PrestoThriftServiceProvider; +import com.facebook.presto.spi.ColumnHandle; +import com.facebook.presto.spi.ConnectorPageSource; +import com.facebook.presto.spi.Page; +import com.facebook.presto.spi.type.Type; +import com.google.common.collect.ImmutableList; +import com.google.common.util.concurrent.ListenableFuture; + +import java.io.IOException; +import java.util.List; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.atomic.AtomicLong; + +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.util.concurrent.Futures.nonCancellationPropagating; +import static com.google.common.util.concurrent.MoreExecutors.directExecutor; +import static io.airlift.concurrent.MoreFutures.getFutureValue; +import static io.airlift.concurrent.MoreFutures.toCompletableFuture; +import static java.util.Objects.requireNonNull; + +public class ThriftPageSource + implements ConnectorPageSource +{ + private final PrestoThriftId splitId; + private final PrestoThriftService client; + private final List columnNames; + private final List columnTypes; + private final long maxBytesPerResponse; + private final AtomicLong readTimeNanos = new AtomicLong(0); + + private PrestoThriftId nextToken; + private boolean firstCall = true; + private CompletableFuture future; + private long completedBytes; + + public ThriftPageSource( + PrestoThriftServiceProvider clientProvider, + ThriftConnectorSplit split, + List columns, + long maxBytesPerResponse) + { + // init columns + requireNonNull(columns, "columns is null"); + ImmutableList.Builder columnNames = new ImmutableList.Builder<>(); + ImmutableList.Builder columnTypes = new ImmutableList.Builder<>(); + for (ColumnHandle columnHandle : columns) { + ThriftColumnHandle thriftColumnHandle = (ThriftColumnHandle) columnHandle; + columnNames.add(thriftColumnHandle.getColumnName()); + columnTypes.add(thriftColumnHandle.getColumnType()); + } + this.columnNames = columnNames.build(); + this.columnTypes = columnTypes.build(); + + // this parameter is read from config, so it should be checked by config validation + // however, here it's a raw constructor parameter, so adding this safety check + checkArgument(maxBytesPerResponse > 0, "maxBytesPerResponse is zero or negative"); + this.maxBytesPerResponse = maxBytesPerResponse; + + // init split + requireNonNull(split, "split is null"); + this.splitId = split.getSplitId(); + + // init client + requireNonNull(clientProvider, "clientProvider is null"); + if (split.getAddresses().isEmpty()) { + this.client = clientProvider.anyHostClient(); + } + else { + this.client = clientProvider.selectedHostClient(split.getAddresses()); + } + } + + @Override + public long getTotalBytes() + { + return 0; + } + + @Override + public long getCompletedBytes() + { + return completedBytes; + } + + @Override + public long getReadTimeNanos() + { + return readTimeNanos.get(); + } + + @Override + public long getSystemMemoryUsage() + { + return 0; + } + + @Override + public boolean isFinished() + { + return !firstCall && !canGetMoreData(nextToken); + } + + @Override + public Page getNextPage() + { + if (future == null) { + // no data request in progress + if (firstCall || canGetMoreData(nextToken)) { + // no data in the current batch, but can request more; will send a request + future = sendDataRequestInternal(); + } + return null; + } + + if (!future.isDone()) { + // data request is in progress + return null; + } + + // response for data request is ready + Page result = processBatch(getFutureValue(future)); + + // immediately try sending a new request + if (canGetMoreData(nextToken)) { + future = sendDataRequestInternal(); + } + else { + future = null; + } + + return result; + } + + private static boolean canGetMoreData(PrestoThriftId nextToken) + { + return nextToken != null; + } + + private CompletableFuture sendDataRequestInternal() + { + long start = System.nanoTime(); + ListenableFuture rowsBatchFuture = client.getRows( + splitId, + columnNames, + maxBytesPerResponse, + new PrestoThriftNullableToken(nextToken)); + rowsBatchFuture.addListener(() -> readTimeNanos.addAndGet(System.nanoTime() - start), directExecutor()); + return toCompletableFuture(nonCancellationPropagating(rowsBatchFuture)); + } + + private Page processBatch(PrestoThriftPageResult rowsBatch) + { + firstCall = false; + nextToken = rowsBatch.getNextToken(); + Page page = rowsBatch.toPage(columnTypes); + if (page != null) { + completedBytes += page.getSizeInBytes(); + } + return page; + } + + @Override + public CompletableFuture isBlocked() + { + return future == null ? NOT_BLOCKED : future; + } + + @Override + public void close() + throws IOException + { + if (future != null) { + future.cancel(true); + } + client.close(); + } +} diff --git a/presto-thrift-connector/src/main/java/com/facebook/presto/connector/thrift/ThriftPageSourceProvider.java b/presto-thrift-connector/src/main/java/com/facebook/presto/connector/thrift/ThriftPageSourceProvider.java new file mode 100644 index 000000000000..4a63442cc58e --- /dev/null +++ b/presto-thrift-connector/src/main/java/com/facebook/presto/connector/thrift/ThriftPageSourceProvider.java @@ -0,0 +1,52 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.connector.thrift; + +import com.facebook.presto.connector.thrift.clientproviders.PrestoThriftServiceProvider; +import com.facebook.presto.spi.ColumnHandle; +import com.facebook.presto.spi.ConnectorPageSource; +import com.facebook.presto.spi.ConnectorSession; +import com.facebook.presto.spi.ConnectorSplit; +import com.facebook.presto.spi.connector.ConnectorPageSourceProvider; +import com.facebook.presto.spi.connector.ConnectorTransactionHandle; + +import javax.inject.Inject; + +import java.util.List; + +import static java.util.Objects.requireNonNull; + +public class ThriftPageSourceProvider + implements ConnectorPageSourceProvider +{ + private final PrestoThriftServiceProvider clientProvider; + private final long maxBytesPerResponse; + + @Inject + public ThriftPageSourceProvider(PrestoThriftServiceProvider clientProvider, ThriftConnectorConfig config) + { + this.clientProvider = requireNonNull(clientProvider, "clientProvider is null"); + this.maxBytesPerResponse = requireNonNull(config, "config is null").getMaxResponseSize().toBytes(); + } + + @Override + public ConnectorPageSource createPageSource( + ConnectorTransactionHandle transactionHandle, + ConnectorSession session, + ConnectorSplit split, + List columns) + { + return new ThriftPageSource(clientProvider, (ThriftConnectorSplit) split, columns, maxBytesPerResponse); + } +} diff --git a/presto-thrift-connector/src/main/java/com/facebook/presto/connector/thrift/ThriftPlugin.java b/presto-thrift-connector/src/main/java/com/facebook/presto/connector/thrift/ThriftPlugin.java new file mode 100644 index 000000000000..529a912dd9ab --- /dev/null +++ b/presto-thrift-connector/src/main/java/com/facebook/presto/connector/thrift/ThriftPlugin.java @@ -0,0 +1,65 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.connector.thrift; + +import com.facebook.presto.spi.Plugin; +import com.facebook.presto.spi.connector.ConnectorFactory; +import com.google.common.collect.ImmutableList; +import com.google.inject.Module; + +import java.util.List; +import java.util.ServiceLoader; + +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Strings.isNullOrEmpty; +import static com.google.common.collect.Iterables.getOnlyElement; +import static java.util.Objects.requireNonNull; + +public class ThriftPlugin + implements Plugin +{ + private final String name; + private final Module locationModule; + + public ThriftPlugin() + { + this(getPluginInfo()); + } + + private ThriftPlugin(ThriftPluginInfo info) + { + this(info.getName(), info.getLocationModule()); + } + + public ThriftPlugin(String name, Module locationModule) + { + checkArgument(!isNullOrEmpty(name), "name is null or empty"); + this.name = name; + this.locationModule = requireNonNull(locationModule, "locationModule is null"); + } + + @Override + public Iterable getConnectorFactories() + { + return ImmutableList.of(new ThriftConnectorFactory(name, locationModule)); + } + + private static ThriftPluginInfo getPluginInfo() + { + ClassLoader classLoader = ThriftPlugin.class.getClassLoader(); + ServiceLoader loader = ServiceLoader.load(ThriftPluginInfo.class, classLoader); + List list = ImmutableList.copyOf(loader); + return list.isEmpty() ? new ThriftPluginInfo() : getOnlyElement(list); + } +} diff --git a/presto-thrift-connector/src/main/java/com/facebook/presto/connector/thrift/ThriftPluginInfo.java b/presto-thrift-connector/src/main/java/com/facebook/presto/connector/thrift/ThriftPluginInfo.java new file mode 100644 index 000000000000..e8e193661c31 --- /dev/null +++ b/presto-thrift-connector/src/main/java/com/facebook/presto/connector/thrift/ThriftPluginInfo.java @@ -0,0 +1,30 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.connector.thrift; + +import com.facebook.presto.connector.thrift.location.StaticLocationModule; +import com.google.inject.Module; + +public class ThriftPluginInfo +{ + public String getName() + { + return "presto-thrift"; + } + + public Module getLocationModule() + { + return new StaticLocationModule(); + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/util/QueryExplanation.java b/presto-thrift-connector/src/main/java/com/facebook/presto/connector/thrift/ThriftSessionProperties.java similarity index 52% rename from presto-main/src/main/java/com/facebook/presto/util/QueryExplanation.java rename to presto-thrift-connector/src/main/java/com/facebook/presto/connector/thrift/ThriftSessionProperties.java index b1d6b11cd919..f3c036b01bd2 100644 --- a/presto-main/src/main/java/com/facebook/presto/util/QueryExplanation.java +++ b/presto-thrift-connector/src/main/java/com/facebook/presto/connector/thrift/ThriftSessionProperties.java @@ -11,31 +11,31 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.facebook.presto.util; +package com.facebook.presto.connector.thrift; -import com.facebook.presto.execution.Input; -import com.fasterxml.jackson.annotation.JsonProperty; +import com.facebook.presto.spi.session.PropertyMetadata; import com.google.common.collect.ImmutableList; -import javax.annotation.concurrent.Immutable; +import javax.inject.Inject; import java.util.List; -import static java.util.Objects.requireNonNull; - -@Immutable -public class QueryExplanation +/** + * Internal session properties are those defined by the connector itself. + * These properties control certain aspects of connector's work. + */ +public final class ThriftSessionProperties { - private final List inputs; + private final List> sessionProperties; - public QueryExplanation(List inputs) + @Inject + public ThriftSessionProperties(ThriftConnectorConfig config) { - this.inputs = ImmutableList.copyOf(requireNonNull(inputs, "inputs is null")); + sessionProperties = ImmutableList.of(); } - @JsonProperty - public List getInputs() + public List> getSessionProperties() { - return inputs; + return sessionProperties; } } diff --git a/presto-thrift-connector/src/main/java/com/facebook/presto/connector/thrift/ThriftSplitManager.java b/presto-thrift-connector/src/main/java/com/facebook/presto/connector/thrift/ThriftSplitManager.java new file mode 100644 index 000000000000..a36f33de8e47 --- /dev/null +++ b/presto-thrift-connector/src/main/java/com/facebook/presto/connector/thrift/ThriftSplitManager.java @@ -0,0 +1,192 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.connector.thrift; + +import com.facebook.presto.connector.thrift.api.PrestoThriftDomain; +import com.facebook.presto.connector.thrift.api.PrestoThriftHostAddress; +import com.facebook.presto.connector.thrift.api.PrestoThriftId; +import com.facebook.presto.connector.thrift.api.PrestoThriftNullableColumnSet; +import com.facebook.presto.connector.thrift.api.PrestoThriftNullableToken; +import com.facebook.presto.connector.thrift.api.PrestoThriftSchemaTableName; +import com.facebook.presto.connector.thrift.api.PrestoThriftService; +import com.facebook.presto.connector.thrift.api.PrestoThriftSplit; +import com.facebook.presto.connector.thrift.api.PrestoThriftSplitBatch; +import com.facebook.presto.connector.thrift.api.PrestoThriftTupleDomain; +import com.facebook.presto.connector.thrift.clientproviders.PrestoThriftServiceProvider; +import com.facebook.presto.spi.ColumnHandle; +import com.facebook.presto.spi.ConnectorSession; +import com.facebook.presto.spi.ConnectorSplit; +import com.facebook.presto.spi.ConnectorSplitSource; +import com.facebook.presto.spi.ConnectorTableLayoutHandle; +import com.facebook.presto.spi.HostAddress; +import com.facebook.presto.spi.connector.ConnectorSplitManager; +import com.facebook.presto.spi.connector.ConnectorTransactionHandle; +import com.facebook.presto.spi.predicate.TupleDomain; +import com.google.common.util.concurrent.Futures; +import com.google.common.util.concurrent.ListenableFuture; + +import javax.annotation.concurrent.NotThreadSafe; +import javax.inject.Inject; + +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.Future; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicReference; + +import static com.facebook.presto.connector.thrift.api.PrestoThriftDomain.fromDomain; +import static com.google.common.base.Preconditions.checkState; +import static com.google.common.collect.ImmutableList.toImmutableList; +import static com.google.common.collect.ImmutableMap.toImmutableMap; +import static com.google.common.collect.ImmutableSet.toImmutableSet; +import static io.airlift.concurrent.MoreFutures.toCompletableFuture; +import static java.util.Objects.requireNonNull; + +public class ThriftSplitManager + implements ConnectorSplitManager +{ + private final PrestoThriftServiceProvider clientProvider; + + @Inject + public ThriftSplitManager(PrestoThriftServiceProvider clientProvider) + { + this.clientProvider = requireNonNull(clientProvider, "clientProvider is null"); + } + + @Override + public ConnectorSplitSource getSplits(ConnectorTransactionHandle transactionHandle, ConnectorSession session, ConnectorTableLayoutHandle layout) + { + ThriftTableLayoutHandle layoutHandle = (ThriftTableLayoutHandle) layout; + return new ThriftSplitSource( + clientProvider.anyHostClient(), + new PrestoThriftSchemaTableName(layoutHandle.getSchemaName(), layoutHandle.getTableName()), + layoutHandle.getColumns().map(ThriftSplitManager::columnNames), + tupleDomainToThriftTupleDomain(layoutHandle.getConstraint())); + } + + private static Set columnNames(Set columns) + { + return columns.stream() + .map(ThriftColumnHandle.class::cast) + .map(ThriftColumnHandle::getColumnName) + .collect(toImmutableSet()); + } + + private static PrestoThriftTupleDomain tupleDomainToThriftTupleDomain(TupleDomain tupleDomain) + { + if (!tupleDomain.getDomains().isPresent()) { + return new PrestoThriftTupleDomain(null); + } + Map thriftDomains = tupleDomain.getDomains().get() + .entrySet().stream() + .collect(toImmutableMap( + entry -> ((ThriftColumnHandle) entry.getKey()).getColumnName(), + entry -> fromDomain(entry.getValue()))); + return new PrestoThriftTupleDomain(thriftDomains); + } + + @NotThreadSafe + private static class ThriftSplitSource + implements ConnectorSplitSource + { + private final PrestoThriftService client; + private final PrestoThriftSchemaTableName schemaTableName; + private final Optional> columnNames; + private final PrestoThriftTupleDomain constraint; + + // the code assumes getNextBatch is called by a single thread + + private final AtomicBoolean hasMoreData; + private final AtomicReference nextToken; + private final AtomicReference> future; + + public ThriftSplitSource( + PrestoThriftService client, + PrestoThriftSchemaTableName schemaTableName, + Optional> columnNames, + PrestoThriftTupleDomain constraint) + { + this.client = requireNonNull(client, "client is null"); + this.schemaTableName = requireNonNull(schemaTableName, "schemaTableName is null"); + this.columnNames = requireNonNull(columnNames, "columnNames is null"); + this.constraint = requireNonNull(constraint, "constraint is null"); + this.nextToken = new AtomicReference<>(null); + this.hasMoreData = new AtomicBoolean(true); + this.future = new AtomicReference<>(null); + } + + /** + * Returns a future with a list of splits. + * This method is assumed to be called in a single-threaded way. + * It can be called by multiple threads, but only if the previous call finished. + */ + @Override + public CompletableFuture> getNextBatch(int maxSize) + { + checkState(future.get() == null || future.get().isDone(), "previous batch not completed"); + checkState(hasMoreData.get(), "this method cannot be invoked when there's no more data"); + PrestoThriftId currentToken = nextToken.get(); + ListenableFuture splitsFuture = client.getSplits( + schemaTableName, + new PrestoThriftNullableColumnSet(columnNames.orElse(null)), + constraint, + maxSize, + new PrestoThriftNullableToken(currentToken)); + ListenableFuture> resultFuture = Futures.transform( + splitsFuture, + batch -> { + requireNonNull(batch, "batch is null"); + List splits = batch.getSplits().stream() + .map(ThriftSplitSource::toConnectorSplit) + .collect(toImmutableList()); + checkState(nextToken.compareAndSet(currentToken, batch.getNextToken())); + checkState(hasMoreData.compareAndSet(true, nextToken.get() != null)); + return splits; + }); + future.set(resultFuture); + return toCompletableFuture(resultFuture); + } + + @Override + public boolean isFinished() + { + return !hasMoreData.get(); + } + + @Override + public void close() + { + Future currentFuture = future.getAndSet(null); + if (currentFuture != null) { + currentFuture.cancel(true); + } + client.close(); + } + + private static ThriftConnectorSplit toConnectorSplit(PrestoThriftSplit thriftSplit) + { + return new ThriftConnectorSplit( + thriftSplit.getSplitId(), + toHostAddressList(thriftSplit.getHosts())); + } + + private static List toHostAddressList(List hosts) + { + return hosts.stream().map(PrestoThriftHostAddress::toHostAddress).collect(toImmutableList()); + } + } +} diff --git a/presto-thrift-connector/src/main/java/com/facebook/presto/connector/thrift/ThriftTableHandle.java b/presto-thrift-connector/src/main/java/com/facebook/presto/connector/thrift/ThriftTableHandle.java new file mode 100644 index 000000000000..c47db0adff71 --- /dev/null +++ b/presto-thrift-connector/src/main/java/com/facebook/presto/connector/thrift/ThriftTableHandle.java @@ -0,0 +1,86 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.connector.thrift; + +import com.facebook.presto.spi.ConnectorTableHandle; +import com.facebook.presto.spi.SchemaTableName; +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; + +import java.util.Objects; + +import static com.google.common.base.MoreObjects.toStringHelper; +import static java.util.Objects.requireNonNull; + +public final class ThriftTableHandle + implements ConnectorTableHandle +{ + private final String schemaName; + private final String tableName; + + @JsonCreator + public ThriftTableHandle( + @JsonProperty("schemaName") String schemaName, + @JsonProperty("tableName") String tableName) + { + this.schemaName = requireNonNull(schemaName, "schemaName is null"); + this.tableName = requireNonNull(tableName, "tableName is null"); + } + + public ThriftTableHandle(SchemaTableName schemaTableName) + { + this(schemaTableName.getSchemaName(), schemaTableName.getTableName()); + } + + @JsonProperty + public String getSchemaName() + { + return schemaName; + } + + @JsonProperty + public String getTableName() + { + return tableName; + } + + @Override + public boolean equals(Object obj) + { + if (this == obj) { + return true; + } + if (obj == null || getClass() != obj.getClass()) { + return false; + } + ThriftTableHandle other = (ThriftTableHandle) obj; + return Objects.equals(this.schemaName, other.schemaName) && + Objects.equals(this.tableName, other.tableName); + } + + @Override + public int hashCode() + { + return Objects.hash(schemaName, tableName); + } + + @Override + public String toString() + { + return toStringHelper(this) + .add("schemaName", getSchemaName()) + .add("tableName", getTableName()) + .toString(); + } +} diff --git a/presto-thrift-connector/src/main/java/com/facebook/presto/connector/thrift/ThriftTableLayoutHandle.java b/presto-thrift-connector/src/main/java/com/facebook/presto/connector/thrift/ThriftTableLayoutHandle.java new file mode 100644 index 000000000000..a1f15667c53c --- /dev/null +++ b/presto-thrift-connector/src/main/java/com/facebook/presto/connector/thrift/ThriftTableLayoutHandle.java @@ -0,0 +1,107 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.connector.thrift; + +import com.facebook.presto.spi.ColumnHandle; +import com.facebook.presto.spi.ConnectorTableLayoutHandle; +import com.facebook.presto.spi.predicate.TupleDomain; +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.google.common.collect.ImmutableSet; + +import java.util.Objects; +import java.util.Optional; +import java.util.Set; + +import static com.google.common.base.MoreObjects.toStringHelper; +import static java.util.Objects.requireNonNull; + +public class ThriftTableLayoutHandle + implements ConnectorTableLayoutHandle +{ + private final String schemaName; + private final String tableName; + private final Optional> columns; + private final TupleDomain constraint; + + @JsonCreator + public ThriftTableLayoutHandle( + @JsonProperty("schemaName") String schemaName, + @JsonProperty("tableName") String tableName, + @JsonProperty("columns") Optional> columns, + @JsonProperty("constraint") TupleDomain constraint) + { + this.schemaName = requireNonNull(schemaName, "schemaName is null"); + this.tableName = requireNonNull(tableName, "tableName is null"); + this.columns = requireNonNull(columns, "columns is null").map(ImmutableSet::copyOf); + this.constraint = requireNonNull(constraint, "constraint is null"); + } + + @JsonProperty + public String getSchemaName() + { + return schemaName; + } + + @JsonProperty + public String getTableName() + { + return tableName; + } + + @JsonProperty + public Optional> getColumns() + { + return columns; + } + + @JsonProperty + public TupleDomain getConstraint() + { + return constraint; + } + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + ThriftTableLayoutHandle other = (ThriftTableLayoutHandle) o; + return schemaName.equals(other.schemaName) + && tableName.equals(other.tableName) + && columns.equals(other.columns) + && constraint.equals(other.constraint); + } + + @Override + public int hashCode() + { + return Objects.hash(schemaName, tableName, columns, constraint); + } + + @Override + public String toString() + { + return toStringHelper(this) + .add("schemaName", schemaName) + .add("tableName", tableName) + .add("columns", columns) + .add("constraint", constraint) + .toString(); + } +} diff --git a/presto-thrift-connector/src/main/java/com/facebook/presto/connector/thrift/ThriftTransactionHandle.java b/presto-thrift-connector/src/main/java/com/facebook/presto/connector/thrift/ThriftTransactionHandle.java new file mode 100644 index 000000000000..9ab3ef68d821 --- /dev/null +++ b/presto-thrift-connector/src/main/java/com/facebook/presto/connector/thrift/ThriftTransactionHandle.java @@ -0,0 +1,23 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.facebook.presto.connector.thrift; + +import com.facebook.presto.spi.connector.ConnectorTransactionHandle; + +public enum ThriftTransactionHandle + implements ConnectorTransactionHandle +{ + INSTANCE +} diff --git a/presto-thrift-connector/src/main/java/com/facebook/presto/connector/thrift/annotations/ForMetadataRefresh.java b/presto-thrift-connector/src/main/java/com/facebook/presto/connector/thrift/annotations/ForMetadataRefresh.java new file mode 100644 index 000000000000..b2fff95db23a --- /dev/null +++ b/presto-thrift-connector/src/main/java/com/facebook/presto/connector/thrift/annotations/ForMetadataRefresh.java @@ -0,0 +1,31 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.connector.thrift.annotations; + +import javax.inject.Qualifier; + +import java.lang.annotation.Retention; +import java.lang.annotation.Target; + +import static java.lang.annotation.ElementType.FIELD; +import static java.lang.annotation.ElementType.METHOD; +import static java.lang.annotation.ElementType.PARAMETER; +import static java.lang.annotation.RetentionPolicy.RUNTIME; + +@Retention(RUNTIME) +@Target({PARAMETER, METHOD, FIELD}) +@Qualifier +public @interface ForMetadataRefresh +{ +} diff --git a/presto-thrift-connector/src/main/java/com/facebook/presto/connector/thrift/annotations/NonRetrying.java b/presto-thrift-connector/src/main/java/com/facebook/presto/connector/thrift/annotations/NonRetrying.java new file mode 100644 index 000000000000..41c908eafccb --- /dev/null +++ b/presto-thrift-connector/src/main/java/com/facebook/presto/connector/thrift/annotations/NonRetrying.java @@ -0,0 +1,29 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.connector.thrift.annotations; + +import com.google.inject.BindingAnnotation; + +import java.lang.annotation.Retention; +import java.lang.annotation.Target; + +import static java.lang.annotation.ElementType.PARAMETER; +import static java.lang.annotation.RetentionPolicy.RUNTIME; + +@BindingAnnotation +@Target(PARAMETER) +@Retention(RUNTIME) +public @interface NonRetrying +{ +} diff --git a/presto-thrift-connector/src/main/java/com/facebook/presto/connector/thrift/clientproviders/DefaultPrestoThriftServiceProvider.java b/presto-thrift-connector/src/main/java/com/facebook/presto/connector/thrift/clientproviders/DefaultPrestoThriftServiceProvider.java new file mode 100644 index 000000000000..5170998530e6 --- /dev/null +++ b/presto-thrift-connector/src/main/java/com/facebook/presto/connector/thrift/clientproviders/DefaultPrestoThriftServiceProvider.java @@ -0,0 +1,76 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.connector.thrift.clientproviders; + +import com.facebook.nifty.client.FramedClientConnector; +import com.facebook.presto.connector.thrift.api.PrestoThriftService; +import com.facebook.presto.connector.thrift.location.HostLocationProvider; +import com.facebook.presto.spi.HostAddress; +import com.facebook.presto.spi.PrestoException; +import com.facebook.swift.service.ThriftClient; +import com.google.common.net.HostAndPort; +import io.airlift.units.Duration; + +import javax.inject.Inject; + +import java.util.List; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; + +import static com.facebook.presto.connector.thrift.ThriftErrorCode.THRIFT_SERVICE_CONNECTION_ERROR; +import static java.util.Objects.requireNonNull; + +public class DefaultPrestoThriftServiceProvider + implements PrestoThriftServiceProvider +{ + private final ThriftClient thriftClient; + private final HostLocationProvider locationProvider; + private final long thriftConnectTimeoutMs; + + @Inject + public DefaultPrestoThriftServiceProvider(ThriftClient thriftClient, HostLocationProvider locationProvider) + { + this.thriftClient = requireNonNull(thriftClient, "thriftClient is null"); + this.locationProvider = requireNonNull(locationProvider, "locationProvider is null"); + this.thriftConnectTimeoutMs = Duration.valueOf(thriftClient.getConnectTimeout()).toMillis(); + } + + @Override + public PrestoThriftService anyHostClient() + { + return connectTo(locationProvider.getAnyHost()); + } + + @Override + public PrestoThriftService selectedHostClient(List hosts) + { + return connectTo(locationProvider.getAnyOf(hosts)); + } + + private PrestoThriftService connectTo(HostAddress host) + { + try { + return thriftClient.open(new FramedClientConnector(HostAndPort.fromParts(host.getHostText(), host.getPort()))) + .get(thriftConnectTimeoutMs, TimeUnit.MILLISECONDS); + } + catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new RuntimeException("Interrupted while connecting to thrift host at " + host, e); + } + catch (ExecutionException | TimeoutException e) { + throw new PrestoException(THRIFT_SERVICE_CONNECTION_ERROR, "Cannot connect to thrift host at " + host, e); + } + } +} diff --git a/presto-thrift-connector/src/main/java/com/facebook/presto/connector/thrift/clientproviders/PrestoThriftServiceProvider.java b/presto-thrift-connector/src/main/java/com/facebook/presto/connector/thrift/clientproviders/PrestoThriftServiceProvider.java new file mode 100644 index 000000000000..c6e0c629063f --- /dev/null +++ b/presto-thrift-connector/src/main/java/com/facebook/presto/connector/thrift/clientproviders/PrestoThriftServiceProvider.java @@ -0,0 +1,34 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.connector.thrift.clientproviders; + +import com.facebook.presto.connector.thrift.api.PrestoThriftService; +import com.facebook.presto.spi.HostAddress; + +import java.util.List; +import java.util.function.Function; + +public interface PrestoThriftServiceProvider +{ + PrestoThriftService anyHostClient(); + + PrestoThriftService selectedHostClient(List hosts); + + default V runOnAnyHost(Function call) + { + try (PrestoThriftService client = anyHostClient()) { + return call.apply(client); + } + } +} diff --git a/presto-thrift-connector/src/main/java/com/facebook/presto/connector/thrift/clientproviders/RetryingPrestoThriftServiceProvider.java b/presto-thrift-connector/src/main/java/com/facebook/presto/connector/thrift/clientproviders/RetryingPrestoThriftServiceProvider.java new file mode 100644 index 000000000000..cfc1ca628055 --- /dev/null +++ b/presto-thrift-connector/src/main/java/com/facebook/presto/connector/thrift/clientproviders/RetryingPrestoThriftServiceProvider.java @@ -0,0 +1,150 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.connector.thrift.clientproviders; + +import com.facebook.presto.connector.thrift.annotations.NonRetrying; +import com.facebook.presto.connector.thrift.api.PrestoThriftId; +import com.facebook.presto.connector.thrift.api.PrestoThriftNullableColumnSet; +import com.facebook.presto.connector.thrift.api.PrestoThriftNullableSchemaName; +import com.facebook.presto.connector.thrift.api.PrestoThriftNullableTableMetadata; +import com.facebook.presto.connector.thrift.api.PrestoThriftNullableToken; +import com.facebook.presto.connector.thrift.api.PrestoThriftPageResult; +import com.facebook.presto.connector.thrift.api.PrestoThriftSchemaTableName; +import com.facebook.presto.connector.thrift.api.PrestoThriftService; +import com.facebook.presto.connector.thrift.api.PrestoThriftServiceException; +import com.facebook.presto.connector.thrift.api.PrestoThriftSplitBatch; +import com.facebook.presto.connector.thrift.api.PrestoThriftTupleDomain; +import com.facebook.presto.connector.thrift.util.RetryDriver; +import com.facebook.presto.spi.HostAddress; +import com.google.common.util.concurrent.ListenableFuture; +import io.airlift.log.Logger; +import io.airlift.units.Duration; + +import javax.annotation.concurrent.NotThreadSafe; +import javax.inject.Inject; + +import java.util.List; +import java.util.concurrent.TimeUnit; +import java.util.function.Supplier; + +import static java.util.Objects.requireNonNull; + +public class RetryingPrestoThriftServiceProvider + implements PrestoThriftServiceProvider +{ + private static final Logger log = Logger.get(RetryingPrestoThriftServiceProvider.class); + private final PrestoThriftServiceProvider original; + private final RetryDriver retry; + + @Inject + public RetryingPrestoThriftServiceProvider(@NonRetrying PrestoThriftServiceProvider original) + { + this.original = requireNonNull(original, "original is null"); + retry = RetryDriver.retry() + .maxAttempts(5) + .stopRetryingWhen(e -> e instanceof PrestoThriftServiceException && !((PrestoThriftServiceException) e).isRetryable()) + .exponentialBackoff( + new Duration(10, TimeUnit.MILLISECONDS), + new Duration(20, TimeUnit.MILLISECONDS), + new Duration(30, TimeUnit.SECONDS), + 1.5); + } + + @Override + public PrestoThriftService anyHostClient() + { + return new RetryingService(original::anyHostClient, retry); + } + + @Override + public PrestoThriftService selectedHostClient(List hosts) + { + return new RetryingService(() -> original.selectedHostClient(hosts), retry); + } + + @NotThreadSafe + private static final class RetryingService + implements PrestoThriftService + { + private final Supplier clientSupplier; + private final RetryDriver retry; + private PrestoThriftService client; + + public RetryingService(Supplier clientSupplier, RetryDriver retry) + { + this.clientSupplier = requireNonNull(clientSupplier, "clientSupplier is null"); + this.retry = retry.onRetry(this::close); + } + + private PrestoThriftService getClient() + { + if (client != null) { + return client; + } + client = clientSupplier.get(); + return client; + } + + @Override + public List listSchemaNames() + { + return retry.run("listSchemaNames", () -> getClient().listSchemaNames()); + } + + @Override + public List listTables(PrestoThriftNullableSchemaName schemaNameOrNull) + { + return retry.run("listTables", () -> getClient().listTables(schemaNameOrNull)); + } + + @Override + public PrestoThriftNullableTableMetadata getTableMetadata(PrestoThriftSchemaTableName schemaTableName) + { + return retry.run("getTableMetadata", () -> getClient().getTableMetadata(schemaTableName)); + } + + @Override + public ListenableFuture getSplits( + PrestoThriftSchemaTableName schemaTableName, + PrestoThriftNullableColumnSet desiredColumns, + PrestoThriftTupleDomain outputConstraint, + int maxSplitCount, + PrestoThriftNullableToken nextToken) + throws PrestoThriftServiceException + { + return retry.run("getSplits", () -> getClient().getSplits(schemaTableName, desiredColumns, outputConstraint, maxSplitCount, nextToken)); + } + + @Override + public ListenableFuture getRows(PrestoThriftId splitId, List columns, long maxBytes, PrestoThriftNullableToken nextToken) + { + return retry.run("getRows", () -> getClient().getRows(splitId, columns, maxBytes, nextToken)); + } + + @Override + public void close() + { + if (client == null) { + return; + } + try { + client.close(); + } + catch (Exception e) { + log.warn("Error closing client", e); + } + client = null; + } + } +} diff --git a/presto-thrift-connector/src/main/java/com/facebook/presto/connector/thrift/location/HostList.java b/presto-thrift-connector/src/main/java/com/facebook/presto/connector/thrift/location/HostList.java new file mode 100644 index 000000000000..95c3d09f258c --- /dev/null +++ b/presto-thrift-connector/src/main/java/com/facebook/presto/connector/thrift/location/HostList.java @@ -0,0 +1,90 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.connector.thrift.location; + +import com.facebook.presto.spi.HostAddress; +import com.google.common.base.Joiner; +import com.google.common.base.Splitter; +import com.google.common.collect.ImmutableList; + +import java.util.List; + +import static com.google.common.base.MoreObjects.toStringHelper; +import static com.google.common.collect.ImmutableList.toImmutableList; +import static java.util.Arrays.asList; +import static java.util.Objects.requireNonNull; + +public final class HostList +{ + private final List hosts; + + private HostList(List hosts) + { + this.hosts = ImmutableList.copyOf(requireNonNull(hosts, "hosts is null")); + } + + // needed for automatic config parsing + @SuppressWarnings("unused") + public static HostList fromString(String hosts) + { + return new HostList(Splitter.on(',').trimResults().omitEmptyStrings().splitToList(hosts).stream().map(HostAddress::fromString).collect(toImmutableList())); + } + + public static HostList of(HostAddress... hosts) + { + return new HostList(asList(hosts)); + } + + public static HostList fromList(List hosts) + { + return new HostList(hosts); + } + + public List getHosts() + { + return hosts; + } + + public String stringValue() + { + return Joiner.on(',').join(hosts); + } + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + HostList hostList = (HostList) o; + return hosts.equals(hostList.hosts); + } + + @Override + public int hashCode() + { + return hosts.hashCode(); + } + + @Override + public String toString() + { + return toStringHelper(this) + .add("hosts", hosts) + .toString(); + } +} diff --git a/presto-thrift-connector/src/main/java/com/facebook/presto/connector/thrift/location/HostLocationProvider.java b/presto-thrift-connector/src/main/java/com/facebook/presto/connector/thrift/location/HostLocationProvider.java new file mode 100644 index 000000000000..143d518d41e4 --- /dev/null +++ b/presto-thrift-connector/src/main/java/com/facebook/presto/connector/thrift/location/HostLocationProvider.java @@ -0,0 +1,25 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.connector.thrift.location; + +import com.facebook.presto.spi.HostAddress; + +import java.util.List; + +public interface HostLocationProvider +{ + HostAddress getAnyHost(); + + HostAddress getAnyOf(List hosts); +} diff --git a/presto-thrift-connector/src/main/java/com/facebook/presto/connector/thrift/location/StaticLocationConfig.java b/presto-thrift-connector/src/main/java/com/facebook/presto/connector/thrift/location/StaticLocationConfig.java new file mode 100644 index 000000000000..c252f17c7fa6 --- /dev/null +++ b/presto-thrift-connector/src/main/java/com/facebook/presto/connector/thrift/location/StaticLocationConfig.java @@ -0,0 +1,36 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.connector.thrift.location; + +import io.airlift.configuration.Config; + +import javax.validation.constraints.NotNull; + +public class StaticLocationConfig +{ + private HostList hosts; + + @NotNull + public HostList getHosts() + { + return hosts; + } + + @Config("static-location.hosts") + public StaticLocationConfig setHosts(HostList hosts) + { + this.hosts = hosts; + return this; + } +} diff --git a/presto-thrift-connector/src/main/java/com/facebook/presto/connector/thrift/location/StaticLocationModule.java b/presto-thrift-connector/src/main/java/com/facebook/presto/connector/thrift/location/StaticLocationModule.java new file mode 100644 index 000000000000..7b012fdf9b04 --- /dev/null +++ b/presto-thrift-connector/src/main/java/com/facebook/presto/connector/thrift/location/StaticLocationModule.java @@ -0,0 +1,31 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.connector.thrift.location; + +import com.google.inject.Binder; +import com.google.inject.Module; +import com.google.inject.Scopes; + +import static io.airlift.configuration.ConfigBinder.configBinder; + +public class StaticLocationModule + implements Module +{ + @Override + public void configure(Binder binder) + { + configBinder(binder).bindConfig(StaticLocationConfig.class); + binder.bind(HostLocationProvider.class).to(StaticLocationProvider.class).in(Scopes.SINGLETON); + } +} diff --git a/presto-thrift-connector/src/main/java/com/facebook/presto/connector/thrift/location/StaticLocationProvider.java b/presto-thrift-connector/src/main/java/com/facebook/presto/connector/thrift/location/StaticLocationProvider.java new file mode 100644 index 000000000000..034667189771 --- /dev/null +++ b/presto-thrift-connector/src/main/java/com/facebook/presto/connector/thrift/location/StaticLocationProvider.java @@ -0,0 +1,65 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.connector.thrift.location; + +import com.facebook.presto.spi.HostAddress; + +import javax.inject.Inject; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.concurrent.ThreadLocalRandom; +import java.util.concurrent.atomic.AtomicInteger; + +import static com.google.common.base.Preconditions.checkArgument; +import static java.util.Objects.requireNonNull; + +public class StaticLocationProvider + implements HostLocationProvider +{ + private final List hosts; + private final AtomicInteger index = new AtomicInteger(0); + + @Inject + public StaticLocationProvider(StaticLocationConfig config) + { + requireNonNull(config, "config is null"); + List hosts = config.getHosts().getHosts(); + checkArgument(!hosts.isEmpty(), "hosts is empty"); + this.hosts = new ArrayList<>(hosts); + Collections.shuffle(this.hosts); + } + + /** + * Provides the next host from a configured list of hosts in a round-robin fashion. + */ + @Override + public HostAddress getAnyHost() + { + return hosts.get(index.getAndUpdate(this::next)); + } + + @Override + public HostAddress getAnyOf(List requestedHosts) + { + checkArgument(requestedHosts != null && !requestedHosts.isEmpty(), "requestedHosts is null or empty"); + return requestedHosts.get(ThreadLocalRandom.current().nextInt(requestedHosts.size())); + } + + private int next(int x) + { + return (x + 1) % hosts.size(); + } +} diff --git a/presto-thrift-connector/src/main/java/com/facebook/presto/connector/thrift/util/RebindSafeMBeanServer.java b/presto-thrift-connector/src/main/java/com/facebook/presto/connector/thrift/util/RebindSafeMBeanServer.java new file mode 100644 index 000000000000..f8fec2de99f8 --- /dev/null +++ b/presto-thrift-connector/src/main/java/com/facebook/presto/connector/thrift/util/RebindSafeMBeanServer.java @@ -0,0 +1,335 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.connector.thrift.util; + +import io.airlift.log.Logger; + +import javax.annotation.concurrent.ThreadSafe; +import javax.management.Attribute; +import javax.management.AttributeList; +import javax.management.AttributeNotFoundException; +import javax.management.InstanceAlreadyExistsException; +import javax.management.InstanceNotFoundException; +import javax.management.IntrospectionException; +import javax.management.InvalidAttributeValueException; +import javax.management.ListenerNotFoundException; +import javax.management.MBeanException; +import javax.management.MBeanInfo; +import javax.management.MBeanRegistrationException; +import javax.management.MBeanServer; +import javax.management.NotCompliantMBeanException; +import javax.management.NotificationFilter; +import javax.management.NotificationListener; +import javax.management.ObjectInstance; +import javax.management.ObjectName; +import javax.management.OperationsException; +import javax.management.QueryExp; +import javax.management.ReflectionException; +import javax.management.loading.ClassLoaderRepository; + +import java.io.ObjectInputStream; +import java.util.Set; + +// TODO: move this to airlift or jmxutils + +/** + * MBeanServer wrapper that a ignores calls to registerMBean when there is already + * a MBean registered with the specified object name. + */ +@ThreadSafe +public class RebindSafeMBeanServer + implements MBeanServer +{ + private static final Logger log = Logger.get(RebindSafeMBeanServer.class); + + private final MBeanServer mbeanServer; + + public RebindSafeMBeanServer(MBeanServer mbeanServer) + { + this.mbeanServer = mbeanServer; + } + + /** + * Delegates to the wrapped mbean server, but if a mbean is already registered + * with the specified name, the existing instance is returned. + */ + @Override + public ObjectInstance registerMBean(Object object, ObjectName name) + throws MBeanRegistrationException, NotCompliantMBeanException + { + while (true) { + try { + // try to register the mbean + return mbeanServer.registerMBean(object, name); + } + catch (InstanceAlreadyExistsException ignored) { + } + + try { + // a mbean is already installed, try to return the already registered instance + ObjectInstance objectInstance = mbeanServer.getObjectInstance(name); + log.debug("%s already bound to %s", name, objectInstance); + return objectInstance; + } + catch (InstanceNotFoundException ignored) { + // the mbean was removed before we could get the reference + // start the whole process over again + } + } + } + + @Override + public void unregisterMBean(ObjectName name) + throws InstanceNotFoundException, MBeanRegistrationException + { + mbeanServer.unregisterMBean(name); + } + + @Override + public ObjectInstance getObjectInstance(ObjectName name) + throws InstanceNotFoundException + { + return mbeanServer.getObjectInstance(name); + } + + @Override + public Set queryMBeans(ObjectName name, QueryExp query) + { + return mbeanServer.queryMBeans(name, query); + } + + @Override + public Set queryNames(ObjectName name, QueryExp query) + { + return mbeanServer.queryNames(name, query); + } + + @Override + public boolean isRegistered(ObjectName name) + { + return mbeanServer.isRegistered(name); + } + + @Override + public Integer getMBeanCount() + { + return mbeanServer.getMBeanCount(); + } + + @Override + public Object getAttribute(ObjectName name, String attribute) + throws MBeanException, AttributeNotFoundException, InstanceNotFoundException, ReflectionException + { + return mbeanServer.getAttribute(name, attribute); + } + + @Override + public AttributeList getAttributes(ObjectName name, String[] attributes) + throws InstanceNotFoundException, ReflectionException + { + return mbeanServer.getAttributes(name, attributes); + } + + @Override + public void setAttribute(ObjectName name, Attribute attribute) + throws InstanceNotFoundException, AttributeNotFoundException, InvalidAttributeValueException, MBeanException, ReflectionException + { + mbeanServer.setAttribute(name, attribute); + } + + @Override + public AttributeList setAttributes(ObjectName name, AttributeList attributes) + throws InstanceNotFoundException, ReflectionException + { + return mbeanServer.setAttributes(name, attributes); + } + + @Override + public Object invoke(ObjectName name, String operationName, Object[] params, String[] signature) + throws InstanceNotFoundException, MBeanException, ReflectionException + { + return mbeanServer.invoke(name, operationName, params, signature); + } + + @Override + public String getDefaultDomain() + { + return mbeanServer.getDefaultDomain(); + } + + @Override + public String[] getDomains() + { + return mbeanServer.getDomains(); + } + + @Override + public void addNotificationListener(ObjectName name, NotificationListener listener, NotificationFilter filter, Object context) + throws InstanceNotFoundException + { + mbeanServer.addNotificationListener(name, listener, filter, context); + } + + @Override + public void addNotificationListener(ObjectName name, ObjectName listener, NotificationFilter filter, Object context) + throws InstanceNotFoundException + { + mbeanServer.addNotificationListener(name, listener, filter, context); + } + + @Override + public void removeNotificationListener(ObjectName name, ObjectName listener) + throws InstanceNotFoundException, ListenerNotFoundException + { + mbeanServer.removeNotificationListener(name, listener); + } + + @Override + public void removeNotificationListener(ObjectName name, ObjectName listener, NotificationFilter filter, Object context) + throws InstanceNotFoundException, ListenerNotFoundException + { + mbeanServer.removeNotificationListener(name, listener, filter, context); + } + + @Override + public void removeNotificationListener(ObjectName name, NotificationListener listener) + throws InstanceNotFoundException, ListenerNotFoundException + { + mbeanServer.removeNotificationListener(name, listener); + } + + @Override + public void removeNotificationListener(ObjectName name, NotificationListener listener, NotificationFilter filter, Object context) + throws InstanceNotFoundException, ListenerNotFoundException + { + mbeanServer.removeNotificationListener(name, listener, filter, context); + } + + @Override + public MBeanInfo getMBeanInfo(ObjectName name) + throws InstanceNotFoundException, IntrospectionException, ReflectionException + { + return mbeanServer.getMBeanInfo(name); + } + + @Override + public boolean isInstanceOf(ObjectName name, String className) + throws InstanceNotFoundException + { + return mbeanServer.isInstanceOf(name, className); + } + + @Override + public Object instantiate(String className) + throws ReflectionException, MBeanException + { + return mbeanServer.instantiate(className); + } + + @Override + public Object instantiate(String className, ObjectName loaderName) + throws ReflectionException, MBeanException, InstanceNotFoundException + { + return mbeanServer.instantiate(className, loaderName); + } + + @Override + public Object instantiate(String className, Object[] params, String[] signature) + throws ReflectionException, MBeanException + { + return mbeanServer.instantiate(className, params, signature); + } + + @Override + public Object instantiate(String className, ObjectName loaderName, Object[] params, String[] signature) + throws ReflectionException, MBeanException, InstanceNotFoundException + { + return mbeanServer.instantiate(className, loaderName, params, signature); + } + + @SuppressWarnings("deprecation") + @Override + @Deprecated + public ObjectInputStream deserialize(ObjectName name, byte[] data) + throws OperationsException + { + return mbeanServer.deserialize(name, data); + } + + @SuppressWarnings("deprecation") + @Override + @Deprecated + public ObjectInputStream deserialize(String className, byte[] data) + throws OperationsException, ReflectionException + { + return mbeanServer.deserialize(className, data); + } + + @SuppressWarnings("deprecation") + @Override + @Deprecated + public ObjectInputStream deserialize(String className, ObjectName loaderName, byte[] data) + throws OperationsException, ReflectionException + { + return mbeanServer.deserialize(className, loaderName, data); + } + + @Override + public ClassLoader getClassLoaderFor(ObjectName mbeanName) + throws InstanceNotFoundException + { + return mbeanServer.getClassLoaderFor(mbeanName); + } + + @Override + public ClassLoader getClassLoader(ObjectName loaderName) + throws InstanceNotFoundException + { + return mbeanServer.getClassLoader(loaderName); + } + + @Override + public ClassLoaderRepository getClassLoaderRepository() + { + return mbeanServer.getClassLoaderRepository(); + } + + @Override + public ObjectInstance createMBean(String className, ObjectName name) + throws ReflectionException, InstanceAlreadyExistsException, MBeanException, NotCompliantMBeanException + { + return mbeanServer.createMBean(className, name); + } + + @Override + public ObjectInstance createMBean(String className, ObjectName name, ObjectName loaderName) + throws ReflectionException, InstanceAlreadyExistsException, MBeanException, NotCompliantMBeanException, InstanceNotFoundException + { + return mbeanServer.createMBean(className, name, loaderName); + } + + @Override + public ObjectInstance createMBean(String className, ObjectName name, Object[] params, String[] signature) + throws ReflectionException, InstanceAlreadyExistsException, MBeanException, NotCompliantMBeanException + { + return mbeanServer.createMBean(className, name, params, signature); + } + + @Override + public ObjectInstance createMBean(String className, ObjectName name, ObjectName loaderName, Object[] params, String[] signature) + throws ReflectionException, InstanceAlreadyExistsException, MBeanException, NotCompliantMBeanException, InstanceNotFoundException + { + return mbeanServer.createMBean(className, name, loaderName, params, signature); + } +} diff --git a/presto-thrift-connector/src/main/java/com/facebook/presto/connector/thrift/util/RetryDriver.java b/presto-thrift-connector/src/main/java/com/facebook/presto/connector/thrift/util/RetryDriver.java new file mode 100644 index 000000000000..6559403dcfd1 --- /dev/null +++ b/presto-thrift-connector/src/main/java/com/facebook/presto/connector/thrift/util/RetryDriver.java @@ -0,0 +1,157 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.connector.thrift.util; + +import io.airlift.log.Logger; +import io.airlift.units.Duration; + +import java.util.Optional; +import java.util.concurrent.Callable; +import java.util.concurrent.ThreadLocalRandom; +import java.util.concurrent.TimeUnit; +import java.util.function.Function; +import java.util.function.Predicate; + +import static com.google.common.base.Throwables.throwIfUnchecked; +import static java.util.Objects.requireNonNull; + +public class RetryDriver +{ + private static final Logger log = Logger.get(RetryDriver.class); + private static final int DEFAULT_RETRY_ATTEMPTS = 10; + private static final Duration DEFAULT_SLEEP_TIME = Duration.valueOf("1s"); + private static final Duration DEFAULT_MAX_RETRY_TIME = Duration.valueOf("30s"); + private static final double DEFAULT_SCALE_FACTOR = 2.0; + + private final int maxAttempts; + private final Duration minSleepTime; + private final Duration maxSleepTime; + private final double scaleFactor; + private final Duration maxRetryTime; + private final Optional retryRunnable; + private final Predicate stopRetrying; + private final Function classifier; + + private RetryDriver( + int maxAttempts, + Duration minSleepTime, + Duration maxSleepTime, + double scaleFactor, + Duration maxRetryTime, + Optional retryRunnable, + Predicate stopRetrying, + Function classifier) + { + this.maxAttempts = maxAttempts; + this.minSleepTime = minSleepTime; + this.maxSleepTime = maxSleepTime; + this.scaleFactor = scaleFactor; + this.maxRetryTime = maxRetryTime; + this.retryRunnable = retryRunnable; + this.stopRetrying = stopRetrying; + this.classifier = classifier; + } + + private RetryDriver() + { + this(DEFAULT_RETRY_ATTEMPTS, + DEFAULT_SLEEP_TIME, + DEFAULT_SLEEP_TIME, + DEFAULT_SCALE_FACTOR, + DEFAULT_MAX_RETRY_TIME, + Optional.empty(), + e -> false, + Function.identity()); + } + + public static RetryDriver retry() + { + return new RetryDriver(); + } + + public final RetryDriver maxAttempts(int maxAttempts) + { + return new RetryDriver(maxAttempts, minSleepTime, maxSleepTime, scaleFactor, maxRetryTime, retryRunnable, stopRetrying, classifier); + } + + public final RetryDriver exponentialBackoff(Duration minSleepTime, Duration maxSleepTime, Duration maxRetryTime, double scaleFactor) + { + return new RetryDriver(maxAttempts, minSleepTime, maxSleepTime, scaleFactor, maxRetryTime, retryRunnable, stopRetrying, classifier); + } + + public final RetryDriver onRetry(Runnable retryRunnable) + { + return new RetryDriver(maxAttempts, minSleepTime, maxSleepTime, scaleFactor, maxRetryTime, Optional.ofNullable(retryRunnable), stopRetrying, classifier); + } + + public RetryDriver stopRetryingWhen(Predicate stopRetrying) + { + return new RetryDriver(maxAttempts, minSleepTime, maxSleepTime, scaleFactor, maxRetryTime, retryRunnable, stopRetrying, classifier); + } + + public RetryDriver withClassifier(Function classifier) + { + return new RetryDriver(maxAttempts, minSleepTime, maxSleepTime, scaleFactor, maxRetryTime, retryRunnable, stopRetrying, classifier); + } + + public V run(String callableName, Callable callable) + { + requireNonNull(callableName, "callableName is null"); + requireNonNull(callable, "callable is null"); + + long startTime = System.nanoTime(); + int attempt = 0; + while (true) { + attempt++; + + if (attempt > 1) { + retryRunnable.ifPresent(Runnable::run); + } + + try { + return callable.call(); + } + catch (InterruptedException ie) { + Thread.currentThread().interrupt(); + throw propagate(ie); + } + catch (Exception e) { + if (stopRetrying.test(e)) { + throw propagate(e); + } + if (attempt >= maxAttempts || Duration.nanosSince(startTime).compareTo(maxRetryTime) >= 0) { + throw propagate(e); + } + log.warn("Failed on executing %s with attempt %d, will retry. Exception: %s", callableName, attempt, e.getMessage()); + + int delayInMs = (int) Math.min(minSleepTime.toMillis() * Math.pow(scaleFactor, attempt - 1), maxSleepTime.toMillis()); + int jitter = ThreadLocalRandom.current().nextInt(Math.max(1, (int) (delayInMs * 0.1))); + try { + TimeUnit.MILLISECONDS.sleep(delayInMs + jitter); + } + catch (InterruptedException ie) { + Thread.currentThread().interrupt(); + throw propagate(ie); + } + } + } + } + + private RuntimeException propagate(Exception e) + { + Exception classified = classifier.apply(e); + throwIfUnchecked(classified); + throw new RuntimeException(classified); + } +} diff --git a/presto-thrift-connector/src/test/java/com/facebook/presto/connector/thrift/TestThriftConnectorConfig.java b/presto-thrift-connector/src/test/java/com/facebook/presto/connector/thrift/TestThriftConnectorConfig.java new file mode 100644 index 000000000000..562a1b3bc908 --- /dev/null +++ b/presto-thrift-connector/src/test/java/com/facebook/presto/connector/thrift/TestThriftConnectorConfig.java @@ -0,0 +1,50 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.connector.thrift; + +import com.google.common.collect.ImmutableMap; +import io.airlift.configuration.testing.ConfigAssertions; +import io.airlift.units.DataSize; +import org.testng.annotations.Test; + +import java.util.Map; + +import static io.airlift.units.DataSize.Unit.MEGABYTE; + +public class TestThriftConnectorConfig +{ + @Test + public void testDefaults() + { + ConfigAssertions.assertRecordedDefaults(ConfigAssertions.recordDefaults(ThriftConnectorConfig.class) + .setMaxResponseSize(new DataSize(16, MEGABYTE)) + .setMetadataRefreshThreads(1) + ); + } + + @Test + public void testExplicitPropertyMappings() + { + Map properties = new ImmutableMap.Builder() + .put("presto-thrift.max-response-size", "2MB") + .put("presto-thrift.metadata-refresh-threads", "10") + .build(); + + ThriftConnectorConfig expected = new ThriftConnectorConfig() + .setMaxResponseSize(new DataSize(2, MEGABYTE)) + .setMetadataRefreshThreads(10); + + ConfigAssertions.assertFullMapping(properties, expected); + } +} diff --git a/presto-thrift-connector/src/test/java/com/facebook/presto/connector/thrift/TestThriftPlugin.java b/presto-thrift-connector/src/test/java/com/facebook/presto/connector/thrift/TestThriftPlugin.java new file mode 100644 index 000000000000..77c2b410ced7 --- /dev/null +++ b/presto-thrift-connector/src/test/java/com/facebook/presto/connector/thrift/TestThriftPlugin.java @@ -0,0 +1,58 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.connector.thrift; + +import com.facebook.presto.spi.Plugin; +import com.facebook.presto.spi.connector.Connector; +import com.facebook.presto.spi.connector.ConnectorFactory; +import com.facebook.presto.testing.TestingConnectorContext; +import com.google.common.collect.ImmutableMap; +import org.testng.annotations.Test; + +import java.util.Map; +import java.util.ServiceLoader; + +import static com.google.common.collect.Iterables.getOnlyElement; +import static io.airlift.testing.Assertions.assertInstanceOf; +import static org.testng.Assert.assertNotNull; + +public class TestThriftPlugin +{ + @Test + public void testPlugin() + throws Exception + { + ThriftPlugin plugin = loadPlugin(ThriftPlugin.class); + + ConnectorFactory factory = getOnlyElement(plugin.getConnectorFactories()); + assertInstanceOf(factory, ThriftConnectorFactory.class); + + Map config = ImmutableMap.of("static-location.hosts", "localhost:7777"); + + Connector connector = factory.create("test", config, new TestingConnectorContext()); + assertNotNull(connector); + assertInstanceOf(connector, ThriftConnector.class); + } + + @SuppressWarnings("unchecked") + private static T loadPlugin(Class clazz) + { + for (Plugin plugin : ServiceLoader.load(Plugin.class)) { + if (clazz.isInstance(plugin)) { + return (T) plugin; + } + } + throw new AssertionError("did not find plugin: " + clazz.getName()); + } +} diff --git a/presto-thrift-connector/src/test/java/com/facebook/presto/connector/thrift/integration/TestThriftDistributedQueries.java b/presto-thrift-connector/src/test/java/com/facebook/presto/connector/thrift/integration/TestThriftDistributedQueries.java new file mode 100644 index 000000000000..90e240898135 --- /dev/null +++ b/presto-thrift-connector/src/test/java/com/facebook/presto/connector/thrift/integration/TestThriftDistributedQueries.java @@ -0,0 +1,34 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.connector.thrift.integration; + +import com.facebook.presto.tests.AbstractTestQueries; + +import static com.facebook.presto.connector.thrift.integration.ThriftQueryRunner.createThriftQueryRunner; + +public class TestThriftDistributedQueries + extends AbstractTestQueries +{ + public TestThriftDistributedQueries() + throws Exception + { + super(() -> createThriftQueryRunner(3, 3)); + } + + @Override + public void testAssignUniqueId() + { + // this test can take a long time + } +} diff --git a/presto-thrift-connector/src/test/java/com/facebook/presto/connector/thrift/integration/TestThriftIntegrationSmokeTest.java b/presto-thrift-connector/src/test/java/com/facebook/presto/connector/thrift/integration/TestThriftIntegrationSmokeTest.java new file mode 100644 index 000000000000..098e78902c66 --- /dev/null +++ b/presto-thrift-connector/src/test/java/com/facebook/presto/connector/thrift/integration/TestThriftIntegrationSmokeTest.java @@ -0,0 +1,44 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.connector.thrift.integration; + +import com.facebook.presto.testing.MaterializedResult; +import com.facebook.presto.tests.AbstractTestIntegrationSmokeTest; +import org.testng.annotations.Test; + +import static com.facebook.presto.connector.thrift.integration.ThriftQueryRunner.createThriftQueryRunner; +import static com.facebook.presto.spi.type.VarcharType.VARCHAR; +import static com.facebook.presto.tests.QueryAssertions.assertContains; + +public class TestThriftIntegrationSmokeTest + extends AbstractTestIntegrationSmokeTest +{ + public TestThriftIntegrationSmokeTest() + throws Exception + { + super(() -> createThriftQueryRunner(2, 2)); + } + + @Override + @Test + public void testShowSchemas() + throws Exception + { + MaterializedResult actualSchemas = computeActual("SHOW SCHEMAS").toJdbcTypes(); + MaterializedResult.Builder resultBuilder = MaterializedResult.resultBuilder(getQueryRunner().getDefaultSession(), VARCHAR) + .row("tiny") + .row("sf1"); + assertContains(actualSchemas, resultBuilder.build()); + } +} diff --git a/presto-thrift-connector/src/test/java/com/facebook/presto/connector/thrift/integration/ThriftQueryRunner.java b/presto-thrift-connector/src/test/java/com/facebook/presto/connector/thrift/integration/ThriftQueryRunner.java new file mode 100644 index 000000000000..f922deafe1af --- /dev/null +++ b/presto-thrift-connector/src/test/java/com/facebook/presto/connector/thrift/integration/ThriftQueryRunner.java @@ -0,0 +1,223 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.connector.thrift.integration; + +import com.facebook.presto.Session; +import com.facebook.presto.connector.thrift.ThriftPlugin; +import com.facebook.presto.connector.thrift.location.HostList; +import com.facebook.presto.connector.thrift.server.ThriftTpchService; +import com.facebook.presto.cost.CostCalculator; +import com.facebook.presto.metadata.Metadata; +import com.facebook.presto.metadata.QualifiedObjectName; +import com.facebook.presto.server.testing.TestingPrestoServer; +import com.facebook.presto.spi.HostAddress; +import com.facebook.presto.spi.Plugin; +import com.facebook.presto.testing.MaterializedResult; +import com.facebook.presto.testing.QueryRunner; +import com.facebook.presto.testing.TestingAccessControlManager; +import com.facebook.presto.tests.DistributedQueryRunner; +import com.facebook.presto.transaction.TransactionManager; +import com.facebook.swift.codec.ThriftCodecManager; +import com.facebook.swift.service.ThriftServer; +import com.facebook.swift.service.ThriftServiceProcessor; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import io.airlift.log.Logger; +import io.airlift.testing.Closeables; + +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.concurrent.locks.Lock; + +import static com.facebook.presto.testing.TestingSession.testSessionBuilder; +import static com.google.common.collect.ImmutableList.toImmutableList; +import static java.util.Objects.requireNonNull; + +public final class ThriftQueryRunner +{ + private ThriftQueryRunner() {} + + public static QueryRunner createThriftQueryRunner(int thriftServers, int workers) + throws Exception + { + List servers = null; + DistributedQueryRunner runner = null; + try { + servers = startThriftServers(thriftServers); + runner = createThriftQueryRunnerInternal(servers, workers); + return new ThriftQueryRunnerWithServers(runner, servers); + } + catch (Throwable t) { + Closeables.closeQuietly(runner); + // runner might be null, so closing servers explicitly + if (servers != null) { + for (ThriftServer server : servers) { + Closeables.closeQuietly(server); + } + } + throw t; + } + } + + public static void main(String[] args) + throws Exception + { + ThriftQueryRunnerWithServers queryRunner = (ThriftQueryRunnerWithServers) createThriftQueryRunner(3, 3); + Thread.sleep(10); + Logger log = Logger.get(ThriftQueryRunner.class); + log.info("======== SERVER STARTED ========"); + log.info("\n====\n%s\n====", queryRunner.getCoordinator().getBaseUrl()); + } + + private static List startThriftServers(int thriftServers) + { + List servers = new ArrayList<>(thriftServers); + for (int i = 0; i < thriftServers; i++) { + ThriftServiceProcessor processor = new ThriftServiceProcessor(new ThriftCodecManager(), ImmutableList.of(), new ThriftTpchService()); + servers.add(new ThriftServer(processor).start()); + } + return servers; + } + + private static DistributedQueryRunner createThriftQueryRunnerInternal(List servers, int workers) + throws Exception + { + List addresses = servers.stream() + .map(server -> HostAddress.fromParts("localhost", server.getPort())) + .collect(toImmutableList()); + HostList hosts = HostList.fromList(addresses); + + Session defaultSession = testSessionBuilder() + .setCatalog("thrift") + .setSchema("tiny") + .build(); + DistributedQueryRunner queryRunner = new DistributedQueryRunner(defaultSession, workers); + queryRunner.installPlugin(new ThriftPlugin()); + Map connectorProperties = ImmutableMap.of( + "static-location.hosts", hosts.stringValue(), + "PrestoThriftService.thrift.client.connect-timeout", "30s" + ); + queryRunner.createCatalog("thrift", "presto-thrift", connectorProperties); + return queryRunner; + } + + /** + * Wraps QueryRunner and a list of ThriftServers to clean them up together. + */ + private static class ThriftQueryRunnerWithServers + implements QueryRunner + { + private final DistributedQueryRunner source; + private final List thriftServers; + + private ThriftQueryRunnerWithServers(DistributedQueryRunner source, List thriftServers) + { + this.source = requireNonNull(source, "source is null"); + this.thriftServers = ImmutableList.copyOf(requireNonNull(thriftServers, "thriftServers is null")); + } + + public TestingPrestoServer getCoordinator() + { + return source.getCoordinator(); + } + + @Override + public void close() + { + Closeables.closeQuietly(source); + for (ThriftServer server : thriftServers) { + Closeables.closeQuietly(server); + } + } + + @Override + public int getNodeCount() + { + return source.getNodeCount(); + } + + @Override + public Session getDefaultSession() + { + return source.getDefaultSession(); + } + + @Override + public TransactionManager getTransactionManager() + { + return source.getTransactionManager(); + } + + @Override + public Metadata getMetadata() + { + return source.getMetadata(); + } + + @Override + public CostCalculator getCostCalculator() + { + return source.getCostCalculator(); + } + + @Override + public TestingAccessControlManager getAccessControl() + { + return source.getAccessControl(); + } + + @Override + public MaterializedResult execute(String sql) + { + return source.execute(sql); + } + + @Override + public MaterializedResult execute(Session session, String sql) + { + return source.execute(session, sql); + } + + @Override + public List listTables(Session session, String catalog, String schema) + { + return source.listTables(session, catalog, schema); + } + + @Override + public boolean tableExists(Session session, String table) + { + return source.tableExists(session, table); + } + + @Override + public void installPlugin(Plugin plugin) + { + source.installPlugin(plugin); + } + + @Override + public void createCatalog(String catalogName, String connectorName, Map properties) + { + source.createCatalog(catalogName, connectorName, properties); + } + + @Override + public Lock getExclusiveLock() + { + return source.getExclusiveLock(); + } + } +} diff --git a/presto-thrift-connector/src/test/java/com/facebook/presto/connector/thrift/location/TestStaticLocationConfig.java b/presto-thrift-connector/src/test/java/com/facebook/presto/connector/thrift/location/TestStaticLocationConfig.java new file mode 100644 index 000000000000..1d7968740369 --- /dev/null +++ b/presto-thrift-connector/src/test/java/com/facebook/presto/connector/thrift/location/TestStaticLocationConfig.java @@ -0,0 +1,46 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.connector.thrift.location; + +import com.facebook.presto.spi.HostAddress; +import com.google.common.collect.ImmutableMap; +import io.airlift.configuration.testing.ConfigAssertions; +import org.testng.annotations.Test; + +import java.util.Map; + +public class TestStaticLocationConfig +{ + @Test + public void testDefaults() + { + ConfigAssertions.assertRecordedDefaults(ConfigAssertions.recordDefaults(StaticLocationConfig.class) + .setHosts(null)); + } + + @Test + public void testExplicitPropertyMappings() + { + Map properties = new ImmutableMap.Builder() + .put("static-location.hosts", "localhost:7777,localhost:7779") + .build(); + + StaticLocationConfig expected = new StaticLocationConfig() + .setHosts(HostList.of( + HostAddress.fromParts("localhost", 7777), + HostAddress.fromParts("localhost", 7779))); + + ConfigAssertions.assertFullMapping(properties, expected); + } +} diff --git a/presto-thrift-connector/src/test/java/com/facebook/presto/connector/thrift/location/TestStaticLocationProvider.java b/presto-thrift-connector/src/test/java/com/facebook/presto/connector/thrift/location/TestStaticLocationProvider.java new file mode 100644 index 000000000000..25e9f8b21da7 --- /dev/null +++ b/presto-thrift-connector/src/test/java/com/facebook/presto/connector/thrift/location/TestStaticLocationProvider.java @@ -0,0 +1,43 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.connector.thrift.location; + +import com.facebook.presto.spi.HostAddress; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableSet; +import org.testng.annotations.Test; + +import java.util.ArrayList; +import java.util.List; + +import static org.testng.Assert.assertEquals; + +public class TestStaticLocationProvider +{ + @Test + public void testGetAnyHostRoundRobin() + throws Exception + { + List expected = ImmutableList.of( + HostAddress.fromParts("localhost1", 11111), + HostAddress.fromParts("localhost2", 22222), + HostAddress.fromParts("localhost3", 33333)); + HostLocationProvider provider = new StaticLocationProvider(new StaticLocationConfig().setHosts(HostList.fromList(expected))); + List actual = new ArrayList<>(expected.size()); + for (int i = 0; i < expected.size(); i++) { + actual.add(provider.getAnyHost()); + } + assertEquals(ImmutableSet.copyOf(actual), ImmutableSet.copyOf(expected)); + } +} diff --git a/presto-thrift-testing-server/.build-airlift b/presto-thrift-testing-server/.build-airlift new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/presto-thrift-testing-server/README.txt b/presto-thrift-testing-server/README.txt new file mode 100644 index 000000000000..cf72ed369edd --- /dev/null +++ b/presto-thrift-testing-server/README.txt @@ -0,0 +1 @@ +Thrift server implementing Thrift Connector API using TPCH data. diff --git a/presto-thrift-testing-server/etc/config.properties b/presto-thrift-testing-server/etc/config.properties new file mode 100644 index 000000000000..afd3bbe56f6b --- /dev/null +++ b/presto-thrift-testing-server/etc/config.properties @@ -0,0 +1,2 @@ +thrift.port=7779 +thrift.max-frame-size=64MB diff --git a/presto-thrift-testing-server/etc/log.properties b/presto-thrift-testing-server/etc/log.properties new file mode 100644 index 000000000000..290ff616938c --- /dev/null +++ b/presto-thrift-testing-server/etc/log.properties @@ -0,0 +1 @@ +com.facebook.presto=DEBUG diff --git a/presto-thrift-testing-server/pom.xml b/presto-thrift-testing-server/pom.xml new file mode 100644 index 000000000000..bf87b07f0783 --- /dev/null +++ b/presto-thrift-testing-server/pom.xml @@ -0,0 +1,96 @@ + + + 4.0.0 + + + com.facebook.presto + presto-root + 0.181-tw-0.37 + + + presto-thrift-testing-server + presto-thrift-testing-server + Presto - Thrift Testing Server + + + ${project.parent.basedir} + com.facebook.presto.connector.thrift.server.ThriftTpchServer + + + + + com.facebook.presto + presto-thrift-connector-api + + + + com.google.guava + guava + + + + com.google.code.findbugs + annotations + + + + com.facebook.swift + swift-codec + + + + com.facebook.swift + swift-service + + + + com.facebook.presto + presto-tpch + + + + io.airlift.tpch + tpch + + + + io.airlift + log + + + + io.airlift + bootstrap + + + + com.google.inject + guice + + + + javax.annotation + javax.annotation-api + + + + io.airlift + concurrent + + + + com.facebook.presto + presto-spi + + + + com.fasterxml.jackson.core + jackson-annotations + + + + io.airlift + json + + + diff --git a/presto-thrift-testing-server/src/main/java/com/facebook/presto/connector/thrift/server/SplitInfo.java b/presto-thrift-testing-server/src/main/java/com/facebook/presto/connector/thrift/server/SplitInfo.java new file mode 100644 index 000000000000..9b295035b88a --- /dev/null +++ b/presto-thrift-testing-server/src/main/java/com/facebook/presto/connector/thrift/server/SplitInfo.java @@ -0,0 +1,64 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.connector.thrift.server; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; + +import static java.util.Objects.requireNonNull; + +public final class SplitInfo +{ + private final String schemaName; + private final String tableName; + private final int partNumber; + private final int totalParts; + + @JsonCreator + public SplitInfo( + @JsonProperty("schemaName") String schemaName, + @JsonProperty("tableName") String tableName, + @JsonProperty("partNumber") int partNumber, + @JsonProperty("totalParts") int totalParts) + { + this.schemaName = requireNonNull(schemaName, "schemaName is null"); + this.tableName = requireNonNull(tableName, "tableName is null"); + this.partNumber = partNumber; + this.totalParts = totalParts; + } + + @JsonProperty + public String getSchemaName() + { + return schemaName; + } + + @JsonProperty + public String getTableName() + { + return tableName; + } + + @JsonProperty + public int getPartNumber() + { + return partNumber; + } + + @JsonProperty + public int getTotalParts() + { + return totalParts; + } +} diff --git a/presto-thrift-testing-server/src/main/java/com/facebook/presto/connector/thrift/server/ThriftTpchServer.java b/presto-thrift-testing-server/src/main/java/com/facebook/presto/connector/thrift/server/ThriftTpchServer.java new file mode 100644 index 000000000000..c46cae0637a8 --- /dev/null +++ b/presto-thrift-testing-server/src/main/java/com/facebook/presto/connector/thrift/server/ThriftTpchServer.java @@ -0,0 +1,60 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.connector.thrift.server; + +import com.facebook.swift.codec.guice.ThriftCodecModule; +import com.facebook.swift.service.guice.ThriftClientModule; +import com.facebook.swift.service.guice.ThriftServerModule; +import com.facebook.swift.service.guice.ThriftServerStatsModule; +import com.google.common.collect.ImmutableList; +import com.google.inject.Module; +import io.airlift.bootstrap.Bootstrap; +import io.airlift.log.Logger; + +import java.util.List; + +import static java.util.Objects.requireNonNull; + +public final class ThriftTpchServer +{ + private ThriftTpchServer() + { + } + + public static void start(List extraModules) + throws Exception + { + Bootstrap app = new Bootstrap( + ImmutableList.builder() + .add(new ThriftCodecModule()) + .add(new ThriftClientModule()) + .add(new ThriftServerModule()) + .add(new ThriftServerStatsModule()) + .add(new ThriftTpchServerModule()) + .addAll(requireNonNull(extraModules, "extraModules is null")) + .build()); + app.strictConfig().initialize(); + } + + public static void main(String[] args) + { + try { + ThriftTpchServer.start(ImmutableList.of()); + } + catch (Throwable t) { + Logger.get(ThriftTpchServer.class).error(t); + System.exit(1); + } + } +} diff --git a/presto-thrift-testing-server/src/main/java/com/facebook/presto/connector/thrift/server/ThriftTpchServerModule.java b/presto-thrift-testing-server/src/main/java/com/facebook/presto/connector/thrift/server/ThriftTpchServerModule.java new file mode 100644 index 000000000000..9e0d292b3ab7 --- /dev/null +++ b/presto-thrift-testing-server/src/main/java/com/facebook/presto/connector/thrift/server/ThriftTpchServerModule.java @@ -0,0 +1,31 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.connector.thrift.server; + +import com.google.inject.Binder; +import com.google.inject.Module; +import com.google.inject.Scopes; + +import static com.facebook.swift.service.guice.ThriftServiceExporter.thriftServerBinder; + +public class ThriftTpchServerModule + implements Module +{ + @Override + public void configure(Binder binder) + { + binder.bind(ThriftTpchService.class).in(Scopes.SINGLETON); + thriftServerBinder(binder).exportThriftService(ThriftTpchService.class); + } +} diff --git a/presto-thrift-testing-server/src/main/java/com/facebook/presto/connector/thrift/server/ThriftTpchService.java b/presto-thrift-testing-server/src/main/java/com/facebook/presto/connector/thrift/server/ThriftTpchService.java new file mode 100644 index 000000000000..dcc40d2eb059 --- /dev/null +++ b/presto-thrift-testing-server/src/main/java/com/facebook/presto/connector/thrift/server/ThriftTpchService.java @@ -0,0 +1,275 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.connector.thrift.server; + +import com.facebook.presto.connector.thrift.api.PrestoThriftBlock; +import com.facebook.presto.connector.thrift.api.PrestoThriftColumnMetadata; +import com.facebook.presto.connector.thrift.api.PrestoThriftId; +import com.facebook.presto.connector.thrift.api.PrestoThriftNullableColumnSet; +import com.facebook.presto.connector.thrift.api.PrestoThriftNullableSchemaName; +import com.facebook.presto.connector.thrift.api.PrestoThriftNullableTableMetadata; +import com.facebook.presto.connector.thrift.api.PrestoThriftNullableToken; +import com.facebook.presto.connector.thrift.api.PrestoThriftPageResult; +import com.facebook.presto.connector.thrift.api.PrestoThriftSchemaTableName; +import com.facebook.presto.connector.thrift.api.PrestoThriftService; +import com.facebook.presto.connector.thrift.api.PrestoThriftServiceException; +import com.facebook.presto.connector.thrift.api.PrestoThriftSplit; +import com.facebook.presto.connector.thrift.api.PrestoThriftSplitBatch; +import com.facebook.presto.connector.thrift.api.PrestoThriftTableMetadata; +import com.facebook.presto.connector.thrift.api.PrestoThriftTupleDomain; +import com.facebook.presto.spi.ConnectorPageSource; +import com.facebook.presto.spi.Page; +import com.facebook.presto.spi.RecordPageSource; +import com.facebook.presto.spi.type.Type; +import com.facebook.presto.tpch.TpchMetadata; +import com.google.common.collect.ImmutableList; +import com.google.common.primitives.Ints; +import com.google.common.util.concurrent.ListenableFuture; +import com.google.common.util.concurrent.ListeningExecutorService; +import io.airlift.json.JsonCodec; +import io.airlift.tpch.TpchColumn; +import io.airlift.tpch.TpchColumnType; +import io.airlift.tpch.TpchEntity; +import io.airlift.tpch.TpchTable; + +import javax.annotation.Nullable; +import javax.annotation.PreDestroy; + +import java.util.ArrayList; +import java.util.List; +import java.util.Optional; + +import static com.facebook.presto.connector.thrift.api.PrestoThriftBlock.fromBlock; +import static com.facebook.presto.spi.block.PageBuilderStatus.DEFAULT_MAX_PAGE_SIZE_IN_BYTES; +import static com.facebook.presto.tpch.TpchMetadata.getPrestoType; +import static com.facebook.presto.tpch.TpchRecordSet.createTpchRecordSet; +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Preconditions.checkState; +import static com.google.common.util.concurrent.MoreExecutors.listeningDecorator; +import static io.airlift.concurrent.Threads.threadsNamed; +import static io.airlift.json.JsonCodec.jsonCodec; +import static java.lang.Math.min; +import static java.util.concurrent.Executors.newCachedThreadPool; +import static java.util.stream.Collectors.toList; + +public class ThriftTpchService + implements PrestoThriftService +{ + private static final int DEFAULT_NUMBER_OF_SPLITS = 3; + private static final List SCHEMAS = ImmutableList.of("tiny", "sf1"); + private static final JsonCodec SPLIT_INFO_CODEC = jsonCodec(SplitInfo.class); + + private final ListeningExecutorService splitsExecutor = + listeningDecorator(newCachedThreadPool(threadsNamed("splits-generator-%s"))); + private final ListeningExecutorService dataExecutor = + listeningDecorator(newCachedThreadPool(threadsNamed("data-generator-%s"))); + + @Override + public List listSchemaNames() + { + return SCHEMAS; + } + + @Override + public List listTables(PrestoThriftNullableSchemaName schemaNameOrNull) + { + List tables = new ArrayList<>(); + for (String schemaName : getSchemaNames(schemaNameOrNull.getSchemaName())) { + for (TpchTable tpchTable : TpchTable.getTables()) { + tables.add(new PrestoThriftSchemaTableName(schemaName, tpchTable.getTableName())); + } + } + return tables; + } + + private static List getSchemaNames(String schemaNameOrNull) + { + if (schemaNameOrNull == null) { + return SCHEMAS; + } + else if (SCHEMAS.contains(schemaNameOrNull)) { + return ImmutableList.of(schemaNameOrNull); + } + else { + return ImmutableList.of(); + } + } + + @Override + public PrestoThriftNullableTableMetadata getTableMetadata(PrestoThriftSchemaTableName schemaTableName) + { + String schemaName = schemaTableName.getSchemaName(); + String tableName = schemaTableName.getTableName(); + if (!SCHEMAS.contains(schemaName) || TpchTable.getTables().stream().noneMatch(table -> table.getTableName().equals(tableName))) { + return new PrestoThriftNullableTableMetadata(null); + } + TpchTable tpchTable = TpchTable.getTable(schemaTableName.getTableName()); + List columns = new ArrayList<>(); + for (TpchColumn column : tpchTable.getColumns()) { + columns.add(new PrestoThriftColumnMetadata(column.getSimplifiedColumnName(), getTypeString(column.getType()), null, false)); + } + return new PrestoThriftNullableTableMetadata(new PrestoThriftTableMetadata(schemaTableName, columns, null)); + } + + @Override + public ListenableFuture getSplits( + PrestoThriftSchemaTableName schemaTableName, + PrestoThriftNullableColumnSet desiredColumns, + PrestoThriftTupleDomain outputConstraint, + int maxSplitCount, + PrestoThriftNullableToken nextToken) + throws PrestoThriftServiceException + { + return splitsExecutor.submit(() -> getSplitsInternal(schemaTableName, maxSplitCount, nextToken.getToken())); + } + + private static PrestoThriftSplitBatch getSplitsInternal( + PrestoThriftSchemaTableName schemaTableName, + int maxSplitCount, + @Nullable PrestoThriftId nextToken) + { + int totalParts = DEFAULT_NUMBER_OF_SPLITS; + // last sent part + int partNumber = nextToken == null ? 0 : Ints.fromByteArray(nextToken.getId()); + int numberOfSplits = min(maxSplitCount, totalParts - partNumber); + + List splits = new ArrayList<>(numberOfSplits); + for (int i = 0; i < numberOfSplits; i++) { + SplitInfo splitInfo = new SplitInfo( + schemaTableName.getSchemaName(), + schemaTableName.getTableName(), + partNumber + 1, + totalParts); + splits.add(new PrestoThriftSplit(new PrestoThriftId(SPLIT_INFO_CODEC.toJsonBytes(splitInfo)), ImmutableList.of())); + partNumber++; + } + PrestoThriftId newNextToken = partNumber < totalParts ? new PrestoThriftId(Ints.toByteArray(partNumber)) : null; + return new PrestoThriftSplitBatch(splits, newNextToken); + } + + @Override + public ListenableFuture getRows( + PrestoThriftId splitId, + List columns, + long maxBytes, + PrestoThriftNullableToken nextToken) + { + return dataExecutor.submit(() -> getRowsInternal(splitId, columns, maxBytes, nextToken.getToken())); + } + + @PreDestroy + @Override + public void close() + { + splitsExecutor.shutdownNow(); + dataExecutor.shutdownNow(); + } + + private static PrestoThriftPageResult getRowsInternal(PrestoThriftId splitId, List columnNames, long maxBytes, @Nullable PrestoThriftId nextToken) + { + checkArgument(maxBytes >= DEFAULT_MAX_PAGE_SIZE_IN_BYTES, "requested maxBytes is too small"); + SplitInfo splitInfo = SPLIT_INFO_CODEC.fromJson(splitId.getId()); + ConnectorPageSource pageSource = createPageSource(splitInfo, columnNames); + + // very inefficient implementation as it needs to re-generate all previous results to get the next page + int skipPages = nextToken != null ? Ints.fromByteArray(nextToken.getId()) : 0; + skipPages(pageSource, skipPages); + + Page page = null; + while (!pageSource.isFinished() && page == null) { + page = pageSource.getNextPage(); + skipPages++; + } + PrestoThriftId newNextToken = pageSource.isFinished() ? null : new PrestoThriftId(Ints.toByteArray(skipPages)); + + return toThriftPage(page, types(splitInfo.getTableName(), columnNames), newNextToken); + } + + private static PrestoThriftPageResult toThriftPage(Page page, List columnTypes, @Nullable PrestoThriftId nextToken) + { + if (page == null) { + checkState(nextToken == null, "there must be no more data when page is null"); + return new PrestoThriftPageResult(ImmutableList.of(), 0, null); + } + checkState(page.getChannelCount() == columnTypes.size(), "number of columns in a page doesn't match the one in requested types"); + int numberOfColumns = columnTypes.size(); + List columnBlocks = new ArrayList<>(numberOfColumns); + for (int i = 0; i < numberOfColumns; i++) { + columnBlocks.add(fromBlock(page.getBlock(i), columnTypes.get(i))); + } + return new PrestoThriftPageResult(columnBlocks, page.getPositionCount(), nextToken); + } + + private static void skipPages(ConnectorPageSource pageSource, int skipPages) + { + for (int i = 0; i < skipPages; i++) { + checkState(!pageSource.isFinished(), "pageSource is unexpectedly finished"); + pageSource.getNextPage(); + } + } + + private static ConnectorPageSource createPageSource(SplitInfo splitInfo, List columnNames) + { + switch (splitInfo.getTableName()) { + case "orders": + return createPageSource(TpchTable.ORDERS, columnNames, splitInfo); + case "customer": + return createPageSource(TpchTable.CUSTOMER, columnNames, splitInfo); + case "lineitem": + return createPageSource(TpchTable.LINE_ITEM, columnNames, splitInfo); + case "nation": + return createPageSource(TpchTable.NATION, columnNames, splitInfo); + case "region": + return createPageSource(TpchTable.REGION, columnNames, splitInfo); + case "part": + return createPageSource(TpchTable.PART, columnNames, splitInfo); + default: + throw new IllegalArgumentException("Table not setup: " + splitInfo.getTableName()); + } + } + + private static ConnectorPageSource createPageSource(TpchTable table, List columnNames, SplitInfo splitInfo) + { + List> columns = columnNames.stream().map(table::getColumn).collect(toList()); + return new RecordPageSource(createTpchRecordSet( + table, + columns, + schemaNameToScaleFactor(splitInfo.getSchemaName()), + splitInfo.getPartNumber(), + splitInfo.getTotalParts(), + Optional.empty())); + } + + private static List types(String tableName, List columnNames) + { + TpchTable table = TpchTable.getTable(tableName); + return columnNames.stream().map(name -> getPrestoType(table.getColumn(name).getType())).collect(toList()); + } + + private static double schemaNameToScaleFactor(String schemaName) + { + switch (schemaName) { + case "tiny": + return 0.01; + case "sf1": + return 1.0; + } + throw new IllegalArgumentException("Schema is not setup: " + schemaName); + } + + private static String getTypeString(TpchColumnType tpchType) + { + return TpchMetadata.getPrestoType(tpchType).getTypeSignature().toString(); + } +} diff --git a/presto-tpch/pom.xml b/presto-tpch/pom.xml index 36d62494e9e6..1f6423a66dd1 100644 --- a/presto-tpch/pom.xml +++ b/presto-tpch/pom.xml @@ -4,7 +4,7 @@ com.facebook.presto presto-root - 0.179-tw-0.36 + 0.181-tw-0.37 presto-tpch diff --git a/presto-verifier/pom.xml b/presto-verifier/pom.xml index a9a8c2b688e0..d8db6ca11a3c 100644 --- a/presto-verifier/pom.xml +++ b/presto-verifier/pom.xml @@ -5,7 +5,7 @@ com.facebook.presto presto-root - 0.179-tw-0.36 + 0.181-tw-0.37 presto-verifier diff --git a/presto-verifier/src/main/java/com/facebook/presto/verifier/PrestoVerifier.java b/presto-verifier/src/main/java/com/facebook/presto/verifier/PrestoVerifier.java index d3118536b721..d22964f3d92b 100644 --- a/presto-verifier/src/main/java/com/facebook/presto/verifier/PrestoVerifier.java +++ b/presto-verifier/src/main/java/com/facebook/presto/verifier/PrestoVerifier.java @@ -20,6 +20,7 @@ import com.facebook.presto.sql.tree.CreateTableAsSelect; import com.facebook.presto.sql.tree.CreateView; import com.facebook.presto.sql.tree.Delete; +import com.facebook.presto.sql.tree.DropColumn; import com.facebook.presto.sql.tree.DropTable; import com.facebook.presto.sql.tree.DropView; import com.facebook.presto.sql.tree.Explain; @@ -321,6 +322,9 @@ private static QueryType statementToQueryType(Statement statement) if (statement instanceof RenameColumn) { return MODIFY; } + if (statement instanceof DropColumn) { + return MODIFY; + } if (statement instanceof RenameTable) { return MODIFY; } diff --git a/presto-verifier/src/main/java/com/facebook/presto/verifier/QueryRewriter.java b/presto-verifier/src/main/java/com/facebook/presto/verifier/QueryRewriter.java index 52a87bca96b2..c7417197244f 100644 --- a/presto-verifier/src/main/java/com/facebook/presto/verifier/QueryRewriter.java +++ b/presto-verifier/src/main/java/com/facebook/presto/verifier/QueryRewriter.java @@ -169,7 +169,8 @@ private String createTemporaryTableName() return rewritePrefix.getSuffix() + UUID.randomUUID().toString().replace("-", ""); } - private List getColumnsForTable(Connection connection, String catalog, String schema, String table) throws SQLException + private List getColumnsForTable(Connection connection, String catalog, String schema, String table) + throws SQLException { ResultSet columns = connection.getMetaData().getColumns(catalog, escapeLikeExpression(connection, schema), escapeLikeExpression(connection, table), null); ImmutableList.Builder columnBuilder = new ImmutableList.Builder<>(); diff --git a/twitter-eventlistener-plugin/pom.xml b/twitter-eventlistener-plugin/pom.xml index ba1fbbbc0a3d..6dbbc9f3f689 100644 --- a/twitter-eventlistener-plugin/pom.xml +++ b/twitter-eventlistener-plugin/pom.xml @@ -4,7 +4,7 @@ com.facebook.presto presto-root - 0.179-tw-0.36 + 0.181-tw-0.37 twitter-eventlistener-plugin @@ -19,7 +19,7 @@ com.facebook.presto presto-spi - 0.179-tw-0.36 + 0.181-tw-0.37 provided