From 83aa1ba3be41dcad56bda6178399e11869ed22ae Mon Sep 17 00:00:00 2001 From: Kyle Barron Date: Tue, 30 Jan 2024 23:34:35 -0500 Subject: [PATCH] support dictionary-encoded arrays (#52) * wip: dictionary encoding * Pass through args * update check on readme * support dictionary encoded arrays * update readme --- README.md | 8 ++- src/field.ts | 42 ++++++++++++++ src/vector.ts | 74 ++++++++++++++++++++++++- tests/ffi.test.ts | 97 +++++++++++++++++++++++++-------- tests/pyarrow_generate_data.py | 13 +++++ tests/table.arrow | Bin 4610 -> 5626 bytes 6 files changed, 208 insertions(+), 26 deletions(-) diff --git a/README.md b/README.md index 545943e..32ce846 100644 --- a/README.md +++ b/README.md @@ -152,8 +152,10 @@ Most of the unsupported types should be pretty straightforward to implement; the ### Decimal -- [ ] Decimal128 (failing a test) -- [ ] Decimal256 (failing a test) +- [ ] Decimal128 (failing a test, this may be [#37920]) +- [ ] Decimal256 (failing a test, this may be [#37920]) + +[#37920]: https://github.com/apache/arrow/issues/37920 ### Temporal Types @@ -174,7 +176,7 @@ Most of the unsupported types should be pretty straightforward to implement; the - [ ] Map - [x] Dense Union - [x] Sparse Union -- [ ] Dictionary-encoded arrays +- [x] Dictionary-encoded arrays ### Extension Types diff --git a/src/field.ts b/src/field.ts index 7a97015..409a08a 100644 --- a/src/field.ts +++ b/src/field.ts @@ -63,6 +63,8 @@ export function parseField(buffer: ArrayBuffer, ptr: number): arrow.Field { const nChildren = dataView.getBigInt64(ptr + 24, true); const ptrToChildrenPtrs = dataView.getUint32(ptr + 32, true); + const dictionaryPtr = dataView.getUint32(ptr + 36, true); + const childrenFields: arrow.Field[] = new Array(Number(nChildren)); for (let i = 0; i < nChildren; i++) { childrenFields[i] = parseField( @@ -71,6 +73,46 @@ export function parseField(buffer: ArrayBuffer, ptr: number): arrow.Field { ); } + const field = parseFieldContent({ + formatString, + flags, + name, + childrenFields, + metadata, + }); + + if (dictionaryPtr !== 0) { + const dictionaryValuesField = parseField(buffer, dictionaryPtr); + const dictionaryType = new arrow.Dictionary( + dictionaryValuesField, + field.type, + null, + flags.dictionaryOrdered, + ); + return new arrow.Field( + field.name, + dictionaryType, + flags.nullable, + metadata, + ); + } + + return field; +} + +function parseFieldContent({ + formatString, + flags, + name, + childrenFields, + metadata, +}: { + formatString: string; + flags: Flags; + name: string; + childrenFields: arrow.Field[]; + metadata: Map | null; +}): arrow.Field { const primitiveType = formatMapping[formatString]; if (primitiveType) { return new arrow.Field(name, primitiveType, flags.nullable, metadata); diff --git a/src/vector.ts b/src/vector.ts index 431f5f7..404afe3 100644 --- a/src/vector.ts +++ b/src/vector.ts @@ -69,6 +69,8 @@ export function parseData( } const ptrToChildrenPtrs = dataView.getUint32(ptr + 44, true); + const dictionaryPtr = dataView.getUint32(ptr + 48, true); + const children: arrow.Data[] = new Array(Number(nChildren)); for (let i = 0; i < nChildren; i++) { children[i] = parseData( @@ -79,6 +81,77 @@ export function parseData( ); } + // Special case for handling dictionary-encoded arrays + if (dictionaryPtr !== 0) { + const dictionaryType = dataType as unknown as arrow.Dictionary; + + // the parent structure points to the index data, the ArrowArray.dictionary + // points to the dictionary values array. + const indicesType = dictionaryType.indices; + const dictionaryIndices = parseDataContent({ + dataType: indicesType, + dataView, + copy, + length, + nullCount, + offset, + nChildren, + children, + bufferPtrs, + }); + + const valueType = dictionaryType.dictionary.type; + const dictionaryValues = parseData(buffer, dictionaryPtr, valueType, copy); + + // @ts-expect-error we're casting to dictionary type + return arrow.makeData({ + type: dictionaryType, + // TODO: double check that this offset should be set on both the values + // and indices of the dictionary array + offset, + length, + nullCount, + nullBitmap: dictionaryIndices.nullBitmap, + // Note: Here we need to pass in the _raw TypedArray_ not a Data instance + data: dictionaryIndices.values, + dictionary: arrow.makeVector(dictionaryValues), + }); + } else { + return parseDataContent({ + dataType, + dataView, + copy, + length, + nullCount, + offset, + nChildren, + children, + bufferPtrs, + }); + } +} + +function parseDataContent({ + dataType, + dataView, + copy, + length, + nullCount, + offset, + nChildren, + children, + bufferPtrs, +}: { + dataType: T; + dataView: DataView; + copy: boolean; + length: number; + nullCount: number; + offset: number; + nChildren: number; + children: arrow.Data[]; + bufferPtrs: Uint32Array; +}): arrow.Data { if (DataType.isNull(dataType)) { return arrow.makeData({ type: dataType, @@ -653,7 +726,6 @@ export function parseData( }); } - // TODO: map arrays, dictionary encoding throw new Error(`Unsupported type ${dataType}`); } diff --git a/tests/ffi.test.ts b/tests/ffi.test.ts index 107f432..a934987 100644 --- a/tests/ffi.test.ts +++ b/tests/ffi.test.ts @@ -196,8 +196,7 @@ describe("binary", (t) => { ); const originalField = TEST_TABLE.schema.fields[columnIndex]; - // declare it's not null - const originalVector = TEST_TABLE.getChildAt(columnIndex) as arrow.Vector; + const originalVector = TEST_TABLE.getChildAt(columnIndex)!; const fieldPtr = FFI_TABLE.schemaAddr(columnIndex); const field = parseField(WASM_MEMORY.buffer, fieldPtr); @@ -277,8 +276,7 @@ describe("string", (t) => { ); const originalField = TEST_TABLE.schema.fields[columnIndex]; - // declare it's not null - const originalVector = TEST_TABLE.getChildAt(columnIndex) as arrow.Vector; + const originalVector = TEST_TABLE.getChildAt(columnIndex)!; const fieldPtr = FFI_TABLE.schemaAddr(columnIndex); const field = parseField(WASM_MEMORY.buffer, fieldPtr); @@ -346,8 +344,7 @@ describe("boolean", (t) => { ); const originalField = TEST_TABLE.schema.fields[columnIndex]; - // declare it's not null - const originalVector = TEST_TABLE.getChildAt(columnIndex) as arrow.Vector; + const originalVector = TEST_TABLE.getChildAt(columnIndex)!; const fieldPtr = FFI_TABLE.schemaAddr(columnIndex); const field = parseField(WASM_MEMORY.buffer, fieldPtr); @@ -379,8 +376,7 @@ describe("null array", (t) => { ); const originalField = TEST_TABLE.schema.fields[columnIndex]; - // declare it's not null - const originalVector = TEST_TABLE.getChildAt(columnIndex) as arrow.Vector; + const originalVector = TEST_TABLE.getChildAt(columnIndex)!; const fieldPtr = FFI_TABLE.schemaAddr(columnIndex); const field = parseField(WASM_MEMORY.buffer, fieldPtr); @@ -412,8 +408,7 @@ describe("list array", (t) => { ); const originalField = TEST_TABLE.schema.fields[columnIndex]; - // declare it's not null - const originalVector = TEST_TABLE.getChildAt(columnIndex) as arrow.Vector; + const originalVector = TEST_TABLE.getChildAt(columnIndex)!; const fieldPtr = FFI_TABLE.schemaAddr(columnIndex); const field = parseField(WASM_MEMORY.buffer, fieldPtr); @@ -499,8 +494,7 @@ describe("extension array", (t) => { ); const originalField = TEST_TABLE.schema.fields[columnIndex]; - // declare it's not null - const originalVector = TEST_TABLE.getChildAt(columnIndex) as arrow.Vector; + const originalVector = TEST_TABLE.getChildAt(columnIndex)!; const fieldPtr = FFI_TABLE.schemaAddr(columnIndex); const field = parseField(WASM_MEMORY.buffer, fieldPtr); @@ -544,8 +538,7 @@ describe("extension array", (t) => { // ); // const originalField = TEST_TABLE.schema.fields[columnIndex]; -// // declare it's not null -// const originalVector = TEST_TABLE.getChildAt(columnIndex) as arrow.Vector; +// const originalVector = TEST_TABLE.getChildAt(columnIndex)!; // const fieldPtr = FFI_TABLE.schemaAddr(columnIndex); // const field = parseField(WASM_MEMORY.buffer, fieldPtr); @@ -572,8 +565,7 @@ describe("date32", (t) => { ); const originalField = TEST_TABLE.schema.fields[columnIndex]; - // declare it's not null - const originalVector = TEST_TABLE.getChildAt(columnIndex) as arrow.Vector; + const originalVector = TEST_TABLE.getChildAt(columnIndex)!; const fieldPtr = FFI_TABLE.schemaAddr(columnIndex); const field = parseField(WASM_MEMORY.buffer, fieldPtr); @@ -606,8 +598,7 @@ describe("date32", (t) => { // ); // const originalField = TEST_TABLE.schema.fields[columnIndex]; -// // declare it's not null -// const originalVector = TEST_TABLE.getChildAt(columnIndex) as arrow.Vector; +// const originalVector = TEST_TABLE.getChildAt(columnIndex)!; // const fieldPtr = FFI_TABLE.schemaAddr(columnIndex); // const field = parseField(WASM_MEMORY.buffer, fieldPtr); @@ -634,8 +625,7 @@ describe("duration", (t) => { ); const originalField = TEST_TABLE.schema.fields[columnIndex]; - // declare it's not null - const originalVector = TEST_TABLE.getChildAt(columnIndex) as arrow.Vector; + const originalVector = TEST_TABLE.getChildAt(columnIndex)!; const fieldPtr = FFI_TABLE.schemaAddr(columnIndex); const field = parseField(WASM_MEMORY.buffer, fieldPtr); @@ -667,8 +657,7 @@ describe("nullable int", (t) => { ); const originalField = TEST_TABLE.schema.fields[columnIndex]; - // declare it's not null - const originalVector = TEST_TABLE.getChildAt(columnIndex) as arrow.Vector; + const originalVector = TEST_TABLE.getChildAt(columnIndex)!; const fieldPtr = FFI_TABLE.schemaAddr(columnIndex); const field = parseField(WASM_MEMORY.buffer, fieldPtr); @@ -693,3 +682,67 @@ describe("nullable int", (t) => { it("copy=false", () => test(false)); it("copy=true", () => test(true)); }); + +describe("dictionary encoded string", (t) => { + function test(copy: boolean) { + let columnIndex = TEST_TABLE.schema.fields.findIndex( + (field) => field.name == "dictionary_encoded_string" + ); + + const originalField = TEST_TABLE.schema.fields[columnIndex]; + const originalVector = TEST_TABLE.getChildAt(columnIndex)!; + const fieldPtr = FFI_TABLE.schemaAddr(columnIndex); + const field = parseField(WASM_MEMORY.buffer, fieldPtr); + + expect(field.name).toStrictEqual(originalField.name); + expect(field.typeId).toStrictEqual(originalField.typeId); + expect(field.nullable).toStrictEqual(originalField.nullable); + + const arrayPtr = FFI_TABLE.arrayAddr(0, columnIndex); + const wasmVector = parseVector( + WASM_MEMORY.buffer, + arrayPtr, + field.type, + copy + ); + + for (let i = 0; i < 3; i++) { + expect(originalVector.get(i)).toStrictEqual(wasmVector.get(i)); + } + } + + it("copy=false", () => test(false)); + it("copy=true", () => test(true)); +}); + +describe("dictionary encoded string (with nulls)", (t) => { + function test(copy: boolean) { + let columnIndex = TEST_TABLE.schema.fields.findIndex( + (field) => field.name == "dictionary_encoded_string_null" + ); + + const originalField = TEST_TABLE.schema.fields[columnIndex]; + const originalVector = TEST_TABLE.getChildAt(columnIndex)!; + const fieldPtr = FFI_TABLE.schemaAddr(columnIndex); + const field = parseField(WASM_MEMORY.buffer, fieldPtr); + + expect(field.name).toStrictEqual(originalField.name); + expect(field.typeId).toStrictEqual(originalField.typeId); + expect(field.nullable).toStrictEqual(originalField.nullable); + + const arrayPtr = FFI_TABLE.arrayAddr(0, columnIndex); + const wasmVector = parseVector( + WASM_MEMORY.buffer, + arrayPtr, + field.type, + copy + ); + + for (let i = 0; i < 3; i++) { + expect(originalVector.get(i)).toStrictEqual(wasmVector.get(i)); + } + } + + it("copy=false", () => test(false)); + it("copy=true", () => test(true)); +}); diff --git a/tests/pyarrow_generate_data.py b/tests/pyarrow_generate_data.py index 60f7d3c..b853c75 100644 --- a/tests/pyarrow_generate_data.py +++ b/tests/pyarrow_generate_data.py @@ -4,6 +4,7 @@ import numpy as np import pandas as pd import pyarrow as pa +import pyarrow.compute as pc import pyarrow.feather as feather @@ -194,6 +195,16 @@ def dense_union_array() -> pa.Array: return union_arr +def dictionary_encoded_string_array() -> pa.DictionaryArray: + arr = pa.StringArray.from_pandas(["a", "a", "b"]) + return pc.dictionary_encode(arr) + + +def dictionary_encoded_string_array_null() -> pa.DictionaryArray: + arr = pa.StringArray.from_pandas(["a", "a", None]) + return pc.dictionary_encode(arr) + + class MyExtensionType(pa.ExtensionType): """ Refer to https://arrow.apache.org/docs/python/extending_types.html for @@ -243,6 +254,8 @@ def table() -> pa.Table: "sparse_union": sparse_union_array(), "dense_union": dense_union_array(), "duration": duration_array(), + "dictionary_encoded_string": dictionary_encoded_string_array(), + "dictionary_encoded_string_null": dictionary_encoded_string_array_null(), } ) diff --git a/tests/table.arrow b/tests/table.arrow index 08279d56d024e233ecb7f1bbe79719bf900db36a..4fb305befbb3d0448be222e555080ff77e0f4254 100644 GIT binary patch delta 1265 zcmcgs!Ac`R5Uq*n$?Q0UtPU7g(NRPeRM_Cfqr!R<6fb%ZVl+VlA&TpQx$G>O56Idc zJbKLD7P5yG{0L8;JbKtaP!t_s&7@~Wka!U*yy@<$*HztBT{#%Ln@GlQn7(#K<=H^tRi^peQKnx$Yh)+FOOVrgzr$5xQDdCLb2(@A z%0*M83qCD)eTI;iVsxq*lAp($Ow9ibw+0NgR>^C$?TU8N zo+vx-`qX+b4mIg~$ZDFbnKwPZ>t`i4M58?D{_&AjEXo&`sqe^g+#-~p0e&=WoM|T^ zgZ5#G?iWs(FBrng^Rh#Wy6@4W9a_|lUbjjOpZINiv_p$_BkK+V NhCOQDBEiAG+!O3Sw9o(m delta 339 zcmeyR-K4_j7!>3mZpgs!9|#VxPUMsGdGYuEe-;J?1_2;eU}a#K0_0=>u?-Ll0Pz!M z28IqG-opeG2Vx5#<^kdzj0_A7Kzs+N&jN^nN*Oi>FvhYnvP`z)P-1MDoX8H zWHQhIIMAIukwbm50z1p*9ULo|fL1avNNirp-OM=IgZG0_P0Y?zkPw2N?8R3%`2(Lq sy$MiQ1BgLZ0c~cu0M^UEumL0h1R&?211frt5#+xJR;~{~Ya!tX0O_4N?f?J)