From ffe17144b575b7a9966bf3959a64952bb94b83be Mon Sep 17 00:00:00 2001 From: Chen <cwjghglbdcj@gmail.com> Date: Tue, 10 Mar 2020 21:09:13 +0800 Subject: [PATCH 01/22] update tree model --- .../federatedml/model/HeteroSecureBoost.java | 5 + .../model/HeteroSecureBoostingTreeGuest.java | 95 +++++++++++++----- .../model/HeteroSecureBoostingTreeHost.java | 96 +++++++++++++++---- 3 files changed, 156 insertions(+), 40 deletions(-) diff --git a/federatedml/src/main/java/com/webank/ai/fate/serving/federatedml/model/HeteroSecureBoost.java b/federatedml/src/main/java/com/webank/ai/fate/serving/federatedml/model/HeteroSecureBoost.java index 9c1adebce..657d4b8ab 100644 --- a/federatedml/src/main/java/com/webank/ai/fate/serving/federatedml/model/HeteroSecureBoost.java +++ b/federatedml/src/main/java/com/webank/ai/fate/serving/federatedml/model/HeteroSecureBoost.java @@ -93,6 +93,10 @@ protected int gotoNextLevel(int treeId, int treeNodeId, Map<String, Object> inpu int fid = this.trees.get(treeId).getTree(treeNodeId).getFid(); double splitValue = this.trees.get(treeId).getSplitMaskdict().get(treeNodeId); String fidStr = String.valueOf(fid); + logger.info("treeId {}, treeNodeId {}",treeId, treeNodeId); + logger.info("treenode fid {}",fidStr); + logger.info("input is {}",input); + if (input.containsKey(fidStr)) { if (Double.parseDouble(input.get(fidStr).toString()) <= splitValue + 1e-20) { nextTreeNodeId = this.trees.get(treeId).getTree(treeNodeId).getLeftNodeid(); @@ -100,6 +104,7 @@ protected int gotoNextLevel(int treeId, int treeNodeId, Map<String, Object> inpu nextTreeNodeId = this.trees.get(treeId).getTree(treeNodeId).getRightNodeid(); } } else { + logger.info("go missing dir"); if (this.trees.get(treeId).getMissingDirMaskdict().containsKey(treeNodeId)) { int missingDir = this.trees.get(treeId).getMissingDirMaskdict().get(treeNodeId); if (missingDir == 1) { diff --git a/federatedml/src/main/java/com/webank/ai/fate/serving/federatedml/model/HeteroSecureBoostingTreeGuest.java b/federatedml/src/main/java/com/webank/ai/fate/serving/federatedml/model/HeteroSecureBoostingTreeGuest.java index 6bf6d53fd..7aab333e6 100644 --- a/federatedml/src/main/java/com/webank/ai/fate/serving/federatedml/model/HeteroSecureBoostingTreeGuest.java +++ b/federatedml/src/main/java/com/webank/ai/fate/serving/federatedml/model/HeteroSecureBoostingTreeGuest.java @@ -35,7 +35,9 @@ private double sigmoid(double x) { return 1. / (1. + Math.exp(-x)); } - private Map<String, Object> softmax(double[] weights) { + private boolean fastMode = true; + + private Map<String, Object> softmax(double weights[]) { int n = weights.length; double max = weights[0]; int maxIndex = 0; @@ -99,6 +101,7 @@ private double getTreeLeafWeight(int treeId, int treeNodeId) { } private int traverseTree(int treeId, int treeNodeId, Map<String, Object> input) { + while (!this.isLocateInLeaf(treeId, treeNodeId) && this.getSite(treeId, treeNodeId).equals(this.site)) { treeNodeId = this.gotoNextLevel(treeId, treeNodeId, input); } @@ -106,6 +109,28 @@ private int traverseTree(int treeId, int treeNodeId, Map<String, Object> input) return treeNodeId; } + private int fastTraverseTree(int treeId, int treeNodeId, Map<String, Object> input, Map<String, Object> lookUpTable) { + + while(!this.isLocateInLeaf(treeId, treeNodeId)){ + if(this.getSite(treeId, treeNodeId).equals(this.site)){ + treeNodeId = this.gotoNextLevel(treeId, treeNodeId, input); + } + else{ + Map<String, Boolean> lookUp = (Map<String, Boolean>) lookUpTable.get(String.valueOf(treeId)); + if(lookUp.get(String.valueOf(treeNodeId))){ + treeNodeId = this.trees.get(treeId).getTree(treeNodeId).getLeftNodeid(); + } + else { + treeNodeId = this.trees.get(treeId).getTree(treeNodeId).getRightNodeid(); + } + } + if(logger.isDebugEnabled()) { + logger.info("tree id is {}, tree node is {}", treeId, treeNodeId); + } + } + + return treeNodeId; + } private Map<String, Object> getFinalPredict(double[] weights) { Map<String, Object> ret = new HashMap<String, Object>(8); @@ -121,9 +146,8 @@ private Map<String, Object> getFinalPredict(double[] weights) { sumWeights[i % this.treeDim] += weights[i] * this.learningRate; } - for (int i = 0; i < this.treeDim; i++) { + for (int i = 0; i < this.treeDim; i++) sumWeights[i] += this.initScore.get(i); - } ret = softmax(sumWeights); } else { @@ -139,9 +163,8 @@ private Map<String, Object> getFinalPredict(double[] weights) { @Override public Map<String, Object> handlePredict(Context context, List<Map<String, Object>> inputData, FederatedParams predictParams) { - if(logger.isDebugEnabled()) { - logger.debug("HeteroSecureBoostingTreeGuest FederatedParams {}", predictParams); - } + + logger.info("HeteroSecureBoostingTreeGuest FederatedParams {}", predictParams); Map<String, Object> input = inputData.get(0); HashMap<String, Object> fidValueMapping = new HashMap<String, Object>(8); @@ -155,9 +178,8 @@ public Map<String, Object> handlePredict(Context context, List<Map<String, Objec ++featureHit; } } - if(logger.isDebugEnabled()) { - logger.debug("feature hit rate : {}", 1.0 * featureHit / this.featureNameFidMapping.size()); - } + + logger.info("feature hit rate : {}", 1.0 * featureHit / this.featureNameFidMapping.size()); int[] treeNodeIds = new int[this.treeNum]; double[] weights = new double[this.treeNum]; int communicationRound = 0; @@ -185,26 +207,53 @@ public Map<String, Object> handlePredict(Context context, List<Map<String, Objec predictParams.getData().put(Dict.TREE_LOCATION, treeLocation); + if(logger.isDebugEnabled()) { + logger.info("fast mode is {}", this.fastMode); + } + try { + logger.info("begin to federated"); ReturnResult tempResult = this.getFederatedPredict(context, predictParams, Dict.FEDERATED_INFERENCE_FOR_TREE, false); + Map<String, Object> returnData = tempResult.getData(); - Map<String, Object> afterLocation = tempResult.getData(); - if(logger.isDebugEnabled()) { - logger.debug("after loccation is {}", afterLocation); - } - for (String location : afterLocation.keySet()) { - treeNodeIds[new Integer(location)] = ((Number) afterLocation.get(location)).intValue(); + boolean getNodeRoute = false; + for(Object obj: returnData.values()){ + if(!(obj instanceof Integer)) getNodeRoute = true; // get node position if value is integer + break; } - if (afterLocation == null) { - logger.error("receive predict result of host is null"); - throw new Exception("Null Data"); + if(this.fastMode && getNodeRoute){ + + if(logger.isDebugEnabled()){ + logger.info("running fast mode, look up table is {}",returnData); + } + + for(String treeIdx: treeLocation.keySet()){ + int idx = Integer.valueOf(treeIdx); + int curNodeId = (Integer)treeLocation.get(treeIdx); + int final_node_id = this.fastTraverseTree(idx, curNodeId, fidValueMapping, returnData); + treeNodeIds[idx] = final_node_id; + } + } + else{ + Map<String, Object> afterLocation = tempResult.getData(); + + if(logger.isDebugEnabled()){ + logger.info("after location is {}", afterLocation); + } + + for (String location : afterLocation.keySet()) { + treeNodeIds[new Integer(location)] = ((Number) afterLocation.get(location)).intValue(); + } + if (afterLocation == null) { + logger.info("receive predict result of host is null"); + throw new Exception("Null Data"); + } } } catch (Exception ex) { ex.printStackTrace(); - logger.error("HeteroSecureBoostingTreeGuest handle error",ex); return null; } } @@ -212,12 +261,12 @@ public Map<String, Object> handlePredict(Context context, List<Map<String, Objec for (int i = 0; i < this.treeNum; ++i) { weights[i] = getTreeLeafWeight(i, treeNodeIds[i]); } + if(logger.isDebugEnabled()){ - logger.debug("tree leaf ids is {}", treeNodeIds); - logger.debug("weights is {}", weights); + logger.info("tree leaf ids is {}", treeNodeIds); + logger.info("weights is {}", weights); } - return getFinalPredict(weights); } -} +} \ No newline at end of file diff --git a/federatedml/src/main/java/com/webank/ai/fate/serving/federatedml/model/HeteroSecureBoostingTreeHost.java b/federatedml/src/main/java/com/webank/ai/fate/serving/federatedml/model/HeteroSecureBoostingTreeHost.java index fd498bff5..746f72a0f 100644 --- a/federatedml/src/main/java/com/webank/ai/fate/serving/federatedml/model/HeteroSecureBoostingTreeHost.java +++ b/federatedml/src/main/java/com/webank/ai/fate/serving/federatedml/model/HeteroSecureBoostingTreeHost.java @@ -16,20 +16,23 @@ package com.webank.ai.fate.serving.federatedml.model; +import com.webank.ai.fate.core.mlmodel.buffer.BoostTreeModelParamProto; import com.webank.ai.fate.serving.core.bean.CacheManager; import com.webank.ai.fate.serving.core.bean.Context; import com.webank.ai.fate.serving.core.bean.Dict; import com.webank.ai.fate.serving.core.bean.FederatedParams; +import com.webank.ai.fate.core.mlmodel.buffer.BoostTreeModelParamProto.DecisionTreeModelParam; +import com.webank.ai.fate.core.mlmodel.buffer.BoostTreeModelParamProto.NodeParam; import java.util.HashMap; import java.util.List; import java.util.Map; public class HeteroSecureBoostingTreeHost extends HeteroSecureBoost { - private final String site = "host"; - // need to change - private final String modelId = "HeteroSecureBoostingTreeHost"; + private final String site = "host"; + private final String modelId = "HeteroSecureBoostingTreeHost"; // need to change + private boolean fastMode = true; // DefaultCacheManager cacheManager = BaseContext.applicationContext.getBean(DefaultCacheManager.class); @@ -44,7 +47,7 @@ Map<String, Double> forward(List<Map<String, Object>> inputDatas) { ++featureHit; } } - logger.info("feature hit rate : {}", 1.0 * featureHit / this.featureNameFidMapping.size()); + LOGGER.info("feature hit rate : {}", 1.0 * featureHit / this.featureNameFidMapping.size()); } */ @@ -60,7 +63,6 @@ private int traverseTree(int treeId, int treeNodeId, Map<String, Object> input) public void saveData(Context context, String tag, Map<String, Object> data) { CacheManager.getInstance().store(context, tag, data); - } public Map<String, Object> getData(Context context, String tag) { @@ -70,11 +72,64 @@ public Map<String, Object> getData(Context context, String tag) { return data; } + public Map<String, Object> extractHostNodeRoute(Map<String, Object> input){ + + // <tree_idx, < node_idx, direction>> + + logger.info("running extractHostNodeRoute"); + + Map<String, Object> result = new HashMap<String, Object>(8); + for(int i=0;i<this.treeNum;i++){ + + DecisionTreeModelParam treeParam = this.trees.get(i); + List<NodeParam> nodes = treeParam.getTreeList(); + Map<String, Boolean> treeRoute = new HashMap<String, Boolean>(8); + + for(int j=0;j<nodes.size();j++){ + + + NodeParam node = nodes.get(j); + + if(!this.getSite(i, j).equals(this.site)){ + continue; + } + + int fid = this.trees.get(i).getTree(j).getFid(); + double splitValue = this.trees.get(i).getSplitMaskdict().get(j); + + boolean direction = false; // false go right, true go left + + if(logger.isDebugEnabled()){ + logger.info("i is {}, j is {}",i,j); + logger.info("best fid is {}", fid); + logger.info("best split val is {}", splitValue); + } + + if (input.containsKey(Integer.toString(fid))){ + Object featVal = input.get(Integer.toString(fid)); + direction = Double.parseDouble(featVal.toString()) <= splitValue + 1e-20; + } + else { + if (this.trees.get(i).getMissingDirMaskdict().containsKey(j)) { + int missingDir = this.trees.get(i).getMissingDirMaskdict().get(j); + direction = (missingDir == 1); + } + } + treeRoute.put(Integer.toString(j),direction); + } + result.put(Integer.toString(i),treeRoute); + } + if(logger.isDebugEnabled()){ + logger.info("show return route:{}",result); + } + return result; + } + @Override public Map<String, Object> handlePredict(Context context, List<Map<String, Object>> inputData, FederatedParams predictParams) { - if(logger.isDebugEnabled()) { - logger.debug("HeteroSecureBoostingTreeHost FederatedParams {}", predictParams); - } + + logger.info("HeteroSecureBoostingTreeHost FederatedParams {}", predictParams); + Map<String, Object> input = inputData.get(0); String tag = predictParams.getCaseId() + "." + this.componentName + "." + Dict.INPUT_DATA; @@ -92,17 +147,24 @@ public Map<String, Object> handlePredict(Context context, List<Map<String, Objec return ret; } - public Map<String, Object> predictSingleRound(Context context, Map<String, Object> interactiveData, FederatedParams predictParams) { + String tag = predictParams.getCaseId() + "." + this.componentName + "." + Dict.INPUT_DATA; Map<String, Object> input = this.getData(context, tag); - Map<String, Object> ret = new HashMap<String, Object>(8); - for (String treeIdx : interactiveData.keySet()) { - int idx = Integer.valueOf(treeIdx); - int nodeId = this.traverseTree(idx, (Integer) interactiveData.get(treeIdx), input); - ret.put(treeIdx, nodeId); - } - return ret; + if(!this.fastMode){ + Map<String, Object> ret = new HashMap<String, Object>(8); + for (String treeIdx : interactiveData.keySet()) { + int idx = Integer.valueOf(treeIdx); + int nodeId = this.traverseTree(idx, (Integer) interactiveData.get(treeIdx), input); + ret.put(treeIdx, nodeId); + } + return ret; + } + else { + // if use fast mode, return data is the look up table: <tree_idx, < node_idx, direction>> + Map<String, Object> ret = this.extractHostNodeRoute(input); + return ret; + } } -} +} \ No newline at end of file From 53a0572f8dda9f77d79e8044d1a064684369ec55 Mon Sep 17 00:00:00 2001 From: kaideng <forgive_dengkai@163.com> Date: Thu, 27 Feb 2020 11:49:05 +0800 Subject: [PATCH 02/22] fix registe too many times Signed-off-by: kaideng <forgive_dengkai@163.com> --- .../com/webank/ai/fate/serving/manager/DefaultModelManager.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/serving-server/src/main/java/com/webank/ai/fate/serving/manager/DefaultModelManager.java b/serving-server/src/main/java/com/webank/ai/fate/serving/manager/DefaultModelManager.java index 194c3b6be..0f156d563 100644 --- a/serving-server/src/main/java/com/webank/ai/fate/serving/manager/DefaultModelManager.java +++ b/serving-server/src/main/java/com/webank/ai/fate/serving/manager/DefaultModelManager.java @@ -134,8 +134,8 @@ public ReturnResult publishLoadModel(Context context, FederatedParty federatedPa logger.debug("transform key {} to md5key {}", key, keyMd5); } zookeeperRegistry.addDynamicEnvironment(keyMd5); - zookeeperRegistry.register(FateServer.serviceSets); }); + zookeeperRegistry.register(FateServer.serviceSets); } } if (logger.isDebugEnabled()) { From 1a8f79ea245989e514a745300a72b075c9c01ae2 Mon Sep 17 00:00:00 2001 From: utu <jinhuitu@gmail.com> Date: Thu, 27 Feb 2020 17:31:22 +0800 Subject: [PATCH 03/22] modify: uncomment redis.password. Signed-off-by: kaideng <forgive_dengkai@163.com> --- serving-server/src/main/resources/serving-server.properties | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/serving-server/src/main/resources/serving-server.properties b/serving-server/src/main/resources/serving-server.properties index f90e32006..628f44401 100644 --- a/serving-server/src/main/resources/serving-server.properties +++ b/serving-server/src/main/resources/serving-server.properties @@ -1,4 +1,4 @@ -# +1# # Copyright 2019 The FATE Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -28,7 +28,7 @@ port=8000 # external cache redis.ip=127.0.0.1 redis.port=6379 -#redis.password=fate_dev +redis.password=fate_dev #redis.timeout=10 #redis.maxTotal=100 #redis.maxIdle=100 From bc8b0cbe13ce3aa9c6a41a4d7a7ecfd7bb1dff2b Mon Sep 17 00:00:00 2001 From: utu <jinhuitu@gmail.com> Date: Thu, 27 Feb 2020 18:17:41 +0800 Subject: [PATCH 04/22] modify: add proxy config. Signed-off-by: kaideng <forgive_dengkai@163.com> --- serving-server/src/main/resources/serving-server.properties | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/serving-server/src/main/resources/serving-server.properties b/serving-server/src/main/resources/serving-server.properties index 628f44401..5c083b631 100644 --- a/serving-server/src/main/resources/serving-server.properties +++ b/serving-server/src/main/resources/serving-server.properties @@ -54,4 +54,6 @@ zk.url=zookeeper://localhost:2181?backup=localhost:2182,localhost:2183 # zk acl #acl.enable=false #acl.username= -#acl.password= \ No newline at end of file +#acl.password= + +#proxy=127.0.0.1:8879 From cd10ccf595a5af846b2c43a28056cc8b022b0696 Mon Sep 17 00:00:00 2001 From: kaideng <forgive_dengkai@163.com> Date: Sat, 29 Feb 2020 20:00:39 +0800 Subject: [PATCH 05/22] change use_zk_router default value to true Signed-off-by: kaideng <forgive_dengkai@163.com> --- .../com/webank/ai/fate/serving/federatedml/model/BaseModel.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/federatedml/src/main/java/com/webank/ai/fate/serving/federatedml/model/BaseModel.java b/federatedml/src/main/java/com/webank/ai/fate/serving/federatedml/model/BaseModel.java index d38e9b5a6..e1079eb8e 100644 --- a/federatedml/src/main/java/com/webank/ai/fate/serving/federatedml/model/BaseModel.java +++ b/federatedml/src/main/java/com/webank/ai/fate/serving/federatedml/model/BaseModel.java @@ -185,7 +185,7 @@ protected ReturnResult getFederatedPredictFromRemote(Context context, FederatedP packetBuilder.setAuth(authBuilder.build()); GrpcConnectionPool grpcConnectionPool = GrpcConnectionPool.getPool(); - String routerByZkString = Configuration.getProperty(Dict.USE_ZK_ROUTER, Dict.FALSE); + String routerByZkString = Configuration.getProperty(Dict.USE_ZK_ROUTER, "true"); boolean routerByzk = Boolean.valueOf(routerByZkString); String address = null; if (!routerByzk) { From 236b5c04584db002e4fee57af29e3d760ec87092 Mon Sep 17 00:00:00 2001 From: v_dylanxu <136539068@qq.com> Date: Fri, 6 Mar 2020 20:16:21 +0800 Subject: [PATCH 06/22] fix restart release process bug Signed-off-by: v_dylanxu <136539068@qq.com> Signed-off-by: kaideng <forgive_dengkai@163.com> --- serving-proxy/bin/service.sh | 2 +- .../serving/proxy/bootstrap/Bootstrap.java | 52 +++++++++---------- .../proxy/rpc/grpc/InterGrpcServer.java | 4 ++ .../proxy/rpc/grpc/IntraGrpcServer.java | 4 ++ serving-server/bin/service.sh | 2 +- .../webank/ai/fate/serving/ServingServer.java | 23 ++++---- 6 files changed, 43 insertions(+), 44 deletions(-) diff --git a/serving-proxy/bin/service.sh b/serving-proxy/bin/service.sh index a687a35a2..bcffd7b26 100644 --- a/serving-proxy/bin/service.sh +++ b/serving-proxy/bin/service.sh @@ -42,7 +42,7 @@ case "$1" in restart) stop $module - sleep 0.5 + sleep 4 start $module status $module ;; diff --git a/serving-proxy/src/main/java/com/webank/ai/fate/serving/proxy/bootstrap/Bootstrap.java b/serving-proxy/src/main/java/com/webank/ai/fate/serving/proxy/bootstrap/Bootstrap.java index f277a4641..24c720af4 100644 --- a/serving-proxy/src/main/java/com/webank/ai/fate/serving/proxy/bootstrap/Bootstrap.java +++ b/serving-proxy/src/main/java/com/webank/ai/fate/serving/proxy/bootstrap/Bootstrap.java @@ -1,10 +1,10 @@ package com.webank.ai.fate.serving.proxy.bootstrap; -import com.google.common.collect.Sets; -import com.webank.ai.fate.register.url.URL; import com.webank.ai.fate.register.zookeeper.ZookeeperRegistry; import com.webank.ai.fate.serving.core.bean.Dict; import com.webank.ai.fate.serving.core.rpc.core.AbstractServiceAdaptor; +import com.webank.ai.fate.serving.proxy.rpc.grpc.InterGrpcServer; +import com.webank.ai.fate.serving.proxy.rpc.grpc.IntraGrpcServer; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.springframework.boot.SpringApplication; @@ -14,8 +14,6 @@ import org.springframework.context.annotation.PropertySource; import org.springframework.scheduling.annotation.EnableScheduling; -import java.util.Set; - /** * @Description TODO * @Author @@ -36,38 +34,36 @@ public void start(String[] args) { public void stop() { logger.info("try to shutdown server ==============!!!!!!!!!!!!!!!!!!!!!"); + AbstractServiceAdaptor.isOpen = false; - int tryNum = 0; - /** - * 3秒 - */ - while (AbstractServiceAdaptor.requestInHandle.get() > 0 && tryNum < 30) { - logger.info("try to shundown,try count {}, remain {}", tryNum, AbstractServiceAdaptor.requestInHandle.get()); - try { - Thread.sleep(100); - } catch (InterruptedException e) { - e.printStackTrace(); + int retryCount = 0; + long requestInProcess = AbstractServiceAdaptor.requestInHandle.get(); + do { + logger.info("try to stop server,there is {} request in process,try count {}", requestInProcess, retryCount + 1); + if (requestInProcess > 0 && retryCount < 30) { + try { + Thread.sleep(100); + } catch (InterruptedException e) { + e.printStackTrace(); + } + retryCount++; + requestInProcess = AbstractServiceAdaptor.requestInHandle.get(); + } else { + break; } - } + } while (requestInProcess > 0 && retryCount < 30); boolean useZkRouter = Boolean.parseBoolean(applicationContext.getEnvironment().getProperty(Dict.USE_ZK_ROUTER, "false")); if (useZkRouter) { ZookeeperRegistry zookeeperRegistry = applicationContext.getBean(ZookeeperRegistry.class); - Set<URL> registered = zookeeperRegistry.getRegistered(); - Set<URL> urls = Sets.newHashSet(); - urls.addAll(registered); - urls.forEach(url -> { - logger.info("unregister {}", url); - zookeeperRegistry.unregister(url); - }); - zookeeperRegistry.destroy(); - try { - Thread.sleep(5000); - } catch (InterruptedException e) { - e.printStackTrace(); - } } + + IntraGrpcServer intraGrpcServer = applicationContext.getBean(IntraGrpcServer.class); + intraGrpcServer.getServer().shutdown(); + + InterGrpcServer interGrpcServer = applicationContext.getBean(InterGrpcServer.class); + interGrpcServer.getServer().shutdown(); } public static void main(String[] args) { diff --git a/serving-proxy/src/main/java/com/webank/ai/fate/serving/proxy/rpc/grpc/InterGrpcServer.java b/serving-proxy/src/main/java/com/webank/ai/fate/serving/proxy/rpc/grpc/InterGrpcServer.java index 9b2d596e4..147f2800a 100644 --- a/serving-proxy/src/main/java/com/webank/ai/fate/serving/proxy/rpc/grpc/InterGrpcServer.java +++ b/serving-proxy/src/main/java/com/webank/ai/fate/serving/proxy/rpc/grpc/InterGrpcServer.java @@ -27,6 +27,10 @@ public class InterGrpcServer implements InitializingBean { Server server ; + public Server getServer() { + return server; + } + @Autowired InterRequestHandler interRequestHandler; diff --git a/serving-proxy/src/main/java/com/webank/ai/fate/serving/proxy/rpc/grpc/IntraGrpcServer.java b/serving-proxy/src/main/java/com/webank/ai/fate/serving/proxy/rpc/grpc/IntraGrpcServer.java index 61c30dc0a..9713546b6 100644 --- a/serving-proxy/src/main/java/com/webank/ai/fate/serving/proxy/rpc/grpc/IntraGrpcServer.java +++ b/serving-proxy/src/main/java/com/webank/ai/fate/serving/proxy/rpc/grpc/IntraGrpcServer.java @@ -33,6 +33,10 @@ public class IntraGrpcServer implements InitializingBean { Server server; + public Server getServer() { + return server; + } + @Override public void afterPropertiesSet() throws Exception { FateServerBuilder serverBuilder = (FateServerBuilder) ServerBuilder.forPort(port); diff --git a/serving-server/bin/service.sh b/serving-server/bin/service.sh index eae2d3d3a..45a70cb02 100644 --- a/serving-server/bin/service.sh +++ b/serving-server/bin/service.sh @@ -40,7 +40,7 @@ case "$1" in restart) stop $module - sleep 0.5 + sleep 4 start $module status $module ;; diff --git a/serving-server/src/main/java/com/webank/ai/fate/serving/ServingServer.java b/serving-server/src/main/java/com/webank/ai/fate/serving/ServingServer.java index 5dc0c2a95..cc30ae669 100644 --- a/serving-server/src/main/java/com/webank/ai/fate/serving/ServingServer.java +++ b/serving-server/src/main/java/com/webank/ai/fate/serving/ServingServer.java @@ -174,25 +174,14 @@ public void run() { private void stop() { if (server != null) { - if (useRegister) { - ZookeeperRegistry zookeeperRegistry = applicationContext.getBean(ZookeeperRegistry.class); - Set<URL> registered = zookeeperRegistry.getRegistered(); - Set<URL> urls = Sets.newHashSet(); - urls.addAll(registered); - urls.forEach(url -> { - logger.info("unregister {}", url); - zookeeperRegistry.unregister(url); - }); - zookeeperRegistry.destroy(); - } + logger.info("try to shutdown server ==============!!!!!!!!!!!!!!!!!!!!!"); + int retryCount = 0; long requestInProcess = BaseContext.requestInProcess.get(); do { - logger.info("try to stop server,there is {} request in process,try count {}", requestInProcess, retryCount + 1); if (requestInProcess > 0 && retryCount < 30) { try { - Thread.sleep(100); } catch (InterruptedException e) { e.printStackTrace(); @@ -203,7 +192,13 @@ private void stop() { break; } - } while (requestInProcess > 0 && retryCount < 3); + } while (requestInProcess > 0 && retryCount < 30); + + if (useRegister) { + ZookeeperRegistry zookeeperRegistry = applicationContext.getBean(ZookeeperRegistry.class); + zookeeperRegistry.destroy(); + } + server.shutdown(); } } From c543e2ea34bae1c7a7acf23c2035db1eee8f4081 Mon Sep 17 00:00:00 2001 From: v_dylanxu <136539068@qq.com> Date: Fri, 6 Mar 2020 21:12:59 +0800 Subject: [PATCH 07/22] fix restart release process bug Signed-off-by: v_dylanxu <136539068@qq.com> Signed-off-by: kaideng <forgive_dengkai@163.com> --- .../serving/proxy/bootstrap/Bootstrap.java | 28 ++++++++----------- .../webank/ai/fate/serving/ServingServer.java | 27 ++++++++---------- 2 files changed, 23 insertions(+), 32 deletions(-) diff --git a/serving-proxy/src/main/java/com/webank/ai/fate/serving/proxy/bootstrap/Bootstrap.java b/serving-proxy/src/main/java/com/webank/ai/fate/serving/proxy/bootstrap/Bootstrap.java index 24c720af4..92d9ed861 100644 --- a/serving-proxy/src/main/java/com/webank/ai/fate/serving/proxy/bootstrap/Bootstrap.java +++ b/serving-proxy/src/main/java/com/webank/ai/fate/serving/proxy/bootstrap/Bootstrap.java @@ -35,30 +35,26 @@ public void start(String[] args) { public void stop() { logger.info("try to shutdown server ==============!!!!!!!!!!!!!!!!!!!!!"); + boolean useZkRouter = Boolean.parseBoolean(applicationContext.getEnvironment().getProperty(Dict.USE_ZK_ROUTER, "false")); + if (useZkRouter) { + ZookeeperRegistry zookeeperRegistry = applicationContext.getBean(ZookeeperRegistry.class); + zookeeperRegistry.destroy(); + } + AbstractServiceAdaptor.isOpen = false; int retryCount = 0; long requestInProcess = AbstractServiceAdaptor.requestInHandle.get(); do { logger.info("try to stop server,there is {} request in process,try count {}", requestInProcess, retryCount + 1); - if (requestInProcess > 0 && retryCount < 30) { - try { - Thread.sleep(100); - } catch (InterruptedException e) { - e.printStackTrace(); - } - retryCount++; - requestInProcess = AbstractServiceAdaptor.requestInHandle.get(); - } else { - break; + try { + Thread.sleep(100); + } catch (InterruptedException e) { + e.printStackTrace(); } + retryCount++; + requestInProcess = AbstractServiceAdaptor.requestInHandle.get(); } while (requestInProcess > 0 && retryCount < 30); - boolean useZkRouter = Boolean.parseBoolean(applicationContext.getEnvironment().getProperty(Dict.USE_ZK_ROUTER, "false")); - if (useZkRouter) { - ZookeeperRegistry zookeeperRegistry = applicationContext.getBean(ZookeeperRegistry.class); - zookeeperRegistry.destroy(); - } - IntraGrpcServer intraGrpcServer = applicationContext.getBean(IntraGrpcServer.class); intraGrpcServer.getServer().shutdown(); diff --git a/serving-server/src/main/java/com/webank/ai/fate/serving/ServingServer.java b/serving-server/src/main/java/com/webank/ai/fate/serving/ServingServer.java index cc30ae669..66b0808a9 100644 --- a/serving-server/src/main/java/com/webank/ai/fate/serving/ServingServer.java +++ b/serving-server/src/main/java/com/webank/ai/fate/serving/ServingServer.java @@ -176,29 +176,24 @@ private void stop() { if (server != null) { logger.info("try to shutdown server ==============!!!!!!!!!!!!!!!!!!!!!"); + if (useRegister) { + ZookeeperRegistry zookeeperRegistry = applicationContext.getBean(ZookeeperRegistry.class); + zookeeperRegistry.destroy(); + } + int retryCount = 0; long requestInProcess = BaseContext.requestInProcess.get(); do { logger.info("try to stop server,there is {} request in process,try count {}", requestInProcess, retryCount + 1); - if (requestInProcess > 0 && retryCount < 30) { - try { - Thread.sleep(100); - } catch (InterruptedException e) { - e.printStackTrace(); - } - retryCount++; - requestInProcess = BaseContext.requestInProcess.get(); - } else { - break; + try { + Thread.sleep(100); + } catch (InterruptedException e) { + e.printStackTrace(); } - + retryCount++; + requestInProcess = BaseContext.requestInProcess.get(); } while (requestInProcess > 0 && retryCount < 30); - if (useRegister) { - ZookeeperRegistry zookeeperRegistry = applicationContext.getBean(ZookeeperRegistry.class); - zookeeperRegistry.destroy(); - } - server.shutdown(); } } From b2d08dd8e86852cd0c73d6ccc0c2bfa9fe23e8fe Mon Sep 17 00:00:00 2001 From: FanTao <289765648@qq.com> Date: Sun, 22 Mar 2020 17:37:32 +0800 Subject: [PATCH 08/22] Update README.md Signed-off-by: kaideng <forgive_dengkai@163.com> --- README.md | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 364ab383d..dcc3a1e6d 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,8 @@ # FATE-Serving -[](https://opensource.org/licenses/Apache-2.0) [](https://checkstyle.sourceforge.io/google_style.html) [](https://github.com/mmyjona/FATE-Serving/pulls) [](https://checkstyle.sourceforge.io/google_style.html) +[](https://opensource.org/licenses/Apache-2.0) +[](https://checkstyle.sourceforge.io/google_style.html) +[](https://checkstyle.sourceforge.io/google_style.html) ## Introduction From 5cfe9a0f41c659a88dd65b3ea6a4086a971a6461 Mon Sep 17 00:00:00 2001 From: v_dylanxu <136539068@qq.com> Date: Wed, 8 Apr 2020 11:18:09 +0800 Subject: [PATCH 09/22] remove empty registry cache Signed-off-by: v_dylanxu <136539068@qq.com> Signed-off-by: kaideng <forgive_dengkai@163.com> --- .../ai/fate/register/common/AbstractRegistry.java | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/register/src/main/java/com/webank/ai/fate/register/common/AbstractRegistry.java b/register/src/main/java/com/webank/ai/fate/register/common/AbstractRegistry.java index 968029291..f7cae18e7 100644 --- a/register/src/main/java/com/webank/ai/fate/register/common/AbstractRegistry.java +++ b/register/src/main/java/com/webank/ai/fate/register/common/AbstractRegistry.java @@ -67,7 +67,7 @@ public AbstractRegistry(URL url) { setUrl(url); // Start file save timer syncSaveFile = url.getParameter(REGISTRY_FILESAVE_SYNC_KEY, false); - String filename = url.getParameter(FILE_KEY, System.getProperty(USER_HOME) + "/.fate/fate-registry-" + url.getParameter(PROJECT_KEY) + "-" + url.getAddress() + ".cache"); + String filename = url.getParameter(FILE_KEY, System.getProperty(USER_HOME) + "/.fate/fate-registry-" + url.getParameter(PROJECT_KEY) + "-" + url.getHost() + "-" + url.getPort() + ".cache"); File file = null; if (StringUtils.isNotEmpty(filename)) { file = new File(filename); @@ -441,7 +441,9 @@ private void saveProperties(URL url) { if (buf.length() > 0) { buf.append(URL_SEPARATOR); } - buf.append(u.toFullString()); + if (!EMPTY_PROTOCOL.equals(u.getProtocol())) { + buf.append(u.toFullString()); + } } } } @@ -449,7 +451,12 @@ private void saveProperties(URL url) { if (logger.isDebugEnabled()) { logger.debug("properties set property key {} value {}", url.getServiceKey(), buf.toString()); } - properties.setProperty(url.getServiceKey(), buf.toString()); + + if (buf.length() == 0) { + properties.remove(url.getServiceKey()); + } else { + properties.setProperty(url.getServiceKey(), buf.toString()); + } long version = lastCacheChanged.incrementAndGet(); if (syncSaveFile) { doSaveProperties(version); From f2b59ccd029d24711623c66bdb25709836c84070 Mon Sep 17 00:00:00 2001 From: kaideng <forgive_dengkai@163.com> Date: Thu, 16 Apr 2020 14:38:15 +0800 Subject: [PATCH 10/22] fix common pre process Signed-off-by: kaideng <forgive_dengkai@163.com> --- .../adapter/processing/CommonPreProcessing.java | 16 ++++------------ 1 file changed, 4 insertions(+), 12 deletions(-) diff --git a/serving-server/src/main/java/com/webank/ai/fate/serving/adapter/processing/CommonPreProcessing.java b/serving-server/src/main/java/com/webank/ai/fate/serving/adapter/processing/CommonPreProcessing.java index be2a0505e..be3bb99c7 100644 --- a/serving-server/src/main/java/com/webank/ai/fate/serving/adapter/processing/CommonPreProcessing.java +++ b/serving-server/src/main/java/com/webank/ai/fate/serving/adapter/processing/CommonPreProcessing.java @@ -1,6 +1,7 @@ package com.webank.ai.fate.serving.adapter.processing; +import com.alibaba.fastjson.JSON; import com.webank.ai.fate.serving.bean.PreProcessingResult; import com.webank.ai.fate.serving.core.bean.Context; import jdk.nashorn.internal.runtime.ParserException; @@ -10,22 +11,13 @@ import java.util.Map; public class CommonPreProcessing implements PreProcessing { + @Override public PreProcessingResult getResult(Context context , String paras) { PreProcessingResult preProcessingResult = new PreProcessingResult(); - preProcessingResult.setProcessingResult(preProcessing(paras)); - Map<String, Object> featureIds = new HashMap<>(); - JSONObject paraObj = new JSONObject(paras); - preProcessingResult.setFeatureIds(featureIds); + preProcessingResult.setProcessingResult( JSON.parseObject(paras, HashMap.class)); + preProcessingResult.setFeatureIds(preProcessingResult.getProcessingResult()); return preProcessingResult; } - private Map<String, Object> preProcessing(String paras) throws ClassCastException, ParserException { - Map<String, Object> feature = new HashMap<>(); - - return feature; - } - - public static void main(String[] args){ - } } From d00266f17c79570971299031da983d84f010b08d Mon Sep 17 00:00:00 2001 From: Chen <cwjghglbdcj@gmail.com> Date: Tue, 21 Apr 2020 11:13:00 +0800 Subject: [PATCH 11/22] fix a small bug --- .../federatedml/model/HeteroSecureBoostingTreeHost.java | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/federatedml/src/main/java/com/webank/ai/fate/serving/federatedml/model/HeteroSecureBoostingTreeHost.java b/federatedml/src/main/java/com/webank/ai/fate/serving/federatedml/model/HeteroSecureBoostingTreeHost.java index 746f72a0f..495815b89 100644 --- a/federatedml/src/main/java/com/webank/ai/fate/serving/federatedml/model/HeteroSecureBoostingTreeHost.java +++ b/federatedml/src/main/java/com/webank/ai/fate/serving/federatedml/model/HeteroSecureBoostingTreeHost.java @@ -112,7 +112,7 @@ public Map<String, Object> extractHostNodeRoute(Map<String, Object> input){ else { if (this.trees.get(i).getMissingDirMaskdict().containsKey(j)) { int missingDir = this.trees.get(i).getMissingDirMaskdict().get(j); - direction = (missingDir == 1); + direction = (missingDir != 1); } } treeRoute.put(Integer.toString(j),direction); @@ -167,4 +167,4 @@ public Map<String, Object> predictSingleRound(Context context, Map<String, Objec return ret; } } -} \ No newline at end of file +} From 837c7864b97f3b1c9d9c7e36794421af3b69fae1 Mon Sep 17 00:00:00 2001 From: v_dylanxu <136539068@qq.com> Date: Wed, 13 May 2020 22:30:13 +0800 Subject: [PATCH 12/22] repair Signed-off-by: v_dylanxu <136539068@qq.com> --- .../ai/fate/serving/adapter/dataaccess/TestFilePick.java | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/serving-server/src/main/java/com/webank/ai/fate/serving/adapter/dataaccess/TestFilePick.java b/serving-server/src/main/java/com/webank/ai/fate/serving/adapter/dataaccess/TestFilePick.java index 18c7d00a8..683b2111f 100644 --- a/serving-server/src/main/java/com/webank/ai/fate/serving/adapter/dataaccess/TestFilePick.java +++ b/serving-server/src/main/java/com/webank/ai/fate/serving/adapter/dataaccess/TestFilePick.java @@ -56,8 +56,10 @@ public ReturnResult getData(Context context, Map<String, Object> featureIds) { }); } Map<String, Object> fdata = featureMaps.get(featureIds.get(Dict.DEVICE_ID)); - if(fdata != null) { - returnResult.setData(fdata); + + Map clone = (Map) ((HashMap)fdata).clone(); + if(clone != null) { + returnResult.setData(clone); returnResult.setRetcode(InferenceRetCode.OK); } else{ logger.error("cant not find features for {}.", featureIds.get(Dict.DEVICE_ID)); From 11efde97ff284fbc44297ac0d26feacf375f0750 Mon Sep 17 00:00:00 2001 From: kaideng <forgive_dengkai@163.com> Date: Tue, 19 May 2020 16:02:36 +0800 Subject: [PATCH 13/22] change version to 1.3 Signed-off-by: kaideng <forgive_dengkai@163.com> --- pom.xml | 4 ++-- serving-proxy/bin/service.sh | 2 +- serving-server/bin/service.sh | 4 ++-- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/pom.xml b/pom.xml index 2ffceb41d..c954d1dcd 100644 --- a/pom.xml +++ b/pom.xml @@ -35,7 +35,7 @@ </modules> <properties> - <fate.version>1.2.0</fate.version> + <fate.version>1.3.0</fate.version> <fate.core.version>0.3</fate.core.version> <java.version>1.8</java.version> <project.build.sourceEncoding>UTF-8</project.build.sourceEncoding> @@ -554,4 +554,4 @@ </plugins> </build> -</project> \ No newline at end of file +</project> diff --git a/serving-proxy/bin/service.sh b/serving-proxy/bin/service.sh index bcffd7b26..60c5ec3c6 100644 --- a/serving-proxy/bin/service.sh +++ b/serving-proxy/bin/service.sh @@ -24,7 +24,7 @@ configpath=$(cd $basepath/conf;pwd) module=serving-proxy main_class=com.webank.ai.fate.serving.proxy.bootstrap.Bootstrap -module_version=1.2.0 +module_version=1.3.0 case "$1" in diff --git a/serving-server/bin/service.sh b/serving-server/bin/service.sh index 45a70cb02..f14a4ac74 100644 --- a/serving-server/bin/service.sh +++ b/serving-server/bin/service.sh @@ -22,7 +22,7 @@ set -e source ./bin/common.sh module=serving-server main_class=com.webank.ai.fate.serving.ServingServer -module_version=1.2.0 +module_version=1.3.0 case "$1" in @@ -47,4 +47,4 @@ case "$1" in *) echo "usage: $0 {start|stop|status|restart}" exit 1 -esac \ No newline at end of file +esac From 0739090c677b1d0e4c1330787739865ea4acff0f Mon Sep 17 00:00:00 2001 From: v_dylanxu <136539068@qq.com> Date: Wed, 20 May 2020 19:33:06 +0800 Subject: [PATCH 14/22] update release note Signed-off-by: v_dylanxu <136539068@qq.com> --- RELEASE.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/RELEASE.md b/RELEASE.md index 1041a134c..29cb428ff 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -1,3 +1,7 @@ +# Release 1.3.0 +## Major Features and Improvements +* Hetero Secureboosting communication optimization: communication round is reduced to 1 by letting the host send a pre-computed host node route, which is used for inferencing, to the guest. + # Release 1.2.0 ## Major Features and Improvements * Replace serving-router with a brand new service called serving-proxy, which supports authentication and inference request with HTTP or gRPC From 20600bc4b6d26070de1a724f4949939dc238b60a Mon Sep 17 00:00:00 2001 From: kaideng <forgive_dengkai@163.com> Date: Wed, 20 May 2020 20:31:25 +0800 Subject: [PATCH 15/22] rm log Signed-off-by: kaideng <forgive_dengkai@163.com> --- .../ai/fate/serving/federatedml/model/HeteroSecureBoost.java | 3 --- 1 file changed, 3 deletions(-) diff --git a/federatedml/src/main/java/com/webank/ai/fate/serving/federatedml/model/HeteroSecureBoost.java b/federatedml/src/main/java/com/webank/ai/fate/serving/federatedml/model/HeteroSecureBoost.java index 657d4b8ab..cbe5b3659 100644 --- a/federatedml/src/main/java/com/webank/ai/fate/serving/federatedml/model/HeteroSecureBoost.java +++ b/federatedml/src/main/java/com/webank/ai/fate/serving/federatedml/model/HeteroSecureBoost.java @@ -93,9 +93,6 @@ protected int gotoNextLevel(int treeId, int treeNodeId, Map<String, Object> inpu int fid = this.trees.get(treeId).getTree(treeNodeId).getFid(); double splitValue = this.trees.get(treeId).getSplitMaskdict().get(treeNodeId); String fidStr = String.valueOf(fid); - logger.info("treeId {}, treeNodeId {}",treeId, treeNodeId); - logger.info("treenode fid {}",fidStr); - logger.info("input is {}",input); if (input.containsKey(fidStr)) { if (Double.parseDouble(input.get(fidStr).toString()) <= splitValue + 1e-20) { From 206d3b2236c7f686d657ee7b305e882bce9425c5 Mon Sep 17 00:00:00 2001 From: mgqa34 <mgq3374541@163.com> Date: Fri, 22 May 2020 12:53:56 +0800 Subject: [PATCH 16/22] fix sparse data bug Signed-off-by: mgqa34 <mgq3374541@163.com> --- .../ai/fate/serving/core/bean/Dict.java | 2 ++ .../serving/federatedml/model/DataIO.java | 32 ++++++++++++++++--- .../serving/federatedml/model/Imputer.java | 27 ++++++++-------- 3 files changed, 43 insertions(+), 18 deletions(-) diff --git a/fate-serving-core/src/main/java/com/webank/ai/fate/serving/core/bean/Dict.java b/fate-serving-core/src/main/java/com/webank/ai/fate/serving/core/bean/Dict.java index c6ae81f74..6c778a1e1 100644 --- a/fate-serving-core/src/main/java/com/webank/ai/fate/serving/core/bean/Dict.java +++ b/fate-serving-core/src/main/java/com/webank/ai/fate/serving/core/bean/Dict.java @@ -100,6 +100,8 @@ public class Dict { public static final String INPUT_DATA_HIT_RATE = "inputDataHitRate"; public static final String GUEST_MODEL_WEIGHT_HIT_RATE = "guestModelWeightHitRate"; public static final String GUEST_INPUT_DATA_HIT_RATE = "guestInputDataHitRate"; + public static final String TAG_INPUT_FORMAT = "tag"; + public static final String SPARSE_INPUT_FORMAT = "sparse"; public static final String MIN_MAX_SCALE = "min_max_scale"; public static final String STANDARD_SCALE = "standard_scale"; public static final String DSL_COMPONENTS = "components"; diff --git a/federatedml/src/main/java/com/webank/ai/fate/serving/federatedml/model/DataIO.java b/federatedml/src/main/java/com/webank/ai/fate/serving/federatedml/model/DataIO.java index 07d20bd4c..6c2164472 100644 --- a/federatedml/src/main/java/com/webank/ai/fate/serving/federatedml/model/DataIO.java +++ b/federatedml/src/main/java/com/webank/ai/fate/serving/federatedml/model/DataIO.java @@ -20,11 +20,13 @@ import com.webank.ai.fate.core.mlmodel.buffer.DataIOMetaProto.DataIOMeta; import com.webank.ai.fate.core.mlmodel.buffer.DataIOParamProto.DataIOParam; import com.webank.ai.fate.serving.core.bean.Context; +import com.webank.ai.fate.serving.core.bean.Dict; import com.webank.ai.fate.serving.core.bean.FederatedParams; import com.webank.ai.fate.serving.core.bean.StatusCode; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import java.util.HashMap; import java.util.List; import java.util.Map; @@ -32,6 +34,8 @@ public class DataIO extends BaseModel { private static final Logger logger = LoggerFactory.getLogger(DataIO.class); private DataIOMeta dataIOMeta; private DataIOParam dataIOParam; + private List<String> header; + private String inputformat; private Imputer imputer; private Outlier outlier; private boolean isImputer; @@ -56,9 +60,12 @@ public int initModel(byte[] protoMeta, byte[] protoParam) { this.outlier = new Outlier(this.dataIOMeta.getOutlierMeta().getOutlierValueList(), this.dataIOParam.getOutlierParam().getOutlierReplaceValue()); } + + this.header = this.dataIOParam.getHeaderList(); + this.inputformat = this.dataIOMeta.getInputFormat(); } catch (Exception ex) { ex.printStackTrace(); - logger.error("init DataIo error",ex); + logger.error("init DataIo error", ex); return StatusCode.ILLEGALDATA; } logger.info("Finish init DataIO class"); @@ -67,16 +74,31 @@ public int initModel(byte[] protoMeta, byte[] protoParam) { @Override public Map<String, Object> handlePredict(Context context, List<Map<String, Object>> inputData, FederatedParams predictParams) { - Map<String, Object> input = inputData.get(0); + Map<String, Object> data = inputData.get(0); + Map<String, Object> output = new HashMap<>(); + + if (this.inputformat.equals(Dict.TAG_INPUT_FORMAT) || this.inputformat.equals(Dict.SPARSE_INPUT_FORMAT + )) { + for (String col: this.header) { + output.put(col, data.getOrDefault(col, 0)); + } + } if (this.isImputer) { - input = this.imputer.transform(input); + output = this.imputer.transform(output); } if (this.isOutlier) { - input = this.outlier.transform(input); + output = this.outlier.transform(output); } - return input; + /* + for (String col: data.keySet()) { + if (!output.containsKey(col)) { + output.put(col, data.get(col)); + } + }*/ + + return output; } } diff --git a/federatedml/src/main/java/com/webank/ai/fate/serving/federatedml/model/Imputer.java b/federatedml/src/main/java/com/webank/ai/fate/serving/federatedml/model/Imputer.java index 3deadd441..2e2567db6 100644 --- a/federatedml/src/main/java/com/webank/ai/fate/serving/federatedml/model/Imputer.java +++ b/federatedml/src/main/java/com/webank/ai/fate/serving/federatedml/model/Imputer.java @@ -19,6 +19,7 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import java.util.HashMap; import java.util.HashSet; import java.util.List; import java.util.Map; @@ -34,21 +35,21 @@ public Imputer(List<String> missingValues, Map<String, String> missingReplaceVal } public Map<String, Object> transform(Map<String, Object> inputData) { - if(inputData!=null) { - for (String key : inputData.keySet()) { - if(inputData.get(key)!=null) { - String value = inputData.get(key).toString(); - if (this.missingValueSet.contains(value.toLowerCase())) { - try { - inputData.put(key, this.missingReplaceValues.get(key)); - } catch (Exception ex) { - logger.error("Imputer transform error",ex); - inputData.put(key, 0.); - } - } + Map<String, Object> output = new HashMap<>(); + + for (String col: this.missingReplaceValues.keySet()) { + if (inputData.containsKey(col)) { + String value = inputData.get(col).toString(); + if (this.missingValueSet.contains(value.toLowerCase())) { + output.put(col, this.missingReplaceValues.get(col)); + } else { + output.put(col, 0); } + } else { + output.put(col, this.missingReplaceValues.get(col)); } } - return inputData; + + return output; } } From 34aa7b80c25f6e4ff2f2a30d0f2fa14af93c049e Mon Sep 17 00:00:00 2001 From: tanmc123 <mingchaotan@outlook.com> Date: Fri, 22 May 2020 17:46:47 +0800 Subject: [PATCH 17/22] Fix binning Signed-off-by: tanmc123 <mingchaotan@outlook.com> --- .../model/HeteroFeatureBinning.java | 24 +++++++++++++------ 1 file changed, 17 insertions(+), 7 deletions(-) diff --git a/federatedml/src/main/java/com/webank/ai/fate/serving/federatedml/model/HeteroFeatureBinning.java b/federatedml/src/main/java/com/webank/ai/fate/serving/federatedml/model/HeteroFeatureBinning.java index 4fad71acb..0f21ead01 100644 --- a/federatedml/src/main/java/com/webank/ai/fate/serving/federatedml/model/HeteroFeatureBinning.java +++ b/federatedml/src/main/java/com/webank/ai/fate/serving/federatedml/model/HeteroFeatureBinning.java @@ -15,6 +15,7 @@ import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.Collections; public class HeteroFeatureBinning extends BaseModel { @@ -56,30 +57,39 @@ public int initModel(byte[] protoMeta, byte[] protoParam) { @Override public Map<String, Object> handlePredict(Context context, List<Map<String, Object>> inputData, FederatedParams predictParams) { HashMap<String, Object> outputData = new HashMap<>(8); + HashMap<String, Long> headerMap = new HashMap<>(); Map<String, Object> firstData = inputData.get(0); if (!this.needRun) { return firstData; } + for (int i=0; i < this.header.size(); i ++) { + headerMap.put(this.header.get(i), (long) i); + } + for (String colName : firstData.keySet()) { try{ if (! this.splitPoints.containsKey(colName)) { outputData.put(colName, firstData.get(colName)); continue; } - Long thisColIndex = (long) this.header.indexOf(colName); +// Long thisColIndex = (long) this.header.indexOf(colName); + Long thisColIndex = headerMap.get(colName); if (! this.transformCols.contains(thisColIndex)) { outputData.put(colName, firstData.get(colName)); continue; } List<Double> splitPoint = this.splitPoints.get(colName); Double colValue = Double.valueOf(firstData.get(colName).toString()); - int colIndex = 0; - for (colIndex = 0; colIndex < splitPoint.size(); colIndex ++) { - if (colValue <= splitPoint.get(colIndex)) { - break; - } - } + int colIndex; + colIndex = Collections.binarySearch(splitPoint, colValue); +// for (colIndex = 0; colIndex < splitPoint.size(); colIndex ++) { +// +// +// if (colValue <= splitPoint.get(colIndex)) { +// break; +// } +// } outputData.put(colName, colIndex); }catch(Throwable e){ logger.error("HeteroFeatureBinning error" ,e); From 2196ec3c952ac317c5755ac8581976c2600394ab Mon Sep 17 00:00:00 2001 From: mgqa34 <mgq3374541@163.com> Date: Tue, 26 May 2020 19:45:10 +0800 Subject: [PATCH 18/22] add debug logger Signed-off-by: mgqa34 <mgq3374541@163.com> --- .../serving/federatedml/model/DataIO.java | 22 ++++++++++++++----- .../model/HeteroFeatureBinning.java | 21 +++++++++--------- .../serving/federatedml/model/Imputer.java | 2 +- 3 files changed, 29 insertions(+), 16 deletions(-) diff --git a/federatedml/src/main/java/com/webank/ai/fate/serving/federatedml/model/DataIO.java b/federatedml/src/main/java/com/webank/ai/fate/serving/federatedml/model/DataIO.java index 6c2164472..5ed4147fc 100644 --- a/federatedml/src/main/java/com/webank/ai/fate/serving/federatedml/model/DataIO.java +++ b/federatedml/src/main/java/com/webank/ai/fate/serving/federatedml/model/DataIO.java @@ -75,21 +75,33 @@ public int initModel(byte[] protoMeta, byte[] protoParam) { @Override public Map<String, Object> handlePredict(Context context, List<Map<String, Object>> inputData, FederatedParams predictParams) { Map<String, Object> data = inputData.get(0); - Map<String, Object> output = new HashMap<>(); + Map<String, Object> outputData = new HashMap<>(); + + if(logger.isDebugEnabled()) { + logger.debug("input-data, not filling, {}", data); + } if (this.inputformat.equals(Dict.TAG_INPUT_FORMAT) || this.inputformat.equals(Dict.SPARSE_INPUT_FORMAT )) { + if(logger.isDebugEnabled()) { + logger.debug("Sparse Data Filling Zeros"); + } for (String col: this.header) { - output.put(col, data.getOrDefault(col, 0)); + outputData.put(col, data.getOrDefault(col, 0)); + } + } else { + outputData = data; + if(logger.isDebugEnabled()) { + logger.debug("Dense input-format, not filling, {}", outputData); } } if (this.isImputer) { - output = this.imputer.transform(output); + outputData = this.imputer.transform(outputData); } if (this.isOutlier) { - output = this.outlier.transform(output); + outputData = this.outlier.transform(outputData); } /* @@ -99,6 +111,6 @@ public Map<String, Object> handlePredict(Context context, List<Map<String, Objec } }*/ - return output; + return outputData; } } diff --git a/federatedml/src/main/java/com/webank/ai/fate/serving/federatedml/model/HeteroFeatureBinning.java b/federatedml/src/main/java/com/webank/ai/fate/serving/federatedml/model/HeteroFeatureBinning.java index 0f21ead01..c86265e4b 100644 --- a/federatedml/src/main/java/com/webank/ai/fate/serving/federatedml/model/HeteroFeatureBinning.java +++ b/federatedml/src/main/java/com/webank/ai/fate/serving/federatedml/model/HeteroFeatureBinning.java @@ -63,19 +63,19 @@ public Map<String, Object> handlePredict(Context context, List<Map<String, Objec return firstData; } - for (int i=0; i < this.header.size(); i ++) { + for (int i = 0; i < this.header.size(); i++) { headerMap.put(this.header.get(i), (long) i); } for (String colName : firstData.keySet()) { - try{ - if (! this.splitPoints.containsKey(colName)) { + try { + if (!this.splitPoints.containsKey(colName)) { outputData.put(colName, firstData.get(colName)); - continue; + continue; } // Long thisColIndex = (long) this.header.indexOf(colName); Long thisColIndex = headerMap.get(colName); - if (! this.transformCols.contains(thisColIndex)) { + if (!this.transformCols.contains(thisColIndex)) { outputData.put(colName, firstData.get(colName)); continue; } @@ -90,14 +90,15 @@ public Map<String, Object> handlePredict(Context context, List<Map<String, Objec // break; // } // } - outputData.put(colName, colIndex); - }catch(Throwable e){ - logger.error("HeteroFeatureBinning error" ,e); + outputData.put(colName, colIndex); + } catch (Throwable e) { + logger.error("HeteroFeatureBinning error", e); } } - if(logger.isDebugEnabled()) { - logger.debug("HeteroFeatureBinning output {}", outputData); + if (logger.isDebugEnabled()) { + logger.debug("DEBUG: HeteroFeatureBinning output {}", outputData); } + return outputData; } diff --git a/federatedml/src/main/java/com/webank/ai/fate/serving/federatedml/model/Imputer.java b/federatedml/src/main/java/com/webank/ai/fate/serving/federatedml/model/Imputer.java index 2e2567db6..eac6b5c71 100644 --- a/federatedml/src/main/java/com/webank/ai/fate/serving/federatedml/model/Imputer.java +++ b/federatedml/src/main/java/com/webank/ai/fate/serving/federatedml/model/Imputer.java @@ -43,7 +43,7 @@ public Map<String, Object> transform(Map<String, Object> inputData) { if (this.missingValueSet.contains(value.toLowerCase())) { output.put(col, this.missingReplaceValues.get(col)); } else { - output.put(col, 0); + output.put(col, inputData.get(col)); } } else { output.put(col, this.missingReplaceValues.get(col)); From d1ce65a95600f482e39647d5067ea79ab5015227 Mon Sep 17 00:00:00 2001 From: tanmc123 <mingchaotan@outlook.com> Date: Tue, 26 May 2020 21:08:36 +0800 Subject: [PATCH 19/22] Fix binning Signed-off-by: tanmc123 <mingchaotan@outlook.com> --- .../serving/federatedml/model/HeteroFeatureBinning.java | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/federatedml/src/main/java/com/webank/ai/fate/serving/federatedml/model/HeteroFeatureBinning.java b/federatedml/src/main/java/com/webank/ai/fate/serving/federatedml/model/HeteroFeatureBinning.java index c86265e4b..8dd45a1a2 100644 --- a/federatedml/src/main/java/com/webank/ai/fate/serving/federatedml/model/HeteroFeatureBinning.java +++ b/federatedml/src/main/java/com/webank/ai/fate/serving/federatedml/model/HeteroFeatureBinning.java @@ -81,8 +81,11 @@ public Map<String, Object> handlePredict(Context context, List<Map<String, Objec } List<Double> splitPoint = this.splitPoints.get(colName); Double colValue = Double.valueOf(firstData.get(colName).toString()); - int colIndex; - colIndex = Collections.binarySearch(splitPoint, colValue); + int colIndex = Collections.binarySearch(splitPoint, colValue); + if (colIndex < 0) { + colIndex = - colIndex - 1; + } + // for (colIndex = 0; colIndex < splitPoint.size(); colIndex ++) { // // From 883f1e0f19fa38fe5e851c9912d7c7056864b899 Mon Sep 17 00:00:00 2001 From: chenweijing <talkingwallace@sohu.com> Date: Tue, 2 Jun 2020 20:22:20 +0800 Subject: [PATCH 20/22] modify predict code Signed-off-by: Chen <cwjghglbdcj@gmail.com> --- .../federatedml/model/HeteroSecureBoost.java | 1 + .../model/HeteroSecureBoostingTreeGuest.java | 22 +++++++++--- .../model/HeteroSecureBoostingTreeHost.java | 34 +++++++++++-------- 3 files changed, 38 insertions(+), 19 deletions(-) diff --git a/federatedml/src/main/java/com/webank/ai/fate/serving/federatedml/model/HeteroSecureBoost.java b/federatedml/src/main/java/com/webank/ai/fate/serving/federatedml/model/HeteroSecureBoost.java index cbe5b3659..7e18b1de5 100644 --- a/federatedml/src/main/java/com/webank/ai/fate/serving/federatedml/model/HeteroSecureBoost.java +++ b/federatedml/src/main/java/com/webank/ai/fate/serving/federatedml/model/HeteroSecureBoost.java @@ -40,6 +40,7 @@ public abstract class HeteroSecureBoost extends BaseModel { protected List<String> classes; protected int treeDim; protected double learningRate; + protected boolean fastMode = true; @Override public int initModel(byte[] protoMeta, byte[] protoParam) { diff --git a/federatedml/src/main/java/com/webank/ai/fate/serving/federatedml/model/HeteroSecureBoostingTreeGuest.java b/federatedml/src/main/java/com/webank/ai/fate/serving/federatedml/model/HeteroSecureBoostingTreeGuest.java index 7aab333e6..7efa319ac 100644 --- a/federatedml/src/main/java/com/webank/ai/fate/serving/federatedml/model/HeteroSecureBoostingTreeGuest.java +++ b/federatedml/src/main/java/com/webank/ai/fate/serving/federatedml/model/HeteroSecureBoostingTreeGuest.java @@ -169,7 +169,11 @@ public Map<String, Object> handlePredict(Context context, List<Map<String, Objec Map<String, Object> input = inputData.get(0); HashMap<String, Object> fidValueMapping = new HashMap<String, Object>(8); - ReturnResult returnResult = this.getFederatedPredict(context, predictParams, Dict.FEDERATED_INFERENCE, false); + if(!this.fastMode){ + // ask host to prepare data, if fast mode is not enabled + ReturnResult returnResult = this.getFederatedPredict(context, predictParams, Dict.FEDERATED_INFERENCE, false); + } + int featureHit = 0; for (String key : input.keySet()) { @@ -183,6 +187,8 @@ public Map<String, Object> handlePredict(Context context, List<Map<String, Objec int[] treeNodeIds = new int[this.treeNum]; double[] weights = new double[this.treeNum]; int communicationRound = 0; + + // start local inference while (true) { HashMap<String, Object> treeLocation = new HashMap<String, Object>(8); for (int i = 0; i < this.treeNum; ++i) { @@ -214,10 +220,18 @@ public Map<String, Object> handlePredict(Context context, List<Map<String, Objec try { logger.info("begin to federated"); - ReturnResult tempResult = this.getFederatedPredict(context, predictParams, Dict.FEDERATED_INFERENCE_FOR_TREE, false); - Map<String, Object> returnData = tempResult.getData(); - boolean getNodeRoute = false; + ReturnResult tempResult; + if(this.fastMode){ + getNodeRoute = true; + tempResult = this.getFederatedPredict(context, predictParams, Dict.FEDERATED_INFERENCE, false); + } + else{ + tempResult = this.getFederatedPredict(context, predictParams, Dict.FEDERATED_INFERENCE_FOR_TREE, false); + } + + + Map<String, Object> returnData = tempResult.getData(); for(Object obj: returnData.values()){ if(!(obj instanceof Integer)) getNodeRoute = true; // get node position if value is integer break; diff --git a/federatedml/src/main/java/com/webank/ai/fate/serving/federatedml/model/HeteroSecureBoostingTreeHost.java b/federatedml/src/main/java/com/webank/ai/fate/serving/federatedml/model/HeteroSecureBoostingTreeHost.java index 495815b89..c7c00b93f 100644 --- a/federatedml/src/main/java/com/webank/ai/fate/serving/federatedml/model/HeteroSecureBoostingTreeHost.java +++ b/federatedml/src/main/java/com/webank/ai/fate/serving/federatedml/model/HeteroSecureBoostingTreeHost.java @@ -32,7 +32,6 @@ public class HeteroSecureBoostingTreeHost extends HeteroSecureBoost { private final String site = "host"; private final String modelId = "HeteroSecureBoostingTreeHost"; // need to change - private boolean fastMode = true; // DefaultCacheManager cacheManager = BaseContext.applicationContext.getBean(DefaultCacheManager.class); @@ -143,7 +142,19 @@ public Map<String, Object> handlePredict(Context context, List<Map<String, Objec ++featureHit; } } - this.saveData(context, tag, fidValueMapping); + + if(this.fastMode){ + if(logger.isDebugEnabled()){ + logger.info("fast mode enabled"); + } + // if use fast mode, return data is the look up table: <tree_idx, < node_idx, direction>> in first + // communication round + ret = this.extractHostNodeRoute(fidValueMapping); + } + else{ + this.saveData(context, tag, fidValueMapping); + } + return ret; } @@ -152,19 +163,12 @@ public Map<String, Object> predictSingleRound(Context context, Map<String, Objec String tag = predictParams.getCaseId() + "." + this.componentName + "." + Dict.INPUT_DATA; Map<String, Object> input = this.getData(context, tag); - if(!this.fastMode){ - Map<String, Object> ret = new HashMap<String, Object>(8); - for (String treeIdx : interactiveData.keySet()) { - int idx = Integer.valueOf(treeIdx); - int nodeId = this.traverseTree(idx, (Integer) interactiveData.get(treeIdx), input); - ret.put(treeIdx, nodeId); - } - return ret; - } - else { - // if use fast mode, return data is the look up table: <tree_idx, < node_idx, direction>> - Map<String, Object> ret = this.extractHostNodeRoute(input); - return ret; + Map<String, Object> ret = new HashMap<String, Object>(8); + for (String treeIdx : interactiveData.keySet()) { + int idx = Integer.valueOf(treeIdx); + int nodeId = this.traverseTree(idx, (Integer) interactiveData.get(treeIdx), input); + ret.put(treeIdx, nodeId); } + return ret; } } From c0517ca03d911dad4c583d65820ada78b8b700c3 Mon Sep 17 00:00:00 2001 From: tanmc123 <mingchaotan@outlook.com> Date: Tue, 2 Jun 2020 20:55:57 +0800 Subject: [PATCH 21/22] Fix binning logic: For data that belongs to last bin, the bin index may be equal to bin length which is wrong. Signed-off-by: tanmc123 <mingchaotan@outlook.com> --- .../fate/serving/federatedml/model/HeteroFeatureBinning.java | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/federatedml/src/main/java/com/webank/ai/fate/serving/federatedml/model/HeteroFeatureBinning.java b/federatedml/src/main/java/com/webank/ai/fate/serving/federatedml/model/HeteroFeatureBinning.java index 8dd45a1a2..87f9e935c 100644 --- a/federatedml/src/main/java/com/webank/ai/fate/serving/federatedml/model/HeteroFeatureBinning.java +++ b/federatedml/src/main/java/com/webank/ai/fate/serving/federatedml/model/HeteroFeatureBinning.java @@ -16,6 +16,7 @@ import java.util.List; import java.util.Map; import java.util.Collections; +import java.lang.Math; public class HeteroFeatureBinning extends BaseModel { @@ -83,9 +84,8 @@ public Map<String, Object> handlePredict(Context context, List<Map<String, Objec Double colValue = Double.valueOf(firstData.get(colName).toString()); int colIndex = Collections.binarySearch(splitPoint, colValue); if (colIndex < 0) { - colIndex = - colIndex - 1; + colIndex = Math.min((- colIndex - 1), splitPoint.size() - 1); } - // for (colIndex = 0; colIndex < splitPoint.size(); colIndex ++) { // // From a978522b8b18bcf93e0c376a5b45bab5db2e8f84 Mon Sep 17 00:00:00 2001 From: chenweijing <talkingwallace@sohu.com> Date: Wed, 3 Jun 2020 14:40:36 +0800 Subject: [PATCH 22/22] remove redundant codes Signed-off-by: Chen <cwjghglbdcj@gmail.com> --- .../federatedml/model/HeteroSecureBoostingTreeGuest.java | 4 ---- 1 file changed, 4 deletions(-) diff --git a/federatedml/src/main/java/com/webank/ai/fate/serving/federatedml/model/HeteroSecureBoostingTreeGuest.java b/federatedml/src/main/java/com/webank/ai/fate/serving/federatedml/model/HeteroSecureBoostingTreeGuest.java index 7efa319ac..5969143e7 100644 --- a/federatedml/src/main/java/com/webank/ai/fate/serving/federatedml/model/HeteroSecureBoostingTreeGuest.java +++ b/federatedml/src/main/java/com/webank/ai/fate/serving/federatedml/model/HeteroSecureBoostingTreeGuest.java @@ -232,10 +232,6 @@ public Map<String, Object> handlePredict(Context context, List<Map<String, Objec Map<String, Object> returnData = tempResult.getData(); - for(Object obj: returnData.values()){ - if(!(obj instanceof Integer)) getNodeRoute = true; // get node position if value is integer - break; - } if(this.fastMode && getNodeRoute){