Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support all integer types in ubjson. #11094

Merged
merged 1 commit into from
Dec 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions include/xgboost/json.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,9 @@ class Value {
// typed array for ubjson
kF32Array,
kF64Array,
kI8Array,
kU8Array,
kI16Array,
kI32Array,
kI64Array
};
Expand Down Expand Up @@ -179,10 +181,18 @@ using F32Array = JsonTypedArray<float, Value::ValueKind::kF32Array>;
* @brief Typed UBJSON array for 64-bit floating point.
*/
using F64Array = JsonTypedArray<double, Value::ValueKind::kF64Array>;
/**
* @brief Typed UBJSON array for int8_t.
*/
using I8Array = JsonTypedArray<std::int8_t, Value::ValueKind::kI8Array>;
/**
* @brief Typed UBJSON array for uint8_t.
*/
using U8Array = JsonTypedArray<std::uint8_t, Value::ValueKind::kU8Array>;
/**
* @brief Typed UBJSON array for int16_t.
*/
using I16Array = JsonTypedArray<std::int16_t, Value::ValueKind::kI16Array>;
/**
* @brief Typed UBJSON array for int32_t.
*/
Expand Down
8 changes: 6 additions & 2 deletions include/xgboost/json_io.h
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,9 @@ class JsonWriter {
virtual void Visit(JsonArray const* arr);
virtual void Visit(F32Array const* arr);
virtual void Visit(F64Array const*) { LOG(FATAL) << "Only UBJSON format can handle f64 array."; }
virtual void Visit(U8Array const* arr);
virtual void Visit(I8Array const* arr);
virtual void Visit(U8Array const* arr);
virtual void Visit(I16Array const* arr);
virtual void Visit(I32Array const* arr);
virtual void Visit(I64Array const* arr);
virtual void Visit(JsonObject const* obj);
Expand Down Expand Up @@ -218,7 +220,7 @@ class UBJReader : public JsonReader {
}

template <typename TypedArray>
auto ParseTypedArray(int64_t n) {
auto ParseTypedArray(std::int64_t n) {
TypedArray results{static_cast<size_t>(n)};
for (int64_t i = 0; i < n; ++i) {
auto v = this->ReadPrimitive<typename TypedArray::Type>();
Expand All @@ -244,7 +246,9 @@ class UBJWriter : public JsonWriter {
void Visit(JsonArray const* arr) override;
void Visit(F32Array const* arr) override;
void Visit(F64Array const* arr) override;
void Visit(I8Array const* arr) override;
void Visit(U8Array const* arr) override;
void Visit(I16Array const* arr) override;
void Visit(I32Array const* arr) override;
void Visit(I64Array const* arr) override;
void Visit(JsonObject const* obj) override;
Expand Down
27 changes: 21 additions & 6 deletions src/common/json.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,13 @@
#include <array> // for array
#include <cctype> // for isdigit
#include <cmath> // for isinf, isnan
#include <cstdint> // for uint8_t, uint16_t, uint32_t
#include <cstdint> // for uint8_t, int16_t, int32_t, int64_t
#include <cstdio> // for EOF
#include <cstdlib> // for size_t, strtof
#include <cstring> // for memcpy
#include <initializer_list> // for initializer_list
#include <iterator> // for distance
#include <limits> // for numeric_limits
#include <memory> // for allocator
#include <sstream> // for operator<<, basic_ostream, operator&, ios, stringstream
#include <system_error> // for errc

Expand All @@ -39,7 +38,9 @@ void JsonWriter::Visit(F32Array const* arr) {
namespace {
auto to_i64 = [](auto v) { return Json{static_cast<int64_t>(v)}; };
} // anonymous namespace
void JsonWriter::Visit(I8Array const* arr) { this->WriteArray(arr, to_i64); }
void JsonWriter::Visit(U8Array const* arr) { this->WriteArray(arr, to_i64); }
void JsonWriter::Visit(I16Array const* arr) { this->WriteArray(arr, to_i64); }
void JsonWriter::Visit(I32Array const* arr) { this->WriteArray(arr, to_i64); }
void JsonWriter::Visit(I64Array const* arr) { this->WriteArray(arr, to_i64); }

Expand Down Expand Up @@ -149,8 +150,12 @@ std::string Value::TypeStr() const {
return "F32Array";
case ValueKind::kF64Array:
return "F64Array";
case ValueKind::kI8Array:
return "I8Array";
case ValueKind::kU8Array:
return "U8Array";
case ValueKind::kI16Array:
return "I16Array";
case ValueKind::kI32Array:
return "I32Array";
case ValueKind::kI64Array:
Expand Down Expand Up @@ -268,7 +273,9 @@ bool JsonTypedArray<T, kind>::operator==(Value const& rhs) const {

template class JsonTypedArray<float, Value::ValueKind::kF32Array>;
template class JsonTypedArray<double, Value::ValueKind::kF64Array>;
template class JsonTypedArray<std::int8_t, Value::ValueKind::kI8Array>;
template class JsonTypedArray<std::uint8_t, Value::ValueKind::kU8Array>;
template class JsonTypedArray<std::int16_t, Value::ValueKind::kI16Array>;
template class JsonTypedArray<std::int32_t, Value::ValueKind::kI32Array>;
template class JsonTypedArray<std::int64_t, Value::ValueKind::kI64Array>;

Expand Down Expand Up @@ -715,8 +722,12 @@ Json UBJReader::ParseArray() {
return ParseTypedArray<F32Array>(n);
case 'D':
return ParseTypedArray<F64Array>(n);
case 'i':
return ParseTypedArray<I8Array>(n);
case 'U':
return ParseTypedArray<U8Array>(n);
case 'I':
return ParseTypedArray<I16Array>(n);
case 'l':
return ParseTypedArray<I32Array>(n);
case 'L':
Expand Down Expand Up @@ -891,13 +902,15 @@ void WriteTypedArray(JsonTypedArray<T, kind> const* arr, std::vector<char>* stre
stream->push_back('d');
} else if (std::is_same_v<T, double>) {
stream->push_back('D');
} else if (std::is_same_v<T, int8_t>) {
} else if (std::is_same_v<T, std::int8_t>) {
stream->push_back('i');
} else if (std::is_same_v<T, uint8_t>) {
} else if (std::is_same_v<T, std::uint8_t>) {
stream->push_back('U');
} else if (std::is_same_v<T, int32_t>) {
} else if (std::is_same_v<T, std::int16_t>) {
stream->push_back('I');
} else if (std::is_same_v<T, std::int32_t>) {
stream->push_back('l');
} else if (std::is_same_v<T, int64_t>) {
} else if (std::is_same_v<T, std::int64_t>) {
stream->push_back('L');
} else {
LOG(FATAL) << "Not implemented";
Expand All @@ -920,7 +933,9 @@ void WriteTypedArray(JsonTypedArray<T, kind> const* arr, std::vector<char>* stre

void UBJWriter::Visit(F32Array const* arr) { WriteTypedArray(arr, stream_); }
void UBJWriter::Visit(F64Array const* arr) { WriteTypedArray(arr, stream_); }
void UBJWriter::Visit(I8Array const* arr) { WriteTypedArray(arr, stream_); }
void UBJWriter::Visit(U8Array const* arr) { WriteTypedArray(arr, stream_); }
void UBJWriter::Visit(I16Array const* arr) { WriteTypedArray(arr, stream_); }
void UBJWriter::Visit(I32Array const* arr) { WriteTypedArray(arr, stream_); }
void UBJWriter::Visit(I64Array const* arr) { WriteTypedArray(arr, stream_); }

Expand Down
14 changes: 14 additions & 0 deletions tests/cpp/common/test_json.cc
Original file line number Diff line number Diff line change
Expand Up @@ -600,8 +600,12 @@ TEST(Json, TypedArray) {
size_t n = 16;
F32Array f32{n};
std::iota(f32.GetArray().begin(), f32.GetArray().end(), -8);
I8Array i8{n};
std::iota(i8.GetArray().begin(), i8.GetArray().end(), 0);
U8Array u8{n};
std::iota(u8.GetArray().begin(), u8.GetArray().end(), 0);
I32Array i16{n};
std::iota(i16.GetArray().begin(), i16.GetArray().end(), -8);
I32Array i32{n};
std::iota(i32.GetArray().begin(), i32.GetArray().end(), -8);
I64Array i64{n};
Expand All @@ -610,8 +614,12 @@ TEST(Json, TypedArray) {
Json json{Object{}};
json["u8"] = std::move(u8);
ASSERT_TRUE(IsA<U8Array>(json["u8"]));
json["i8"] = std::move(i8);
ASSERT_TRUE(IsA<I8Array>(json["i8"]));
json["f32"] = std::move(f32);
ASSERT_TRUE(IsA<F32Array>(json["f32"]));
json["i16"] = std::move(i16);
ASSERT_TRUE(IsA<I32Array>(json["i16"]));
json["i32"] = std::move(i32);
ASSERT_TRUE(IsA<I32Array>(json["i32"]));
json["i64"] = std::move(i64);
Expand All @@ -638,6 +646,12 @@ TEST(Json, TypedArray) {
for (int32_t i = -8; i < 8; ++i) {
ASSERT_EQ(arr[i + 8], i);
}

ASSERT_TRUE(IsA<I8Array>(loaded["i8"])) << loaded["i8"].GetValue().TypeStr();
auto const& i8_arr = get<I8Array>(loaded["i8"]);
for (decltype(n) i = 0; i < n; ++i) {
ASSERT_EQ(i8_arr[i], i);
}
}

{
Expand Down
8 changes: 8 additions & 0 deletions tests/cpp/test_serialization.cc
Original file line number Diff line number Diff line change
Expand Up @@ -78,10 +78,18 @@ void CompareJSON(Json l, Json r) {
}
break;
}
case Value::ValueKind::kI8Array: {
CompareIntArray<I8Array>(l, r);
break;
}
case Value::ValueKind::kU8Array: {
CompareIntArray<U8Array>(l, r);
break;
}
case Value::ValueKind::kI16Array: {
CompareIntArray<I16Array>(l, r);
break;
}
case Value::ValueKind::kI32Array: {
CompareIntArray<I32Array>(l, r);
break;
Expand Down
Loading