Skip to content

Commit

Permalink
Added data_converter
Browse files Browse the repository at this point in the history
  • Loading branch information
nvidianz committed Apr 10, 2024
1 parent 4e5ba5d commit bf1695f
Show file tree
Hide file tree
Showing 5 changed files with 501 additions and 0 deletions.
124 changes: 124 additions & 0 deletions nvflare/app_common/xgb/sec/dam.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import struct
from io import BytesIO
from typing import List

SIGNATURE = "NVDADAM1" # DAM (Direct Accessible Marshalling) V1
PREFIX_LEN = 24

DATA_TYPE_INT = 1
DATA_TYPE_FLOAT = 2
DATA_TYPE_STRING = 3
DATA_TYPE_INT_ARRAY = 257
DATA_TYPE_FLOAT_ARRAY = 258


class DamEncoder:
def __init__(self, data_set_id: int):
self.data_set_id = data_set_id
self.entries = []
self.buffer = BytesIO()

def add_int_array(self, value: List[int]):
self.entries.append((DATA_TYPE_INT_ARRAY, value))

def add_float_array(self, value: List[float]):
self.entries.append((DATA_TYPE_FLOAT_ARRAY, value))

def finish(self) -> bytes:
size = PREFIX_LEN
for entry in self.entries:
size += 16
size += len(entry) * 8

self.write_str(SIGNATURE)
self.write_int64(size)
self.write_int64(self.data_set_id)

for entry in self.entries:
data_type, value = entry
self.write_int64(data_type)
self.write_int64(len(value))

for x in value:
if data_type == DATA_TYPE_INT_ARRAY:
self.write_int64(x)
else:
self.write_float(x)

return self.buffer.getvalue()

def write_int64(self, value: int):
self.buffer.write(struct.pack("q", value))

def write_float(self, value: float):
self.buffer.write(struct.pack("d", value))

def write_str(self, value: str):
self.buffer.write(value.encode("utf-8"))


class DamDecoder:
def __init__(self, buffer: bytes):
self.buffer = buffer
self.pos = 0
self.signature = self.read_string(8)
self.size = self.read_int64()
self.data_set_id = self.read_int64()

def is_valid(self):
return self.signature == SIGNATURE

def get_data_set_id(self):
return self.data_set_id

def decode_int_array(self) -> List[int]:
data_type = self.read_int64()
if data_type != DATA_TYPE_INT_ARRAY:
raise RuntimeError("Invalid data type for int array")

num = self.read_int64()
result = [0] * num
for i in range(num):
result[i] = self.read_int64()

return result

def decode_float_array(self):
data_type = self.read_int64()
if data_type != DATA_TYPE_FLOAT_ARRAY:
raise RuntimeError("Invalid data type for float array")

num = self.read_int64()
result = [0.0] * num
for i in range(num):
result[i] = self.read_float()

return result

def read_string(self, length: int) -> str:
result = self.buffer[self.pos : self.pos + length].decode("utf-8")
self.pos += length
return result

def read_int64(self) -> int:
(result,) = struct.unpack_from("q", self.buffer, self.pos)
self.pos += 8
return result

def read_float(self) -> float:
(result,) = struct.unpack_from("d", self.buffer, self.pos)
self.pos += 8
return result
78 changes: 78 additions & 0 deletions nvflare/app_common/xgb/sec/data_converter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict, List, Tuple

from nvflare.apis.fl_context import FLContext


class FeatureContext:
def __init__(self, feature_id, sample_bin_assignment, num_bins: int):
self.feature_id = feature_id
self.num_bins = num_bins # how many bins this feature has
self.sample_bin_assignment = sample_bin_assignment # sample/bin assignment; normalized to [0 .. num_bins-1]


class AggregationContext:
def __init__(self, features: List[FeatureContext], sample_groups: Dict[int, List[int]]): # group_id => sample Ids
self.features = features
self.sample_groups = sample_groups


class FeatureAggregationResult:
def __init__(self, feature_id: int, aggregated_hist: List[Tuple[int, int]]):
self.feature_id = feature_id
self.aggregated_hist = aggregated_hist # list of (G, H) values, one for each bin of the feature


class DataConverter:
def decode_gh_pairs(self, buffer: bytes, fl_ctx: FLContext) -> List[Tuple[int, int]]:
"""Decode the buffer to extract (g, h) pairs.
Args:
buffer: the buffer to be decoded
fl_ctx: FLContext info
Returns: if the buffer contains (g, h) pairs, return a tuple of (g_numbers, h_numbers);
otherwise, return None
"""
pass

def decode_aggregation_context(self, buffer: bytes, fl_ctx: FLContext) -> AggregationContext:
"""Decode the buffer to extract aggregation context info
Args:
buffer: buffer to be decoded
fl_ctx: FLContext info
Returns: if the buffer contains aggregation context, return an AggregationContext object;
otherwise, return None
"""
pass

def encode_aggregation_result(
self, aggr_results: Dict[int, List[FeatureAggregationResult]], fl_ctx: FLContext
) -> bytes:
"""Encode an individual rank's aggr result to a buffer based on XGB data structure
Args:
aggr_results: aggregation result for all features and all groups from all clients
group_id => list of feature aggr results
fl_ctx: FLContext info
Returns: a buffer of bytes
"""
pass
137 changes: 137 additions & 0 deletions nvflare/app_common/xgb/sec/processor_data_converter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict, List, Tuple

