forked from NVIDIA/NVFlare
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
5 changed files
with
501 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,124 @@ | ||
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
import struct | ||
from io import BytesIO | ||
from typing import List | ||
|
||
SIGNATURE = "NVDADAM1" # DAM (Direct Accessible Marshalling) V1 | ||
PREFIX_LEN = 24 | ||
|
||
DATA_TYPE_INT = 1 | ||
DATA_TYPE_FLOAT = 2 | ||
DATA_TYPE_STRING = 3 | ||
DATA_TYPE_INT_ARRAY = 257 | ||
DATA_TYPE_FLOAT_ARRAY = 258 | ||
|
||
|
||
class DamEncoder: | ||
def __init__(self, data_set_id: int): | ||
self.data_set_id = data_set_id | ||
self.entries = [] | ||
self.buffer = BytesIO() | ||
|
||
def add_int_array(self, value: List[int]): | ||
self.entries.append((DATA_TYPE_INT_ARRAY, value)) | ||
|
||
def add_float_array(self, value: List[float]): | ||
self.entries.append((DATA_TYPE_FLOAT_ARRAY, value)) | ||
|
||
def finish(self) -> bytes: | ||
size = PREFIX_LEN | ||
for entry in self.entries: | ||
size += 16 | ||
size += len(entry) * 8 | ||
|
||
self.write_str(SIGNATURE) | ||
self.write_int64(size) | ||
self.write_int64(self.data_set_id) | ||
|
||
for entry in self.entries: | ||
data_type, value = entry | ||
self.write_int64(data_type) | ||
self.write_int64(len(value)) | ||
|
||
for x in value: | ||
if data_type == DATA_TYPE_INT_ARRAY: | ||
self.write_int64(x) | ||
else: | ||
self.write_float(x) | ||
|
||
return self.buffer.getvalue() | ||
|
||
def write_int64(self, value: int): | ||
self.buffer.write(struct.pack("q", value)) | ||
|
||
def write_float(self, value: float): | ||
self.buffer.write(struct.pack("d", value)) | ||
|
||
def write_str(self, value: str): | ||
self.buffer.write(value.encode("utf-8")) | ||
|
||
|
||
class DamDecoder: | ||
def __init__(self, buffer: bytes): | ||
self.buffer = buffer | ||
self.pos = 0 | ||
self.signature = self.read_string(8) | ||
self.size = self.read_int64() | ||
self.data_set_id = self.read_int64() | ||
|
||
def is_valid(self): | ||
return self.signature == SIGNATURE | ||
|
||
def get_data_set_id(self): | ||
return self.data_set_id | ||
|
||
def decode_int_array(self) -> List[int]: | ||
data_type = self.read_int64() | ||
if data_type != DATA_TYPE_INT_ARRAY: | ||
raise RuntimeError("Invalid data type for int array") | ||
|
||
num = self.read_int64() | ||
result = [0] * num | ||
for i in range(num): | ||
result[i] = self.read_int64() | ||
|
||
return result | ||
|
||
def decode_float_array(self): | ||
data_type = self.read_int64() | ||
if data_type != DATA_TYPE_FLOAT_ARRAY: | ||
raise RuntimeError("Invalid data type for float array") | ||
|
||
num = self.read_int64() | ||
result = [0.0] * num | ||
for i in range(num): | ||
result[i] = self.read_float() | ||
|
||
return result | ||
|
||
def read_string(self, length: int) -> str: | ||
result = self.buffer[self.pos : self.pos + length].decode("utf-8") | ||
self.pos += length | ||
return result | ||
|
||
def read_int64(self) -> int: | ||
(result,) = struct.unpack_from("q", self.buffer, self.pos) | ||
self.pos += 8 | ||
return result | ||
|
||
def read_float(self) -> float: | ||
(result,) = struct.unpack_from("d", self.buffer, self.pos) | ||
self.pos += 8 | ||
return result |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
from typing import Dict, List, Tuple | ||
|
||
from nvflare.apis.fl_context import FLContext | ||
|
||
|
||
class FeatureContext: | ||
def __init__(self, feature_id, sample_bin_assignment, num_bins: int): | ||
self.feature_id = feature_id | ||
self.num_bins = num_bins # how many bins this feature has | ||
self.sample_bin_assignment = sample_bin_assignment # sample/bin assignment; normalized to [0 .. num_bins-1] | ||
|
||
|
||
class AggregationContext: | ||
def __init__(self, features: List[FeatureContext], sample_groups: Dict[int, List[int]]): # group_id => sample Ids | ||
self.features = features | ||
self.sample_groups = sample_groups | ||
|
||
|
||
class FeatureAggregationResult: | ||
def __init__(self, feature_id: int, aggregated_hist: List[Tuple[int, int]]): | ||
self.feature_id = feature_id | ||
self.aggregated_hist = aggregated_hist # list of (G, H) values, one for each bin of the feature | ||
|
||
|
||
class DataConverter: | ||
def decode_gh_pairs(self, buffer: bytes, fl_ctx: FLContext) -> List[Tuple[int, int]]: | ||
"""Decode the buffer to extract (g, h) pairs. | ||
Args: | ||
buffer: the buffer to be decoded | ||
fl_ctx: FLContext info | ||
Returns: if the buffer contains (g, h) pairs, return a tuple of (g_numbers, h_numbers); | ||
otherwise, return None | ||
""" | ||
pass | ||
|
||
def decode_aggregation_context(self, buffer: bytes, fl_ctx: FLContext) -> AggregationContext: | ||
"""Decode the buffer to extract aggregation context info | ||
Args: | ||
buffer: buffer to be decoded | ||
fl_ctx: FLContext info | ||
Returns: if the buffer contains aggregation context, return an AggregationContext object; | ||
otherwise, return None | ||
""" | ||
pass | ||
|
||
def encode_aggregation_result( | ||
self, aggr_results: Dict[int, List[FeatureAggregationResult]], fl_ctx: FLContext | ||
) -> bytes: | ||
"""Encode an individual rank's aggr result to a buffer based on XGB data structure | ||
Args: | ||
aggr_results: aggregation result for all features and all groups from all clients | ||
group_id => list of feature aggr results | ||
fl_ctx: FLContext info | ||
Returns: a buffer of bytes | ||
""" | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,137 @@ | ||
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
from typing import Dict, List, Tuple | ||
|
||
from nvflare.apis.fl_context import FLContext | ||
from nvflare.app_common.xgb.sec.dam import DamDecoder, DamEncoder | ||
from nvflare.app_common.xgb.sec.data_converter import ( | ||
AggregationContext, | ||
DataConverter, | ||
FeatureAggregationResult, | ||
FeatureContext, | ||
) | ||
|
||
DATA_SET_GH_PAIRS = 1 | ||
DATA_SET_AGGREGATION = 2 | ||
DATA_SET_AGGREGATION_WITH_FEATURES = 3 | ||
DATA_SET_AGGREGATION_RESULT = 4 | ||
|
||
SCALE_FACTOR = 1000000.0 # Preserve 6 decimal places | ||
|
||
|
||
class ProcessorDataConverter(DataConverter): | ||
def __init__(self): | ||
super().__init__() | ||
self.features = [] | ||
self.feature_list = None | ||
self.num_samples = 0 | ||
|
||
def decode_gh_pairs(self, buffer: bytes, fl_ctx: FLContext) -> List[Tuple[int, int]]: | ||
decoder = DamDecoder(buffer) | ||
if not decoder.is_valid(): | ||
raise RuntimeError("GH Buffer is not properly encoded") | ||
|
||
if decoder.get_data_set_id() != DATA_SET_GH_PAIRS: | ||
raise RuntimeError(f"Data is not for GH Pairs: {decoder.get_data_set_id()}") | ||
|
||
float_array = decoder.decode_float_array() | ||
result = [] | ||
self.num_samples = int(len(float_array) / 2) | ||
|
||
for i in range(self.num_samples): | ||
result.append((self.float_to_int(float_array[2 * i]), self.float_to_int(float_array[2 * i + 1]))) | ||
|
||
return result | ||
|
||
def decode_aggregation_context(self, buffer: bytes, fl_ctx: FLContext) -> AggregationContext: | ||
decoder = DamDecoder(buffer) | ||
if not decoder.is_valid(): | ||
raise RuntimeError("Aggregation Buffer is not properly encoded") | ||
data_set_id = decoder.get_data_set_id() | ||
cuts = decoder.decode_int_array() | ||
|
||
if data_set_id == DATA_SET_AGGREGATION_WITH_FEATURES: | ||
self.feature_list = decoder.decode_int_array() | ||
num = len(self.feature_list) | ||
slots = decoder.decode_int_array() | ||
for i in range(num): | ||
bin_assignment = [] | ||
for row_id in range(self.num_samples): | ||
_, bin_num = self.slot_to_bin(cuts, slots[row_id * num + i]) | ||
bin_assignment.append(bin_num) | ||
|
||
bin_size = self.get_bin_size(cuts, self.feature_list[i]) | ||
feature_ctx = FeatureContext(self.feature_list[i], bin_assignment, bin_size) | ||
self.features.append(feature_ctx) | ||
elif data_set_id != DATA_SET_AGGREGATION: | ||
raise RuntimeError(f"Invalid DataSet: {data_set_id}") | ||
|
||
node_list = decoder.decode_int_array() | ||
sample_groups = {} | ||
for node in node_list: | ||
row_ids = decoder.decode_int_array() | ||
sample_groups[node] = row_ids | ||
|
||
return AggregationContext(self.features, sample_groups) | ||
|
||
def encode_aggregation_result( | ||
self, aggr_results: Dict[int, List[FeatureAggregationResult]], fl_ctx: FLContext | ||
) -> bytes: | ||
encoder = DamEncoder(DATA_SET_AGGREGATION_RESULT) | ||
node_list = sorted(aggr_results.keys()) | ||
encoder.add_int_array(node_list) | ||
|
||
for node in node_list: | ||
result_list = aggr_results.get(node) | ||
for f in self.feature_list: | ||
encoder.add_float_array(self.find_histo_for_feature(result_list, f)) | ||
|
||
return encoder.finish() | ||
|
||
@staticmethod | ||
def get_bin_size(cuts: [int], feature_id: int) -> int: | ||
return cuts[feature_id + 1] - cuts[feature_id] | ||
|
||
@staticmethod | ||
def slot_to_bin(cuts: [int], slot: int) -> Tuple[int, int]: | ||
if slot < 0 or slot >= cuts[-1]: | ||
raise RuntimeError(f"Invalid slot {slot}, out of range [0-{cuts[-1]-1}]") | ||
|
||
for i in range(len(cuts) - 1): | ||
if cuts[i] <= slot < cuts[i + 1]: | ||
bin_num = slot - cuts[i] | ||
return i, bin_num | ||
|
||
raise RuntimeError(f"Logic error. Slot {slot}, out of range [0-{cuts[-1] - 1}]") | ||
|
||
@staticmethod | ||
def float_to_int(value: float) -> int: | ||
return int(value * SCALE_FACTOR) | ||
|
||
@staticmethod | ||
def int_to_float(value: int) -> float: | ||
return value / SCALE_FACTOR | ||
|
||
@staticmethod | ||
def find_histo_for_feature(result_list: List[FeatureAggregationResult], feature_id: int) -> List[float]: | ||
for result in result_list: | ||
if result.feature_id == feature_id: | ||
float_array = [] | ||
for (g, h) in result.aggregated_hist: | ||
float_array.append(ProcessorDataConverter.int_to_float(g)) | ||
float_array.append(ProcessorDataConverter.int_to_float(h)) | ||
|
||
return float_array | ||
|
||
raise RuntimeError(f"Logic error. Feature {feature_id} not found in the list") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
from nvflare.app_common.xgb.sec.dam import DamDecoder, DamEncoder | ||
|
||
DATA_SET = 123456 | ||
INT_ARRAY = [123, 456, 789] | ||
FLOAT_ARRAY = [1.2, 2.3, 3.4, 4.5] | ||
|
||
|
||
class TestDam: | ||
def test_encode_decode(self): | ||
encoder = DamEncoder(DATA_SET) | ||
encoder.add_int_array(INT_ARRAY) | ||
encoder.add_float_array(FLOAT_ARRAY) | ||
buffer = encoder.finish() | ||
|
||
decoder = DamDecoder(buffer) | ||
assert decoder.is_valid() | ||
assert decoder.get_data_set_id() == DATA_SET | ||
|
||
int_array = decoder.decode_int_array() | ||
assert int_array == INT_ARRAY | ||
|
||
float_array = decoder.decode_float_array() | ||
assert float_array == FLOAT_ARRAY |
Oops, something went wrong.