Skip to content

Commit

Permalink
Pipe: add decompressed length in RPC compression payload to avoid pot…
Browse files Browse the repository at this point in the history
…ential OOM on receiver (#12701)
  • Loading branch information
DanielWang2035 committed Jun 11, 2024
1 parent 724f2ad commit 278da48
Show file tree
Hide file tree
Showing 7 changed files with 72 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,27 @@ protected PipeCompressor(PipeCompressionType compressionType) {

public abstract byte[] compress(byte[] data) throws IOException;

/**
* Decompress the byte array to a byte array. NOTE: the length of the decompressed byte array is
* not provided in this method, and some decompressors (LZ4) may construct large byte arrays,
* leading to potential OOM.
*
* @param byteArray the byte array to be decompressed
* @return the decompressed byte array
* @throws IOException
*/
public abstract byte[] decompress(byte[] byteArray) throws IOException;

/**
* Decompress the byte array to a byte array with a known length.
*
* @param byteArray the byte array to be decompressed
* @param decompressedLength the length of the decompressed byte array
* @return the decompressed byte array
* @throws IOException
*/
public abstract byte[] decompress(byte[] byteArray, int decompressedLength) throws IOException;

public byte serialize() {
return compressionType.getIndex();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,4 +44,11 @@ public byte[] compress(byte[] data) throws IOException {
public byte[] decompress(byte[] byteArray) throws IOException {
return DECOMPRESSOR.uncompress(byteArray);
}

@Override
public byte[] decompress(byte[] byteArray, int decompressedLength) throws IOException {
byte[] uncompressed = new byte[decompressedLength];
DECOMPRESSOR.uncompress(byteArray, 0, byteArray.length, uncompressed, 0);
return uncompressed;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -44,4 +44,11 @@ public byte[] compress(byte[] data) throws IOException {
public byte[] decompress(byte[] byteArray) throws IOException {
return DECOMPRESSOR.uncompress(byteArray);
}

@Override
public byte[] decompress(byte[] byteArray, int decompressedLength) throws IOException {
byte[] uncompressed = new byte[decompressedLength];
DECOMPRESSOR.uncompress(byteArray, 0, byteArray.length, uncompressed, 0);
return uncompressed;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -44,4 +44,11 @@ public byte[] compress(byte[] data) throws IOException {
public byte[] decompress(byte[] byteArray) throws IOException {
return DECOMPRESSOR.uncompress(byteArray);
}

@Override
public byte[] decompress(byte[] byteArray, int decompressedLength) throws IOException {
byte[] uncompressed = new byte[decompressedLength];
DECOMPRESSOR.uncompress(byteArray, 0, byteArray.length, uncompressed, 0);
return uncompressed;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -44,4 +44,11 @@ public byte[] compress(byte[] data) throws IOException {
public byte[] decompress(byte[] byteArray) throws IOException {
return DECOMPRESSOR.uncompress(byteArray);
}

@Override
public byte[] decompress(byte[] byteArray, int decompressedLength) throws IOException {
byte[] uncompressed = new byte[decompressedLength];
DECOMPRESSOR.uncompress(byteArray, 0, byteArray.length, uncompressed, 0);
return uncompressed;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -41,4 +41,9 @@ public byte[] compress(byte[] data) throws IOException {
public byte[] decompress(byte[] byteArray) {
return Zstd.decompress(byteArray, (int) Zstd.decompressedSize(byteArray, 0, byteArray.length));
}

@Override
public byte[] decompress(byte[] byteArray, int decompressedLength) {
return Zstd.decompress(byteArray, decompressedLength);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,10 @@ public static TPipeTransferReq toTPipeTransferReq(
// version
// type: TRANSFER_COMPRESSED
// body:
// (byte) count of compressors
// (bytes) 1 byte for each compressor
// (byte) count of compressors (n)
// (n*3 bytes) for each compressor:
// (byte) compressor type
// (int) length of uncompressed bytes
// compressed req:
// (byte) version
// (2 bytes) type
Expand All @@ -56,18 +58,17 @@ public static TPipeTransferReq toTPipeTransferReq(

try (final PublicBAOS byteArrayOutputStream = new PublicBAOS();
final DataOutputStream outputStream = new DataOutputStream(byteArrayOutputStream)) {
ReadWriteIOUtils.write((byte) compressors.size(), outputStream);
for (final PipeCompressor compressor : compressors) {
ReadWriteIOUtils.write(compressor.serialize(), outputStream);
}

byte[] body =
BytesUtils.concatByteArrayList(
Arrays.asList(
new byte[] {originalReq.version},
BytesUtils.shortToBytes(originalReq.type),
originalReq.getBody()));

ReadWriteIOUtils.write((byte) compressors.size(), outputStream);
for (final PipeCompressor compressor : compressors) {
ReadWriteIOUtils.write(compressor.serialize(), outputStream);
ReadWriteIOUtils.write(body.length, outputStream);
body = compressor.compress(body);
}
outputStream.write(body);
Expand All @@ -84,17 +85,19 @@ public static TPipeTransferReq fromTPipeTransferReq(final TPipeTransferReq trans
final ByteBuffer compressedBuffer = transferReq.body;

final List<PipeCompressor> compressors = new ArrayList<>();
final List<Integer> uncompressedLengths = new ArrayList<>();
final int compressorsSize = ReadWriteIOUtils.readByte(compressedBuffer);
for (int i = 0; i < compressorsSize; ++i) {
compressors.add(
PipeCompressorFactory.getCompressor(ReadWriteIOUtils.readByte(compressedBuffer)));
uncompressedLengths.add(ReadWriteIOUtils.readInt(compressedBuffer));
}

byte[] body = new byte[compressedBuffer.remaining()];
compressedBuffer.get(body);

for (int i = compressors.size() - 1; i >= 0; --i) {
body = compressors.get(i).decompress(body);
body = compressors.get(i).decompress(body, uncompressedLengths.get(i));
}

final ByteBuffer decompressedBuffer = ByteBuffer.wrap(body);
Expand All @@ -115,26 +118,27 @@ public static byte[] toTPipeTransferReqBytes(
// The generated bytes consists of:
// (byte) version
// (2 bytes) type: TRANSFER_COMPRESSED
// (byte) count of compressors
// (bytes) 1 byte for each compressor
// (byte) count of compressors (n)
// (n*3 bytes) for each compressor:
// (byte) compressor type
// (int) length of uncompressed bytes
// compressed req:
// (byte) version
// (2 bytes) type
// (bytes) body
try (final PublicBAOS byteArrayOutputStream = new PublicBAOS();
final DataOutputStream outputStream = new DataOutputStream(byteArrayOutputStream)) {
byte[] body = rawReqInBytes;

ReadWriteIOUtils.write(IoTDBConnectorRequestVersion.VERSION_1.getVersion(), outputStream);
ReadWriteIOUtils.write(PipeRequestType.TRANSFER_COMPRESSED.getType(), outputStream);
ReadWriteIOUtils.write((byte) compressors.size(), outputStream);
for (final PipeCompressor compressor : compressors) {
ReadWriteIOUtils.write(compressor.serialize(), outputStream);
ReadWriteIOUtils.write(body.length, outputStream);
body = compressor.compress(body);
}

byte[] compressedReq = rawReqInBytes;
for (final PipeCompressor compressor : compressors) {
compressedReq = compressor.compress(compressedReq);
}
outputStream.write(compressedReq);
outputStream.write(body);

return byteArrayOutputStream.toByteArray();
}
Expand Down

0 comments on commit 278da48

Please sign in to comment.