from nvflare.apis.fl_context import FLContext
from nvflare.app_common.xgb.sec.dam import DamDecoder, DamEncoder
from nvflare.app_common.xgb.sec.data_converter import (
AggregationContext,
DataConverter,
FeatureAggregationResult,
FeatureContext,
)

DATA_SET_GH_PAIRS = 1
DATA_SET_AGGREGATION = 2
DATA_SET_AGGREGATION_WITH_FEATURES = 3
DATA_SET_AGGREGATION_RESULT = 4

SCALE_FACTOR = 1000000.0 # Preserve 6 decimal places


class ProcessorDataConverter(DataConverter):
def __init__(self):
super().__init__()
self.features = []
self.feature_list = None
self.num_samples = 0

def decode_gh_pairs(self, buffer: bytes, fl_ctx: FLContext) -> List[Tuple[int, int]]:
decoder = DamDecoder(buffer)
if not decoder.is_valid():
raise RuntimeError("GH Buffer is not properly encoded")

if decoder.get_data_set_id() != DATA_SET_GH_PAIRS:
raise RuntimeError(f"Data is not for GH Pairs: {decoder.get_data_set_id()}")

float_array = decoder.decode_float_array()
result = []
self.num_samples = int(len(float_array) / 2)

for i in range(self.num_samples):
result.append((self.float_to_int(float_array[2 * i]), self.float_to_int(float_array[2 * i + 1])))

return result

def decode_aggregation_context(self, buffer: bytes, fl_ctx: FLContext) -> AggregationContext:
decoder = DamDecoder(buffer)
if not decoder.is_valid():
raise RuntimeError("Aggregation Buffer is not properly encoded")
data_set_id = decoder.get_data_set_id()
cuts = decoder.decode_int_array()

if data_set_id == DATA_SET_AGGREGATION_WITH_FEATURES:
self.feature_list = decoder.decode_int_array()
num = len(self.feature_list)
slots = decoder.decode_int_array()
for i in range(num):
bin_assignment = []
for row_id in range(self.num_samples):
_, bin_num = self.slot_to_bin(cuts, slots[row_id * num + i])
bin_assignment.append(bin_num)

bin_size = self.get_bin_size(cuts, self.feature_list[i])
feature_ctx = FeatureContext(self.feature_list[i], bin_assignment, bin_size)
self.features.append(feature_ctx)
elif data_set_id != DATA_SET_AGGREGATION:
raise RuntimeError(f"Invalid DataSet: {data_set_id}")

node_list = decoder.decode_int_array()
sample_groups = {}
for node in node_list:
row_ids = decoder.decode_int_array()
sample_groups[node] = row_ids

return AggregationContext(self.features, sample_groups)

def encode_aggregation_result(
self, aggr_results: Dict[int, List[FeatureAggregationResult]], fl_ctx: FLContext
) -> bytes:
encoder = DamEncoder(DATA_SET_AGGREGATION_RESULT)
node_list = sorted(aggr_results.keys())
encoder.add_int_array(node_list)

for node in node_list:
result_list = aggr_results.get(node)
for f in self.feature_list:
encoder.add_float_array(self.find_histo_for_feature(result_list, f))

return encoder.finish()

@staticmethod
def get_bin_size(cuts: [int], feature_id: int) -> int:
return cuts[feature_id + 1] - cuts[feature_id]

@staticmethod
def slot_to_bin(cuts: [int], slot: int) -> Tuple[int, int]:
if slot < 0 or slot >= cuts[-1]:
raise RuntimeError(f"Invalid slot {slot}, out of range [0-{cuts[-1]-1}]")

for i in range(len(cuts) - 1):
if cuts[i] <= slot < cuts[i + 1]:
bin_num = slot - cuts[i]
return i, bin_num

raise RuntimeError(f"Logic error. Slot {slot}, out of range [0-{cuts[-1] - 1}]")

@staticmethod
def float_to_int(value: float) -> int:
return int(value * SCALE_FACTOR)

@staticmethod
def int_to_float(value: int) -> float:
return value / SCALE_FACTOR

@staticmethod
def find_histo_for_feature(result_list: List[FeatureAggregationResult], feature_id: int) -> List[float]:
for result in result_list:
if result.feature_id == feature_id:
float_array = []
for (g, h) in result.aggregated_hist:
float_array.append(ProcessorDataConverter.int_to_float(g))
float_array.append(ProcessorDataConverter.int_to_float(h))

return float_array

raise RuntimeError(f"Logic error. Feature {feature_id} not found in the list")
36 changes: 36 additions & 0 deletions tests/unit_test/app_common/xgb/sec/dam_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from nvflare.app_common.xgb.sec.dam import DamDecoder, DamEncoder

DATA_SET = 123456
INT_ARRAY = [123, 456, 789]
FLOAT_ARRAY = [1.2, 2.3, 3.4, 4.5]


class TestDam:
def test_encode_decode(self):
encoder = DamEncoder(DATA_SET)
encoder.add_int_array(INT_ARRAY)
encoder.add_float_array(FLOAT_ARRAY)
buffer = encoder.finish()

decoder = DamDecoder(buffer)
assert decoder.is_valid()
assert decoder.get_data_set_id() == DATA_SET

int_array = decoder.decode_int_array()
assert int_array == INT_ARRAY

float_array = decoder.decode_float_array()
assert float_array == FLOAT_ARRAY
Loading

0 comments on commit bf1695f

Please sign in to comment.