diff --git a/RELEASE.md b/RELEASE.md index 1041a134c..29cb428ff 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -1,3 +1,7 @@ +# Release 1.3.0 +## Major Features and Improvements +* Hetero Secureboosting communication optimization: communication round is reduced to 1 by letting the host send a pre-computed host node route, which is used for inferencing, to the guest. + # Release 1.2.0 ## Major Features and Improvements * Replace serving-router with a brand new service called serving-proxy, which supports authentication and inference request with HTTP or gRPC diff --git a/fate-serving-core/src/main/java/com/webank/ai/fate/serving/core/bean/Dict.java b/fate-serving-core/src/main/java/com/webank/ai/fate/serving/core/bean/Dict.java index c6ae81f74..6c778a1e1 100644 --- a/fate-serving-core/src/main/java/com/webank/ai/fate/serving/core/bean/Dict.java +++ b/fate-serving-core/src/main/java/com/webank/ai/fate/serving/core/bean/Dict.java @@ -100,6 +100,8 @@ public class Dict { public static final String INPUT_DATA_HIT_RATE = "inputDataHitRate"; public static final String GUEST_MODEL_WEIGHT_HIT_RATE = "guestModelWeightHitRate"; public static final String GUEST_INPUT_DATA_HIT_RATE = "guestInputDataHitRate"; + public static final String TAG_INPUT_FORMAT = "tag"; + public static final String SPARSE_INPUT_FORMAT = "sparse"; public static final String MIN_MAX_SCALE = "min_max_scale"; public static final String STANDARD_SCALE = "standard_scale"; public static final String DSL_COMPONENTS = "components"; diff --git a/federatedml/src/main/java/com/webank/ai/fate/serving/federatedml/model/DataIO.java b/federatedml/src/main/java/com/webank/ai/fate/serving/federatedml/model/DataIO.java index 07d20bd4c..5ed4147fc 100644 --- a/federatedml/src/main/java/com/webank/ai/fate/serving/federatedml/model/DataIO.java +++ b/federatedml/src/main/java/com/webank/ai/fate/serving/federatedml/model/DataIO.java @@ -20,11 +20,13 @@ import com.webank.ai.fate.core.mlmodel.buffer.DataIOMetaProto.DataIOMeta; import com.webank.ai.fate.core.mlmodel.buffer.DataIOParamProto.DataIOParam; import com.webank.ai.fate.serving.core.bean.Context; +import com.webank.ai.fate.serving.core.bean.Dict; import com.webank.ai.fate.serving.core.bean.FederatedParams; import com.webank.ai.fate.serving.core.bean.StatusCode; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import java.util.HashMap; import java.util.List; import java.util.Map; @@ -32,6 +34,8 @@ public class DataIO extends BaseModel { private static final Logger logger = LoggerFactory.getLogger(DataIO.class); private DataIOMeta dataIOMeta; private DataIOParam dataIOParam; + private List header; + private String inputformat; private Imputer imputer; private Outlier outlier; private boolean isImputer; @@ -56,9 +60,12 @@ public int initModel(byte[] protoMeta, byte[] protoParam) { this.outlier = new Outlier(this.dataIOMeta.getOutlierMeta().getOutlierValueList(), this.dataIOParam.getOutlierParam().getOutlierReplaceValue()); } + + this.header = this.dataIOParam.getHeaderList(); + this.inputformat = this.dataIOMeta.getInputFormat(); } catch (Exception ex) { ex.printStackTrace(); - logger.error("init DataIo error",ex); + logger.error("init DataIo error", ex); return StatusCode.ILLEGALDATA; } logger.info("Finish init DataIO class"); @@ -67,16 +74,43 @@ public int initModel(byte[] protoMeta, byte[] protoParam) { @Override public Map handlePredict(Context context, List> inputData, FederatedParams predictParams) { - Map input = inputData.get(0); + Map data = inputData.get(0); + Map outputData = new HashMap<>(); + + if(logger.isDebugEnabled()) { + logger.debug("input-data, not filling, {}", data); + } + + if (this.inputformat.equals(Dict.TAG_INPUT_FORMAT) || this.inputformat.equals(Dict.SPARSE_INPUT_FORMAT + )) { + if(logger.isDebugEnabled()) { + logger.debug("Sparse Data Filling Zeros"); + } + for (String col: this.header) { + outputData.put(col, data.getOrDefault(col, 0)); + } + } else { + outputData = data; + if(logger.isDebugEnabled()) { + logger.debug("Dense input-format, not filling, {}", outputData); + } + } if (this.isImputer) { - input = this.imputer.transform(input); + outputData = this.imputer.transform(outputData); } if (this.isOutlier) { - input = this.outlier.transform(input); + outputData = this.outlier.transform(outputData); } - return input; + /* + for (String col: data.keySet()) { + if (!output.containsKey(col)) { + output.put(col, data.get(col)); + } + }*/ + + return outputData; } } diff --git a/federatedml/src/main/java/com/webank/ai/fate/serving/federatedml/model/HeteroFeatureBinning.java b/federatedml/src/main/java/com/webank/ai/fate/serving/federatedml/model/HeteroFeatureBinning.java index 4fad71acb..87f9e935c 100644 --- a/federatedml/src/main/java/com/webank/ai/fate/serving/federatedml/model/HeteroFeatureBinning.java +++ b/federatedml/src/main/java/com/webank/ai/fate/serving/federatedml/model/HeteroFeatureBinning.java @@ -15,6 +15,8 @@ import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.Collections; +import java.lang.Math; public class HeteroFeatureBinning extends BaseModel { @@ -56,38 +58,50 @@ public int initModel(byte[] protoMeta, byte[] protoParam) { @Override public Map handlePredict(Context context, List> inputData, FederatedParams predictParams) { HashMap outputData = new HashMap<>(8); + HashMap headerMap = new HashMap<>(); Map firstData = inputData.get(0); if (!this.needRun) { return firstData; } + for (int i = 0; i < this.header.size(); i++) { + headerMap.put(this.header.get(i), (long) i); + } + for (String colName : firstData.keySet()) { - try{ - if (! this.splitPoints.containsKey(colName)) { + try { + if (!this.splitPoints.containsKey(colName)) { outputData.put(colName, firstData.get(colName)); - continue; + continue; } - Long thisColIndex = (long) this.header.indexOf(colName); - if (! this.transformCols.contains(thisColIndex)) { +// Long thisColIndex = (long) this.header.indexOf(colName); + Long thisColIndex = headerMap.get(colName); + if (!this.transformCols.contains(thisColIndex)) { outputData.put(colName, firstData.get(colName)); continue; } List splitPoint = this.splitPoints.get(colName); Double colValue = Double.valueOf(firstData.get(colName).toString()); - int colIndex = 0; - for (colIndex = 0; colIndex < splitPoint.size(); colIndex ++) { - if (colValue <= splitPoint.get(colIndex)) { - break; - } - } - outputData.put(colName, colIndex); - }catch(Throwable e){ - logger.error("HeteroFeatureBinning error" ,e); + int colIndex = Collections.binarySearch(splitPoint, colValue); + if (colIndex < 0) { + colIndex = Math.min((- colIndex - 1), splitPoint.size() - 1); + } +// for (colIndex = 0; colIndex < splitPoint.size(); colIndex ++) { +// +// +// if (colValue <= splitPoint.get(colIndex)) { +// break; +// } +// } + outputData.put(colName, colIndex); + } catch (Throwable e) { + logger.error("HeteroFeatureBinning error", e); } } - if(logger.isDebugEnabled()) { - logger.debug("HeteroFeatureBinning output {}", outputData); + if (logger.isDebugEnabled()) { + logger.debug("DEBUG: HeteroFeatureBinning output {}", outputData); } + return outputData; } diff --git a/federatedml/src/main/java/com/webank/ai/fate/serving/federatedml/model/HeteroSecureBoost.java b/federatedml/src/main/java/com/webank/ai/fate/serving/federatedml/model/HeteroSecureBoost.java index 9c1adebce..7e18b1de5 100644 --- a/federatedml/src/main/java/com/webank/ai/fate/serving/federatedml/model/HeteroSecureBoost.java +++ b/federatedml/src/main/java/com/webank/ai/fate/serving/federatedml/model/HeteroSecureBoost.java @@ -40,6 +40,7 @@ public abstract class HeteroSecureBoost extends BaseModel { protected List classes; protected int treeDim; protected double learningRate; + protected boolean fastMode = true; @Override public int initModel(byte[] protoMeta, byte[] protoParam) { @@ -93,6 +94,7 @@ protected int gotoNextLevel(int treeId, int treeNodeId, Map inpu int fid = this.trees.get(treeId).getTree(treeNodeId).getFid(); double splitValue = this.trees.get(treeId).getSplitMaskdict().get(treeNodeId); String fidStr = String.valueOf(fid); + if (input.containsKey(fidStr)) { if (Double.parseDouble(input.get(fidStr).toString()) <= splitValue + 1e-20) { nextTreeNodeId = this.trees.get(treeId).getTree(treeNodeId).getLeftNodeid(); @@ -100,6 +102,7 @@ protected int gotoNextLevel(int treeId, int treeNodeId, Map inpu nextTreeNodeId = this.trees.get(treeId).getTree(treeNodeId).getRightNodeid(); } } else { + logger.info("go missing dir"); if (this.trees.get(treeId).getMissingDirMaskdict().containsKey(treeNodeId)) { int missingDir = this.trees.get(treeId).getMissingDirMaskdict().get(treeNodeId); if (missingDir == 1) { diff --git a/federatedml/src/main/java/com/webank/ai/fate/serving/federatedml/model/HeteroSecureBoostingTreeGuest.java b/federatedml/src/main/java/com/webank/ai/fate/serving/federatedml/model/HeteroSecureBoostingTreeGuest.java index 6bf6d53fd..5969143e7 100644 --- a/federatedml/src/main/java/com/webank/ai/fate/serving/federatedml/model/HeteroSecureBoostingTreeGuest.java +++ b/federatedml/src/main/java/com/webank/ai/fate/serving/federatedml/model/HeteroSecureBoostingTreeGuest.java @@ -35,7 +35,9 @@ private double sigmoid(double x) { return 1. / (1. + Math.exp(-x)); } - private Map softmax(double[] weights) { + private boolean fastMode = true; + + private Map softmax(double weights[]) { int n = weights.length; double max = weights[0]; int maxIndex = 0; @@ -99,6 +101,7 @@ private double getTreeLeafWeight(int treeId, int treeNodeId) { } private int traverseTree(int treeId, int treeNodeId, Map input) { + while (!this.isLocateInLeaf(treeId, treeNodeId) && this.getSite(treeId, treeNodeId).equals(this.site)) { treeNodeId = this.gotoNextLevel(treeId, treeNodeId, input); } @@ -106,6 +109,28 @@ private int traverseTree(int treeId, int treeNodeId, Map input) return treeNodeId; } + private int fastTraverseTree(int treeId, int treeNodeId, Map input, Map lookUpTable) { + + while(!this.isLocateInLeaf(treeId, treeNodeId)){ + if(this.getSite(treeId, treeNodeId).equals(this.site)){ + treeNodeId = this.gotoNextLevel(treeId, treeNodeId, input); + } + else{ + Map lookUp = (Map) lookUpTable.get(String.valueOf(treeId)); + if(lookUp.get(String.valueOf(treeNodeId))){ + treeNodeId = this.trees.get(treeId).getTree(treeNodeId).getLeftNodeid(); + } + else { + treeNodeId = this.trees.get(treeId).getTree(treeNodeId).getRightNodeid(); + } + } + if(logger.isDebugEnabled()) { + logger.info("tree id is {}, tree node is {}", treeId, treeNodeId); + } + } + + return treeNodeId; + } private Map getFinalPredict(double[] weights) { Map ret = new HashMap(8); @@ -121,9 +146,8 @@ private Map getFinalPredict(double[] weights) { sumWeights[i % this.treeDim] += weights[i] * this.learningRate; } - for (int i = 0; i < this.treeDim; i++) { + for (int i = 0; i < this.treeDim; i++) sumWeights[i] += this.initScore.get(i); - } ret = softmax(sumWeights); } else { @@ -139,14 +163,17 @@ private Map getFinalPredict(double[] weights) { @Override public Map handlePredict(Context context, List> inputData, FederatedParams predictParams) { - if(logger.isDebugEnabled()) { - logger.debug("HeteroSecureBoostingTreeGuest FederatedParams {}", predictParams); - } + + logger.info("HeteroSecureBoostingTreeGuest FederatedParams {}", predictParams); Map input = inputData.get(0); HashMap fidValueMapping = new HashMap(8); - ReturnResult returnResult = this.getFederatedPredict(context, predictParams, Dict.FEDERATED_INFERENCE, false); + if(!this.fastMode){ + // ask host to prepare data, if fast mode is not enabled + ReturnResult returnResult = this.getFederatedPredict(context, predictParams, Dict.FEDERATED_INFERENCE, false); + } + int featureHit = 0; for (String key : input.keySet()) { @@ -155,12 +182,13 @@ public Map handlePredict(Context context, List treeLocation = new HashMap(8); for (int i = 0; i < this.treeNum; ++i) { @@ -185,26 +213,57 @@ public Map handlePredict(Context context, List afterLocation = tempResult.getData(); - if(logger.isDebugEnabled()) { - logger.debug("after loccation is {}", afterLocation); + boolean getNodeRoute = false; + ReturnResult tempResult; + if(this.fastMode){ + getNodeRoute = true; + tempResult = this.getFederatedPredict(context, predictParams, Dict.FEDERATED_INFERENCE, false); } - for (String location : afterLocation.keySet()) { - treeNodeIds[new Integer(location)] = ((Number) afterLocation.get(location)).intValue(); + else{ + tempResult = this.getFederatedPredict(context, predictParams, Dict.FEDERATED_INFERENCE_FOR_TREE, false); } - if (afterLocation == null) { - logger.error("receive predict result of host is null"); - throw new Exception("Null Data"); + + Map returnData = tempResult.getData(); + + if(this.fastMode && getNodeRoute){ + + if(logger.isDebugEnabled()){ + logger.info("running fast mode, look up table is {}",returnData); + } + + for(String treeIdx: treeLocation.keySet()){ + int idx = Integer.valueOf(treeIdx); + int curNodeId = (Integer)treeLocation.get(treeIdx); + int final_node_id = this.fastTraverseTree(idx, curNodeId, fidValueMapping, returnData); + treeNodeIds[idx] = final_node_id; + } + } + else{ + Map afterLocation = tempResult.getData(); + + if(logger.isDebugEnabled()){ + logger.info("after location is {}", afterLocation); + } + + for (String location : afterLocation.keySet()) { + treeNodeIds[new Integer(location)] = ((Number) afterLocation.get(location)).intValue(); + } + if (afterLocation == null) { + logger.info("receive predict result of host is null"); + throw new Exception("Null Data"); + } } } catch (Exception ex) { ex.printStackTrace(); - logger.error("HeteroSecureBoostingTreeGuest handle error",ex); return null; } } @@ -212,12 +271,12 @@ public Map handlePredict(Context context, List forward(List> inputDatas) { ++featureHit; } } - logger.info("feature hit rate : {}", 1.0 * featureHit / this.featureNameFidMapping.size()); + LOGGER.info("feature hit rate : {}", 1.0 * featureHit / this.featureNameFidMapping.size()); } */ @@ -60,7 +62,6 @@ private int traverseTree(int treeId, int treeNodeId, Map input) public void saveData(Context context, String tag, Map data) { CacheManager.getInstance().store(context, tag, data); - } public Map getData(Context context, String tag) { @@ -70,11 +71,64 @@ public Map getData(Context context, String tag) { return data; } + public Map extractHostNodeRoute(Map input){ + + // > + + logger.info("running extractHostNodeRoute"); + + Map result = new HashMap(8); + for(int i=0;i nodes = treeParam.getTreeList(); + Map treeRoute = new HashMap(8); + + for(int j=0;j handlePredict(Context context, List> inputData, FederatedParams predictParams) { - if(logger.isDebugEnabled()) { - logger.debug("HeteroSecureBoostingTreeHost FederatedParams {}", predictParams); - } + + logger.info("HeteroSecureBoostingTreeHost FederatedParams {}", predictParams); + Map input = inputData.get(0); String tag = predictParams.getCaseId() + "." + this.componentName + "." + Dict.INPUT_DATA; @@ -88,21 +142,33 @@ public Map handlePredict(Context context, List> in first + // communication round + ret = this.extractHostNodeRoute(fidValueMapping); + } + else{ + this.saveData(context, tag, fidValueMapping); + } + return ret; } - public Map predictSingleRound(Context context, Map interactiveData, FederatedParams predictParams) { + String tag = predictParams.getCaseId() + "." + this.componentName + "." + Dict.INPUT_DATA; Map input = this.getData(context, tag); + Map ret = new HashMap(8); for (String treeIdx : interactiveData.keySet()) { int idx = Integer.valueOf(treeIdx); int nodeId = this.traverseTree(idx, (Integer) interactiveData.get(treeIdx), input); ret.put(treeIdx, nodeId); } - return ret; } } diff --git a/federatedml/src/main/java/com/webank/ai/fate/serving/federatedml/model/Imputer.java b/federatedml/src/main/java/com/webank/ai/fate/serving/federatedml/model/Imputer.java index 3deadd441..eac6b5c71 100644 --- a/federatedml/src/main/java/com/webank/ai/fate/serving/federatedml/model/Imputer.java +++ b/federatedml/src/main/java/com/webank/ai/fate/serving/federatedml/model/Imputer.java @@ -19,6 +19,7 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import java.util.HashMap; import java.util.HashSet; import java.util.List; import java.util.Map; @@ -34,21 +35,21 @@ public Imputer(List missingValues, Map missingReplaceVal } public Map transform(Map inputData) { - if(inputData!=null) { - for (String key : inputData.keySet()) { - if(inputData.get(key)!=null) { - String value = inputData.get(key).toString(); - if (this.missingValueSet.contains(value.toLowerCase())) { - try { - inputData.put(key, this.missingReplaceValues.get(key)); - } catch (Exception ex) { - logger.error("Imputer transform error",ex); - inputData.put(key, 0.); - } - } + Map output = new HashMap<>(); + + for (String col: this.missingReplaceValues.keySet()) { + if (inputData.containsKey(col)) { + String value = inputData.get(col).toString(); + if (this.missingValueSet.contains(value.toLowerCase())) { + output.put(col, this.missingReplaceValues.get(col)); + } else { + output.put(col, inputData.get(col)); } + } else { + output.put(col, this.missingReplaceValues.get(col)); } } - return inputData; + + return output; } } diff --git a/pom.xml b/pom.xml index 2ffceb41d..c954d1dcd 100644 --- a/pom.xml +++ b/pom.xml @@ -35,7 +35,7 @@ - 1.2.0 + 1.3.0 0.3 1.8 UTF-8 @@ -554,4 +554,4 @@ - \ No newline at end of file + diff --git a/serving-proxy/bin/service.sh b/serving-proxy/bin/service.sh index bcffd7b26..60c5ec3c6 100644 --- a/serving-proxy/bin/service.sh +++ b/serving-proxy/bin/service.sh @@ -24,7 +24,7 @@ configpath=$(cd $basepath/conf;pwd) module=serving-proxy main_class=com.webank.ai.fate.serving.proxy.bootstrap.Bootstrap -module_version=1.2.0 +module_version=1.3.0 case "$1" in diff --git a/serving-server/bin/service.sh b/serving-server/bin/service.sh index 45a70cb02..f14a4ac74 100644 --- a/serving-server/bin/service.sh +++ b/serving-server/bin/service.sh @@ -22,7 +22,7 @@ set -e source ./bin/common.sh module=serving-server main_class=com.webank.ai.fate.serving.ServingServer -module_version=1.2.0 +module_version=1.3.0 case "$1" in @@ -47,4 +47,4 @@ case "$1" in *) echo "usage: $0 {start|stop|status|restart}" exit 1 -esac \ No newline at end of file +esac diff --git a/serving-server/src/main/java/com/webank/ai/fate/serving/adapter/dataaccess/TestFilePick.java b/serving-server/src/main/java/com/webank/ai/fate/serving/adapter/dataaccess/TestFilePick.java index 18c7d00a8..683b2111f 100644 --- a/serving-server/src/main/java/com/webank/ai/fate/serving/adapter/dataaccess/TestFilePick.java +++ b/serving-server/src/main/java/com/webank/ai/fate/serving/adapter/dataaccess/TestFilePick.java @@ -56,8 +56,10 @@ public ReturnResult getData(Context context, Map featureIds) { }); } Map fdata = featureMaps.get(featureIds.get(Dict.DEVICE_ID)); - if(fdata != null) { - returnResult.setData(fdata); + + Map clone = (Map) ((HashMap)fdata).clone(); + if(clone != null) { + returnResult.setData(clone); returnResult.setRetcode(InferenceRetCode.OK); } else{ logger.error("cant not find features for {}.", featureIds.get(Dict.DEVICE_ID)); diff --git a/serving-server/src/main/resources/serving-server.properties b/serving-server/src/main/resources/serving-server.properties index f90e32006..5c083b631 100644 --- a/serving-server/src/main/resources/serving-server.properties +++ b/serving-server/src/main/resources/serving-server.properties @@ -1,4 +1,4 @@ -# +1# # Copyright 2019 The FATE Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -28,7 +28,7 @@ port=8000 # external cache redis.ip=127.0.0.1 redis.port=6379 -#redis.password=fate_dev +redis.password=fate_dev #redis.timeout=10 #redis.maxTotal=100 #redis.maxIdle=100 @@ -54,4 +54,6 @@ zk.url=zookeeper://localhost:2181?backup=localhost:2182,localhost:2183 # zk acl #acl.enable=false #acl.username= -#acl.password= \ No newline at end of file +#acl.password= + +#proxy=127.0.0.1:8879