diff --git a/src/vector.ts b/src/vector.ts index 83704b8..6a0e573 100644 --- a/src/vector.ts +++ b/src/vector.ts @@ -64,7 +64,12 @@ export function parseData( if (DataType.isInt(dataType)) { const [validityPtr, dataPtr] = bufferPtrs; - const nullBitmap = parseNullBitmap(dataView.buffer, validityPtr, copy); + const nullBitmap = parseNullBitmap( + dataView.buffer, + validityPtr, + length, + copy + ); const byteLength = (length * dataType.bitWidth) / 8; const data = copy ? new dataType.ArrayType(copyBuffer(dataView.buffer, dataPtr, byteLength)) @@ -81,7 +86,12 @@ export function parseData( if (DataType.isFloat(dataType)) { const [validityPtr, dataPtr] = bufferPtrs; - const nullBitmap = parseNullBitmap(dataView.buffer, validityPtr, copy); + const nullBitmap = parseNullBitmap( + dataView.buffer, + validityPtr, + length, + copy + ); // bitwidth doesn't exist on float types I guess const byteLength = length * dataType.ArrayType.BYTES_PER_ELEMENT; const data = copy @@ -99,7 +109,12 @@ export function parseData( if (DataType.isBool(dataType)) { const [validityPtr, dataPtr] = bufferPtrs; - const nullBitmap = parseNullBitmap(dataView.buffer, validityPtr, copy); + const nullBitmap = parseNullBitmap( + dataView.buffer, + validityPtr, + length, + copy + ); // Boolean arrays are bit-packed. This means the byte length should be the number of elements, // rounded up to the nearest byte to account for the remainder. @@ -120,7 +135,12 @@ export function parseData( if (DataType.isDecimal(dataType)) { const [validityPtr, dataPtr] = bufferPtrs; - const nullBitmap = parseNullBitmap(dataView.buffer, validityPtr, copy); + const nullBitmap = parseNullBitmap( + dataView.buffer, + validityPtr, + length, + copy + ); const byteLength = (length * dataType.bitWidth) / 8; const data = copy ? new dataType.ArrayType(copyBuffer(dataView.buffer, dataPtr, byteLength)) @@ -137,7 +157,12 @@ export function parseData( if (DataType.isDate(dataType)) { const [validityPtr, dataPtr] = bufferPtrs; - const nullBitmap = parseNullBitmap(dataView.buffer, validityPtr, copy); + const nullBitmap = parseNullBitmap( + dataView.buffer, + validityPtr, + length, + copy + ); let byteWidth = getDateByteWidth(dataType); const data = copy @@ -157,7 +182,12 @@ export function parseData( if (DataType.isTime(dataType)) { const [validityPtr, dataPtr] = bufferPtrs; - const nullBitmap = parseNullBitmap(dataView.buffer, validityPtr, copy); + const nullBitmap = parseNullBitmap( + dataView.buffer, + validityPtr, + length, + copy + ); const byteLength = (length * dataType.bitWidth) / 8; const data = copy ? new dataType.ArrayType(copyBuffer(dataView.buffer, dataPtr, byteLength)) @@ -174,7 +204,12 @@ export function parseData( if (DataType.isTimestamp(dataType)) { const [validityPtr, dataPtr] = bufferPtrs; - const nullBitmap = parseNullBitmap(dataView.buffer, validityPtr, copy); + const nullBitmap = parseNullBitmap( + dataView.buffer, + validityPtr, + length, + copy + ); let byteWidth = getTimeByteWidth(dataType); const data = copy @@ -194,7 +229,12 @@ export function parseData( if (DataType.isInterval(dataType)) { const [validityPtr, dataPtr] = bufferPtrs; - const nullBitmap = parseNullBitmap(dataView.buffer, validityPtr, copy); + const nullBitmap = parseNullBitmap( + dataView.buffer, + validityPtr, + length, + copy + ); // What's the bitwidth here? if (copy) { @@ -215,7 +255,12 @@ export function parseData( if (DataType.isBinary(dataType)) { const [validityPtr, offsetsPtr, dataPtr] = bufferPtrs; - const nullBitmap = parseNullBitmap(dataView.buffer, validityPtr, copy); + const nullBitmap = parseNullBitmap( + dataView.buffer, + validityPtr, + length, + copy + ); const valueOffsets = copy ? new Int32Array( @@ -247,7 +292,12 @@ export function parseData( if (isLargeBinary(dataType)) { const [validityPtr, offsetsPtr, dataPtr] = bufferPtrs; - const nullBitmap = parseNullBitmap(dataView.buffer, validityPtr, copy); + const nullBitmap = parseNullBitmap( + dataView.buffer, + validityPtr, + length, + copy + ); // The original value offsets are an Int64Array, which Arrow JS does not yet support natively const originalValueOffsets = new BigInt64Array( @@ -285,7 +335,12 @@ export function parseData( if (DataType.isUtf8(dataType)) { const [validityPtr, offsetsPtr, dataPtr] = bufferPtrs; - const nullBitmap = parseNullBitmap(dataView.buffer, validityPtr, copy); + const nullBitmap = parseNullBitmap( + dataView.buffer, + validityPtr, + length, + copy + ); const valueOffsets = copy ? new Int32Array( @@ -317,7 +372,12 @@ export function parseData( if (isLargeUtf8(dataType)) { const [validityPtr, offsetsPtr, dataPtr] = bufferPtrs; - const nullBitmap = parseNullBitmap(dataView.buffer, validityPtr, copy); + const nullBitmap = parseNullBitmap( + dataView.buffer, + validityPtr, + length, + copy + ); // The original value offsets are an Int64Array, which Arrow JS does not yet support natively const originalValueOffsets = new BigInt64Array( @@ -355,7 +415,12 @@ export function parseData( if (DataType.isFixedSizeBinary(dataType)) { const [validityPtr, dataPtr] = bufferPtrs; - const nullBitmap = parseNullBitmap(dataView.buffer, validityPtr, copy); + const nullBitmap = parseNullBitmap( + dataView.buffer, + validityPtr, + length, + copy + ); const data = copy ? new dataType.ArrayType( copyBuffer(dataView.buffer, dataPtr, length * dataType.byteWidth) @@ -378,7 +443,12 @@ export function parseData( if (DataType.isList(dataType)) { assert(nChildren === 1); const [validityPtr, offsetsPtr] = bufferPtrs; - const nullBitmap = parseNullBitmap(dataView.buffer, validityPtr, copy); + const nullBitmap = parseNullBitmap( + dataView.buffer, + validityPtr, + length, + copy + ); const valueOffsets = copy ? new Int32Array( copyBuffer( @@ -404,7 +474,12 @@ export function parseData( dataType; assert(nChildren === 1); const [validityPtr, offsetsPtr] = bufferPtrs; - const nullBitmap = parseNullBitmap(dataView.buffer, validityPtr, copy); + const nullBitmap = parseNullBitmap( + dataView.buffer, + validityPtr, + length, + copy + ); // The original value offsets are an Int64Array, which Arrow JS does not yet support natively const originalValueOffsets = new BigInt64Array( @@ -435,7 +510,12 @@ export function parseData( if (DataType.isFixedSizeList(dataType)) { assert(nChildren === 1); const [validityPtr] = bufferPtrs; - const nullBitmap = parseNullBitmap(dataView.buffer, validityPtr, copy); + const nullBitmap = parseNullBitmap( + dataView.buffer, + validityPtr, + length, + copy + ); return arrow.makeData({ type: dataType, @@ -449,7 +529,12 @@ export function parseData( if (DataType.isStruct(dataType)) { const [validityPtr] = bufferPtrs; - const nullBitmap = parseNullBitmap(dataView.buffer, validityPtr, copy); + const nullBitmap = parseNullBitmap( + dataView.buffer, + validityPtr, + length, + copy + ); return arrow.makeData({ type: dataType, @@ -490,11 +575,21 @@ function getTimeByteWidth(type: arrow.Time | arrow.Timestamp): number { function parseNullBitmap( buffer: ArrayBuffer, validityPtr: number, + length: number, copy: boolean ): NullBitmap { - // TODO: parse validity bitmaps - const nullBitmap = validityPtr === 0 ? null : null; - return nullBitmap; + if (validityPtr === 0) { + return null; + } + + // Each value takes up one bit + const byteLength = (length >> 3) + 1; + + if (copy) { + return new Uint8Array(copyBuffer(buffer, validityPtr, byteLength)); + } else { + return new Uint8Array(buffer, validityPtr, byteLength); + } } /** Copy existing buffer into new buffer */ diff --git a/tests/ffi.test.ts b/tests/ffi.test.ts index a5818f4..c3be397 100644 --- a/tests/ffi.test.ts +++ b/tests/ffi.test.ts @@ -2,7 +2,12 @@ import { readFileSync } from "fs"; import { describe, it, expect } from "vitest"; import * as arrow from "apache-arrow"; import * as wasm from "rust-arrow-ffi"; -import { arrowTableToFFI, arraysEqual, loadIPCTableFromDisk } from "./utils"; +import { + arrowTableToFFI, + arraysEqual, + loadIPCTableFromDisk, + validityEqual, +} from "./utils"; import { parseField, parseVector } from "../src"; import { Type } from "../src/types"; @@ -622,3 +627,37 @@ describe("date32", (t) => { // expect(originalVector.get(i), wasmVector.get(i)); // } // }); + +describe("nullable int", (t) => { + function test(copy: boolean) { + let columnIndex = TEST_TABLE.schema.fields.findIndex( + (field) => field.name == "nullable_int" + ); + + const originalField = TEST_TABLE.schema.fields[columnIndex]; + // declare it's not null + const originalVector = TEST_TABLE.getChildAt(columnIndex) as arrow.Vector; + const fieldPtr = FFI_TABLE.schemaAddr(columnIndex); + const field = parseField(WASM_MEMORY.buffer, fieldPtr); + + expect(field.name).toStrictEqual(originalField.name); + expect(field.typeId).toStrictEqual(originalField.typeId); + expect(field.nullable).toStrictEqual(originalField.nullable); + + const arrayPtr = FFI_TABLE.arrayAddr(0, columnIndex); + const wasmVector = parseVector( + WASM_MEMORY.buffer, + arrayPtr, + field.type, + copy + ); + + expect( + validityEqual(originalVector, wasmVector), + "validity should be equal" + ).toBeTruthy(); + } + + it("copy=false", () => test(false)); + it("copy=true", () => test(true)); +}); diff --git a/tests/pyarrow_generate_data.py b/tests/pyarrow_generate_data.py index 03c6606..a5613f6 100644 --- a/tests/pyarrow_generate_data.py +++ b/tests/pyarrow_generate_data.py @@ -115,6 +115,15 @@ def timestamp_array() -> pa.Array: return arr +def nullable_int() -> pa.Array: + # True means null + mask = [True, False, True] + arr = pa.array([1, 2, 3], type=pa.uint8(), mask=mask) + assert isinstance(arr, pa.UInt8Array) + assert not arr[0].is_valid + return arr + + class MyExtensionType(pa.ExtensionType): """ Refer to https://arrow.apache.org/docs/python/extending_types.html for @@ -160,6 +169,7 @@ def table() -> pa.Table: "date32": date32_array(), "date64": date64_array(), "timestamp": timestamp_array(), + "nullable_int": nullable_int(), } ) diff --git a/tests/table.arrow b/tests/table.arrow index 2326863..ec0d119 100644 Binary files a/tests/table.arrow and b/tests/table.arrow differ diff --git a/tests/utils.ts b/tests/utils.ts index 5037628..b55e3ea 100644 --- a/tests/utils.ts +++ b/tests/utils.ts @@ -35,3 +35,32 @@ export function arraysEqual( return true; } + +export function validityEqual(v1: arrow.Vector, v2: arrow.Vector): boolean { + if (v1.length !== v2.length) { + return false; + } + + if (v1.data.length !== v2.data.length) { + console.log("todo: support different data lengths"); + return false; + } + for (let i = 0; i < v1.data.length; i++) { + const d1 = v1.data[i]; + const d2 = v2.data[i]; + // Check that null bitmaps have same length + if (d1 !== null && d2 !== null) { + if (d1.nullBitmap.length !== d2.nullBitmap.length) { + return false; + } + } + } + + for (let i = 0; i < v1.length; i++) { + if (v1.isValid(i) !== v2.isValid(i)) { + return false; + } + } + + return true; +}