From 2b63bf2267235e2bc1703bd3316416c0987789d5 Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Fri, 13 Dec 2024 00:30:48 +0800 Subject: [PATCH] Support all integer types in ubjson. To support serializing integer index for categorical features. --- include/xgboost/json.h | 10 ++++++++++ include/xgboost/json_io.h | 8 ++++++-- src/common/json.cc | 27 +++++++++++++++++++++------ tests/cpp/common/test_json.cc | 14 ++++++++++++++ tests/cpp/test_serialization.cc | 8 ++++++++ 5 files changed, 59 insertions(+), 8 deletions(-) diff --git a/include/xgboost/json.h b/include/xgboost/json.h index 82bc56f76b52..ce7b8b1d8e95 100644 --- a/include/xgboost/json.h +++ b/include/xgboost/json.h @@ -43,7 +43,9 @@ class Value { // typed array for ubjson kF32Array, kF64Array, + kI8Array, kU8Array, + kI16Array, kI32Array, kI64Array }; @@ -179,10 +181,18 @@ using F32Array = JsonTypedArray; * @brief Typed UBJSON array for 64-bit floating point. */ using F64Array = JsonTypedArray; +/** + * @brief Typed UBJSON array for int8_t. + */ +using I8Array = JsonTypedArray; /** * @brief Typed UBJSON array for uint8_t. */ using U8Array = JsonTypedArray; +/** + * @brief Typed UBJSON array for int16_t. + */ +using I16Array = JsonTypedArray; /** * @brief Typed UBJSON array for int32_t. */ diff --git a/include/xgboost/json_io.h b/include/xgboost/json_io.h index 57f8005ab777..aea854a85b67 100644 --- a/include/xgboost/json_io.h +++ b/include/xgboost/json_io.h @@ -140,7 +140,9 @@ class JsonWriter { virtual void Visit(JsonArray const* arr); virtual void Visit(F32Array const* arr); virtual void Visit(F64Array const*) { LOG(FATAL) << "Only UBJSON format can handle f64 array."; } - virtual void Visit(U8Array const* arr); + virtual void Visit(I8Array const* arr); + virtual void Visit(U8Array const* arr); + virtual void Visit(I16Array const* arr); virtual void Visit(I32Array const* arr); virtual void Visit(I64Array const* arr); virtual void Visit(JsonObject const* obj); @@ -218,7 +220,7 @@ class UBJReader : public JsonReader { } template - auto ParseTypedArray(int64_t n) { + auto ParseTypedArray(std::int64_t n) { TypedArray results{static_cast(n)}; for (int64_t i = 0; i < n; ++i) { auto v = this->ReadPrimitive(); @@ -244,7 +246,9 @@ class UBJWriter : public JsonWriter { void Visit(JsonArray const* arr) override; void Visit(F32Array const* arr) override; void Visit(F64Array const* arr) override; + void Visit(I8Array const* arr) override; void Visit(U8Array const* arr) override; + void Visit(I16Array const* arr) override; void Visit(I32Array const* arr) override; void Visit(I64Array const* arr) override; void Visit(JsonObject const* obj) override; diff --git a/src/common/json.cc b/src/common/json.cc index da3df5366169..26d56bb1b03e 100644 --- a/src/common/json.cc +++ b/src/common/json.cc @@ -6,14 +6,13 @@ #include // for array #include // for isdigit #include // for isinf, isnan -#include // for uint8_t, uint16_t, uint32_t +#include // for uint8_t, int16_t, int32_t, int64_t #include // for EOF #include // for size_t, strtof #include // for memcpy #include // for initializer_list #include // for distance #include // for numeric_limits -#include // for allocator #include // for operator<<, basic_ostream, operator&, ios, stringstream #include // for errc @@ -39,7 +38,9 @@ void JsonWriter::Visit(F32Array const* arr) { namespace { auto to_i64 = [](auto v) { return Json{static_cast(v)}; }; } // anonymous namespace +void JsonWriter::Visit(I8Array const* arr) { this->WriteArray(arr, to_i64); } void JsonWriter::Visit(U8Array const* arr) { this->WriteArray(arr, to_i64); } +void JsonWriter::Visit(I16Array const* arr) { this->WriteArray(arr, to_i64); } void JsonWriter::Visit(I32Array const* arr) { this->WriteArray(arr, to_i64); } void JsonWriter::Visit(I64Array const* arr) { this->WriteArray(arr, to_i64); } @@ -149,8 +150,12 @@ std::string Value::TypeStr() const { return "F32Array"; case ValueKind::kF64Array: return "F64Array"; + case ValueKind::kI8Array: + return "I8Array"; case ValueKind::kU8Array: return "U8Array"; + case ValueKind::kI16Array: + return "I16Array"; case ValueKind::kI32Array: return "I32Array"; case ValueKind::kI64Array: @@ -268,7 +273,9 @@ bool JsonTypedArray::operator==(Value const& rhs) const { template class JsonTypedArray; template class JsonTypedArray; +template class JsonTypedArray; template class JsonTypedArray; +template class JsonTypedArray; template class JsonTypedArray; template class JsonTypedArray; @@ -715,8 +722,12 @@ Json UBJReader::ParseArray() { return ParseTypedArray(n); case 'D': return ParseTypedArray(n); + case 'i': + return ParseTypedArray(n); case 'U': return ParseTypedArray(n); + case 'I': + return ParseTypedArray(n); case 'l': return ParseTypedArray(n); case 'L': @@ -891,13 +902,15 @@ void WriteTypedArray(JsonTypedArray const* arr, std::vector* stre stream->push_back('d'); } else if (std::is_same_v) { stream->push_back('D'); - } else if (std::is_same_v) { + } else if (std::is_same_v) { stream->push_back('i'); - } else if (std::is_same_v) { + } else if (std::is_same_v) { stream->push_back('U'); - } else if (std::is_same_v) { + } else if (std::is_same_v) { + stream->push_back('I'); + } else if (std::is_same_v) { stream->push_back('l'); - } else if (std::is_same_v) { + } else if (std::is_same_v) { stream->push_back('L'); } else { LOG(FATAL) << "Not implemented"; @@ -920,7 +933,9 @@ void WriteTypedArray(JsonTypedArray const* arr, std::vector* stre void UBJWriter::Visit(F32Array const* arr) { WriteTypedArray(arr, stream_); } void UBJWriter::Visit(F64Array const* arr) { WriteTypedArray(arr, stream_); } +void UBJWriter::Visit(I8Array const* arr) { WriteTypedArray(arr, stream_); } void UBJWriter::Visit(U8Array const* arr) { WriteTypedArray(arr, stream_); } +void UBJWriter::Visit(I16Array const* arr) { WriteTypedArray(arr, stream_); } void UBJWriter::Visit(I32Array const* arr) { WriteTypedArray(arr, stream_); } void UBJWriter::Visit(I64Array const* arr) { WriteTypedArray(arr, stream_); } diff --git a/tests/cpp/common/test_json.cc b/tests/cpp/common/test_json.cc index e144bdc45b9f..7e290a3dcadc 100644 --- a/tests/cpp/common/test_json.cc +++ b/tests/cpp/common/test_json.cc @@ -600,8 +600,12 @@ TEST(Json, TypedArray) { size_t n = 16; F32Array f32{n}; std::iota(f32.GetArray().begin(), f32.GetArray().end(), -8); + I8Array i8{n}; + std::iota(i8.GetArray().begin(), i8.GetArray().end(), 0); U8Array u8{n}; std::iota(u8.GetArray().begin(), u8.GetArray().end(), 0); + I32Array i16{n}; + std::iota(i16.GetArray().begin(), i16.GetArray().end(), -8); I32Array i32{n}; std::iota(i32.GetArray().begin(), i32.GetArray().end(), -8); I64Array i64{n}; @@ -610,8 +614,12 @@ TEST(Json, TypedArray) { Json json{Object{}}; json["u8"] = std::move(u8); ASSERT_TRUE(IsA(json["u8"])); + json["i8"] = std::move(i8); + ASSERT_TRUE(IsA(json["i8"])); json["f32"] = std::move(f32); ASSERT_TRUE(IsA(json["f32"])); + json["i16"] = std::move(i16); + ASSERT_TRUE(IsA(json["i16"])); json["i32"] = std::move(i32); ASSERT_TRUE(IsA(json["i32"])); json["i64"] = std::move(i64); @@ -638,6 +646,12 @@ TEST(Json, TypedArray) { for (int32_t i = -8; i < 8; ++i) { ASSERT_EQ(arr[i + 8], i); } + + ASSERT_TRUE(IsA(loaded["i8"])) << loaded["i8"].GetValue().TypeStr(); + auto const& i8_arr = get(loaded["i8"]); + for (decltype(n) i = 0; i < n; ++i) { + ASSERT_EQ(i8_arr[i], i); + } } { diff --git a/tests/cpp/test_serialization.cc b/tests/cpp/test_serialization.cc index 283a56fc5ddf..83b4fd93bdb0 100644 --- a/tests/cpp/test_serialization.cc +++ b/tests/cpp/test_serialization.cc @@ -78,10 +78,18 @@ void CompareJSON(Json l, Json r) { } break; } + case Value::ValueKind::kI8Array: { + CompareIntArray(l, r); + break; + } case Value::ValueKind::kU8Array: { CompareIntArray(l, r); break; } + case Value::ValueKind::kI16Array: { + CompareIntArray(l, r); + break; + } case Value::ValueKind::kI32Array: { CompareIntArray(l, r); break;