Skip to content

Commit 74de834

Browse files
committed
add test for dictionary arrays of types that have custom encoding
1 parent 6276df6 commit 74de834

File tree

1 file changed

+126
-7
lines changed

1 file changed

+126
-7
lines changed

arrow-json/src/writer/mod.rs

+126-7
Original file line numberDiff line numberDiff line change
@@ -437,6 +437,7 @@ where
437437
#[cfg(test)]
438438
mod tests {
439439
use core::str;
440+
use std::collections::HashMap;
440441
use std::fs::{read_to_string, File};
441442
use std::io::{BufReader, Seek};
442443
use std::sync::Arc;
@@ -2126,16 +2127,12 @@ mod tests {
21262127
DataType::Null => None,
21272128
DataType::Int32 => Some(UnionValue::Int32(
21282129
buffers[type_id as usize]
2129-
.as_any()
2130-
.downcast_ref::<Int32Array>()
2131-
.unwrap()
2130+
.as_primitive::<Int32Type>()
21322131
.value(idx),
21332132
)),
21342133
DataType::Utf8 => Some(UnionValue::String(
21352134
buffers[type_id as usize]
2136-
.as_any()
2137-
.downcast_ref::<StringArray>()
2138-
.unwrap()
2135+
.as_string::<i32>()
21392136
.value(idx)
21402137
.to_string(),
21412138
)),
@@ -2329,7 +2326,7 @@ mod tests {
23292326
) -> Result<Option<Box<dyn Encoder + 'a>>, ArrowError> {
23302327
match array.data_type() {
23312328
DataType::Binary => {
2332-
let array = array.as_any().downcast_ref::<BinaryArray>().unwrap();
2329+
let array = array.as_binary::<i32>();
23332330
let encoder = IntArrayBinaryEncoder { array };
23342331
Ok(Some(Box::new(encoder)))
23352332
}
@@ -2371,4 +2368,126 @@ mod tests {
23712368

23722369
assert_eq!(json_value, expected);
23732370
}
2371+
2372+
#[test]
2373+
fn test_encoder_factory_customize_dictionary() {
2374+
// Test that we can customize the encoding of T even when it shows up as Dictionary<_, T>.
2375+
2376+
// No particular reason to choose this example.
2377+
// Just trying to add some variety to the test cases and demonstrate use cases of the encoder factory.
2378+
struct PaddedInt32Encoder<'a> {
2379+
array: Int32Array,
2380+
nulls: Option<&'a NullBuffer>,
2381+
}
2382+
2383+
impl<'a> Encoder for PaddedInt32Encoder<'a> {
2384+
fn encode(&mut self, idx: usize, out: &mut Vec<u8>) {
2385+
let value = self.array.value(idx);
2386+
write!(out, "\"{value:0>8}\"").unwrap();
2387+
}
2388+
2389+
fn has_nulls(&self) -> bool {
2390+
self.array.is_nullable()
2391+
}
2392+
2393+
fn is_null(&self, idx: usize) -> bool {
2394+
self.nulls
2395+
.map(|nulls| nulls.is_null(idx))
2396+
.unwrap_or_default()
2397+
}
2398+
}
2399+
2400+
#[derive(Debug)]
2401+
struct CustomEncoderFactory;
2402+
2403+
impl EncoderFactory for CustomEncoderFactory {
2404+
fn make_default_encoder<'a>(
2405+
&self,
2406+
field: &FieldRef,
2407+
array: &'a dyn Array,
2408+
_options: &EncoderOptions,
2409+
) -> Result<Option<Box<dyn Encoder + 'a>>, ArrowError> {
2410+
// The point here is:
2411+
// 1. You can use information from Field to determine how to do the encoding.
2412+
// 2. For dictionary arrays the Field is always the outer field but the array may be the keys or values array
2413+
// and thus the data type of `field` may not match the data type of `array`.
2414+
let padded = field
2415+
.metadata()
2416+
.get("padded")
2417+
.map(|v| v == "true")
2418+
.unwrap_or_default();
2419+
match (array.data_type(), padded) {
2420+
(DataType::Int32, true) => {
2421+
let array = array.as_primitive::<Int32Type>();
2422+
let nulls = array.nulls();
2423+
let encoder = PaddedInt32Encoder {
2424+
array: array.clone(),
2425+
nulls,
2426+
};
2427+
Ok(Some(Box::new(encoder)))
2428+
}
2429+
_ => Ok(None),
2430+
}
2431+
}
2432+
}
2433+
2434+
let to_json = |batch| {
2435+
let mut buf = Vec::new();
2436+
let mut writer = WriterBuilder::new()
2437+
.with_encoder_factory(Arc::new(CustomEncoderFactory))
2438+
.build::<_, JsonArray>(&mut buf);
2439+
writer.write_batches(&[batch]).unwrap();
2440+
writer.finish().unwrap();
2441+
serde_json::from_slice::<Value>(&buf).unwrap()
2442+
};
2443+
2444+
// Control case: no dictionary wrapping works as expected.
2445+
let array = Int32Array::from(vec![Some(1), None, Some(2)]);
2446+
let field = Arc::new(Field::new("int", DataType::Int32, true).with_metadata(
2447+
HashMap::from_iter(vec![("padded".to_string(), "true".to_string())]),
2448+
));
2449+
let batch = RecordBatch::try_new(
2450+
Arc::new(Schema::new(vec![field.clone()])),
2451+
vec![Arc::new(array)],
2452+
)
2453+
.unwrap();
2454+
2455+
let json_value = to_json(&batch);
2456+
2457+
let expected = json!([
2458+
{"int": "00000001"},
2459+
{},
2460+
{"int": "00000002"},
2461+
]);
2462+
2463+
assert_eq!(json_value, expected);
2464+
2465+
// Now make a dictionary batch
2466+
let mut array_builder = PrimitiveDictionaryBuilder::<UInt16Type, Int32Type>::new();
2467+
array_builder.append_value(1);
2468+
array_builder.append_null();
2469+
array_builder.append_value(1);
2470+
let array = array_builder.finish();
2471+
let field = Field::new(
2472+
"int",
2473+
DataType::Dictionary(Box::new(DataType::UInt16), Box::new(DataType::Int32)),
2474+
true,
2475+
)
2476+
.with_metadata(HashMap::from_iter(vec![(
2477+
"padded".to_string(),
2478+
"true".to_string(),
2479+
)]));
2480+
let batch = RecordBatch::try_new(Arc::new(Schema::new(vec![field])), vec![Arc::new(array)])
2481+
.unwrap();
2482+
2483+
let json_value = to_json(&batch);
2484+
2485+
let expected = json!([
2486+
{"int": "00000001"},
2487+
{},
2488+
{"int": "00000001"},
2489+
]);
2490+
2491+
assert_eq!(json_value, expected);
2492+
}
23742493
}

0 commit comments

Comments
 (0)