diff --git a/java/vector/src/main/java/org/apache/arrow/vector/file/ArrowReader.java b/java/vector/src/main/java/org/apache/arrow/vector/file/ArrowReader.java index cd520da54f2f5..58c51605c5600 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/file/ArrowReader.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/file/ArrowReader.java @@ -31,6 +31,7 @@ import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.vector.schema.ArrowFieldNode; import org.apache.arrow.vector.schema.ArrowRecordBatch; +import org.apache.arrow.vector.stream.MessageSerializer; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -39,7 +40,7 @@ public class ArrowReader implements AutoCloseable { private static final Logger LOGGER = LoggerFactory.getLogger(ArrowReader.class); - private static final byte[] MAGIC = "ARROW1".getBytes(); + public static final byte[] MAGIC = "ARROW1".getBytes(); private final SeekableByteChannel in; @@ -73,13 +74,6 @@ private int readFully(ByteBuffer buffer) throws IOException { return total; } - private static int bytesToInt(byte[] bytes) { - return ((int)(bytes[3] & 255) << 24) + - ((int)(bytes[2] & 255) << 16) + - ((int)(bytes[1] & 255) << 8) + - ((int)(bytes[0] & 255) << 0); - } - public ArrowFooter readFooter() throws IOException { if (footer == null) { if (in.size() <= (MAGIC.length * 2 + 4)) { @@ -93,7 +87,7 @@ public ArrowFooter readFooter() throws IOException { if (!Arrays.equals(MAGIC, Arrays.copyOfRange(array, 4, array.length))) { throw new InvalidArrowFileException("missing Magic number " + Arrays.toString(buffer.array())); } - int footerLength = bytesToInt(array); + int footerLength = MessageSerializer.bytesToInt(array); if (footerLength <= 0 || footerLength + MAGIC.length * 2 + 4 > in.size()) { throw new InvalidArrowFileException("invalid footer length: " + footerLength); } diff --git a/java/vector/src/main/java/org/apache/arrow/vector/file/ArrowWriter.java b/java/vector/src/main/java/org/apache/arrow/vector/file/ArrowWriter.java index 1cd87ebc33594..3febd11f4c76a 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/file/ArrowWriter.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/file/ArrowWriter.java @@ -18,7 +18,6 @@ package org.apache.arrow.vector.file; import java.io.IOException; -import java.nio.ByteBuffer; import java.nio.channels.WritableByteChannel; import java.util.ArrayList; import java.util.Collections; @@ -26,32 +25,25 @@ import org.apache.arrow.vector.schema.ArrowBuffer; import org.apache.arrow.vector.schema.ArrowRecordBatch; -import org.apache.arrow.vector.schema.FBSerializable; import org.apache.arrow.vector.types.pojo.Schema; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import com.google.flatbuffers.FlatBufferBuilder; - import io.netty.buffer.ArrowBuf; public class ArrowWriter implements AutoCloseable { private static final Logger LOGGER = LoggerFactory.getLogger(ArrowWriter.class); - private static final byte[] MAGIC = "ARROW1".getBytes(); - - private final WritableByteChannel out; + private final WriteChannel out; private final Schema schema; private final List recordBatches = new ArrayList<>(); - private long currentPosition = 0; - private boolean started = false; public ArrowWriter(WritableByteChannel out, Schema schema) { - this.out = out; + this.out = new WriteChannel(out); this.schema = schema; } @@ -59,53 +51,19 @@ private void start() throws IOException { writeMagic(); } - private long write(byte[] buffer) throws IOException { - return write(ByteBuffer.wrap(buffer)); - } - - private long writeZeros(int zeroCount) throws IOException { - return write(new byte[zeroCount]); - } - - private long align() throws IOException { - if (currentPosition % 8 != 0) { // align on 8 byte boundaries - return writeZeros(8 - (int)(currentPosition % 8)); - } - return 0; - } - - private long write(ByteBuffer buffer) throws IOException { - long length = buffer.remaining(); - out.write(buffer); - currentPosition += length; - return length; - } - - private static byte[] intToBytes(int value) { - byte[] outBuffer = new byte[4]; - outBuffer[3] = (byte)(value >>> 24); - outBuffer[2] = (byte)(value >>> 16); - outBuffer[1] = (byte)(value >>> 8); - outBuffer[0] = (byte)(value >>> 0); - return outBuffer; - } - - private long writeIntLittleEndian(int v) throws IOException { - return write(intToBytes(v)); - } // TODO: write dictionaries public void writeRecordBatch(ArrowRecordBatch recordBatch) throws IOException { checkStarted(); - align(); + out.align(); // write metadata header with int32 size prefix - long offset = currentPosition; - write(recordBatch, true); - align(); + long offset = out.getCurrentPosition(); + out.write(recordBatch, true); + out.align(); // write body - long bodyOffset = currentPosition; + long bodyOffset = out.getCurrentPosition(); List buffers = recordBatch.getBuffers(); List buffersLayout = recordBatch.getBuffersLayout(); if (buffers.size() != buffersLayout.size()) { @@ -115,31 +73,25 @@ public void writeRecordBatch(ArrowRecordBatch recordBatch) throws IOException { ArrowBuf buffer = buffers.get(i); ArrowBuffer layout = buffersLayout.get(i); long startPosition = bodyOffset + layout.getOffset(); - if (startPosition != currentPosition) { - writeZeros((int)(startPosition - currentPosition)); + if (startPosition != out.getCurrentPosition()) { + out.writeZeros((int)(startPosition - out.getCurrentPosition())); } - write(buffer); - if (currentPosition != startPosition + layout.getSize()) { - throw new IllegalStateException("wrong buffer size: " + currentPosition + " != " + startPosition + layout.getSize()); + out.write(buffer); + if (out.getCurrentPosition() != startPosition + layout.getSize()) { + throw new IllegalStateException("wrong buffer size: " + out.getCurrentPosition() + " != " + startPosition + layout.getSize()); } } int metadataLength = (int)(bodyOffset - offset); if (metadataLength <= 0) { throw new InvalidArrowFileException("invalid recordBatch"); } - long bodyLength = currentPosition - bodyOffset; + long bodyLength = out.getCurrentPosition() - bodyOffset; LOGGER.debug(String.format("RecordBatch at %d, metadata: %d, body: %d", offset, metadataLength, bodyLength)); // add metadata to footer recordBatches.add(new ArrowBlock(offset, metadataLength, bodyLength)); } - private void write(ArrowBuf buffer) throws IOException { - ByteBuffer nioBuffer = buffer.nioBuffer(buffer.readerIndex(), buffer.readableBytes()); - LOGGER.debug("Writing buffer with size: " + nioBuffer.remaining()); - write(nioBuffer); - } - private void checkStarted() throws IOException { if (!started) { started = true; @@ -147,15 +99,16 @@ private void checkStarted() throws IOException { } } + @Override public void close() throws IOException { try { - long footerStart = currentPosition; + long footerStart = out.getCurrentPosition(); writeFooter(); - int footerLength = (int)(currentPosition - footerStart); + int footerLength = (int)(out.getCurrentPosition() - footerStart); if (footerLength <= 0 ) { throw new InvalidArrowFileException("invalid footer"); } - writeIntLittleEndian(footerLength); + out.writeIntLittleEndian(footerLength); LOGGER.debug(String.format("Footer starts at %d, length: %d", footerStart, footerLength)); writeMagic(); } finally { @@ -164,27 +117,12 @@ public void close() throws IOException { } private void writeMagic() throws IOException { - write(MAGIC); - LOGGER.debug(String.format("magic written, now at %d", currentPosition)); + out.write(ArrowReader.MAGIC); + LOGGER.debug(String.format("magic written, now at %d", out.getCurrentPosition())); } private void writeFooter() throws IOException { // TODO: dictionaries - write(new ArrowFooter(schema, Collections.emptyList(), recordBatches), false); - } - - private long write(FBSerializable writer, boolean withSizePrefix) throws IOException { - FlatBufferBuilder builder = new FlatBufferBuilder(); - int root = writer.writeTo(builder); - builder.finish(root); - - ByteBuffer buffer = builder.dataBuffer(); - - if (withSizePrefix) { - writeIntLittleEndian(buffer.remaining()); - } - - return write(buffer); + out.write(new ArrowFooter(schema, Collections.emptyList(), recordBatches), false); } - } diff --git a/java/vector/src/main/java/org/apache/arrow/vector/file/ReadChannel.java b/java/vector/src/main/java/org/apache/arrow/vector/file/ReadChannel.java new file mode 100644 index 0000000000000..b062f3826eab3 --- /dev/null +++ b/java/vector/src/main/java/org/apache/arrow/vector/file/ReadChannel.java @@ -0,0 +1,75 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.arrow.vector.file; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.channels.ReadableByteChannel; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import io.netty.buffer.ArrowBuf; + +public class ReadChannel implements AutoCloseable { + + private static final Logger LOGGER = LoggerFactory.getLogger(ReadChannel.class); + + private ReadableByteChannel in; + private long bytesRead = 0; + + public ReadChannel(ReadableByteChannel in) { + this.in = in; + } + + public long bytesRead() { return bytesRead; } + + /** + * Reads bytes into buffer until it is full (buffer.remaining() == 0). Returns the + * number of bytes read which can be less than full if there are no more. + */ + public int readFully(ByteBuffer buffer) throws IOException { + LOGGER.debug("Reading buffer with size: " + buffer.remaining()); + int totalRead = 0; + while (buffer.remaining() != 0) { + int read = in.read(buffer); + if (read < 0) return totalRead; + totalRead += read; + if (read == 0) break; + } + this.bytesRead += totalRead; + return totalRead; + } + + /** + * Reads up to len into buffer. Returns bytes read. + */ + public int readFully(ArrowBuf buffer, int l) throws IOException { + int n = readFully(buffer.nioBuffer(buffer.writerIndex(), l)); + buffer.writerIndex(n); + return n; + } + + @Override + public void close() throws IOException { + if (this.in != null) { + in.close(); + in = null; + } + } +} diff --git a/java/vector/src/main/java/org/apache/arrow/vector/file/WriteChannel.java b/java/vector/src/main/java/org/apache/arrow/vector/file/WriteChannel.java new file mode 100644 index 0000000000000..d99c9a6c99958 --- /dev/null +++ b/java/vector/src/main/java/org/apache/arrow/vector/file/WriteChannel.java @@ -0,0 +1,111 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.arrow.vector.file; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.channels.WritableByteChannel; + +import org.apache.arrow.vector.schema.FBSerializable; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import com.google.flatbuffers.FlatBufferBuilder; + +import io.netty.buffer.ArrowBuf; + +/** + * Wrapper around a WritableByteChannel that maintains the position as well adding + * some common serialization utilities. + */ +public class WriteChannel implements AutoCloseable { + private static final Logger LOGGER = LoggerFactory.getLogger(WriteChannel.class); + + private long currentPosition = 0; + + private final WritableByteChannel out; + + public WriteChannel(WritableByteChannel out) { + this.out = out; + } + + @Override + public void close() throws IOException { + out.close(); + } + + public long getCurrentPosition() { + return currentPosition; + } + + public long write(byte[] buffer) throws IOException { + return write(ByteBuffer.wrap(buffer)); + } + + public long writeZeros(int zeroCount) throws IOException { + return write(new byte[zeroCount]); + } + + public long align() throws IOException { + if (currentPosition % 8 != 0) { // align on 8 byte boundaries + return writeZeros(8 - (int)(currentPosition % 8)); + } + return 0; + } + + public long write(ByteBuffer buffer) throws IOException { + long length = buffer.remaining(); + LOGGER.debug("Writing buffer with size: " + length); + out.write(buffer); + currentPosition += length; + return length; + } + + public static byte[] intToBytes(int value) { + byte[] outBuffer = new byte[4]; + outBuffer[3] = (byte)(value >>> 24); + outBuffer[2] = (byte)(value >>> 16); + outBuffer[1] = (byte)(value >>> 8); + outBuffer[0] = (byte)(value >>> 0); + return outBuffer; + } + + public long writeIntLittleEndian(int v) throws IOException { + return write(intToBytes(v)); + } + + public void write(ArrowBuf buffer) throws IOException { + ByteBuffer nioBuffer = buffer.nioBuffer(buffer.readerIndex(), buffer.readableBytes()); + write(nioBuffer); + } + + public long write(FBSerializable writer, boolean withSizePrefix) throws IOException { + ByteBuffer buffer = serialize(writer); + if (withSizePrefix) { + writeIntLittleEndian(buffer.remaining()); + } + return write(buffer); + } + + public static ByteBuffer serialize(FBSerializable writer) { + FlatBufferBuilder builder = new FlatBufferBuilder(); + int root = writer.writeTo(builder); + builder.finish(root); + return builder.dataBuffer(); + } +} diff --git a/java/vector/src/main/java/org/apache/arrow/vector/schema/ArrowRecordBatch.java b/java/vector/src/main/java/org/apache/arrow/vector/schema/ArrowRecordBatch.java index adb99e2f3ffb7..40c2fbfd984f8 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/schema/ArrowRecordBatch.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/schema/ArrowRecordBatch.java @@ -19,6 +19,7 @@ import static org.apache.arrow.vector.schema.FBSerializables.writeAllStructsToVector; +import java.nio.ByteBuffer; import java.util.ArrayList; import java.util.Collections; import java.util.List; @@ -130,6 +131,28 @@ public String toString() { + buffersLayout + ", closed=" + closed + "]"; } + /** + * Computes the size of the serialized body for this recordBatch. + */ + public int computeBodyLength() { + int size = 0; + + List buffers = getBuffers(); + List buffersLayout = getBuffersLayout(); + if (buffers.size() != buffersLayout.size()) { + throw new IllegalStateException("the layout does not match: " + + buffers.size() + " != " + buffersLayout.size()); + } + for (int i = 0; i < buffers.size(); i++) { + ArrowBuf buffer = buffers.get(i); + ArrowBuffer layout = buffersLayout.get(i); + size += (layout.getOffset() - size); + ByteBuffer nioBuffer = + buffer.nioBuffer(buffer.readerIndex(), buffer.readableBytes()); + size += nioBuffer.remaining(); + } + return size; + } } diff --git a/java/vector/src/main/java/org/apache/arrow/vector/stream/ArrowStreamReader.java b/java/vector/src/main/java/org/apache/arrow/vector/stream/ArrowStreamReader.java new file mode 100644 index 0000000000000..f32966c5d5217 --- /dev/null +++ b/java/vector/src/main/java/org/apache/arrow/vector/stream/ArrowStreamReader.java @@ -0,0 +1,95 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.arrow.vector.stream; + +import java.io.IOException; +import java.io.InputStream; +import java.nio.channels.Channels; +import java.nio.channels.ReadableByteChannel; + +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.vector.file.ReadChannel; +import org.apache.arrow.vector.schema.ArrowRecordBatch; +import org.apache.arrow.vector.types.pojo.Schema; + +import com.google.common.base.Preconditions; + +/** + * This classes reads from an input stream and produces ArrowRecordBatches. + */ +public class ArrowStreamReader implements AutoCloseable { + private ReadChannel in; + private final BufferAllocator allocator; + private Schema schema; + + /** + * Constructs a streaming read, reading bytes from 'in'. Non-blocking. + */ + public ArrowStreamReader(ReadableByteChannel in, BufferAllocator allocator) { + super(); + this.in = new ReadChannel(in); + this.allocator = allocator; + } + + public ArrowStreamReader(InputStream in, BufferAllocator allocator) { + this(Channels.newChannel(in), allocator); + } + + /** + * Initializes the reader. Must be called before the other APIs. This is blocking. + */ + public void init() throws IOException { + Preconditions.checkState(this.schema == null, "Cannot call init() more than once."); + this.schema = readSchema(); + } + + /** + * Returns the schema for all records in this stream. + */ + public Schema getSchema () { + Preconditions.checkState(this.schema != null, "Must call init() first."); + return schema; + } + + public long bytesRead() { return in.bytesRead(); } + + /** + * Reads and returns the next ArrowRecordBatch. Returns null if this is the end + * of stream. + */ + public ArrowRecordBatch nextRecordBatch() throws IOException { + Preconditions.checkState(this.in != null, "Cannot call after close()"); + Preconditions.checkState(this.schema != null, "Must call init() first."); + return MessageSerializer.deserializeRecordBatch(in, allocator); + } + + @Override + public void close() throws IOException { + if (this.in != null) { + in.close(); + in = null; + } + } + + /** + * Reads the schema message from the beginning of the stream. + */ + private Schema readSchema() throws IOException { + return MessageSerializer.deserializeSchema(in); + } +} diff --git a/java/vector/src/main/java/org/apache/arrow/vector/stream/ArrowStreamWriter.java b/java/vector/src/main/java/org/apache/arrow/vector/stream/ArrowStreamWriter.java new file mode 100644 index 0000000000000..06acf9f2c140e --- /dev/null +++ b/java/vector/src/main/java/org/apache/arrow/vector/stream/ArrowStreamWriter.java @@ -0,0 +1,71 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.arrow.vector.stream; + +import java.io.IOException; +import java.io.OutputStream; +import java.nio.channels.Channels; +import java.nio.channels.WritableByteChannel; + +import org.apache.arrow.vector.file.WriteChannel; +import org.apache.arrow.vector.schema.ArrowRecordBatch; +import org.apache.arrow.vector.types.pojo.Schema; + +public class ArrowStreamWriter implements AutoCloseable { + private final WriteChannel out; + private final Schema schema; + private boolean headerSent = false; + + /** + * Creates the stream writer. non-blocking. + * totalBatches can be set if the writer knows beforehand. Can be -1 if unknown. + */ + public ArrowStreamWriter(WritableByteChannel out, Schema schema, int totalBatches) { + this.out = new WriteChannel(out); + this.schema = schema; + } + + public ArrowStreamWriter(OutputStream out, Schema schema, int totalBatches) + throws IOException { + this(Channels.newChannel(out), schema, totalBatches); + } + + public long bytesWritten() { return out.getCurrentPosition(); } + + public void writeRecordBatch(ArrowRecordBatch batch) throws IOException { + // Send the header if we have not yet. + checkAndSendHeader(); + MessageSerializer.serialize(out, batch); + } + + @Override + public void close() throws IOException { + // The header might not have been sent if this is an empty stream. Send it even in + // this case so readers see a valid empty stream. + checkAndSendHeader(); + out.close(); + } + + private void checkAndSendHeader() throws IOException { + if (!headerSent) { + MessageSerializer.serialize(out, schema); + headerSent = true; + } + } +} + diff --git a/java/vector/src/main/java/org/apache/arrow/vector/stream/MessageSerializer.java b/java/vector/src/main/java/org/apache/arrow/vector/stream/MessageSerializer.java new file mode 100644 index 0000000000000..22c46e2817b1e --- /dev/null +++ b/java/vector/src/main/java/org/apache/arrow/vector/stream/MessageSerializer.java @@ -0,0 +1,216 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.arrow.vector.stream; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.List; + +import org.apache.arrow.flatbuf.Buffer; +import org.apache.arrow.flatbuf.FieldNode; +import org.apache.arrow.flatbuf.Message; +import org.apache.arrow.flatbuf.MessageHeader; +import org.apache.arrow.flatbuf.MetadataVersion; +import org.apache.arrow.flatbuf.RecordBatch; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.vector.file.ReadChannel; +import org.apache.arrow.vector.file.WriteChannel; +import org.apache.arrow.vector.schema.ArrowBuffer; +import org.apache.arrow.vector.schema.ArrowFieldNode; +import org.apache.arrow.vector.schema.ArrowRecordBatch; +import org.apache.arrow.vector.types.pojo.Schema; + +import com.google.flatbuffers.FlatBufferBuilder; + +import io.netty.buffer.ArrowBuf; + +/** + * Utility class for serializing Messages. Messages are all serialized a similar way. + * 1. 4 byte little endian message header prefix + * 2. FB serialized Message: This includes it the body length, which is the serialized + * body and the type of the message. + * 3. Serialized message. + * + * For schema messages, the serialization is simply the FB serialized Schema. + * + * For RecordBatch messages the serialization is: + * 1. 4 byte little endian batch metadata header + * 2. FB serialized RowBatch + * 3. serialized RowBatch buffers. + */ +public class MessageSerializer { + + public static int bytesToInt(byte[] bytes) { + return ((bytes[3] & 255) << 24) + + ((bytes[2] & 255) << 16) + + ((bytes[1] & 255) << 8) + + ((bytes[0] & 255) << 0); + } + + /** + * Serialize a schema object. + */ + public static long serialize(WriteChannel out, Schema schema) throws IOException { + FlatBufferBuilder builder = new FlatBufferBuilder(); + builder.finish(schema.getSchema(builder)); + ByteBuffer serializedBody = builder.dataBuffer(); + ByteBuffer serializedHeader = + serializeHeader(MessageHeader.Schema, serializedBody.remaining()); + + long size = out.writeIntLittleEndian(serializedHeader.remaining()); + size += out.write(serializedHeader); + size += out.write(serializedBody); + return size; + } + + /** + * Deserializes a schema object. Format is from serialize(). + */ + public static Schema deserializeSchema(ReadChannel in) throws IOException { + Message header = deserializeHeader(in, MessageHeader.Schema); + if (header == null) { + throw new IOException("Unexpected end of input. Missing schema."); + } + + // Now read the schema. + ByteBuffer buffer = ByteBuffer.allocate((int)header.bodyLength()); + if (in.readFully(buffer) != header.bodyLength()) { + throw new IOException("Unexpected end of input trying to read schema."); + } + buffer.rewind(); + return Schema.deserialize(buffer); + } + + /** + * Serializes an ArrowRecordBatch. + */ + public static long serialize(WriteChannel out, ArrowRecordBatch batch) + throws IOException { + long start = out.getCurrentPosition(); + int bodyLength = batch.computeBodyLength(); + + ByteBuffer metadata = WriteChannel.serialize(batch); + ByteBuffer serializedHeader = + serializeHeader(MessageHeader.RecordBatch, bodyLength + metadata.remaining() + 4); + + // Write message header. + out.writeIntLittleEndian(serializedHeader.remaining()); + out.write(serializedHeader); + + // Write the metadata, with the 4 byte little endian prefix + out.writeIntLittleEndian(metadata.remaining()); + out.write(metadata); + + // Write batch header. + long offset = out.getCurrentPosition(); + List buffers = batch.getBuffers(); + List buffersLayout = batch.getBuffersLayout(); + + for (int i = 0; i < buffers.size(); i++) { + ArrowBuf buffer = buffers.get(i); + ArrowBuffer layout = buffersLayout.get(i); + long startPosition = offset + layout.getOffset(); + if (startPosition != out.getCurrentPosition()) { + out.writeZeros((int)(startPosition - out.getCurrentPosition())); + } + out.write(buffer); + if (out.getCurrentPosition() != startPosition + layout.getSize()) { + throw new IllegalStateException("wrong buffer size: " + out.getCurrentPosition() + + " != " + startPosition + layout.getSize()); + } + } + return out.getCurrentPosition() - start; + } + + /** + * Deserializes a RecordBatch + */ + public static ArrowRecordBatch deserializeRecordBatch(ReadChannel in, + BufferAllocator alloc) throws IOException { + Message header = deserializeHeader(in, MessageHeader.RecordBatch); + if (header == null) return null; + + int messageLen = (int)header.bodyLength(); + // Now read the buffer. This has the metadata followed by the data. + ArrowBuf buffer = alloc.buffer(messageLen); + if (in.readFully(buffer, messageLen) != messageLen) { + throw new IOException("Unexpected end of input trying to read batch."); + } + + // Read the metadata. It starts with the 4 byte size of the metadata. + int metadataSize = buffer.readInt(); + RecordBatch recordBatchFB = + RecordBatch.getRootAsRecordBatch( buffer.nioBuffer().asReadOnlyBuffer()); + + // No read the body + final ArrowBuf body = buffer.slice(4 + metadataSize, messageLen - metadataSize - 4); + int nodesLength = recordBatchFB.nodesLength(); + List nodes = new ArrayList<>(); + for (int i = 0; i < nodesLength; ++i) { + FieldNode node = recordBatchFB.nodes(i); + nodes.add(new ArrowFieldNode(node.length(), node.nullCount())); + } + List buffers = new ArrayList<>(); + for (int i = 0; i < recordBatchFB.buffersLength(); ++i) { + Buffer bufferFB = recordBatchFB.buffers(i); + ArrowBuf vectorBuffer = body.slice((int)bufferFB.offset(), (int)bufferFB.length()); + buffers.add(vectorBuffer); + } + ArrowRecordBatch arrowRecordBatch = + new ArrowRecordBatch(recordBatchFB.length(), nodes, buffers); + buffer.release(); + return arrowRecordBatch; + } + + /** + * Serializes a message header. + */ + private static ByteBuffer serializeHeader(byte headerType, int bodyLength) { + FlatBufferBuilder headerBuilder = new FlatBufferBuilder(); + Message.startMessage(headerBuilder); + Message.addHeaderType(headerBuilder, headerType); + Message.addVersion(headerBuilder, MetadataVersion.V1); + Message.addBodyLength(headerBuilder, bodyLength); + headerBuilder.finish(Message.endMessage(headerBuilder)); + return headerBuilder.dataBuffer(); + } + + private static Message deserializeHeader(ReadChannel in, byte headerType) throws IOException { + // Read the header size. There is an i32 little endian prefix. + ByteBuffer buffer = ByteBuffer.allocate(4); + if (in.readFully(buffer) != 4) { + return null; + } + + int headerLength = bytesToInt(buffer.array()); + buffer = ByteBuffer.allocate(headerLength); + if (in.readFully(buffer) != headerLength) { + throw new IOException( + "Unexpected end of stream trying to read header."); + } + buffer.rewind(); + + Message header = Message.getRootAsMessage(buffer); + if (header.headerType() != headerType) { + throw new IOException("Invalid message: expecting " + headerType + + ". Message contained: " + header.headerType()); + } + return header; + } +} diff --git a/java/vector/src/main/java/org/apache/arrow/vector/types/pojo/Schema.java b/java/vector/src/main/java/org/apache/arrow/vector/types/pojo/Schema.java index 5ca8ade7891ee..c33bd6e6e61b0 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/types/pojo/Schema.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/types/pojo/Schema.java @@ -22,6 +22,7 @@ import static org.apache.arrow.vector.types.pojo.Field.convertField; import java.io.IOException; +import java.nio.ByteBuffer; import java.util.ArrayList; import java.util.Collections; import java.util.List; @@ -65,6 +66,10 @@ public static Schema fromJSON(String json) throws IOException { return reader.readValue(checkNotNull(json)); } + public static Schema deserialize(ByteBuffer buffer) { + return convertSchema(org.apache.arrow.flatbuf.Schema.getRootAsSchema(buffer)); + } + public static Schema convertSchema(org.apache.arrow.flatbuf.Schema schema) { ImmutableList.Builder childrenBuilder = ImmutableList.builder(); for (int i = 0; i < schema.fieldsLength(); i++) { diff --git a/java/vector/src/test/java/org/apache/arrow/vector/file/TestArrowFile.java b/java/vector/src/test/java/org/apache/arrow/vector/file/TestArrowFile.java index 5fa18b3ca5339..bf635fb39f5b8 100644 --- a/java/vector/src/test/java/org/apache/arrow/vector/file/TestArrowFile.java +++ b/java/vector/src/test/java/org/apache/arrow/vector/file/TestArrowFile.java @@ -18,12 +18,16 @@ package org.apache.arrow.vector.file; import static org.apache.arrow.vector.TestVectorUnloadLoad.newVectorUnloader; +import static org.junit.Assert.assertTrue; +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; import java.io.File; import java.io.FileInputStream; import java.io.FileNotFoundException; import java.io.FileOutputStream; import java.io.IOException; +import java.io.OutputStream; import java.util.List; import org.apache.arrow.memory.BufferAllocator; @@ -35,6 +39,8 @@ import org.apache.arrow.vector.complex.NullableMapVector; import org.apache.arrow.vector.schema.ArrowBuffer; import org.apache.arrow.vector.schema.ArrowRecordBatch; +import org.apache.arrow.vector.stream.ArrowStreamReader; +import org.apache.arrow.vector.stream.ArrowStreamWriter; import org.apache.arrow.vector.types.pojo.Schema; import org.junit.Assert; import org.junit.Test; @@ -52,7 +58,7 @@ public void testWrite() throws IOException { BufferAllocator vectorAllocator = allocator.newChildAllocator("original vectors", 0, Integer.MAX_VALUE); MapVector parent = new MapVector("parent", vectorAllocator, null)) { writeData(count, parent); - write(parent.getChild("root"), file); + write(parent.getChild("root"), file, new ByteArrayOutputStream()); } } @@ -66,13 +72,14 @@ public void testWriteComplex() throws IOException { writeComplexData(count, parent); FieldVector root = parent.getChild("root"); validateComplexContent(count, new VectorSchemaRoot(root)); - write(root, file); + write(root, file, new ByteArrayOutputStream()); } } @Test public void testWriteRead() throws IOException { File file = new File("target/mytest.arrow"); + ByteArrayOutputStream stream = new ByteArrayOutputStream(); int count = COUNT; // write @@ -80,7 +87,7 @@ public void testWriteRead() throws IOException { BufferAllocator originalVectorAllocator = allocator.newChildAllocator("original vectors", 0, Integer.MAX_VALUE); MapVector parent = new MapVector("parent", originalVectorAllocator, null)) { writeData(count, parent); - write(parent.getChild("root"), file); + write(parent.getChild("root"), file, stream); } // read @@ -116,11 +123,40 @@ public void testWriteRead() throws IOException { } } } + + // Read from stream. + try ( + BufferAllocator readerAllocator = allocator.newChildAllocator("reader", 0, Integer.MAX_VALUE); + ByteArrayInputStream input = new ByteArrayInputStream(stream.toByteArray()); + ArrowStreamReader arrowReader = new ArrowStreamReader(input, readerAllocator); + BufferAllocator vectorAllocator = allocator.newChildAllocator("final vectors", 0, Integer.MAX_VALUE); + MapVector parent = new MapVector("parent", vectorAllocator, null) + ) { + arrowReader.init(); + Schema schema = arrowReader.getSchema(); + LOGGER.debug("reading schema: " + schema); + + try (VectorSchemaRoot root = new VectorSchemaRoot(schema, vectorAllocator)) { + VectorLoader vectorLoader = new VectorLoader(root); + while (true) { + try (ArrowRecordBatch recordBatch = arrowReader.nextRecordBatch()) { + if (recordBatch == null) break; + List buffersLayout = recordBatch.getBuffersLayout(); + for (ArrowBuffer arrowBuffer : buffersLayout) { + Assert.assertEquals(0, arrowBuffer.getOffset() % 8); + } + vectorLoader.load(recordBatch); + } + } + validateContent(count, root); + } + } } @Test public void testWriteReadComplex() throws IOException { File file = new File("target/mytest_complex.arrow"); + ByteArrayOutputStream stream = new ByteArrayOutputStream(); int count = COUNT; // write @@ -128,7 +164,7 @@ public void testWriteReadComplex() throws IOException { BufferAllocator originalVectorAllocator = allocator.newChildAllocator("original vectors", 0, Integer.MAX_VALUE); MapVector parent = new MapVector("parent", originalVectorAllocator, null)) { writeComplexData(count, parent); - write(parent.getChild("root"), file); + write(parent.getChild("root"), file, stream); } // read @@ -156,11 +192,36 @@ public void testWriteReadComplex() throws IOException { } } } + + // Read from stream. + try ( + BufferAllocator readerAllocator = allocator.newChildAllocator("reader", 0, Integer.MAX_VALUE); + ByteArrayInputStream input = new ByteArrayInputStream(stream.toByteArray()); + ArrowStreamReader arrowReader = new ArrowStreamReader(input, readerAllocator); + BufferAllocator vectorAllocator = allocator.newChildAllocator("final vectors", 0, Integer.MAX_VALUE); + MapVector parent = new MapVector("parent", vectorAllocator, null) + ) { + arrowReader.init(); + Schema schema = arrowReader.getSchema(); + LOGGER.debug("reading schema: " + schema); + + try (VectorSchemaRoot root = new VectorSchemaRoot(schema, vectorAllocator)) { + VectorLoader vectorLoader = new VectorLoader(root); + while (true) { + try (ArrowRecordBatch recordBatch = arrowReader.nextRecordBatch()) { + if (recordBatch == null) break; + vectorLoader.load(recordBatch); + } + } + validateComplexContent(count, root); + } + } } @Test public void testWriteReadMultipleRBs() throws IOException { File file = new File("target/mytest_multiple.arrow"); + ByteArrayOutputStream stream = new ByteArrayOutputStream(); int[] counts = { 10, 5 }; // write @@ -172,10 +233,12 @@ public void testWriteReadMultipleRBs() throws IOException { VectorUnloader vectorUnloader0 = newVectorUnloader(parent.getChild("root")); Schema schema = vectorUnloader0.getSchema(); Assert.assertEquals(2, schema.getFields().size()); - try (ArrowWriter arrowWriter = new ArrowWriter(fileOutputStream.getChannel(), schema);) { + try (ArrowWriter arrowWriter = new ArrowWriter(fileOutputStream.getChannel(), schema); + ArrowStreamWriter streamWriter = new ArrowStreamWriter(stream, schema, 2)) { try (ArrowRecordBatch recordBatch = vectorUnloader0.getRecordBatch()) { Assert.assertEquals("RB #0", counts[0], recordBatch.getLength()); arrowWriter.writeRecordBatch(recordBatch); + streamWriter.writeRecordBatch(recordBatch); } parent.allocateNew(); writeData(counts[1], parent); // if we write the same data we don't catch that the metadata is stored in the wrong order. @@ -183,6 +246,7 @@ public void testWriteReadMultipleRBs() throws IOException { try (ArrowRecordBatch recordBatch = vectorUnloader1.getRecordBatch()) { Assert.assertEquals("RB #1", counts[1], recordBatch.getLength()); arrowWriter.writeRecordBatch(recordBatch); + streamWriter.writeRecordBatch(recordBatch); } } } @@ -222,11 +286,42 @@ public void testWriteReadMultipleRBs() throws IOException { } } } + + // read stream + try ( + BufferAllocator readerAllocator = allocator.newChildAllocator("reader", 0, Integer.MAX_VALUE); + ByteArrayInputStream input = new ByteArrayInputStream(stream.toByteArray()); + ArrowStreamReader arrowReader = new ArrowStreamReader(input, readerAllocator); + BufferAllocator vectorAllocator = allocator.newChildAllocator("final vectors", 0, Integer.MAX_VALUE); + MapVector parent = new MapVector("parent", vectorAllocator, null) + ) { + arrowReader.init(); + Schema schema = arrowReader.getSchema(); + LOGGER.debug("reading schema: " + schema); + int i = 0; + try (VectorSchemaRoot root = new VectorSchemaRoot(schema, vectorAllocator);) { + VectorLoader vectorLoader = new VectorLoader(root); + for (int n = 0; n < 2; n++) { + try (ArrowRecordBatch recordBatch = arrowReader.nextRecordBatch()) { + assertTrue(recordBatch != null); + Assert.assertEquals("RB #" + i, counts[i], recordBatch.getLength()); + List buffersLayout = recordBatch.getBuffersLayout(); + for (ArrowBuffer arrowBuffer : buffersLayout) { + Assert.assertEquals(0, arrowBuffer.getOffset() % 8); + } + vectorLoader.load(recordBatch); + validateContent(counts[i], root); + } + ++i; + } + } + } } @Test public void testWriteReadUnion() throws IOException { File file = new File("target/mytest_write_union.arrow"); + ByteArrayOutputStream stream = new ByteArrayOutputStream(); int count = COUNT; try ( BufferAllocator vectorAllocator = allocator.newChildAllocator("original vectors", 0, Integer.MAX_VALUE); @@ -238,9 +333,9 @@ public void testWriteReadUnion() throws IOException { validateUnionData(count, new VectorSchemaRoot(parent.getChild("root"))); - write(parent.getChild("root"), file); + write(parent.getChild("root"), file, stream); } - // read + // read try ( BufferAllocator readerAllocator = allocator.newChildAllocator("reader", 0, Integer.MAX_VALUE); FileInputStream fileInputStream = new FileInputStream(file); @@ -263,9 +358,37 @@ public void testWriteReadUnion() throws IOException { } } } + + // Read from stream. + try ( + BufferAllocator readerAllocator = allocator.newChildAllocator("reader", 0, Integer.MAX_VALUE); + ByteArrayInputStream input = new ByteArrayInputStream(stream.toByteArray()); + ArrowStreamReader arrowReader = new ArrowStreamReader(input, readerAllocator); + BufferAllocator vectorAllocator = allocator.newChildAllocator("final vectors", 0, Integer.MAX_VALUE); + MapVector parent = new MapVector("parent", vectorAllocator, null) + ) { + arrowReader.init(); + Schema schema = arrowReader.getSchema(); + LOGGER.debug("reading schema: " + schema); + + try (VectorSchemaRoot root = new VectorSchemaRoot(schema, vectorAllocator)) { + VectorLoader vectorLoader = new VectorLoader(root); + while (true) { + try (ArrowRecordBatch recordBatch = arrowReader.nextRecordBatch()) { + if (recordBatch == null) break; + vectorLoader.load(recordBatch); + } + } + validateUnionData(count, root); + } + } } - private void write(FieldVector parent, File file) throws FileNotFoundException, IOException { + /** + * Writes the contents of parents to file. If outStream is non-null, also writes it + * to outStream in the streaming serialized format. + */ + private void write(FieldVector parent, File file, OutputStream outStream) throws FileNotFoundException, IOException { VectorUnloader vectorUnloader = newVectorUnloader(parent); Schema schema = vectorUnloader.getSchema(); LOGGER.debug("writing schema: " + schema); @@ -276,5 +399,15 @@ private void write(FieldVector parent, File file) throws FileNotFoundException, ) { arrowWriter.writeRecordBatch(recordBatch); } + + // Also try serializing to the stream writer. + if (outStream != null) { + try ( + ArrowStreamWriter arrowWriter = new ArrowStreamWriter(outStream, schema, -1); + ArrowRecordBatch recordBatch = vectorUnloader.getRecordBatch(); + ) { + arrowWriter.writeRecordBatch(recordBatch); + } + } } } diff --git a/java/vector/src/test/java/org/apache/arrow/vector/stream/MessageSerializerTest.java b/java/vector/src/test/java/org/apache/arrow/vector/stream/MessageSerializerTest.java new file mode 100644 index 0000000000000..7b4de80ee03ea --- /dev/null +++ b/java/vector/src/test/java/org/apache/arrow/vector/stream/MessageSerializerTest.java @@ -0,0 +1,115 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.arrow.vector.stream; + +import static java.util.Arrays.asList; +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.nio.channels.Channels; +import java.util.Collections; +import java.util.List; + +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.file.ReadChannel; +import org.apache.arrow.vector.file.WriteChannel; +import org.apache.arrow.vector.schema.ArrowFieldNode; +import org.apache.arrow.vector.schema.ArrowRecordBatch; +import org.apache.arrow.vector.types.pojo.ArrowType; +import org.apache.arrow.vector.types.pojo.Field; +import org.apache.arrow.vector.types.pojo.Schema; +import org.junit.Test; + +import io.netty.buffer.ArrowBuf; + +public class MessageSerializerTest { + + public static ArrowBuf buf(BufferAllocator alloc, byte[] bytes) { + ArrowBuf buffer = alloc.buffer(bytes.length); + buffer.writeBytes(bytes); + return buffer; + } + + public static byte[] array(ArrowBuf buf) { + byte[] bytes = new byte[buf.readableBytes()]; + buf.readBytes(bytes); + return bytes; + } + + @Test + public void testSchemaMessageSerialization() throws IOException { + Schema schema = testSchema(); + ByteArrayOutputStream out = new ByteArrayOutputStream(); + long size = MessageSerializer.serialize( + new WriteChannel(Channels.newChannel(out)), schema); + assertEquals(size, out.toByteArray().length); + + ByteArrayInputStream in = new ByteArrayInputStream(out.toByteArray()); + Schema deserialized = MessageSerializer.deserializeSchema( + new ReadChannel(Channels.newChannel(in))); + assertEquals(schema, deserialized); + assertEquals(1, deserialized.getFields().size()); + } + + @Test + public void testSerializeRecordBatch() throws IOException { + byte[] validity = new byte[] { (byte)255, 0}; + // second half is "undefined" + byte[] values = new byte[] { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}; + + BufferAllocator alloc = new RootAllocator(Long.MAX_VALUE); + ArrowBuf validityb = buf(alloc, validity); + ArrowBuf valuesb = buf(alloc, values); + + ArrowRecordBatch batch = new ArrowRecordBatch( + 16, asList(new ArrowFieldNode(16, 8)), asList(validityb, valuesb)); + + ByteArrayOutputStream out = new ByteArrayOutputStream(); + MessageSerializer.serialize(new WriteChannel(Channels.newChannel(out)), batch); + + ByteArrayInputStream in = new ByteArrayInputStream(out.toByteArray()); + ArrowRecordBatch deserialized = MessageSerializer.deserializeRecordBatch( + new ReadChannel(Channels.newChannel(in)), alloc); + verifyBatch(deserialized, validity, values); + } + + public static Schema testSchema() { + return new Schema(asList(new Field( + "testField", true, new ArrowType.Int(8, true), Collections.emptyList()))); + } + + // Verifies batch contents matching test schema. + public static void verifyBatch(ArrowRecordBatch batch, byte[] validity, byte[] values) { + assertTrue(batch != null); + List nodes = batch.getNodes(); + assertEquals(1, nodes.size()); + ArrowFieldNode node = nodes.get(0); + assertEquals(16, node.getLength()); + assertEquals(8, node.getNullCount()); + List buffers = batch.getBuffers(); + assertEquals(2, buffers.size()); + assertArrayEquals(validity, MessageSerializerTest.array(buffers.get(0))); + assertArrayEquals(values, MessageSerializerTest.array(buffers.get(1))); + } + +} diff --git a/java/vector/src/test/java/org/apache/arrow/vector/stream/TestArrowStream.java b/java/vector/src/test/java/org/apache/arrow/vector/stream/TestArrowStream.java new file mode 100644 index 0000000000000..ba1cdaeeb2262 --- /dev/null +++ b/java/vector/src/test/java/org/apache/arrow/vector/stream/TestArrowStream.java @@ -0,0 +1,96 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.arrow.vector.stream; + +import static java.util.Arrays.asList; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; + +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.file.BaseFileTest; +import org.apache.arrow.vector.schema.ArrowFieldNode; +import org.apache.arrow.vector.schema.ArrowRecordBatch; +import org.apache.arrow.vector.types.pojo.Schema; +import org.junit.Test; + +import io.netty.buffer.ArrowBuf; + +public class TestArrowStream extends BaseFileTest { + @Test + public void testEmptyStream() throws IOException { + Schema schema = MessageSerializerTest.testSchema(); + + // Write the stream. + ByteArrayOutputStream out = new ByteArrayOutputStream(); + try (ArrowStreamWriter writer = new ArrowStreamWriter(out, schema, -1)) { + } + + ByteArrayInputStream in = new ByteArrayInputStream(out.toByteArray()); + try (ArrowStreamReader reader = new ArrowStreamReader(in, allocator)) { + reader.init(); + assertEquals(schema, reader.getSchema()); + // Empty should return null. Can be called repeatedly. + assertTrue(reader.nextRecordBatch() == null); + assertTrue(reader.nextRecordBatch() == null); + } + } + + @Test + public void testReadWrite() throws IOException { + Schema schema = MessageSerializerTest.testSchema(); + byte[] validity = new byte[] { (byte)255, 0}; + // second half is "undefined" + byte[] values = new byte[] { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}; + + int numBatches = 5; + BufferAllocator alloc = new RootAllocator(Long.MAX_VALUE); + ByteArrayOutputStream out = new ByteArrayOutputStream(); + long bytesWritten = 0; + try (ArrowStreamWriter writer = new ArrowStreamWriter(out, schema, numBatches)) { + ArrowBuf validityb = MessageSerializerTest.buf(alloc, validity); + ArrowBuf valuesb = MessageSerializerTest.buf(alloc, values); + for (int i = 0; i < numBatches; i++) { + writer.writeRecordBatch(new ArrowRecordBatch( + 16, asList(new ArrowFieldNode(16, 8)), asList(validityb, valuesb))); + } + bytesWritten = writer.bytesWritten(); + } + + ByteArrayInputStream in = new ByteArrayInputStream(out.toByteArray()); + try (ArrowStreamReader reader = new ArrowStreamReader(in, alloc)) { + reader.init(); + Schema readSchema = reader.getSchema(); + for (int i = 0; i < numBatches; i++) { + assertEquals(schema, readSchema); + assertTrue( + readSchema.getFields().get(0).getTypeLayout().getVectorTypes().toString(), + readSchema.getFields().get(0).getTypeLayout().getVectors().size() > 0); + ArrowRecordBatch recordBatch = reader.nextRecordBatch(); + MessageSerializerTest.verifyBatch(recordBatch, validity, values); + assertTrue(recordBatch != null); + } + assertTrue(reader.nextRecordBatch() == null); + assertEquals(bytesWritten, reader.bytesRead()); + } + } +} diff --git a/java/vector/src/test/java/org/apache/arrow/vector/stream/TestArrowStreamPipe.java b/java/vector/src/test/java/org/apache/arrow/vector/stream/TestArrowStreamPipe.java new file mode 100644 index 0000000000000..e187fa535cada --- /dev/null +++ b/java/vector/src/test/java/org/apache/arrow/vector/stream/TestArrowStreamPipe.java @@ -0,0 +1,129 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.arrow.vector.stream; + +import static java.util.Arrays.asList; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; + +import java.io.IOException; +import java.nio.channels.Pipe; +import java.nio.channels.ReadableByteChannel; +import java.nio.channels.WritableByteChannel; + +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.schema.ArrowFieldNode; +import org.apache.arrow.vector.schema.ArrowRecordBatch; +import org.apache.arrow.vector.types.pojo.Schema; +import org.junit.Test; + +import io.netty.buffer.ArrowBuf; + +public class TestArrowStreamPipe { + Schema schema = MessageSerializerTest.testSchema(); + // second half is "undefined" + byte[] values = new byte[] { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}; + + private final class WriterThread extends Thread { + private final int numBatches; + private final ArrowStreamWriter writer; + + public WriterThread(int numBatches, WritableByteChannel sinkChannel) + throws IOException { + this.numBatches = numBatches; + writer = new ArrowStreamWriter(sinkChannel, schema, -1); + } + + @Override + public void run() { + BufferAllocator alloc = new RootAllocator(Long.MAX_VALUE); + try { + ArrowBuf valuesb = MessageSerializerTest.buf(alloc, values); + for (int i = 0; i < numBatches; i++) { + // Send a changing byte id first. + byte[] validity = new byte[] { (byte)i, 0}; + ArrowBuf validityb = MessageSerializerTest.buf(alloc, validity); + writer.writeRecordBatch(new ArrowRecordBatch( + 16, asList(new ArrowFieldNode(16, 8)), asList(validityb, valuesb))); + } + writer.close(); + } catch (IOException e) { + e.printStackTrace(); + assertTrue(false); + } + } + + public long bytesWritten() { return writer.bytesWritten(); } + } + + private final class ReaderThread extends Thread { + private int batchesRead = 0; + private final ArrowStreamReader reader; + private final BufferAllocator alloc = new RootAllocator(Long.MAX_VALUE); + + public ReaderThread(ReadableByteChannel sourceChannel) + throws IOException { + reader = new ArrowStreamReader(sourceChannel, alloc); + } + + @Override + public void run() { + try { + reader.init(); + assertEquals(schema, reader.getSchema()); + assertTrue( + reader.getSchema().getFields().get(0).getTypeLayout().getVectorTypes().toString(), + reader.getSchema().getFields().get(0).getTypeLayout().getVectors().size() > 0); + + // Read all the batches. Each batch contains an incrementing id and then some + // constant data. Verify both. + while (true) { + ArrowRecordBatch batch = reader.nextRecordBatch(); + if (batch == null) break; + byte[] validity = new byte[] { (byte)batchesRead, 0}; + MessageSerializerTest.verifyBatch(batch, validity, values); + batchesRead++; + } + } catch (IOException e) { + e.printStackTrace(); + assertTrue(false); + } + } + + public int getBatchesRead() { return batchesRead; } + public long bytesRead() { return reader.bytesRead(); } + } + + // Starts up a producer and consumer thread to read/write batches. + @Test + public void pipeTest() throws IOException, InterruptedException { + int NUM_BATCHES = 1000; + Pipe pipe = Pipe.open(); + WriterThread writer = new WriterThread(NUM_BATCHES, pipe.sink()); + ReaderThread reader = new ReaderThread(pipe.source()); + + writer.start(); + reader.start(); + reader.join(); + writer.join(); + + assertEquals(NUM_BATCHES, reader.getBatchesRead()); + assertEquals(writer.bytesWritten(), reader.bytesRead()); + } +}