diff --git a/swift/Arrow/Sources/Arrow/ArrowArray.swift b/swift/Arrow/Sources/Arrow/ArrowArray.swift index 32b6ba1704511..b0f20ee06c2e4 100644 --- a/swift/Arrow/Sources/Arrow/ArrowArray.swift +++ b/swift/Arrow/Sources/Arrow/ArrowArray.swift @@ -21,7 +21,7 @@ public protocol ArrowArrayHolder { var type: ArrowType {get} var length: UInt {get} var nullCount: UInt {get} - var array: Any {get} + var array: AnyArray {get} var data: ArrowData {get} var getBufferData: () -> [Data] {get} var getBufferDataSizes: () -> [Int] {get} @@ -29,11 +29,11 @@ public protocol ArrowArrayHolder { } public class ArrowArrayHolderImpl: ArrowArrayHolder { - public let array: Any public let data: ArrowData public let type: ArrowType public let length: UInt public let nullCount: UInt + public let array: AnyArray public let getBufferData: () -> [Data] public let getBufferDataSizes: () -> [Int] public let getArrowColumn: (ArrowField, [ArrowArrayHolder]) throws -> ArrowColumn @@ -73,6 +73,50 @@ public class ArrowArrayHolderImpl: ArrowArrayHolder { return ArrowColumn(field, chunked: ChunkedArrayHolder(try ChunkedArray(arrays))) } } + + public static func loadArray( // swiftlint:disable:this cyclomatic_complexity + _ arrowType: ArrowType, with: ArrowData) throws -> ArrowArrayHolder { + switch arrowType.id { + case .int8: + return ArrowArrayHolderImpl(FixedArray(with)) + case .int16: + return ArrowArrayHolderImpl(FixedArray(with)) + case .int32: + return ArrowArrayHolderImpl(FixedArray(with)) + case .int64: + return ArrowArrayHolderImpl(FixedArray(with)) + case .uint8: + return ArrowArrayHolderImpl(FixedArray(with)) + case .uint16: + return ArrowArrayHolderImpl(FixedArray(with)) + case .uint32: + return ArrowArrayHolderImpl(FixedArray(with)) + case .uint64: + return ArrowArrayHolderImpl(FixedArray(with)) + case .double: + return ArrowArrayHolderImpl(FixedArray(with)) + case .float: + return ArrowArrayHolderImpl(FixedArray(with)) + case .date32: + return ArrowArrayHolderImpl(Date32Array(with)) + case .date64: + return ArrowArrayHolderImpl(Date64Array(with)) + case .time32: + return ArrowArrayHolderImpl(Time32Array(with)) + case .time64: + return ArrowArrayHolderImpl(Time64Array(with)) + case .string: + return ArrowArrayHolderImpl(StringArray(with)) + case .boolean: + return ArrowArrayHolderImpl(BoolArray(with)) + case .binary: + return ArrowArrayHolderImpl(BinaryArray(with)) + case .strct: + return ArrowArrayHolderImpl(StructArray(with)) + default: + throw ArrowError.invalid("Array not found for type: \(arrowType)") + } + } } public class ArrowArray: AsString, AnyArray { @@ -221,10 +265,7 @@ public class BinaryArray: ArrowArray { } public override func asString(_ index: UInt) -> String { - if self[index] == nil { - return "" - } - + if self[index] == nil {return ""} let data = self[index]! if options.printAsHex { return data.hexEncodedString() @@ -233,3 +274,58 @@ public class BinaryArray: ArrowArray { } } } + +public class StructArray: ArrowArray<[Any?]> { + public private(set) var arrowFields: [ArrowArrayHolder]? + public required init(_ arrowData: ArrowData) { + super.init(arrowData) + } + + public func initialize() throws -> StructArray { + var fields = [ArrowArrayHolder]() + for child in arrowData.children { + fields.append(try ArrowArrayHolderImpl.loadArray(child.type, with: child)) + } + + self.arrowFields = fields + return self + } + + public override subscript(_ index: UInt) -> [Any?]? { + if self.arrowData.isNull(index) { + return nil + } + + if let fields = arrowFields { + var result = [Any?]() + for field in fields { + result.append(field.array.asAny(index)) + } + + return result + } + + return nil + } + + public override func asString(_ index: UInt) -> String { + if self.arrowData.isNull(index) { + return "" + } + + var output = "{" + if let fields = arrowFields { + for fieldIndex in 0.. [Data] {self.holder.getBufferData} public var getBufferDataSizes: () -> [Int] {self.holder.getBufferDataSizes} diff --git a/swift/Arrow/Sources/Arrow/ArrowDecoder.swift b/swift/Arrow/Sources/Arrow/ArrowDecoder.swift index 7e684f360ac51..35dd4dcd1e899 100644 --- a/swift/Arrow/Sources/Arrow/ArrowDecoder.swift +++ b/swift/Arrow/Sources/Arrow/ArrowDecoder.swift @@ -96,11 +96,7 @@ public class ArrowDecoder: Decoder { throw ArrowError.invalid("Column for key \"\(name)\" not found") } - guard let anyArray = col.array as? AnyArray else { - throw ArrowError.invalid("Unable to convert array to AnyArray") - } - - return anyArray + return col.array } func getCol(_ index: Int) throws -> AnyArray { @@ -108,11 +104,7 @@ public class ArrowDecoder: Decoder { throw ArrowError.outOfBounds(index: Int64(index)) } - guard let anyArray = self.columns[index].array as? AnyArray else { - throw ArrowError.invalid("Unable to convert array to AnyArray") - } - - return anyArray + return self.columns[index].array } func doDecode(_ key: CodingKey) throws -> T? { diff --git a/swift/Arrow/Sources/Arrow/ArrowTable.swift b/swift/Arrow/Sources/Arrow/ArrowTable.swift index b9d15154c4f94..dedf90f791cce 100644 --- a/swift/Arrow/Sources/Arrow/ArrowTable.swift +++ b/swift/Arrow/Sources/Arrow/ArrowTable.swift @@ -185,7 +185,7 @@ public class RecordBatch { public func anyData(for columnIndex: Int) -> AnyArray { let arrayHolder = column(columnIndex) - return (arrayHolder.array as! AnyArray) // swiftlint:disable:this force_cast + return arrayHolder.array } public func column(_ index: Int) -> ArrowArrayHolder { diff --git a/swift/Arrow/Sources/Arrow/ChunkedArray.swift b/swift/Arrow/Sources/Arrow/ChunkedArray.swift index c5ccfe4aec0e6..fb5734f64b6ba 100644 --- a/swift/Arrow/Sources/Arrow/ChunkedArray.swift +++ b/swift/Arrow/Sources/Arrow/ChunkedArray.swift @@ -18,6 +18,7 @@ import Foundation public protocol AnyArray { + var arrowData: ArrowData {get} func asAny(_ index: UInt) -> Any? var length: UInt {get} } diff --git a/swift/Arrow/Tests/ArrowTests/CodableTests.swift b/swift/Arrow/Tests/ArrowTests/CodableTests.swift index a0c4e111e4360..b8f389a5e0089 100644 --- a/swift/Arrow/Tests/ArrowTests/CodableTests.swift +++ b/swift/Arrow/Tests/ArrowTests/CodableTests.swift @@ -227,7 +227,7 @@ final class CodableTests: XCTestCase { // swiftlint:disable:this type_body_lengt } func getArrayValue(_ rb: RecordBatch, colIndex: Int, rowIndex: UInt) -> T? { - let anyArray = rb.columns[colIndex].array as! AnyArray // swiftlint:disable:this force_cast + let anyArray = rb.columns[colIndex].array return anyArray.asAny(UInt(rowIndex)) as? T } @@ -324,7 +324,7 @@ final class CodableTests: XCTestCase { // swiftlint:disable:this type_body_lengt XCTAssertEqual(rb.columns[0].type.id, ArrowTypeId.int32) for index in 0..<100 { if index == 10 { - let anyArray = rb.columns[0].array as! AnyArray // swiftlint:disable:this force_cast + let anyArray = rb.columns[0].array XCTAssertNil(anyArray.asAny(UInt(index))) } else { XCTAssertEqual(getArrayValue(rb, colIndex: 0, rowIndex: UInt(index)), Int32(index))