diff --git a/csharp/src/Apache.Arrow/Arrays/Array.cs b/csharp/src/Apache.Arrow/Arrays/Array.cs index a453b0807267f..0838134b19c6d 100644 --- a/csharp/src/Apache.Arrow/Arrays/Array.cs +++ b/csharp/src/Apache.Arrow/Arrays/Array.cs @@ -62,16 +62,7 @@ internal static void Accept(T array, IArrowArrayVisitor visitor) public Array Slice(int offset, int length) { - if (offset > Length) - { - throw new ArgumentException($"Offset {offset} cannot be greater than Length {Length} for Array.Slice"); - } - - length = Math.Min(Data.Length - offset, length); - offset += Data.Offset; - - ArrayData newData = Data.Slice(offset, length); - return ArrowArrayFactory.BuildArray(newData) as Array; + return ArrowArrayFactory.Slice(this, offset, length) as Array; } public void Dispose() @@ -88,4 +79,4 @@ protected virtual void Dispose(bool disposing) } } } -} \ No newline at end of file +} diff --git a/csharp/src/Apache.Arrow/Arrays/ArrayDataConcatenator.cs b/csharp/src/Apache.Arrow/Arrays/ArrayDataConcatenator.cs index 8859ecd7f05b9..806defdc7ce66 100644 --- a/csharp/src/Apache.Arrow/Arrays/ArrayDataConcatenator.cs +++ b/csharp/src/Apache.Arrow/Arrays/ArrayDataConcatenator.cs @@ -49,7 +49,8 @@ private class ArrayDataConcatenationVisitor : IArrowTypeVisitor, IArrowTypeVisitor, IArrowTypeVisitor, - IArrowTypeVisitor + IArrowTypeVisitor, + IArrowTypeVisitor { public ArrayData Result { get; private set; } private readonly IReadOnlyList _arrayDataList; @@ -123,6 +124,33 @@ public void Visit(StructType type) Result = new ArrayData(type, _arrayDataList[0].Length, _arrayDataList[0].NullCount, 0, _arrayDataList[0].Buffers, children); } + public void Visit(UnionType type) + { + int bufferCount = type.Mode switch + { + UnionMode.Sparse => 1, + UnionMode.Dense => 2, + _ => throw new InvalidOperationException("TODO"), + }; + + CheckData(type, bufferCount); + List children = new List(type.Fields.Count); + + for (int i = 0; i < type.Fields.Count; i++) + { + children.Add(Concatenate(SelectChildren(i), _allocator)); + } + + ArrowBuffer[] buffers = new ArrowBuffer[bufferCount]; + buffers[0] = ConcatenateUnionTypeBuffer(); + if (bufferCount > 1) + { + buffers[1] = ConcatenateUnionOffsetBuffer(); + } + + Result = new ArrayData(type, _totalLength, _totalNullCount, 0, buffers, children); + } + public void Visit(IArrowType type) { throw new NotImplementedException($"Concatenation for {type.Name} is not supported yet."); @@ -231,6 +259,38 @@ private ArrowBuffer ConcatenateOffsetBuffer() return builder.Build(_allocator); } + private ArrowBuffer ConcatenateUnionTypeBuffer() + { + var builder = new ArrowBuffer.Builder(_totalLength); + + foreach (ArrayData arrayData in _arrayDataList) + { + builder.Append(arrayData.Buffers[0]); + } + + return builder.Build(_allocator); + } + + private ArrowBuffer ConcatenateUnionOffsetBuffer() + { + var builder = new ArrowBuffer.Builder(_totalLength); + int baseOffset = 0; + + foreach (ArrayData arrayData in _arrayDataList) + { + ReadOnlySpan span = arrayData.Buffers[1].Span.CastTo(); + foreach (int offset in span) + { + builder.Append(baseOffset + offset); + } + + // The next offset must start from the current last offset. + baseOffset += span[arrayData.Length]; + } + + return builder.Build(_allocator); + } + private List SelectChildren(int index) { var children = new List(_arrayDataList.Count); diff --git a/csharp/src/Apache.Arrow/Arrays/ArrayDataTypeComparer.cs b/csharp/src/Apache.Arrow/Arrays/ArrayDataTypeComparer.cs index 8a6bfed29abb6..6b54ec1edb573 100644 --- a/csharp/src/Apache.Arrow/Arrays/ArrayDataTypeComparer.cs +++ b/csharp/src/Apache.Arrow/Arrays/ArrayDataTypeComparer.cs @@ -27,7 +27,8 @@ internal sealed class ArrayDataTypeComparer : IArrowTypeVisitor, IArrowTypeVisitor, IArrowTypeVisitor, - IArrowTypeVisitor + IArrowTypeVisitor, + IArrowTypeVisitor { private readonly IArrowType _expectedType; private bool _dataTypeMatch; @@ -122,6 +123,15 @@ public void Visit(StructType actualType) } } + public void Visit(UnionType actualType) + { + if (_expectedType is UnionType expectedType + && CompareNested(expectedType, actualType)) + { + _dataTypeMatch = true; + } + } + private static bool CompareNested(NestedType expectedType, NestedType actualType) { if (expectedType.Fields.Count != actualType.Fields.Count) diff --git a/csharp/src/Apache.Arrow/Arrays/ArrowArrayFactory.cs b/csharp/src/Apache.Arrow/Arrays/ArrowArrayFactory.cs index f82037bff47b1..aa407203d1858 100644 --- a/csharp/src/Apache.Arrow/Arrays/ArrowArrayFactory.cs +++ b/csharp/src/Apache.Arrow/Arrays/ArrowArrayFactory.cs @@ -62,7 +62,7 @@ public static IArrowArray BuildArray(ArrayData data) case ArrowTypeId.Struct: return new StructArray(data); case ArrowTypeId.Union: - return new UnionArray(data); + return UnionArray.Create(data); case ArrowTypeId.Date64: return new Date64Array(data); case ArrowTypeId.Date32: @@ -91,5 +91,19 @@ public static IArrowArray BuildArray(ArrayData data) throw new NotSupportedException($"An ArrowArray cannot be built for type {data.DataType.TypeId}."); } } + + public static IArrowArray Slice(IArrowArray array, int offset, int length) + { + if (offset > array.Length) + { + throw new ArgumentException($"Offset {offset} cannot be greater than Length {array.Length} for Array.Slice"); + } + + length = Math.Min(array.Data.Length - offset, length); + offset += array.Data.Offset; + + ArrayData newData = array.Data.Slice(offset, length); + return BuildArray(newData); + } } } diff --git a/csharp/src/Apache.Arrow/Arrays/DenseUnionArray.cs b/csharp/src/Apache.Arrow/Arrays/DenseUnionArray.cs new file mode 100644 index 0000000000000..1aacbe11f08b9 --- /dev/null +++ b/csharp/src/Apache.Arrow/Arrays/DenseUnionArray.cs @@ -0,0 +1,52 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +using Apache.Arrow.Types; +using System; +using System.Collections.Generic; +using System.Linq; + +namespace Apache.Arrow +{ + public class DenseUnionArray : UnionArray + { + public ArrowBuffer ValueOffsetBuffer => Data.Buffers[1]; + + public ReadOnlySpan ValueOffsets => ValueOffsetBuffer.Span.CastTo(); + + public DenseUnionArray( + IArrowType dataType, + int length, + IEnumerable children, + ArrowBuffer typeIds, + ArrowBuffer valuesOffsetBuffer, + int nullCount = 0, + int offset = 0) + : base(new ArrayData( + dataType, length, nullCount, offset, new[] { typeIds, valuesOffsetBuffer }, + children.Select(child => child.Data))) + { + _fields = children.ToArray(); + ValidateMode(UnionMode.Dense, Type.Mode); + } + + public DenseUnionArray(ArrayData data) + : base(data) + { + ValidateMode(UnionMode.Dense, Type.Mode); + data.EnsureBufferCount(2); + } + } +} diff --git a/csharp/src/Apache.Arrow/Arrays/PrimitiveArrayBuilder.cs b/csharp/src/Apache.Arrow/Arrays/PrimitiveArrayBuilder.cs index a50d4b52c3257..67fe46633c18f 100644 --- a/csharp/src/Apache.Arrow/Arrays/PrimitiveArrayBuilder.cs +++ b/csharp/src/Apache.Arrow/Arrays/PrimitiveArrayBuilder.cs @@ -137,6 +137,9 @@ public TBuilder Append(T value) return Instance; } + public TBuilder Append(T? value) => + (value == null) ? AppendNull() : Append(value.Value); + public TBuilder Append(ReadOnlySpan span) { int len = ValueBuffer.Length; diff --git a/csharp/src/Apache.Arrow/Arrays/SparseUnionArray.cs b/csharp/src/Apache.Arrow/Arrays/SparseUnionArray.cs new file mode 100644 index 0000000000000..b79c44c979e47 --- /dev/null +++ b/csharp/src/Apache.Arrow/Arrays/SparseUnionArray.cs @@ -0,0 +1,46 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +using Apache.Arrow.Types; +using System.Collections.Generic; +using System.Linq; + +namespace Apache.Arrow +{ + public class SparseUnionArray : UnionArray + { + public SparseUnionArray( + IArrowType dataType, + int length, + IEnumerable children, + ArrowBuffer typeIds, + int nullCount = 0, + int offset = 0) + : base(new ArrayData( + dataType, length, nullCount, offset, new[] { typeIds }, + children.Select(child => child.Data))) + { + _fields = children.ToArray(); + ValidateMode(UnionMode.Sparse, Type.Mode); + } + + public SparseUnionArray(ArrayData data) + : base(data) + { + ValidateMode(UnionMode.Sparse, Type.Mode); + data.EnsureBufferCount(1); + } + } +} diff --git a/csharp/src/Apache.Arrow/Arrays/UnionArray.cs b/csharp/src/Apache.Arrow/Arrays/UnionArray.cs index 8bccea2b59e31..0a7ae288fd0c5 100644 --- a/csharp/src/Apache.Arrow/Arrays/UnionArray.cs +++ b/csharp/src/Apache.Arrow/Arrays/UnionArray.cs @@ -15,37 +15,88 @@ using Apache.Arrow.Types; using System; +using System.Collections.Generic; +using System.Threading; namespace Apache.Arrow { - public class UnionArray: Array + public abstract class UnionArray : IArrowArray { - public UnionType Type => Data.DataType as UnionType; + protected IReadOnlyList _fields; - public UnionMode Mode => Type.Mode; + public IReadOnlyList Fields => + LazyInitializer.EnsureInitialized(ref _fields, () => InitializeFields()); + + public ArrayData Data { get; } - public ArrowBuffer TypeBuffer => Data.Buffers[1]; + public UnionType Type => (UnionType)Data.DataType; - public ArrowBuffer ValueOffsetBuffer => Data.Buffers[2]; + public UnionMode Mode => Type.Mode; + + public ArrowBuffer TypeBuffer => Data.Buffers[0]; public ReadOnlySpan TypeIds => TypeBuffer.Span; - public ReadOnlySpan ValueOffsets => ValueOffsetBuffer.Span.CastTo().Slice(0, Length + 1); + public int Length => Data.Length; + + public int Offset => Data.Offset; - public UnionArray(ArrayData data) - : base(data) + public int NullCount => Data.NullCount; + + public bool IsValid(int index) => NullCount == 0 || Fields[TypeIds[index]].IsValid(index); + + public bool IsNull(int index) => !IsValid(index); + + protected UnionArray(ArrayData data) { + Data = data; data.EnsureDataType(ArrowTypeId.Union); - data.EnsureBufferCount(3); } - public IArrowArray GetChild(int index) + public static UnionArray Create(ArrayData data) { - // TODO: Implement - throw new NotImplementedException(); + return ((UnionType)data.DataType).Mode switch + { + UnionMode.Dense => new DenseUnionArray(data), + UnionMode.Sparse => new SparseUnionArray(data), + _ => throw new InvalidOperationException("unknown union mode in array creation") + }; } - public override void Accept(IArrowArrayVisitor visitor) => Accept(this, visitor); + public void Accept(IArrowArrayVisitor visitor) => Array.Accept(this, visitor); + public void Dispose() + { + Dispose(true); + GC.SuppressFinalize(this); + } + + protected virtual void Dispose(bool disposing) + { + if (disposing) + { + Data.Dispose(); + } + } + + protected static void ValidateMode(UnionMode expected, UnionMode actual) + { + if (expected != actual) + { + throw new ArgumentException( + $"Specified union mode <{actual}> does not match expected mode <{expected}>", + "Mode"); + } + } + + private IReadOnlyList InitializeFields() + { + IArrowArray[] result = new IArrowArray[Data.Children.Length]; + for (int i = 0; i < Data.Children.Length; i++) + { + result[i] = ArrowArrayFactory.BuildArray(Data.Children[i]); + } + return result; + } } } diff --git a/csharp/src/Apache.Arrow/C/CArrowArrayImporter.cs b/csharp/src/Apache.Arrow/C/CArrowArrayImporter.cs index 9b7bcb7abe5a5..da1b0f31b8f08 100644 --- a/csharp/src/Apache.Arrow/C/CArrowArrayImporter.cs +++ b/csharp/src/Apache.Arrow/C/CArrowArrayImporter.cs @@ -170,6 +170,15 @@ private ArrayData GetAsArrayData(CArrowArray* cArray, IArrowType type) buffers = new ArrowBuffer[] { ImportValidityBuffer(cArray) }; break; case ArrowTypeId.Union: + UnionType unionType = (UnionType)type; + children = ProcessStructChildren(cArray, unionType.Fields); + buffers = unionType.Mode switch + { + UnionMode.Dense => ImportDenseUnionBuffers(cArray), + UnionMode.Sparse => ImportSparseUnionBuffers(cArray), + _ => throw new InvalidOperationException("unknown union mode in import") + }; ; + break; case ArrowTypeId.Map: break; case ArrowTypeId.Null: @@ -286,6 +295,35 @@ private ArrowBuffer[] ImportFixedSizeListBuffers(CArrowArray* cArray) return buffers; } + private ArrowBuffer[] ImportDenseUnionBuffers(CArrowArray* cArray) + { + if (cArray->n_buffers != 2) + { + throw new InvalidOperationException("Dense union arrays are expected to have exactly two children"); + } + int length = checked((int)cArray->length); + int offsetsLength = length * 4; + + ArrowBuffer[] buffers = new ArrowBuffer[2]; + buffers[0] = new ArrowBuffer(AddMemory((IntPtr)cArray->buffers[0], 0, length)); + buffers[1] = new ArrowBuffer(AddMemory((IntPtr)cArray->buffers[1], 0, offsetsLength)); + + return buffers; + } + + private ArrowBuffer[] ImportSparseUnionBuffers(CArrowArray* cArray) + { + if (cArray->n_buffers != 1) + { + throw new InvalidOperationException("Sparse union arrays are expected to have exactly one child"); + } + + ArrowBuffer[] buffers = new ArrowBuffer[1]; + buffers[0] = new ArrowBuffer(AddMemory((IntPtr)cArray->buffers[0], 0, checked((int)cArray->length))); + + return buffers; + } + private ArrowBuffer[] ImportFixedWidthBuffers(CArrowArray* cArray, int bitWidth) { if (cArray->n_buffers != 2) diff --git a/csharp/src/Apache.Arrow/C/CArrowSchemaExporter.cs b/csharp/src/Apache.Arrow/C/CArrowSchemaExporter.cs index 66142da331ac8..c1a12362a942a 100644 --- a/csharp/src/Apache.Arrow/C/CArrowSchemaExporter.cs +++ b/csharp/src/Apache.Arrow/C/CArrowSchemaExporter.cs @@ -124,6 +124,23 @@ public static unsafe void ExportSchema(Schema schema, CArrowSchema* out_schema) _ => throw new InvalidDataException($"Unsupported time unit for export: {unit}"), }; + private static string FormatUnion(UnionType unionType) + { + StringBuilder builder = new StringBuilder(); + builder.Append(unionType.Mode switch + { + UnionMode.Sparse => "+us:", + UnionMode.Dense => "+ud:", + _ => throw new InvalidDataException($"Unsupported union mode for export: {unionType.Mode}"), + }); + for (int i = 0; i < unionType.TypeIds.Length; i++) + { + if (i > 0) { builder.Append(','); } + builder.Append(unionType.TypeIds[i]); + } + return builder.ToString(); + } + private static string GetFormat(IArrowType datatype) { switch (datatype) @@ -170,6 +187,7 @@ private static string GetFormat(IArrowType datatype) case FixedSizeListType fixedListType: return $"+w:{fixedListType.ListSize}"; case StructType _: return "+s"; + case UnionType u: return FormatUnion(u); // Dictionary case DictionaryType dictionaryType: return GetFormat(dictionaryType.IndexType); diff --git a/csharp/src/Apache.Arrow/C/CArrowSchemaImporter.cs b/csharp/src/Apache.Arrow/C/CArrowSchemaImporter.cs index 2a750d5e8250d..42c8cdd5ef548 100644 --- a/csharp/src/Apache.Arrow/C/CArrowSchemaImporter.cs +++ b/csharp/src/Apache.Arrow/C/CArrowSchemaImporter.cs @@ -184,21 +184,7 @@ public ArrowType GetAsType() } else if (format == "+s") { - var child_schemas = new ImportedArrowSchema[_cSchema->n_children]; - - for (int i = 0; i < _cSchema->n_children; i++) - { - if (_cSchema->GetChild(i) == null) - { - throw new InvalidDataException("Expected struct type child to be non-null."); - } - child_schemas[i] = new ImportedArrowSchema(_cSchema->GetChild(i), isRoot: false); - } - - - List childFields = child_schemas.Select(schema => schema.GetAsField()).ToList(); - - return new StructType(childFields); + return new StructType(ParseChildren("struct")); } else if (format.StartsWith("+w:")) { @@ -265,6 +251,30 @@ public ArrowType GetAsType() return new FixedSizeBinaryType(width); } + // Unions + if (format.StartsWith("+ud:") || format.StartsWith("+us:")) + { + UnionMode unionMode = format[2] == 'd' ? UnionMode.Dense : UnionMode.Sparse; + List typeIds = new List(); + int pos = 4; + do + { + int next = format.IndexOf(',', pos); + if (next < 0) { next = format.Length; } + + int code; + if (!int.TryParse(format.Substring(pos, next - pos), out code)) + { + throw new InvalidDataException($"Invalid type code for union import: {format.Substring(pos, next - pos)}"); + } + typeIds.Add(code); + + pos = next + 1; + } while (pos < format.Length); + + return new UnionType(ParseChildren("union"), typeIds, unionMode); + } + return format switch { // Primitives @@ -324,6 +334,22 @@ public Schema GetAsSchema() } } + private List ParseChildren(string typeName) + { + var child_schemas = new ImportedArrowSchema[_cSchema->n_children]; + + for (int i = 0; i < _cSchema->n_children; i++) + { + if (_cSchema->GetChild(i) == null) + { + throw new InvalidDataException($"Expected {typeName} type child to be non-null."); + } + child_schemas[i] = new ImportedArrowSchema(_cSchema->GetChild(i), isRoot: false); + } + + return child_schemas.Select(schema => schema.GetAsField()).ToList(); + } + private unsafe static IReadOnlyDictionary GetMetadata(byte* metadata) { if (metadata == null) diff --git a/csharp/src/Apache.Arrow/ChunkedArray.cs b/csharp/src/Apache.Arrow/ChunkedArray.cs index 5f25acfe04a2f..f5909f5adfe48 100644 --- a/csharp/src/Apache.Arrow/ChunkedArray.cs +++ b/csharp/src/Apache.Arrow/ChunkedArray.cs @@ -15,7 +15,6 @@ using System; using System.Collections.Generic; -using Apache.Arrow; using Apache.Arrow.Types; namespace Apache.Arrow @@ -25,7 +24,7 @@ namespace Apache.Arrow /// public class ChunkedArray { - private IList Arrays { get; } + private IList Arrays { get; } public IArrowType DataType { get; } public long Length { get; } public long NullCount { get; } @@ -35,9 +34,16 @@ public int ArrayCount get => Arrays.Count; } - public Array Array(int index) => Arrays[index]; + public Array Array(int index) => Arrays[index] as Array; + + public IArrowArray ArrowArray(int index) => Arrays[index]; public ChunkedArray(IList arrays) + : this(Cast(arrays)) + { + } + + public ChunkedArray(IList arrays) { Arrays = arrays ?? throw new ArgumentNullException(nameof(arrays)); if (arrays.Count < 1) @@ -45,14 +51,14 @@ public ChunkedArray(IList arrays) throw new ArgumentException($"Count must be at least 1. Got {arrays.Count} instead"); } DataType = arrays[0].Data.DataType; - foreach (Array array in arrays) + foreach (IArrowArray array in arrays) { Length += array.Length; NullCount += array.NullCount; } } - public ChunkedArray(Array array) : this(new[] { array }) { } + public ChunkedArray(Array array) : this(new IArrowArray[] { array }) { } public ChunkedArray Slice(long offset, long length) { @@ -69,10 +75,10 @@ public ChunkedArray Slice(long offset, long length) curArrayIndex++; } - IList newArrays = new List(); + IList newArrays = new List(); while (curArrayIndex < numArrays && length > 0) { - newArrays.Add(Arrays[curArrayIndex].Slice((int)offset, + newArrays.Add(ArrowArrayFactory.Slice(Arrays[curArrayIndex], (int)offset, length > Arrays[curArrayIndex].Length ? Arrays[curArrayIndex].Length : (int)length)); length -= Arrays[curArrayIndex].Length - offset; offset = 0; @@ -86,6 +92,16 @@ public ChunkedArray Slice(long offset) return Slice(offset, Length - offset); } + private static IArrowArray[] Cast(IList arrays) + { + IArrowArray[] arrowArrays = new IArrowArray[arrays.Count]; + for (int i = 0; i < arrays.Count; i++) + { + arrowArrays[i] = arrays[i]; + } + return arrowArrays; + } + // TODO: Flatten for Structs } } diff --git a/csharp/src/Apache.Arrow/Column.cs b/csharp/src/Apache.Arrow/Column.cs index 4eaf9a559e75d..0709b9142cafd 100644 --- a/csharp/src/Apache.Arrow/Column.cs +++ b/csharp/src/Apache.Arrow/Column.cs @@ -28,19 +28,23 @@ public class Column public ChunkedArray Data { get; } public Column(Field field, IList arrays) + : this(field, new ChunkedArray(arrays), doValidation: true) + { + } + + public Column(Field field, IList arrays) + : this(field, new ChunkedArray(arrays), doValidation: true) { - Data = new ChunkedArray(arrays); - Field = field; - if (!ValidateArrayDataTypes()) - { - throw new ArgumentException($"{Field.DataType} must match {Data.DataType}"); - } } - private Column(Field field, ChunkedArray arrays) + private Column(Field field, ChunkedArray data, bool doValidation = false) { + Data = data; Field = field; - Data = arrays; + if (doValidation && !ValidateArrayDataTypes()) + { + throw new ArgumentException($"{Field.DataType} must match {Data.DataType}"); + } } public long Length => Data.Length; @@ -64,12 +68,12 @@ private bool ValidateArrayDataTypes() for (int i = 0; i < Data.ArrayCount; i++) { - if (Data.Array(i).Data.DataType.TypeId != Field.DataType.TypeId) + if (Data.ArrowArray(i).Data.DataType.TypeId != Field.DataType.TypeId) { return false; } - Data.Array(i).Data.DataType.Accept(dataTypeComparer); + Data.ArrowArray(i).Data.DataType.Accept(dataTypeComparer); if (!dataTypeComparer.DataTypeMatch) { diff --git a/csharp/src/Apache.Arrow/Extensions/FlatbufExtensions.cs b/csharp/src/Apache.Arrow/Extensions/FlatbufExtensions.cs index d2a70bca9e4ec..35c5b3e55157d 100644 --- a/csharp/src/Apache.Arrow/Extensions/FlatbufExtensions.cs +++ b/csharp/src/Apache.Arrow/Extensions/FlatbufExtensions.cs @@ -80,6 +80,16 @@ public static Types.TimeUnit ToArrow(this Flatbuf.TimeUnit unit) throw new ArgumentException($"Unexpected Flatbuf TimeUnit", nameof(unit)); } } + + public static Types.UnionMode ToArrow(this Flatbuf.UnionMode mode) + { + return mode switch + { + Flatbuf.UnionMode.Dense => Types.UnionMode.Dense, + Flatbuf.UnionMode.Sparse => Types.UnionMode.Sparse, + _ => throw new ArgumentException($"Unsupported Flatbuf UnionMode", nameof(mode)), + }; + } } } diff --git a/csharp/src/Apache.Arrow/Interfaces/IArrowArray.cs b/csharp/src/Apache.Arrow/Interfaces/IArrowArray.cs index 50fbc3af6dd72..9bcee36ef4eaf 100644 --- a/csharp/src/Apache.Arrow/Interfaces/IArrowArray.cs +++ b/csharp/src/Apache.Arrow/Interfaces/IArrowArray.cs @@ -32,9 +32,5 @@ public interface IArrowArray : IDisposable ArrayData Data { get; } void Accept(IArrowArrayVisitor visitor); - - //IArrowArray Slice(int offset); - - //IArrowArray Slice(int offset, int length); } } diff --git a/csharp/src/Apache.Arrow/Ipc/ArrowReaderImplementation.cs b/csharp/src/Apache.Arrow/Ipc/ArrowReaderImplementation.cs index c9c1b21673316..d3115da52cc6c 100644 --- a/csharp/src/Apache.Arrow/Ipc/ArrowReaderImplementation.cs +++ b/csharp/src/Apache.Arrow/Ipc/ArrowReaderImplementation.cs @@ -116,11 +116,11 @@ protected RecordBatch CreateArrowObjectFromMessage( break; case Flatbuf.MessageHeader.DictionaryBatch: Flatbuf.DictionaryBatch dictionaryBatch = message.Header().Value; - ReadDictionaryBatch(dictionaryBatch, bodyByteBuffer, memoryOwner); + ReadDictionaryBatch(message.Version, dictionaryBatch, bodyByteBuffer, memoryOwner); break; case Flatbuf.MessageHeader.RecordBatch: Flatbuf.RecordBatch rb = message.Header().Value; - List arrays = BuildArrays(Schema, bodyByteBuffer, rb); + List arrays = BuildArrays(message.Version, Schema, bodyByteBuffer, rb); return new RecordBatch(Schema, memoryOwner, arrays, (int)rb.Length); default: // NOTE: Skip unsupported message type @@ -136,7 +136,11 @@ internal static ByteBuffer CreateByteBuffer(ReadOnlyMemory buffer) return new ByteBuffer(new ReadOnlyMemoryBufferAllocator(buffer), 0); } - private void ReadDictionaryBatch(Flatbuf.DictionaryBatch dictionaryBatch, ByteBuffer bodyByteBuffer, IMemoryOwner memoryOwner) + private void ReadDictionaryBatch( + MetadataVersion version, + Flatbuf.DictionaryBatch dictionaryBatch, + ByteBuffer bodyByteBuffer, + IMemoryOwner memoryOwner) { long id = dictionaryBatch.Id; IArrowType valueType = DictionaryMemo.GetDictionaryType(id); @@ -149,7 +153,7 @@ private void ReadDictionaryBatch(Flatbuf.DictionaryBatch dictionaryBatch, ByteBu Field valueField = new Field("dummy", valueType, true); var schema = new Schema(new[] { valueField }, default); - IList arrays = BuildArrays(schema, bodyByteBuffer, recordBatch.Value); + IList arrays = BuildArrays(version, schema, bodyByteBuffer, recordBatch.Value); if (arrays.Count != 1) { @@ -167,6 +171,7 @@ private void ReadDictionaryBatch(Flatbuf.DictionaryBatch dictionaryBatch, ByteBu } private List BuildArrays( + MetadataVersion version, Schema schema, ByteBuffer messageBuffer, Flatbuf.RecordBatch recordBatchMessage) @@ -187,8 +192,8 @@ private List BuildArrays( Flatbuf.FieldNode fieldNode = recordBatchEnumerator.CurrentNode; ArrayData arrayData = field.DataType.IsFixedPrimitive() - ? LoadPrimitiveField(ref recordBatchEnumerator, field, in fieldNode, messageBuffer, bufferCreator) - : LoadVariableField(ref recordBatchEnumerator, field, in fieldNode, messageBuffer, bufferCreator); + ? LoadPrimitiveField(version, ref recordBatchEnumerator, field, in fieldNode, messageBuffer, bufferCreator) + : LoadVariableField(version, ref recordBatchEnumerator, field, in fieldNode, messageBuffer, bufferCreator); arrays.Add(ArrowArrayFactory.BuildArray(arrayData)); } while (recordBatchEnumerator.MoveNextNode()); @@ -225,6 +230,7 @@ private IBufferCreator GetBufferCreator(BodyCompression? compression) } private ArrayData LoadPrimitiveField( + MetadataVersion version, ref RecordBatchEnumerator recordBatchEnumerator, Field field, in Flatbuf.FieldNode fieldNode, @@ -245,31 +251,44 @@ private ArrayData LoadPrimitiveField( throw new InvalidDataException("Null count length must be >= 0"); // TODO:Localize exception message } - if (field.DataType.TypeId == ArrowTypeId.Null) + int buffers; + switch (field.DataType.TypeId) { - return new ArrayData(field.DataType, fieldLength, fieldNullCount, 0, System.Array.Empty()); - } - - ArrowBuffer nullArrowBuffer = BuildArrowBuffer(bodyData, recordBatchEnumerator.CurrentBuffer, bufferCreator); - if (!recordBatchEnumerator.MoveNextBuffer()) - { - throw new Exception("Unable to move to the next buffer."); + case ArrowTypeId.Null: + return new ArrayData(field.DataType, fieldLength, fieldNullCount, 0, System.Array.Empty()); + case ArrowTypeId.Union: + if (version < MetadataVersion.V5) + { + if (fieldNullCount > 0) + { + if (recordBatchEnumerator.CurrentBuffer.Length > 0) + { + // With older metadata we can get a validity bitmap. Fixing up union data is hard, + // so we will just quit. + throw new NotSupportedException("Cannot read pre-1.0.0 Union array with top-level validity bitmap"); + } + } + recordBatchEnumerator.MoveNextBuffer(); + } + buffers = ((UnionType)field.DataType).Mode == Types.UnionMode.Dense ? 2 : 1; + break; + case ArrowTypeId.Struct: + case ArrowTypeId.FixedSizeList: + buffers = 1; + break; + default: + buffers = 2; + break; } - ArrowBuffer[] arrowBuff; - if (field.DataType.TypeId == ArrowTypeId.Struct || field.DataType.TypeId == ArrowTypeId.FixedSizeList) + ArrowBuffer[] arrowBuff = new ArrowBuffer[buffers]; + for (int i = 0; i < buffers; i++) { - arrowBuff = new[] { nullArrowBuffer }; - } - else - { - ArrowBuffer valueArrowBuffer = BuildArrowBuffer(bodyData, recordBatchEnumerator.CurrentBuffer, bufferCreator); + arrowBuff[i] = BuildArrowBuffer(bodyData, recordBatchEnumerator.CurrentBuffer, bufferCreator); recordBatchEnumerator.MoveNextBuffer(); - - arrowBuff = new[] { nullArrowBuffer, valueArrowBuffer }; } - ArrayData[] children = GetChildren(ref recordBatchEnumerator, field, bodyData, bufferCreator); + ArrayData[] children = GetChildren(version, ref recordBatchEnumerator, field, bodyData, bufferCreator); IArrowArray dictionary = null; if (field.DataType.TypeId == ArrowTypeId.Dictionary) @@ -282,6 +301,7 @@ private ArrayData LoadPrimitiveField( } private ArrayData LoadVariableField( + MetadataVersion version, ref RecordBatchEnumerator recordBatchEnumerator, Field field, in Flatbuf.FieldNode fieldNode, @@ -316,7 +336,7 @@ private ArrayData LoadVariableField( } ArrowBuffer[] arrowBuff = new[] { nullArrowBuffer, offsetArrowBuffer, valueArrowBuffer }; - ArrayData[] children = GetChildren(ref recordBatchEnumerator, field, bodyData, bufferCreator); + ArrayData[] children = GetChildren(version, ref recordBatchEnumerator, field, bodyData, bufferCreator); IArrowArray dictionary = null; if (field.DataType.TypeId == ArrowTypeId.Dictionary) @@ -329,6 +349,7 @@ private ArrayData LoadVariableField( } private ArrayData[] GetChildren( + MetadataVersion version, ref RecordBatchEnumerator recordBatchEnumerator, Field field, ByteBuffer bodyData, @@ -345,8 +366,8 @@ private ArrayData[] GetChildren( Field childField = type.Fields[index]; ArrayData child = childField.DataType.IsFixedPrimitive() - ? LoadPrimitiveField(ref recordBatchEnumerator, childField, in childFieldNode, bodyData, bufferCreator) - : LoadVariableField(ref recordBatchEnumerator, childField, in childFieldNode, bodyData, bufferCreator); + ? LoadPrimitiveField(version, ref recordBatchEnumerator, childField, in childFieldNode, bodyData, bufferCreator) + : LoadVariableField(version, ref recordBatchEnumerator, childField, in childFieldNode, bodyData, bufferCreator); children[index] = child; } diff --git a/csharp/src/Apache.Arrow/Ipc/ArrowStreamWriter.cs b/csharp/src/Apache.Arrow/Ipc/ArrowStreamWriter.cs index a5d8db3f509d7..2b3815af71142 100644 --- a/csharp/src/Apache.Arrow/Ipc/ArrowStreamWriter.cs +++ b/csharp/src/Apache.Arrow/Ipc/ArrowStreamWriter.cs @@ -55,6 +55,7 @@ internal class ArrowRecordBatchFlatBufferBuilder : IArrowArrayVisitor, IArrowArrayVisitor, IArrowArrayVisitor, + IArrowArrayVisitor, IArrowArrayVisitor, IArrowArrayVisitor, IArrowArrayVisitor, @@ -156,6 +157,22 @@ public void Visit(StructArray array) } } + public void Visit(UnionArray array) + { + _buffers.Add(CreateBuffer(array.TypeBuffer)); + + ArrowBuffer? offsets = (array as DenseUnionArray)?.ValueOffsetBuffer; + if (offsets != null) + { + _buffers.Add(CreateBuffer(offsets.Value)); + } + + for (int i = 0; i < array.Fields.Count; i++) + { + array.Fields[i].Accept(this); + } + } + public void Visit(DictionaryArray array) { // Dictionary is serialized separately in Dictionary serialization. @@ -218,7 +235,7 @@ public void Visit(IArrowArray array) private readonly bool _leaveOpen; private readonly IpcOptions _options; - private protected const Flatbuf.MetadataVersion CurrentMetadataVersion = Flatbuf.MetadataVersion.V4; + private protected const Flatbuf.MetadataVersion CurrentMetadataVersion = Flatbuf.MetadataVersion.V5; private static readonly byte[] s_padding = new byte[64]; diff --git a/csharp/src/Apache.Arrow/Ipc/ArrowTypeFlatbufferBuilder.cs b/csharp/src/Apache.Arrow/Ipc/ArrowTypeFlatbufferBuilder.cs index 203aa72d93ea3..b11467538dd04 100644 --- a/csharp/src/Apache.Arrow/Ipc/ArrowTypeFlatbufferBuilder.cs +++ b/csharp/src/Apache.Arrow/Ipc/ArrowTypeFlatbufferBuilder.cs @@ -120,7 +120,9 @@ public void Visit(FixedSizeListType type) public void Visit(UnionType type) { - throw new NotImplementedException(); + Result = FieldType.Build( + Flatbuf.Type.Union, + Flatbuf.Union.CreateUnion(Builder, ToFlatBuffer(type.Mode), Flatbuf.Union.CreateTypeIdsVector(Builder, type.TypeIds))); } public void Visit(StringType type) @@ -279,5 +281,15 @@ private static Flatbuf.TimeUnit ToFlatBuffer(TimeUnit unit) return result; } + + private static Flatbuf.UnionMode ToFlatBuffer(Types.UnionMode mode) + { + return mode switch + { + Types.UnionMode.Dense => Flatbuf.UnionMode.Dense, + Types.UnionMode.Sparse => Flatbuf.UnionMode.Sparse, + _ => throw new ArgumentException($"unsupported union mode <{mode}>", nameof(mode)), + }; + } } } diff --git a/csharp/src/Apache.Arrow/Ipc/MessageSerializer.cs b/csharp/src/Apache.Arrow/Ipc/MessageSerializer.cs index 8ca69b61165bf..6249063ba81f4 100644 --- a/csharp/src/Apache.Arrow/Ipc/MessageSerializer.cs +++ b/csharp/src/Apache.Arrow/Ipc/MessageSerializer.cs @@ -203,6 +203,10 @@ private static Types.IArrowType GetFieldArrowType(Flatbuf.Field field, Field[] c case Flatbuf.Type.Struct_: Debug.Assert(childFields != null); return new Types.StructType(childFields); + case Flatbuf.Type.Union: + Debug.Assert(childFields != null); + Flatbuf.Union unionMetadata = field.Type().Value; + return new Types.UnionType(childFields, unionMetadata.GetTypeIdsArray(), unionMetadata.Mode.ToArrow()); default: throw new InvalidDataException($"Arrow primitive '{field.TypeType}' is unsupported."); } diff --git a/csharp/src/Apache.Arrow/Table.cs b/csharp/src/Apache.Arrow/Table.cs index 0b9f31557bec8..939ec23f54ff2 100644 --- a/csharp/src/Apache.Arrow/Table.cs +++ b/csharp/src/Apache.Arrow/Table.cs @@ -37,10 +37,10 @@ public static Table TableFromRecordBatches(Schema schema, IList rec List columns = new List(nColumns); for (int icol = 0; icol < nColumns; icol++) { - List columnArrays = new List(nBatches); + List columnArrays = new List(nBatches); for (int jj = 0; jj < nBatches; jj++) { - columnArrays.Add(recordBatches[jj].Column(icol) as Array); + columnArrays.Add(recordBatches[jj].Column(icol)); } columns.Add(new Column(schema.GetFieldByIndex(icol), columnArrays)); } diff --git a/csharp/src/Apache.Arrow/Types/UnionType.cs b/csharp/src/Apache.Arrow/Types/UnionType.cs index 293271018aa26..23fa3b45ab278 100644 --- a/csharp/src/Apache.Arrow/Types/UnionType.cs +++ b/csharp/src/Apache.Arrow/Types/UnionType.cs @@ -24,20 +24,21 @@ public enum UnionMode Dense } - public sealed class UnionType : ArrowType + public sealed class UnionType : NestedType { public override ArrowTypeId TypeId => ArrowTypeId.Union; public override string Name => "union"; public UnionMode Mode { get; } - - public IEnumerable TypeCodes { get; } + + public int[] TypeIds { get; } public UnionType( - IEnumerable fields, IEnumerable typeCodes, + IEnumerable fields, IEnumerable typeIds, UnionMode mode = UnionMode.Sparse) + : base(fields.ToArray()) { - TypeCodes = typeCodes.ToList(); + TypeIds = typeIds.ToArray(); Mode = mode; } diff --git a/csharp/test/Apache.Arrow.IntegrationTest/IntegrationCommand.cs b/csharp/test/Apache.Arrow.IntegrationTest/IntegrationCommand.cs index abf7451e5e98c..1e76ee505a516 100644 --- a/csharp/test/Apache.Arrow.IntegrationTest/IntegrationCommand.cs +++ b/csharp/test/Apache.Arrow.IntegrationTest/IntegrationCommand.cs @@ -128,7 +128,7 @@ private RecordBatch CreateRecordBatch(Schema schema, JsonRecordBatch jsonRecordB for (int i = 0; i < jsonRecordBatch.Columns.Count; i++) { JsonFieldData data = jsonRecordBatch.Columns[i]; - Field field = schema.GetFieldByName(data.Name); + Field field = schema.FieldsList[i]; ArrayCreator creator = new ArrayCreator(data); field.DataType.Accept(creator); arrays.Add(creator.Array); @@ -188,6 +188,7 @@ private static IArrowType ToArrowType(JsonArrowType type, Field[] children) "list" => ToListArrowType(type, children), "fixedsizelist" => ToFixedSizeListArrowType(type, children), "struct" => ToStructArrowType(type, children), + "union" => ToUnionArrowType(type, children), "null" => NullType.Default, _ => throw new NotSupportedException($"JsonArrowType not supported: {type.Name}") }; @@ -281,6 +282,17 @@ private static IArrowType ToStructArrowType(JsonArrowType type, Field[] children return new StructType(children); } + private static IArrowType ToUnionArrowType(JsonArrowType type, Field[] children) + { + UnionMode mode = type.Mode switch + { + "SPARSE" => UnionMode.Sparse, + "DENSE" => UnionMode.Dense, + _ => throw new NotSupportedException($"Union mode not supported: {type.Mode}"), + }; + return new UnionType(children, type.TypeIds, mode); + } + private class ArrayCreator : IArrowTypeVisitor, IArrowTypeVisitor, @@ -306,6 +318,7 @@ private class ArrayCreator : IArrowTypeVisitor, IArrowTypeVisitor, IArrowTypeVisitor, + IArrowTypeVisitor, IArrowTypeVisitor { private JsonFieldData JsonFieldData { get; set; } @@ -556,6 +569,43 @@ public void Visit(StructType type) Array = new StructArray(arrayData); } + public void Visit(UnionType type) + { + ArrowBuffer[] buffers; + if (type.Mode == UnionMode.Dense) + { + buffers = new ArrowBuffer[2]; + buffers[1] = GetOffsetBuffer(); + } + else + { + buffers = new ArrowBuffer[1]; + } + buffers[0] = GetTypeIdBuffer(); + + ArrayData[] children = GetChildren(type); + + int nullCount = 0; + ArrayData arrayData = new ArrayData(type, JsonFieldData.Count, nullCount, 0, buffers, children); + Array = UnionArray.Create(arrayData); + } + + private ArrayData[] GetChildren(NestedType type) + { + ArrayData[] children = new ArrayData[type.Fields.Count]; + + var data = JsonFieldData; + for (int i = 0; i < children.Length; i++) + { + JsonFieldData = data.Children[i]; + type.Fields[i].DataType.Accept(this); + children[i] = Array.Data; + } + JsonFieldData = data; + + return children; + } + private static byte[] ConvertHexStringToByteArray(string hexString) { byte[] data = new byte[hexString.Length / 2]; @@ -619,11 +669,22 @@ private void GenerateLongArray(Func valueOffsets = new ArrowBuffer.Builder(JsonFieldData.Offset.Length); valueOffsets.AppendRange(JsonFieldData.Offset); return valueOffsets.Build(default); } + private ArrowBuffer GetTypeIdBuffer() + { + ArrowBuffer.Builder typeIds = new ArrowBuffer.Builder(JsonFieldData.TypeId.Length); + for (int i = 0; i < JsonFieldData.TypeId.Length; i++) + { + typeIds.Append(checked((byte)JsonFieldData.TypeId[i])); + } + return typeIds.Build(default); + } + private ArrowBuffer GetValidityBuffer(out int nullCount) { if (JsonFieldData.Validity == null) diff --git a/csharp/test/Apache.Arrow.IntegrationTest/JsonFile.cs b/csharp/test/Apache.Arrow.IntegrationTest/JsonFile.cs index f0f63d3e19b8c..112eeabcb9931 100644 --- a/csharp/test/Apache.Arrow.IntegrationTest/JsonFile.cs +++ b/csharp/test/Apache.Arrow.IntegrationTest/JsonFile.cs @@ -71,6 +71,10 @@ public class JsonArrowType // FixedSizeList fields public int ListSize { get; set; } + // union fields + public string Mode { get; set; } + public int[] TypeIds { get; set; } + [JsonExtensionData] public Dictionary ExtensionData { get; set; } } diff --git a/csharp/test/Apache.Arrow.Tests/ArrayTypeComparer.cs b/csharp/test/Apache.Arrow.Tests/ArrayTypeComparer.cs index 77584aefb1bf4..c8bcc3cee0f99 100644 --- a/csharp/test/Apache.Arrow.Tests/ArrayTypeComparer.cs +++ b/csharp/test/Apache.Arrow.Tests/ArrayTypeComparer.cs @@ -28,7 +28,8 @@ public class ArrayTypeComparer : IArrowTypeVisitor, IArrowTypeVisitor, IArrowTypeVisitor, - IArrowTypeVisitor + IArrowTypeVisitor, + IArrowTypeVisitor { private readonly IArrowType _expectedType; @@ -114,6 +115,22 @@ public void Visit(StructType actualType) CompareNested(expectedType, actualType); } + public void Visit(UnionType actualType) + { + Assert.IsAssignableFrom(_expectedType); + UnionType expectedType = (UnionType)_expectedType; + + Assert.Equal(expectedType.Mode, actualType.Mode); + + Assert.Equal(expectedType.TypeIds.Length, actualType.TypeIds.Length); + for (int i = 0; i < expectedType.TypeIds.Length; i++) + { + Assert.Equal(expectedType.TypeIds[i], actualType.TypeIds[i]); + } + + CompareNested(expectedType, actualType); + } + private static void CompareNested(NestedType expectedType, NestedType actualType) { Assert.Equal(expectedType.Fields.Count, actualType.Fields.Count); diff --git a/csharp/test/Apache.Arrow.Tests/ArrowArrayConcatenatorTests.cs b/csharp/test/Apache.Arrow.Tests/ArrowArrayConcatenatorTests.cs index 36cffe7eb4da1..f5a2c345e2ae6 100644 --- a/csharp/test/Apache.Arrow.Tests/ArrowArrayConcatenatorTests.cs +++ b/csharp/test/Apache.Arrow.Tests/ArrowArrayConcatenatorTests.cs @@ -77,6 +77,22 @@ private static IEnumerable, IArrowArray>> GenerateTestDa new Field.Builder().Name("Ints").DataType(Int32Type.Default).Nullable(true).Build() }), new FixedSizeListType(Int32Type.Default, 1), + new UnionType( + new List{ + new Field.Builder().Name("Strings").DataType(StringType.Default).Nullable(true).Build(), + new Field.Builder().Name("Ints").DataType(Int32Type.Default).Nullable(true).Build() + }, + new[] { 0, 1 }, + UnionMode.Sparse + ), + new UnionType( + new List{ + new Field.Builder().Name("Strings").DataType(StringType.Default).Nullable(true).Build(), + new Field.Builder().Name("Ints").DataType(Int32Type.Default).Nullable(true).Build() + }, + new[] { 0, 1 }, + UnionMode.Dense + ), }; foreach (IArrowType type in targetTypes) @@ -119,7 +135,8 @@ private class TestDataGenerator : IArrowTypeVisitor, IArrowTypeVisitor, IArrowTypeVisitor, - IArrowTypeVisitor + IArrowTypeVisitor, + IArrowTypeVisitor { private List> _baseData; @@ -392,6 +409,91 @@ public void Visit(StructType type) ExpectedArray = new StructArray(type, 3, new List { resultStringArray, resultInt32Array }, nullBitmapBuffer, 1); } + public void Visit(UnionType type) + { + bool isDense = type.Mode == UnionMode.Dense; + + StringArray.Builder stringResultBuilder = new StringArray.Builder().Reserve(_baseDataTotalElementCount); + Int32Array.Builder intResultBuilder = new Int32Array.Builder().Reserve(_baseDataTotalElementCount); + ArrowBuffer.Builder typeResultBuilder = new ArrowBuffer.Builder().Reserve(_baseDataTotalElementCount); + ArrowBuffer.Builder offsetResultBuilder = new ArrowBuffer.Builder().Reserve(_baseDataTotalElementCount); + int resultNullCount = 0; + + for (int i = 0; i < _baseDataListCount; i++) + { + List dataList = _baseData[i]; + StringArray.Builder stringBuilder = new StringArray.Builder().Reserve(dataList.Count); + Int32Array.Builder intBuilder = new Int32Array.Builder().Reserve(dataList.Count); + ArrowBuffer.Builder typeBuilder = new ArrowBuffer.Builder().Reserve(dataList.Count); + ArrowBuffer.Builder offsetBuilder = new ArrowBuffer.Builder().Reserve(dataList.Count); + int nullCount = 0; + + for (int j = 0; j < dataList.Count; j++) + { + byte index = (byte)Math.Max(j % 3, 1); + int? intValue = (index == 1) ? dataList[j] : null; + string stringValue = (index == 1) ? null : dataList[j]?.ToString(); + typeBuilder.Append(index); + + if (isDense) + { + if (index == 0) + { + offsetBuilder.Append(stringBuilder.Length); + offsetResultBuilder.Append(stringResultBuilder.Length); + stringBuilder.Append(stringValue); + stringResultBuilder.Append(stringValue); + } + else + { + offsetBuilder.Append(intBuilder.Length); + offsetResultBuilder.Append(intResultBuilder.Length); + intBuilder.Append(intValue); + intResultBuilder.Append(intValue); + } + } + else + { + stringBuilder.Append(stringValue); + stringResultBuilder.Append(stringValue); + intBuilder.Append(intValue); + intResultBuilder.Append(intValue); + } + + if (dataList[j] == null) + { + nullCount++; + resultNullCount++; + } + } + + ArrowBuffer[] buffers; + if (isDense) + { + buffers = new[] { typeBuilder.Build(), offsetBuilder.Build() }; + } + else + { + buffers = new[] { typeBuilder.Build() }; + } + TestTargetArrayList.Add(UnionArray.Create(new ArrayData( + type, dataList.Count, nullCount, 0, buffers, + new[] { stringBuilder.Build().Data, intBuilder.Build().Data }))); + } + + ArrowBuffer[] resultBuffers; + if (isDense) + { + resultBuffers = new[] { typeResultBuilder.Build(), offsetResultBuilder.Build() }; + } + else + { + resultBuffers = new[] { typeResultBuilder.Build() }; + } + ExpectedArray = UnionArray.Create(new ArrayData( + type, _baseDataTotalElementCount, resultNullCount, 0, resultBuffers, + new[] { stringResultBuilder.Build().Data, intResultBuilder.Build().Data })); + } public void Visit(IArrowType type) { diff --git a/csharp/test/Apache.Arrow.Tests/ArrowReaderVerifier.cs b/csharp/test/Apache.Arrow.Tests/ArrowReaderVerifier.cs index e588eab51e1fc..8b41763a70ac8 100644 --- a/csharp/test/Apache.Arrow.Tests/ArrowReaderVerifier.cs +++ b/csharp/test/Apache.Arrow.Tests/ArrowReaderVerifier.cs @@ -91,6 +91,7 @@ private class ArrayComparer : IArrowArrayVisitor, IArrowArrayVisitor, IArrowArrayVisitor, + IArrowArrayVisitor, IArrowArrayVisitor, IArrowArrayVisitor, IArrowArrayVisitor, @@ -151,6 +152,24 @@ public void Visit(StructArray array) } } + public void Visit(UnionArray array) + { + Assert.IsAssignableFrom(_expectedArray); + UnionArray expectedArray = (UnionArray)_expectedArray; + + Assert.Equal(expectedArray.Mode, array.Mode); + Assert.Equal(expectedArray.Length, array.Length); + Assert.Equal(expectedArray.NullCount, array.NullCount); + Assert.Equal(expectedArray.Offset, array.Offset); + Assert.Equal(expectedArray.Data.Children.Length, array.Data.Children.Length); + Assert.Equal(expectedArray.Fields.Count, array.Fields.Count); + + for (int i = 0; i < array.Fields.Count; i++) + { + array.Fields[i].Accept(new ArrayComparer(expectedArray.Fields[i], _strictCompare)); + } + } + public void Visit(DictionaryArray array) { Assert.IsAssignableFrom(_expectedArray); diff --git a/csharp/test/Apache.Arrow.Tests/CDataInterfacePythonTests.cs b/csharp/test/Apache.Arrow.Tests/CDataInterfacePythonTests.cs index 29b1b9e7db74a..f28b89a9cd17e 100644 --- a/csharp/test/Apache.Arrow.Tests/CDataInterfacePythonTests.cs +++ b/csharp/test/Apache.Arrow.Tests/CDataInterfacePythonTests.cs @@ -112,6 +112,9 @@ private static Schema GetTestSchema() .Field(f => f.Name("dict_string_ordered").DataType(new DictionaryType(Int32Type.Default, StringType.Default, true)).Nullable(false)) .Field(f => f.Name("list_dict_string").DataType(new ListType(new DictionaryType(Int32Type.Default, StringType.Default, false))).Nullable(false)) + .Field(f => f.Name("dense_union").DataType(new UnionType(new[] { new Field("i64", Int64Type.Default, false), new Field("f32", FloatType.Default, true), }, new[] { 0, 1 }, UnionMode.Dense))) + .Field(f => f.Name("sparse_union").DataType(new UnionType(new[] { new Field("i32", Int32Type.Default, true), new Field("f64", DoubleType.Default, false), }, new[] { 0, 1 }, UnionMode.Sparse))) + // Checking wider characters. .Field(f => f.Name("hello 你好 😄").DataType(BooleanType.Default).Nullable(true)) @@ -172,6 +175,9 @@ private static IEnumerable GetPythonFields() yield return pa.field("dict_string_ordered", pa.dictionary(pa.int32(), pa.utf8(), true), false); yield return pa.field("list_dict_string", pa.list_(pa.dictionary(pa.int32(), pa.utf8(), false)), false); + yield return pa.field("dense_union", pa.dense_union(List(pa.field("i64", pa.int64(), false), pa.field("f32", pa.float32(), true)))); + yield return pa.field("sparse_union", pa.sparse_union(List(pa.field("i32", pa.int32(), true), pa.field("f64", pa.float64(), false)))); + yield return pa.field("hello 你好 😄", pa.bool_(), true); } } @@ -485,22 +491,29 @@ public unsafe void ImportRecordBatch() pa.array(List(0.0, 1.4, 2.5, 3.6, 4.7)), pa.array(new PyObject[] { List(1, 2), List(3, 4), PyObject.None, PyObject.None, List(5, 4, 3) }), pa.StructArray.from_arrays( - new PyList(new PyObject[] - { + List( List(10, 9, null, null, null), List("banana", "apple", "orange", "cherry", "grape"), - List(null, 4.3, -9, 123.456, 0), - }), + List(null, 4.3, -9, 123.456, 0) + ), new[] { "fld1", "fld2", "fld3" }), pa.DictionaryArray.from_arrays( pa.array(List(1, 0, 1, 1, null)), - pa.array(List("foo", "bar")) - ), + pa.array(List("foo", "bar"))), pa.FixedSizeListArray.from_arrays( pa.array(List(1, 2, 3, 4, null, 6, 7, null, null, null)), 2), + pa.UnionArray.from_dense( + pa.array(List(0, 1, 1, 0, 0), type: "int8"), + pa.array(List(0, 0, 1, 1, 2), type: "int32"), + List( + pa.array(List(1, 4, null)), + pa.array(List("two", "three")) + ), + /* field name */ List("i32", "s"), + /* type codes */ List(3, 2)), }), - new[] { "col1", "col2", "col3", "col4", "col5", "col6", "col7", "col8" }); + new[] { "col1", "col2", "col3", "col4", "col5", "col6", "col7", "col8", "col9" }); dynamic batch = table.to_batches()[0]; @@ -568,6 +581,10 @@ public unsafe void ImportRecordBatch() Assert.Equal(new long[] { 1, 2, 3, 4, 0, 6, 7, 0, 0, 0 }, col8a.Values.ToArray()); Assert.True(col8a.IsValid(3)); Assert.False(col8a.IsValid(9)); + + UnionArray col9 = (UnionArray)recordBatch.Column("col9"); + Assert.Equal(5, col9.Length); + Assert.True(col9 is DenseUnionArray); } [SkippableFact] @@ -789,6 +806,11 @@ private static PyObject List(params string[] values) return new PyList(values.Select(i => i == null ? PyObject.None : new PyString(i)).ToArray()); } + private static PyObject List(params PyObject[] values) + { + return new PyList(values); + } + sealed class TestArrayStream : IArrowArrayStream { private readonly RecordBatch[] _batches; diff --git a/csharp/test/Apache.Arrow.Tests/ColumnTests.cs b/csharp/test/Apache.Arrow.Tests/ColumnTests.cs index b90c681622d5f..2d867b79176aa 100644 --- a/csharp/test/Apache.Arrow.Tests/ColumnTests.cs +++ b/csharp/test/Apache.Arrow.Tests/ColumnTests.cs @@ -39,7 +39,7 @@ public void TestColumn() Array intArrayCopy = MakeIntArray(10); Field field = new Field.Builder().Name("f0").DataType(Int32Type.Default).Build(); - Column column = new Column(field, new[] { intArray, intArrayCopy }); + Column column = new Column(field, new IArrowArray[] { intArray, intArrayCopy }); Assert.True(column.Name == field.Name); Assert.True(column.Field == field); diff --git a/csharp/test/Apache.Arrow.Tests/TableTests.cs b/csharp/test/Apache.Arrow.Tests/TableTests.cs index b4c4b1faed190..8b07a38c1b8c0 100644 --- a/csharp/test/Apache.Arrow.Tests/TableTests.cs +++ b/csharp/test/Apache.Arrow.Tests/TableTests.cs @@ -30,7 +30,7 @@ public static Table MakeTableWithOneColumnOfTwoIntArrays(int lengthOfEachArray) Field field = new Field.Builder().Name("f0").DataType(Int32Type.Default).Build(); Schema s0 = new Schema.Builder().Field(field).Build(); - Column column = new Column(field, new List { intArray, intArrayCopy }); + Column column = new Column(field, new List { intArray, intArrayCopy }); Table table = new Table(s0, new List { column }); return table; } @@ -60,7 +60,7 @@ public void TestTableFromRecordBatches() Table table1 = Table.TableFromRecordBatches(recordBatch1.Schema, recordBatches); Assert.Equal(20, table1.RowCount); - Assert.Equal(24, table1.ColumnCount); + Assert.Equal(26, table1.ColumnCount); FixedSizeBinaryType type = new FixedSizeBinaryType(17); Field newField1 = new Field(type.Name, type, false); @@ -86,13 +86,13 @@ public void TestTableAddRemoveAndSetColumn() Array nonEqualLengthIntArray = ColumnTests.MakeIntArray(10); Field field1 = new Field.Builder().Name("f1").DataType(Int32Type.Default).Build(); - Column nonEqualLengthColumn = new Column(field1, new[] { nonEqualLengthIntArray}); + Column nonEqualLengthColumn = new Column(field1, new IArrowArray[] { nonEqualLengthIntArray }); Assert.Throws(() => table.InsertColumn(-1, nonEqualLengthColumn)); Assert.Throws(() => table.InsertColumn(1, nonEqualLengthColumn)); Array equalLengthIntArray = ColumnTests.MakeIntArray(20); Field field2 = new Field.Builder().Name("f2").DataType(Int32Type.Default).Build(); - Column equalLengthColumn = new Column(field2, new[] { equalLengthIntArray}); + Column equalLengthColumn = new Column(field2, new IArrowArray[] { equalLengthIntArray }); Column existingColumn = table.Column(0); Table newTable = table.InsertColumn(0, equalLengthColumn); @@ -118,7 +118,7 @@ public void TestBuildFromRecordBatch() RecordBatch batch = TestData.CreateSampleRecordBatch(schema, 10); Table table = Table.TableFromRecordBatches(schema, new[] { batch }); - Assert.NotNull(table.Column(0).Data.Array(0) as Int64Array); + Assert.NotNull(table.Column(0).Data.ArrowArray(0) as Int64Array); } } diff --git a/csharp/test/Apache.Arrow.Tests/TestData.cs b/csharp/test/Apache.Arrow.Tests/TestData.cs index 41507311f6a04..9e2061e3428a9 100644 --- a/csharp/test/Apache.Arrow.Tests/TestData.cs +++ b/csharp/test/Apache.Arrow.Tests/TestData.cs @@ -60,6 +60,8 @@ public static RecordBatch CreateSampleRecordBatch(int length, int columnSetCount builder.Field(CreateField(new DictionaryType(Int32Type.Default, StringType.Default, false), i)); builder.Field(CreateField(new FixedSizeBinaryType(16), i)); builder.Field(CreateField(new FixedSizeListType(Int32Type.Default, 3), i)); + builder.Field(CreateField(new UnionType(new[] { CreateField(StringType.Default, i), CreateField(Int32Type.Default, i) }, new[] { 0, 1 }, UnionMode.Sparse), i)); + builder.Field(CreateField(new UnionType(new[] { CreateField(StringType.Default, i), CreateField(Int32Type.Default, i) }, new[] { 0, 1 }, UnionMode.Dense), -i)); } //builder.Field(CreateField(HalfFloatType.Default)); @@ -125,6 +127,7 @@ private class ArrayCreator : IArrowTypeVisitor, IArrowTypeVisitor, IArrowTypeVisitor, + IArrowTypeVisitor, IArrowTypeVisitor, IArrowTypeVisitor, IArrowTypeVisitor, @@ -315,6 +318,67 @@ public void Visit(StructType type) Array = new StructArray(type, Length, childArrays, nullBitmap.Build()); } + public void Visit(UnionType type) + { + int[] lengths = new int[type.Fields.Count]; + if (type.Mode == UnionMode.Sparse) + { + for (int i = 0; i < lengths.Length; i++) + { + lengths[i] = Length; + } + } + else + { + int totalLength = Length; + int oneLength = Length / lengths.Length; + for (int i = 1; i < lengths.Length; i++) + { + lengths[i] = oneLength; + totalLength -= oneLength; + } + lengths[0] = totalLength; + } + + ArrayData[] childArrays = new ArrayData[type.Fields.Count]; + for (int i = 0; i < childArrays.Length; i++) + { + childArrays[i] = CreateArray(type.Fields[i], lengths[i]).Data; + } + + ArrowBuffer.Builder typeIdBuilder = new ArrowBuffer.Builder(Length); + byte index = 0; + for (int i = 0; i < Length; i++) + { + typeIdBuilder.Append(index); + index++; + if (index == lengths.Length) + { + index = 0; + } + } + + ArrowBuffer[] buffers; + if (type.Mode == UnionMode.Sparse) + { + buffers = new ArrowBuffer[1]; + } + else + { + ArrowBuffer.Builder offsetBuilder = new ArrowBuffer.Builder(Length); + for (int i = 0; i < Length; i++) + { + offsetBuilder.Append(i / lengths.Length); + } + + buffers = new ArrowBuffer[2]; + buffers[1] = offsetBuilder.Build(); + } + buffers[0] = typeIdBuilder.Build(); + + Array = UnionArray.Create(new ArrayData(type, Length, 0, 0, buffers, childArrays)); + } + public void Visit(DictionaryType type) { Int32Array.Builder indicesBuilder = new Int32Array.Builder().Reserve(Length); diff --git a/dev/archery/archery/integration/datagen.py b/dev/archery/archery/integration/datagen.py index 5ac32da56a8de..299881c4b613a 100644 --- a/dev/archery/archery/integration/datagen.py +++ b/dev/archery/archery/integration/datagen.py @@ -1833,8 +1833,7 @@ def _temp_path(): .skip_tester('C#') .skip_tester('JS'), - generate_unions_case() - .skip_tester('C#'), + generate_unions_case(), generate_custom_metadata_case() .skip_tester('C#'), diff --git a/docs/source/status.rst b/docs/source/status.rst index 36c29fcdc4da6..6314fd4c8d31f 100644 --- a/docs/source/status.rst +++ b/docs/source/status.rst @@ -83,9 +83,9 @@ Data Types +-------------------+-------+-------+-------+------------+-------+-------+-------+-------+ | Map | ✓ | ✓ | ✓ | ✓ | | ✓ | ✓ | | +-------------------+-------+-------+-------+------------+-------+-------+-------+-------+ -| Dense Union | ✓ | ✓ | ✓ | ✓ | | ✓ | ✓ | | +| Dense Union | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | | +-------------------+-------+-------+-------+------------+-------+-------+-------+-------+ -| Sparse Union | ✓ | ✓ | ✓ | ✓ | | ✓ | ✓ | | +| Sparse Union | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | | +-------------------+-------+-------+-------+------------+-------+-------+-------+-------+ +-------------------+-------+-------+-------+------------+-------+-------+-------+-------+