Skip to content

Commit

Permalink
Switch to aircompressor 2.0 API
Browse files Browse the repository at this point in the history
  • Loading branch information
wendigo committed Jul 19, 2024
1 parent 2a06c88 commit b4ce9a2
Show file tree
Hide file tree
Showing 30 changed files with 209 additions and 264 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,13 @@
package io.trino.execution.buffer;

import com.google.common.base.VerifyException;
import io.airlift.compress.Decompressor;
import io.airlift.compress.lz4.Lz4Decompressor;
import io.airlift.compress.lz4.Lz4RawCompressor;
import io.airlift.compress.v2.Decompressor;
import io.airlift.compress.v2.lz4.Lz4Decompressor;
import io.airlift.compress.v2.lz4.Lz4JavaCompressor;
import io.airlift.compress.v2.snappy.SnappyDecompressor;
import io.airlift.compress.v2.snappy.SnappyJavaCompressor;
import io.airlift.compress.v2.zstd.ZstdDecompressor;
import io.airlift.compress.v2.zstd.ZstdJavaCompressor;
import io.airlift.slice.Slice;
import io.airlift.slice.SliceInput;
import io.airlift.slice.Slices;
Expand Down Expand Up @@ -122,7 +126,7 @@ private SerializedPageInput(Optional<Decompressor> decompressor, Optional<Secret
int bufferSize;
if (decompressor.isPresent()) {
// to store compressed block size
bufferSize = Lz4RawCompressor.maxCompressedLength(blockSizeInBytes)
bufferSize = maxCompressedLength(decompressor.get(), blockSizeInBytes)
// to store compressed block size
+ Integer.BYTES
// to guarantee a single long can always be read entirely
Expand Down Expand Up @@ -427,6 +431,16 @@ private static int getCompressedBlockSize(int compressedBlockMarker)
return compressedBlockMarker & ~SERIALIZED_PAGE_COMPRESSED_BLOCK_MASK;
}

private static int maxCompressedLength(Decompressor decompressor, int blockSizeInBytes)
{
return switch (decompressor) {
case ZstdDecompressor _ -> new ZstdJavaCompressor().maxCompressedLength(blockSizeInBytes);
case SnappyDecompressor _ -> new SnappyJavaCompressor().maxCompressedLength(blockSizeInBytes);
case Lz4Decompressor _ -> new Lz4JavaCompressor().maxCompressedLength(blockSizeInBytes);
default -> throw new IllegalArgumentException("Cannot estimate max compressed length for decompressor: %s".formatted(decompressor));
};
}

private static boolean isCompressed(int compressedBlockMarker)
{
return (compressedBlockMarker & SERIALIZED_PAGE_COMPRESSED_BLOCK_MASK) == SERIALIZED_PAGE_COMPRESSED_BLOCK_MASK;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,8 @@
package io.trino.execution.buffer;

import com.google.common.base.VerifyException;
import io.airlift.compress.Compressor;
import io.airlift.compress.lz4.Lz4Compressor;
import io.airlift.compress.lz4.Lz4RawCompressor;
import io.airlift.compress.v2.Compressor;
import io.airlift.compress.v2.lz4.Lz4Compressor;
import io.airlift.slice.Slice;
import io.airlift.slice.SliceOutput;
import io.airlift.slice.Slices;
Expand All @@ -38,7 +37,6 @@
import static io.airlift.slice.SizeOf.instanceSize;
import static io.airlift.slice.SizeOf.sizeOf;
import static io.airlift.slice.SizeOf.sizeOfByteArray;
import static io.airlift.slice.SizeOf.sizeOfIntArray;
import static io.trino.execution.buffer.PageCodecMarker.COMPRESSED;
import static io.trino.execution.buffer.PageCodecMarker.ENCRYPTED;
import static io.trino.execution.buffer.PagesSerdeUtil.ESTIMATED_AES_CIPHER_RETAINED_SIZE;
Expand Down Expand Up @@ -95,7 +93,8 @@ private static class SerializedPageOutput
{
private static final int INSTANCE_SIZE = instanceSize(SerializedPageOutput.class);
// TODO: implement getRetainedSizeInBytes in Lz4Compressor
private static final int COMPRESSOR_RETAINED_SIZE = toIntExact(instanceSize(Lz4Compressor.class) + sizeOfIntArray(Lz4RawCompressor.MAX_TABLE_SIZE));
// TODO: need a fix
private static final int COMPRESSOR_RETAINED_SIZE = toIntExact(instanceSize(Lz4Compressor.class));
private static final int ENCRYPTION_KEY_RETAINED_SIZE = toIntExact(instanceSize(SecretKeySpec.class) + sizeOfByteArray(256 / 8));

private static final double MINIMUM_COMPRESSION_RATIO = 0.8;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,16 @@
*/
package io.trino.execution.buffer;

import io.airlift.compress.Compressor;
import io.airlift.compress.Decompressor;
import io.airlift.compress.lz4.Lz4Compressor;
import io.airlift.compress.lz4.Lz4Decompressor;
import io.airlift.compress.zstd.ZstdCompressor;
import io.airlift.compress.zstd.ZstdDecompressor;
import io.airlift.compress.v2.Compressor;
import io.airlift.compress.v2.Decompressor;
import io.airlift.compress.v2.lz4.Lz4JavaCompressor;
import io.airlift.compress.v2.lz4.Lz4JavaDecompressor;
import io.airlift.compress.v2.lz4.Lz4NativeCompressor;
import io.airlift.compress.v2.lz4.Lz4NativeDecompressor;
import io.airlift.compress.v2.zstd.ZstdJavaCompressor;
import io.airlift.compress.v2.zstd.ZstdJavaDecompressor;
import io.airlift.compress.v2.zstd.ZstdNativeCompressor;
import io.airlift.compress.v2.zstd.ZstdNativeDecompressor;
import io.trino.spi.block.BlockEncodingSerde;

import javax.crypto.SecretKey;
Expand Down Expand Up @@ -54,17 +58,17 @@ public static Optional<Compressor> createCompressor(CompressionCodec compression
{
return switch (compressionCodec) {
case NONE -> Optional.empty();
case LZ4 -> Optional.of(new Lz4Compressor());
case ZSTD -> Optional.of(new ZstdCompressor());
case LZ4 -> Optional.of(Lz4NativeCompressor.isEnabled() ? new Lz4NativeCompressor() : new Lz4JavaCompressor());
case ZSTD -> Optional.of(ZstdNativeCompressor.isEnabled() ? new ZstdNativeCompressor() : new ZstdJavaCompressor());
};
}

public static Optional<Decompressor> createDecompressor(CompressionCodec compressionCodec)
{
return switch (compressionCodec) {
case NONE -> Optional.empty();
case LZ4 -> Optional.of(new Lz4Decompressor());
case ZSTD -> Optional.of(new ZstdDecompressor());
case LZ4 -> Optional.of(Lz4NativeCompressor.isEnabled() ? new Lz4NativeDecompressor() : new Lz4JavaDecompressor());
case ZSTD -> Optional.of(ZstdNativeDecompressor.isEnabled() ? new ZstdNativeDecompressor() : new ZstdJavaDecompressor());
};
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,12 @@
package io.trino.server.protocol;

import com.google.inject.Inject;
import io.airlift.compress.zstd.ZstdCompressor;
import io.airlift.compress.zstd.ZstdDecompressor;
import io.airlift.compress.v2.zstd.ZstdCompressor;
import io.airlift.compress.v2.zstd.ZstdDecompressor;
import io.airlift.compress.v2.zstd.ZstdJavaCompressor;
import io.airlift.compress.v2.zstd.ZstdJavaDecompressor;
import io.airlift.compress.v2.zstd.ZstdNativeCompressor;
import io.airlift.compress.v2.zstd.ZstdNativeDecompressor;
import io.trino.server.ProtocolConfig;

import static com.google.common.io.BaseEncoding.base64Url;
Expand Down Expand Up @@ -43,7 +47,7 @@ public String encodePreparedStatementForHeader(String preparedStatement)
return preparedStatement;
}

ZstdCompressor compressor = new ZstdCompressor();
ZstdCompressor compressor = ZstdNativeCompressor.isEnabled() ? new ZstdNativeCompressor() : new ZstdJavaCompressor();
byte[] inputBytes = preparedStatement.getBytes(UTF_8);
byte[] compressed = new byte[compressor.maxCompressedLength(inputBytes.length)];
int outputSize = compressor.compress(inputBytes, 0, inputBytes.length, compressed, 0, compressed.length);
Expand All @@ -63,9 +67,9 @@ public String decodePreparedStatementFromHeader(String headerValue)

String encoded = headerValue.substring(PREFIX.length());
byte[] compressed = base64Url().decode(encoded);

byte[] preparedStatement = new byte[toIntExact(ZstdDecompressor.getDecompressedSize(compressed, 0, compressed.length))];
new ZstdDecompressor().decompress(compressed, 0, compressed.length, preparedStatement, 0, preparedStatement.length);
ZstdDecompressor decompressor = ZstdNativeDecompressor.isEnabled() ? new ZstdNativeDecompressor() : new ZstdJavaDecompressor();
byte[] preparedStatement = new byte[toIntExact(decompressor.getDecompressedSize(compressed, 0, compressed.length))];
decompressor.decompress(compressed, 0, compressed.length, preparedStatement, 0, preparedStatement.length);
return new String(preparedStatement, UTF_8);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,51 +13,26 @@
*/
package io.trino.server.security.oauth2;

import io.airlift.compress.zstd.ZstdCompressor;
import io.airlift.compress.zstd.ZstdDecompressor;
import io.airlift.compress.zstd.ZstdInputStream;
import io.airlift.compress.zstd.ZstdOutputStream;
import io.jsonwebtoken.CompressionCodec;
import io.jsonwebtoken.CompressionException;
import io.airlift.compress.v2.zstd.ZstdInputStream;
import io.airlift.compress.v2.zstd.ZstdOutputStream;
import io.jsonwebtoken.io.CompressionAlgorithm;

import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.io.UncheckedIOException;

import static java.lang.Math.toIntExact;
import static java.util.Arrays.copyOfRange;

public class ZstdCodec
implements CompressionCodec
implements CompressionAlgorithm
{
public static final String CODEC_NAME = "ZSTD";

@Override
public String getAlgorithmName()
public String getId()
{
return CODEC_NAME;
}

@Override
public byte[] compress(byte[] bytes)
throws CompressionException
{
ZstdCompressor compressor = new ZstdCompressor();
byte[] compressed = new byte[compressor.maxCompressedLength(bytes.length)];
int outputSize = compressor.compress(bytes, 0, bytes.length, compressed, 0, compressed.length);
return copyOfRange(compressed, 0, outputSize);
}

@Override
public byte[] decompress(byte[] bytes)
throws CompressionException
{
byte[] output = new byte[toIntExact(ZstdDecompressor.getDecompressedSize(bytes, 0, bytes.length))];
new ZstdDecompressor().decompress(bytes, 0, bytes.length, output, 0, output.length);
return output;
}

@Override
public OutputStream compress(OutputStream out)
{
Expand All @@ -74,10 +49,4 @@ public InputStream decompress(InputStream in)
{
return new ZstdInputStream(in);
}

@Override
public String getId()
{
return CODEC_NAME;
}
}

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
*/
package io.trino.hive.formats.compression;

import io.airlift.compress.hadoop.HadoopStreams;
import io.airlift.compress.v2.hadoop.HadoopStreams;

import java.io.IOException;
import java.io.InputStream;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,15 @@

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import io.airlift.compress.bzip2.BZip2HadoopStreams;
import io.airlift.compress.deflate.JdkDeflateHadoopStreams;
import io.airlift.compress.gzip.JdkGzipHadoopStreams;
import io.airlift.compress.hadoop.HadoopStreams;
import io.airlift.compress.lz4.Lz4HadoopStreams;
import io.airlift.compress.lzo.LzoHadoopStreams;
import io.airlift.compress.lzo.LzopHadoopStreams;
import io.airlift.compress.snappy.SnappyHadoopStreams;
import io.airlift.compress.zstd.ZstdHadoopStreams;
import io.airlift.compress.v2.bzip2.BZip2HadoopStreams;
import io.airlift.compress.v2.deflate.JdkDeflateHadoopStreams;
import io.airlift.compress.v2.gzip.JdkGzipHadoopStreams;
import io.airlift.compress.v2.hadoop.HadoopStreams;
import io.airlift.compress.v2.lz4.Lz4HadoopStreams;
import io.airlift.compress.v2.lzo.LzoHadoopStreams;
import io.airlift.compress.v2.lzo.LzopHadoopStreams;
import io.airlift.compress.v2.snappy.SnappyHadoopStreams;
import io.airlift.compress.v2.zstd.ZstdHadoopStreams;

import java.util.Arrays;
import java.util.List;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
*/
package io.trino.hive.formats.compression;

import io.airlift.compress.hadoop.HadoopStreams;
import io.airlift.compress.v2.hadoop.HadoopStreams;
import io.airlift.slice.Slice;
import io.trino.plugin.base.io.ChunkedSliceOutput;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
*/
package io.trino.hive.formats.compression;

import io.airlift.compress.hadoop.HadoopStreams;
import io.airlift.compress.v2.hadoop.HadoopStreams;
import io.airlift.slice.DynamicSliceOutput;
import io.airlift.slice.Slice;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
*/
package io.trino.hive.formats.compression;

import io.airlift.compress.hadoop.HadoopStreams;
import io.airlift.compress.v2.hadoop.HadoopStreams;
import io.airlift.slice.Slice;

import java.io.IOException;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,4 @@
package com.hadoop.compression.lzo;

public class LzopCodec
extends io.airlift.compress.lzo.LzopCodec {}
extends io.airlift.compress.v2.lzo.LzopCodec {}
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,5 @@
package org.apache.hadoop.io.compress;

public class LzoCodec
extends io.airlift.compress.lzo.LzoCodec
extends io.airlift.compress.v2.lzo.LzoCodec
{}
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@
*/
package io.trino.orc;

import io.airlift.compress.Compressor;
import io.airlift.compress.v2.Compressor;

import java.nio.ByteBuffer;
import java.lang.foreign.MemorySegment;
import java.util.zip.Deflater;

import static java.util.zip.Deflater.FULL_FLUSH;
Expand Down Expand Up @@ -58,8 +58,8 @@ public int compress(byte[] input, int inputOffset, int inputLength, byte[] outpu
}

@Override
public void compress(ByteBuffer input, ByteBuffer output)
public int compress(MemorySegment input, MemorySegment output)
{
throw new UnsupportedOperationException("Compression of byte buffer not supported for deflate");
throw new UnsupportedOperationException();
}
}
Loading

0 comments on commit b4ce9a2

Please sign in to comment.