Skip to content

Commit

Permalink
support dictionary-encoded arrays (#52)
Browse files Browse the repository at this point in the history
* wip: dictionary encoding

* Pass through args

* update check on readme

* support dictionary encoded arrays

* update readme
  • Loading branch information
kylebarron authored Jan 31, 2024
1 parent 6b83961 commit 83aa1ba
Show file tree
Hide file tree
Showing 6 changed files with 208 additions and 26 deletions.
8 changes: 5 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -152,8 +152,10 @@ Most of the unsupported types should be pretty straightforward to implement; the

### Decimal

- [ ] Decimal128 (failing a test)
- [ ] Decimal256 (failing a test)
- [ ] Decimal128 (failing a test, this may be [#37920])
- [ ] Decimal256 (failing a test, this may be [#37920])

[#37920]: https://github.com/apache/arrow/issues/37920

### Temporal Types

Expand All @@ -174,7 +176,7 @@ Most of the unsupported types should be pretty straightforward to implement; the
- [ ] Map
- [x] Dense Union
- [x] Sparse Union
- [ ] Dictionary-encoded arrays
- [x] Dictionary-encoded arrays

### Extension Types

Expand Down
42 changes: 42 additions & 0 deletions src/field.ts
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@ export function parseField(buffer: ArrayBuffer, ptr: number): arrow.Field {
const nChildren = dataView.getBigInt64(ptr + 24, true);

const ptrToChildrenPtrs = dataView.getUint32(ptr + 32, true);
const dictionaryPtr = dataView.getUint32(ptr + 36, true);

const childrenFields: arrow.Field[] = new Array(Number(nChildren));
for (let i = 0; i < nChildren; i++) {
childrenFields[i] = parseField(
Expand All @@ -71,6 +73,46 @@ export function parseField(buffer: ArrayBuffer, ptr: number): arrow.Field {
);
}

const field = parseFieldContent({
formatString,
flags,
name,
childrenFields,
metadata,
});

if (dictionaryPtr !== 0) {
const dictionaryValuesField = parseField(buffer, dictionaryPtr);
const dictionaryType = new arrow.Dictionary(
dictionaryValuesField,
field.type,
null,
flags.dictionaryOrdered,
);
return new arrow.Field(
field.name,
dictionaryType,
flags.nullable,
metadata,
);
}

return field;
}

function parseFieldContent({
formatString,
flags,
name,
childrenFields,
metadata,
}: {
formatString: string;
flags: Flags;
name: string;
childrenFields: arrow.Field[];
metadata: Map<string, string> | null;
}): arrow.Field {
const primitiveType = formatMapping[formatString];
if (primitiveType) {
return new arrow.Field(name, primitiveType, flags.nullable, metadata);
Expand Down
74 changes: 73 additions & 1 deletion src/vector.ts
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@ export function parseData<T extends DataType>(
}

const ptrToChildrenPtrs = dataView.getUint32(ptr + 44, true);
const dictionaryPtr = dataView.getUint32(ptr + 48, true);

const children: arrow.Data[] = new Array(Number(nChildren));
for (let i = 0; i < nChildren; i++) {
children[i] = parseData(
Expand All @@ -79,6 +81,77 @@ export function parseData<T extends DataType>(
);
}

// Special case for handling dictionary-encoded arrays
if (dictionaryPtr !== 0) {
const dictionaryType = dataType as unknown as arrow.Dictionary;

// the parent structure points to the index data, the ArrowArray.dictionary
// points to the dictionary values array.
const indicesType = dictionaryType.indices;
const dictionaryIndices = parseDataContent({
dataType: indicesType,
dataView,
copy,
length,
nullCount,
offset,
nChildren,
children,
bufferPtrs,
});

const valueType = dictionaryType.dictionary.type;
const dictionaryValues = parseData(buffer, dictionaryPtr, valueType, copy);

// @ts-expect-error we're casting to dictionary type
return arrow.makeData({
type: dictionaryType,
// TODO: double check that this offset should be set on both the values
// and indices of the dictionary array
offset,
length,
nullCount,
nullBitmap: dictionaryIndices.nullBitmap,
// Note: Here we need to pass in the _raw TypedArray_ not a Data instance
data: dictionaryIndices.values,
dictionary: arrow.makeVector(dictionaryValues),
});
} else {
return parseDataContent({
dataType,
dataView,
copy,
length,
nullCount,
offset,
nChildren,
children,
bufferPtrs,
});
}
}

function parseDataContent<T extends DataType>({
dataType,
dataView,
copy,
length,
nullCount,
offset,
nChildren,
children,
bufferPtrs,
}: {
dataType: T;
dataView: DataView;
copy: boolean;
length: number;
nullCount: number;
offset: number;
nChildren: number;
children: arrow.Data[];
bufferPtrs: Uint32Array;
}): arrow.Data<T> {
if (DataType.isNull(dataType)) {
return arrow.makeData({
type: dataType,
Expand Down Expand Up @@ -653,7 +726,6 @@ export function parseData<T extends DataType>(
});
}

// TODO: map arrays, dictionary encoding
throw new Error(`Unsupported type ${dataType}`);
}

Expand Down
97 changes: 75 additions & 22 deletions tests/ffi.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -196,8 +196,7 @@ describe("binary", (t) => {
);

const originalField = TEST_TABLE.schema.fields[columnIndex];
// declare it's not null
const originalVector = TEST_TABLE.getChildAt(columnIndex) as arrow.Vector;
const originalVector = TEST_TABLE.getChildAt(columnIndex)!;
const fieldPtr = FFI_TABLE.schemaAddr(columnIndex);
const field = parseField(WASM_MEMORY.buffer, fieldPtr);

Expand Down Expand Up @@ -277,8 +276,7 @@ describe("string", (t) => {
);

const originalField = TEST_TABLE.schema.fields[columnIndex];
// declare it's not null
const originalVector = TEST_TABLE.getChildAt(columnIndex) as arrow.Vector;
const originalVector = TEST_TABLE.getChildAt(columnIndex)!;
const fieldPtr = FFI_TABLE.schemaAddr(columnIndex);
const field = parseField(WASM_MEMORY.buffer, fieldPtr);

Expand Down Expand Up @@ -346,8 +344,7 @@ describe("boolean", (t) => {
);

const originalField = TEST_TABLE.schema.fields[columnIndex];
// declare it's not null
const originalVector = TEST_TABLE.getChildAt(columnIndex) as arrow.Vector;
const originalVector = TEST_TABLE.getChildAt(columnIndex)!;
const fieldPtr = FFI_TABLE.schemaAddr(columnIndex);
const field = parseField(WASM_MEMORY.buffer, fieldPtr);

Expand Down Expand Up @@ -379,8 +376,7 @@ describe("null array", (t) => {
);

const originalField = TEST_TABLE.schema.fields[columnIndex];
// declare it's not null
const originalVector = TEST_TABLE.getChildAt(columnIndex) as arrow.Vector;
const originalVector = TEST_TABLE.getChildAt(columnIndex)!;
const fieldPtr = FFI_TABLE.schemaAddr(columnIndex);
const field = parseField(WASM_MEMORY.buffer, fieldPtr);

Expand Down Expand Up @@ -412,8 +408,7 @@ describe("list array", (t) => {
);

const originalField = TEST_TABLE.schema.fields[columnIndex];
// declare it's not null
const originalVector = TEST_TABLE.getChildAt(columnIndex) as arrow.Vector;
const originalVector = TEST_TABLE.getChildAt(columnIndex)!;
const fieldPtr = FFI_TABLE.schemaAddr(columnIndex);
const field = parseField(WASM_MEMORY.buffer, fieldPtr);

Expand Down Expand Up @@ -499,8 +494,7 @@ describe("extension array", (t) => {
);

const originalField = TEST_TABLE.schema.fields[columnIndex];
// declare it's not null
const originalVector = TEST_TABLE.getChildAt(columnIndex) as arrow.Vector;
const originalVector = TEST_TABLE.getChildAt(columnIndex)!;
const fieldPtr = FFI_TABLE.schemaAddr(columnIndex);
const field = parseField(WASM_MEMORY.buffer, fieldPtr);

Expand Down Expand Up @@ -544,8 +538,7 @@ describe("extension array", (t) => {
// );

// const originalField = TEST_TABLE.schema.fields[columnIndex];
// // declare it's not null
// const originalVector = TEST_TABLE.getChildAt(columnIndex) as arrow.Vector;
// const originalVector = TEST_TABLE.getChildAt(columnIndex)!;
// const fieldPtr = FFI_TABLE.schemaAddr(columnIndex);
// const field = parseField(WASM_MEMORY.buffer, fieldPtr);

Expand All @@ -572,8 +565,7 @@ describe("date32", (t) => {
);

const originalField = TEST_TABLE.schema.fields[columnIndex];
// declare it's not null
const originalVector = TEST_TABLE.getChildAt(columnIndex) as arrow.Vector;
const originalVector = TEST_TABLE.getChildAt(columnIndex)!;
const fieldPtr = FFI_TABLE.schemaAddr(columnIndex);
const field = parseField(WASM_MEMORY.buffer, fieldPtr);

Expand Down Expand Up @@ -606,8 +598,7 @@ describe("date32", (t) => {
// );

// const originalField = TEST_TABLE.schema.fields[columnIndex];
// // declare it's not null
// const originalVector = TEST_TABLE.getChildAt(columnIndex) as arrow.Vector;
// const originalVector = TEST_TABLE.getChildAt(columnIndex)!;
// const fieldPtr = FFI_TABLE.schemaAddr(columnIndex);
// const field = parseField(WASM_MEMORY.buffer, fieldPtr);

Expand All @@ -634,8 +625,7 @@ describe("duration", (t) => {
);

const originalField = TEST_TABLE.schema.fields[columnIndex];
// declare it's not null
const originalVector = TEST_TABLE.getChildAt(columnIndex) as arrow.Vector;
const originalVector = TEST_TABLE.getChildAt(columnIndex)!;
const fieldPtr = FFI_TABLE.schemaAddr(columnIndex);
const field = parseField(WASM_MEMORY.buffer, fieldPtr);

Expand Down Expand Up @@ -667,8 +657,7 @@ describe("nullable int", (t) => {
);

const originalField = TEST_TABLE.schema.fields[columnIndex];
// declare it's not null
const originalVector = TEST_TABLE.getChildAt(columnIndex) as arrow.Vector;
const originalVector = TEST_TABLE.getChildAt(columnIndex)!;
const fieldPtr = FFI_TABLE.schemaAddr(columnIndex);
const field = parseField(WASM_MEMORY.buffer, fieldPtr);

Expand All @@ -693,3 +682,67 @@ describe("nullable int", (t) => {
it("copy=false", () => test(false));
it("copy=true", () => test(true));
});

describe("dictionary encoded string", (t) => {
function test(copy: boolean) {
let columnIndex = TEST_TABLE.schema.fields.findIndex(
(field) => field.name == "dictionary_encoded_string"
);

const originalField = TEST_TABLE.schema.fields[columnIndex];
const originalVector = TEST_TABLE.getChildAt(columnIndex)!;
const fieldPtr = FFI_TABLE.schemaAddr(columnIndex);
const field = parseField(WASM_MEMORY.buffer, fieldPtr);

expect(field.name).toStrictEqual(originalField.name);
expect(field.typeId).toStrictEqual(originalField.typeId);
expect(field.nullable).toStrictEqual(originalField.nullable);

const arrayPtr = FFI_TABLE.arrayAddr(0, columnIndex);
const wasmVector = parseVector(
WASM_MEMORY.buffer,
arrayPtr,
field.type,
copy
);

for (let i = 0; i < 3; i++) {
expect(originalVector.get(i)).toStrictEqual(wasmVector.get(i));
}
}

it("copy=false", () => test(false));
it("copy=true", () => test(true));
});

describe("dictionary encoded string (with nulls)", (t) => {
function test(copy: boolean) {
let columnIndex = TEST_TABLE.schema.fields.findIndex(
(field) => field.name == "dictionary_encoded_string_null"
);

const originalField = TEST_TABLE.schema.fields[columnIndex];
const originalVector = TEST_TABLE.getChildAt(columnIndex)!;
const fieldPtr = FFI_TABLE.schemaAddr(columnIndex);
const field = parseField(WASM_MEMORY.buffer, fieldPtr);

expect(field.name).toStrictEqual(originalField.name);
expect(field.typeId).toStrictEqual(originalField.typeId);
expect(field.nullable).toStrictEqual(originalField.nullable);

const arrayPtr = FFI_TABLE.arrayAddr(0, columnIndex);
const wasmVector = parseVector(
WASM_MEMORY.buffer,
arrayPtr,
field.type,
copy
);

for (let i = 0; i < 3; i++) {
expect(originalVector.get(i)).toStrictEqual(wasmVector.get(i));
}
}

it("copy=false", () => test(false));
it("copy=true", () => test(true));
});
13 changes: 13 additions & 0 deletions tests/pyarrow_generate_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import numpy as np
import pandas as pd
import pyarrow as pa
import pyarrow.compute as pc
import pyarrow.feather as feather


Expand Down Expand Up @@ -194,6 +195,16 @@ def dense_union_array() -> pa.Array:
return union_arr


def dictionary_encoded_string_array() -> pa.DictionaryArray:
arr = pa.StringArray.from_pandas(["a", "a", "b"])
return pc.dictionary_encode(arr)


def dictionary_encoded_string_array_null() -> pa.DictionaryArray:
arr = pa.StringArray.from_pandas(["a", "a", None])
return pc.dictionary_encode(arr)


class MyExtensionType(pa.ExtensionType):
"""
Refer to https://arrow.apache.org/docs/python/extending_types.html for
Expand Down Expand Up @@ -243,6 +254,8 @@ def table() -> pa.Table:
"sparse_union": sparse_union_array(),
"dense_union": dense_union_array(),
"duration": duration_array(),
"dictionary_encoded_string": dictionary_encoded_string_array(),
"dictionary_encoded_string_null": dictionary_encoded_string_array_null(),
}
)

Expand Down
Binary file modified tests/table.arrow
Binary file not shown.

0 comments on commit 83aa1ba

Please sign in to comment.