Skip to content

Commit

Permalink
Merge pull request #204 from FederatedAI/dev-2.1.7
Browse files Browse the repository at this point in the history
Dev 2.1.7
dylan-fan authored Nov 9, 2023
2 parents 8892c98 + 893b30c commit ab7ba98
Showing 62 changed files with 1,108 additions and 268 deletions.
4 changes: 3 additions & 1 deletion document/docs/service/adapter.md
Original file line number Diff line number Diff line change
@@ -40,8 +40,10 @@ Context为上下文信息,用于传递请求所需参数,featureIds用于传
#在host方的配置文件serving-server.properties中将其配置成自定义的类的全路径,如下所示
feature.single.adaptor=com.webank.ai.fate.serving.adaptor.dataaccess.CustomAdapter
feature.batch.adaptor=com.webank.ai.fate.serving.adaptor.dataaccess.CustomBatchAdapter
feature.batch.single.adatpor=com.webank.ai.fate.serving.adaptor.dataaccess.CustomAdapter
```
可以根据需要实现Adapter中的逻辑,并修改serving-server.properties中feature.single.adaptor或feature.batch.adaptor配置项为新增Adapter的全类名即可。可以参考源码中的MockAdaptor
: feature.batch.single.adatpor与feature.batch.adatpor配套使用,feature.batch.single.adatpor可根据用户场景自行实现,fate-serving中目前支持httpAdaptor

## fate-serving-extension
为了更好的代码解耦合,代码中将自定义adapter分离到fate-serving-extension模块中。用户可在此模块中开发自定义的adapter。
@@ -69,7 +71,7 @@ x0:1,x1:5,x2:13,x3:58,x4:95,x5:352,x6:418,x7:833,x8:888,x9:937,x10:32776

#### HttpAdapter
在serving-server.properties文件中配置属性feature.single.adaptor和http.adapter.url,feature.single.adaptor为继承AbstractSingleFeatureDataAdaptor
接口,url为调用获取数据接口地址。
接口,url为调用获取数据接口地址。http.adapter.url中标明的用户接口,返回格式请定义为 {"code": 200, "data": xxx}标准格式即可,httpAdapter中会根据接口返回状态码是否为200判断用户数据拉取接口是否执行成功。
```yaml
feature.single.adaptor=com.webank.ai.fate.serving.adaptor.dataaccess.HttpAdapter
http.adapter.url=http://127.0.0.1:9380/v1/http/adapter/getFeature
2 changes: 2 additions & 0 deletions document/docs/service/admin.md
Original file line number Diff line number Diff line change
@@ -5,6 +5,8 @@ serving-admin提供了FATE-Serving集群的可视化操作界面,依赖zookeep
### 功能介绍
#### 用户管理
默认用户:admin,默认密码:admin,用户可在[conf/application.properties](config/admin.md)中修改预设用户。
除此之外serving-admin提供一个基本的登录密码加解密功能,用户可在[conf/application.properties](config/application.properties)
中通过设置admin.isEncrypt参数为true(默认为false关闭),同时根据spring.security中的BCryptPasswordEncoder库对密码进行提前处理并预设为默认密码。
serving-admin仅实现简单的用户登录,用户可业务需求,自行实现登录逻辑,或接入第三方平台。

#### 节点管理
2 changes: 1 addition & 1 deletion fate-serving-admin-ui/README.md
Original file line number Diff line number Diff line change
@@ -4,7 +4,7 @@ fate-serving定制联邦服务管理端
## Build Setup

# Install dependencies
npm install
npm install --force

# Serve with hot reload at localhost:8080
npm run serve
3 changes: 2 additions & 1 deletion fate-serving-admin-ui/package.json
Original file line number Diff line number Diff line change
@@ -48,6 +48,7 @@
"svg-sprite-loader": "^3.9.2",
"vue-template-compiler": "^2.6.10",
"webpack-bundle-analyzer": ">=3.3.2",
"webpack-cli": "^3.2.3"
"webpack-cli": "^3.2.3",
"webpack": "^4.0.0"
}
}
4 changes: 2 additions & 2 deletions fate-serving-admin-ui/pom.xml
Original file line number Diff line number Diff line change
@@ -55,7 +55,7 @@
<goal>install-node-and-npm</goal>
</goals>
<configuration>
<nodeVersion>v9.11.1</nodeVersion>
<nodeVersion>v16.20.2</nodeVersion>
</configuration>
</execution>
<!-- Install all project dependencies -->
@@ -68,7 +68,7 @@
<phase>generate-resources</phase>
<!-- Optional configuration which provides for running any npm command -->
<configuration>
<arguments>install</arguments>
<arguments>install --force</arguments>
</configuration>
</execution>
<!-- Build and minify static files -->
2 changes: 1 addition & 1 deletion fate-serving-admin/bin/service.sh
Original file line number Diff line number Diff line change
@@ -24,7 +24,7 @@ basepath=$(cd `dirname $0`;pwd)
configpath=$(cd $basepath/conf;pwd)
module=serving-admin
main_class=com.webank.ai.fate.serving.admin.Bootstrap
module_version=2.1.6
module_version=2.1.7

case "$1" in
start)
5 changes: 5 additions & 0 deletions fate-serving-admin/pom.xml
Original file line number Diff line number Diff line change
@@ -51,6 +51,11 @@
<version>${fate.version}</version>
</dependency>

<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-security</artifactId>
</dependency>



<!--<dependency>
Original file line number Diff line number Diff line change
@@ -119,6 +119,7 @@ public void stop() {
} catch (InterruptedException e) {
e.printStackTrace();
}
tryNum++;
}
}
}
Original file line number Diff line number Diff line change
@@ -44,8 +44,7 @@ public static boolean isAllowModify(String project, String config) {
return Boolean.FALSE;
}

boolean match = Arrays.asList(value.config).contains(config);
return match;
return Arrays.asList(value.config).contains(config);
}

}
Original file line number Diff line number Diff line change
@@ -72,8 +72,7 @@ public Cache cache() {
Integer maxSize = MetaInfo.PROPERTY_LOCAL_CACHE_MAXSIZE;
Integer expireTime = MetaInfo.PROPERTY_LOCAL_CACHE_EXPIRE;
Integer interval = MetaInfo.PROPERTY_LOCAL_CACHE_INTERVAL;
ExpiringLRUCache lruCache = new ExpiringLRUCache(maxSize, expireTime, interval);
return lruCache;
return new ExpiringLRUCache(maxSize, expireTime, interval);
}


Original file line number Diff line number Diff line change
@@ -16,12 +16,14 @@

package com.webank.ai.fate.serving.admin.controller;

import com.fasterxml.jackson.databind.ObjectMapper;
import com.google.common.base.Preconditions;
import com.google.common.collect.Maps;
import com.webank.ai.fate.api.networking.common.CommonServiceGrpc;
import com.webank.ai.fate.api.networking.common.CommonServiceProto;
import com.webank.ai.fate.serving.admin.bean.ServiceConfiguration;
import com.webank.ai.fate.serving.admin.services.ComponentService;
import com.webank.ai.fate.serving.admin.utils.NetAddressChecker;
import com.webank.ai.fate.serving.core.bean.*;
import com.webank.ai.fate.serving.core.constant.StatusCode;
import com.webank.ai.fate.serving.core.exceptions.RemoteRpcException;
@@ -53,24 +55,22 @@ public class ComponentController {

private static final Logger logger = LoggerFactory.getLogger(ComponentController.class);

private final ObjectMapper objectMapper = new ObjectMapper();

@Autowired
ComponentService componentServices;
GrpcConnectionPool grpcConnectionPool = GrpcConnectionPool.getPool();

@GetMapping("/component/list")
public ReturnResult list() {
ComponentService.NodeData cachedNodeData = componentServices.getCachedNodeData();
return ReturnResult.build(StatusCode.SUCCESS, Dict.SUCCESS, JsonUtil.json2Object(JsonUtil.object2Json(cachedNodeData), Map.class));
Map<String, Object> cachedNodeDataMap = objectMapper.convertValue(cachedNodeData, Map.class);
return ReturnResult.build(StatusCode.SUCCESS, Dict.SUCCESS, cachedNodeDataMap);
}

@GetMapping("/component/listProps")
public ReturnResult listProps(String host, int port, String keyword) {
if (!NetUtils.isValidAddress(host + ":" + port)) {
throw new SysException("invalid address");
}
if (!componentServices.isAllowAccess(host, port)) {
throw new RemoteRpcException("no allow access, target: " + host + ":" + port);
}
NetAddressChecker.check(host, port);
ManagedChannel managedChannel = grpcConnectionPool.getManagedChannel(host, port);
CommonServiceGrpc.CommonServiceBlockingStub blockingStub = CommonServiceGrpc.newBlockingStub(managedChannel);
blockingStub = blockingStub.withDeadlineAfter(MetaInfo.PROPERTY_GRPC_TIMEOUT, TimeUnit.MILLISECONDS);
@@ -102,17 +102,15 @@ public ReturnResult listProps(String host, int port, String keyword) {

@PostMapping("/component/updateConfig")
public ReturnResult updateConfig(@RequestBody RequestParamWrapper requestParams) {
Preconditions.checkArgument(StringUtils.isNotBlank(requestParams.getFilePath()), "file path is blank");
Preconditions.checkArgument(StringUtils.isNotBlank(requestParams.getData()), "data is blank");
String filePath = requestParams.getFilePath();
String data = requestParams.getData();
Preconditions.checkArgument(StringUtils.isNotBlank(filePath), "file path is blank");
Preconditions.checkArgument(StringUtils.isNotBlank(data), "data is blank");

String host = requestParams.getHost();
int port = requestParams.getPort();
NetAddressChecker.check(host, port);

if (!componentServices.isAllowAccess(host, port)) {
throw new RemoteRpcException("no allow access, target: " + host + ":" + port);
}

String filePath = requestParams.getFilePath();
String fileName = filePath.substring(filePath.lastIndexOf(File.separator) + 1);

String project = componentServices.getProject(host, port);
@@ -124,8 +122,8 @@ public ReturnResult updateConfig(@RequestBody RequestParamWrapper requestParams)
CommonServiceGrpc.CommonServiceBlockingStub blockingStub = CommonServiceGrpc.newBlockingStub(managedChannel)
.withDeadlineAfter(MetaInfo.PROPERTY_GRPC_TIMEOUT, TimeUnit.MILLISECONDS);
CommonServiceProto.UpdateConfigRequest.Builder builder = CommonServiceProto.UpdateConfigRequest.newBuilder();
builder.setFilePath(requestParams.getFilePath());
builder.setData(requestParams.getData());
builder.setFilePath(filePath);
builder.setData(data);

CommonServiceProto.CommonResponse response = blockingStub.updateConfig(builder.build());
return ReturnResult.build(response.getStatusCode(), response.getMessage());
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
package com.webank.ai.fate.serving.admin.controller;

import org.apache.logging.log4j.Level;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.core.LoggerContext;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.web.bind.annotation.*;

/**
* @author hcy
*/
@RequestMapping("/admin")
@RestController
public class DynamicLogController {
private static final Logger logger = LoggerFactory.getLogger(DynamicLogController.class);

@GetMapping("/alterSysLogLevel/{level}")
public String alterSysLogLevel(@PathVariable String level){
try {
LoggerContext context = (LoggerContext) LogManager.getContext(false);
context.getLogger("ROOT").setLevel(Level.valueOf(level));
return "ok";
} catch (Exception ex) {
logger.error("admin alterSysLogLevel failed : " + ex);
return "failed";
}

}

@GetMapping("/alterPkgLogLevel")
public String alterPkgLogLevel(@RequestParam String level, @RequestParam String pkgName){
try {
LoggerContext context = (LoggerContext) LogManager.getContext(false);
context.getLogger(pkgName).setLevel(Level.valueOf(level));
return "ok";
} catch (Exception ex) {
logger.error("admin alterPkgLogLevel failed : " + ex);
return "failed";
}
}
}
Original file line number Diff line number Diff line change
@@ -30,6 +30,7 @@
import org.springframework.web.bind.annotation.RequestBody;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RestController;
import org.springframework.security.crypto.bcrypt.BCryptPasswordEncoder;

import javax.servlet.http.HttpServletRequest;
import java.util.Arrays;
@@ -51,7 +52,10 @@ public class LoginController {
private String username;

@Value("${admin.password}")
private String password;
private String hashedPassword;

@Value("${admin.isEncrypt}")
private Boolean isEncrypt;

@Autowired
private Cache cache;
@@ -60,12 +64,20 @@ public class LoginController {
public ReturnResult login(@RequestBody RequestParamWrapper requestParams) {
String username = requestParams.getUsername();
String password = requestParams.getPassword();
boolean passwordIfCorrect;

Preconditions.checkArgument(StringUtils.isNotBlank(username), "parameter username is blank");
Preconditions.checkArgument(StringUtils.isNotBlank(password), "parameter password is blank");

ReturnResult result = new ReturnResult();
if (username.equals(this.username) && password.equals(this.password)) {

if (isEncrypt) {
passwordIfCorrect = new BCryptPasswordEncoder().matches(password, this.hashedPassword);
} else {
passwordIfCorrect = password.equals(this.hashedPassword);
}

if (username.equals(this.username) && passwordIfCorrect) {
String userInfo = StringUtils.join(Arrays.asList(username, password), "_");
String token = EncryptUtils.encrypt(Dict.USER_CACHE_KEY_PREFIX + userInfo, EncryptMethod.MD5);
cache.put(token, userInfo, MetaInfo.PROPERTY_CACHE_TYPE.equalsIgnoreCase("local") ? MetaInfo.PROPERTY_LOCAL_CACHE_EXPIRE : MetaInfo.PROPERTY_REDIS_EXPIRE);
@@ -77,7 +89,7 @@ public ReturnResult login(@RequestBody RequestParamWrapper requestParams) {
result.setRetcode(StatusCode.SUCCESS);
result.setData(data);
} else {
logger.info("user {} login failure, username or password {} is wrong.", username,password);
logger.error("user {} login failure, username or password {} is wrong.", username,password);
result.setRetcode(StatusCode.PARAM_ERROR);
result.setRetmsg("username or password is wrong");
}
@@ -94,7 +106,7 @@ public ReturnResult logout(HttpServletRequest request) {
cache.delete(sessionToken);
result.setRetcode(StatusCode.SUCCESS);
} else {
logger.info("Session token unavailable");
logger.error("Session token unavailable");
result.setRetcode(StatusCode.PARAM_ERROR);
result.setRetmsg("Session token unavailable");
}
Original file line number Diff line number Diff line change
@@ -23,6 +23,7 @@
import com.webank.ai.fate.api.mlmodel.manager.ModelServiceGrpc;
import com.webank.ai.fate.api.mlmodel.manager.ModelServiceProto;
import com.webank.ai.fate.serving.admin.services.ComponentService;
import com.webank.ai.fate.serving.admin.utils.NetAddressChecker;
import com.webank.ai.fate.serving.core.bean.GrpcConnectionPool;
import com.webank.ai.fate.serving.core.bean.MetaInfo;
import com.webank.ai.fate.serving.core.bean.RequestParamWrapper;
@@ -66,12 +67,15 @@ public ReturnResult queryModel(String host, Integer port, String serviceId,Strin
Preconditions.checkArgument(StringUtils.isNotBlank(host), "parameter host is blank");
Preconditions.checkArgument(port != 0, "parameter port is blank");

if (page == null || page < 0) {
page = 1;
int defaultPage = 1;
int defaultPageSize = 10;

if (page == null || page <= 0) {
page = defaultPage;
}

if (pageSize == null) {
pageSize = 10;
if (pageSize == null || pageSize <= 0) {
pageSize = defaultPageSize;
}

if (logger.isDebugEnabled()) {
@@ -153,7 +157,6 @@ public Callable<ReturnResult> transfer(@RequestBody RequestParamWrapper requestP
return () -> {
String host = requestParams.getHost();
Integer port = requestParams.getPort();
List<String> serviceIds = requestParams.getServiceIds();
String tableName = requestParams.getTableName();
String namespace = requestParams.getNamespace();

@@ -176,13 +179,6 @@ public Callable<ReturnResult> transfer(@RequestBody RequestParamWrapper requestP
//.setServiceId()
.setNamespace(namespace).setTableName(tableName).setSourceIp(host).setSourcePort(port).build();



ModelServiceProto.UnloadRequest unloadRequest = ModelServiceProto.UnloadRequest.newBuilder()
.setTableName(tableName)
.setNamespace(namespace)
.build();

ListenableFuture<ModelServiceProto.FetchModelResponse> future = futureStub.fetchModel(fetchModelRequest);
ModelServiceProto.FetchModelResponse response = future.get(MetaInfo.PROPERTY_GRPC_TIMEOUT, TimeUnit.MILLISECONDS);

@@ -281,13 +277,7 @@ private ModelServiceGrpc.ModelServiceBlockingStub getModelServiceBlockingStub(St
Preconditions.checkArgument(StringUtils.isNotBlank(host), "parameter host is blank");
Preconditions.checkArgument(port != null && port.intValue() != 0, "parameter port was wrong");

if (!NetUtils.isValidAddress(host + ":" + port)) {
throw new SysException("invalid address");
}

if (!componentService.isAllowAccess(host, port)) {
throw new RemoteRpcException("no allow access, target: " + host + ":" + port);
}
NetAddressChecker.check(host, port);

ManagedChannel managedChannel = grpcConnectionPool.getManagedChannel(host, port);
ModelServiceGrpc.ModelServiceBlockingStub blockingStub = ModelServiceGrpc.newBlockingStub(managedChannel);
@@ -299,17 +289,10 @@ private ModelServiceGrpc.ModelServiceFutureStub getModelServiceFutureStub(String
Preconditions.checkArgument(StringUtils.isNotBlank(host), "parameter host is blank");
Preconditions.checkArgument(port != null && port.intValue() != 0, "parameter port was wrong");

if (!NetUtils.isValidAddress(host + ":" + port)) {
throw new SysException("invalid address");
}

if (!componentService.isAllowAccess(host, port)) {
throw new RemoteRpcException("no allow access, target: " + host + ":" + port);
}
NetAddressChecker.check(host, port);

ManagedChannel managedChannel = grpcConnectionPool.getManagedChannel(host, port);
ModelServiceGrpc.ModelServiceFutureStub futureStub = ModelServiceGrpc.newFutureStub(managedChannel);
return futureStub;
return ModelServiceGrpc.newFutureStub(managedChannel);
}

public void parseComponentInfo(ModelServiceProto.QueryModelResponse response){
Original file line number Diff line number Diff line change
@@ -29,6 +29,7 @@

import com.webank.ai.fate.serving.admin.services.ComponentService;
import com.webank.ai.fate.serving.admin.services.HealthCheckService;
import com.webank.ai.fate.serving.admin.utils.NetAddressChecker;
import com.webank.ai.fate.serving.common.flow.JvmInfo;
import com.webank.ai.fate.serving.common.flow.MetricNode;
import com.webank.ai.fate.serving.common.health.HealthCheckRecord;
@@ -278,24 +279,15 @@ private CommonServiceGrpc.CommonServiceBlockingStub getMonitorServiceBlockStub(S
Preconditions.checkArgument(StringUtils.isNotBlank(host), "parameter host is blank");
Preconditions.checkArgument(port != 0, "parameter port was wrong");

if (!NetUtils.isValidAddress(host + ":" + port)) {
throw new SysException("invalid address");
}

if (!componentService.isAllowAccess(host, port)) {
throw new RemoteRpcException("no allow access, target: " + host + ":" + port);
}
NetAddressChecker.check(host, port);

ManagedChannel managedChannel = grpcConnectionPool.getManagedChannel(host, port);
CommonServiceGrpc.CommonServiceBlockingStub blockingStub = CommonServiceGrpc.newBlockingStub(managedChannel);
return blockingStub;
return CommonServiceGrpc.newBlockingStub(managedChannel);
}

private InferenceServiceGrpc.InferenceServiceBlockingStub getInferenceServiceBlockingStub(String host, int port, int timeout) throws Exception {
Preconditions.checkArgument(StringUtils.isNotBlank(host), "host is blank");
if (!NetUtils.isValidAddress(host + ":" + port)) {
throw new SysException("invalid address");
}
NetAddressChecker.check(host, port);

ManagedChannel managedChannel = grpcConnectionPool.getManagedChannel(host, port);
InferenceServiceGrpc.InferenceServiceBlockingStub blockingStub = InferenceServiceGrpc.newBlockingStub(managedChannel);
@@ -340,8 +332,8 @@ private void checkInferenceService(Map<String,Map> componentMap, ComponentServic
List currentList = currentComponentMap.get(serviceInfo.getHost() + ":" + port);
Map<String,Object> currentInfoMap = new HashMap<>();
currentInfoMap.put("type","inference");
if (result.getBody() == null || result.getBody().toStringUtf8() == "null") {
currentInfoMap.put("data",null);
if (result.getBody() == null || "null".equals(result.getBody().toStringUtf8())) {
currentInfoMap.put("data", null);
}
else{
Map resultMap = JsonUtil.json2Object(result.getBody().toStringUtf8(),Map.class);
Original file line number Diff line number Diff line change
@@ -22,6 +22,7 @@
import com.webank.ai.fate.register.url.URL;
import com.webank.ai.fate.register.zookeeper.ZookeeperRegistry;
import com.webank.ai.fate.serving.admin.services.ComponentService;
import com.webank.ai.fate.serving.admin.utils.NetAddressChecker;
import com.webank.ai.fate.serving.core.bean.*;
import com.webank.ai.fate.serving.core.exceptions.RemoteRpcException;
import com.webank.ai.fate.serving.core.exceptions.SysException;
@@ -54,7 +55,7 @@ public class RouterController {
ZookeeperRegistry zookeeperRegistry;
@Autowired
ComponentService componentService;
String ROUTER_URL = "proxy/online/queryRouter";
static final String ROUTER_URL = "proxy/online/queryRouter";


@PostMapping("/router/query")
@@ -108,12 +109,7 @@ public ReturnResult queryRouter(@RequestBody RouterTableRequest routerTable) {
private RouterTableServiceGrpc.RouterTableServiceBlockingStub getRouterTableServiceBlockingStub(String host, Integer port) {
ParameterUtils.checkArgument(StringUtils.isNotBlank(host), "parameter host is blank");
ParameterUtils.checkArgument(port != null && port != 0, "parameter port was wrong");
if (!NetUtils.isValidAddress(host + ":" + port)) {
throw new SysException("invalid address");
}
if (!componentService.isAllowAccess(host, port)) {
throw new RemoteRpcException("no allow access, target: " + host + ":" + port);
}
NetAddressChecker.check(host, port);

ManagedChannel managedChannel = grpcConnectionPool.getManagedChannel(host, port);
RouterTableServiceGrpc.RouterTableServiceBlockingStub blockingStub = RouterTableServiceGrpc.newBlockingStub(managedChannel);
Original file line number Diff line number Diff line change
@@ -27,6 +27,7 @@
import com.webank.ai.fate.register.zookeeper.ZookeeperRegistry;
import com.webank.ai.fate.serving.admin.bean.VerifyService;
import com.webank.ai.fate.serving.admin.services.ComponentService;
import com.webank.ai.fate.serving.admin.utils.NetAddressChecker;
import com.webank.ai.fate.serving.core.bean.*;
import com.webank.ai.fate.serving.core.constant.StatusCode;
import com.webank.ai.fate.serving.core.exceptions.RemoteRpcException;
@@ -72,12 +73,15 @@ public class ServiceController {
*/
@GetMapping("/service/list")
public ReturnResult listRegistered(Integer page, Integer pageSize) {
if (page == null || page < 0) {
page = 1;
int defaultPage = 1;
int defaultPageSize = 10;

if (page == null || page <= 0) {
page = defaultPage;
}

if (pageSize == null) {
pageSize = 10;
if (pageSize == null || pageSize <= 0) {
pageSize = defaultPageSize;
}

if (logger.isDebugEnabled()) {
@@ -103,8 +107,9 @@ public ReturnResult listRegistered(Integer page, Integer pageSize) {
URL url = URL.valueOf(u);
if (!Constants.EMPTY_PROTOCOL.equals(url.getProtocol())) {
String[] split = key.split("/");
if(!filterSet.contains(split[2]))
if (!filterSet.contains(split[2])) {
continue;
}
ServiceDataWrapper wrapper = new ServiceDataWrapper();
wrapper.setUrl(url.toFullString());
wrapper.setProject(split[0]);
@@ -128,8 +133,6 @@ public ReturnResult listRegistered(Integer page, Integer pageSize) {
totalSize = resultList.size();

resultList = resultList.stream().sorted((Comparator.comparing(o -> (o.getProject() + o.getEnvironment())))).collect(Collectors.toList());
// resultList = resultList.stream().sorted((Comparator.comparingInt(o -> (o.getProject() + o.getEnvironment()).hashCode()))).collect(Collectors.toList());
// Pagination
int totalPage = (resultList.size() + pageSize - 1) / pageSize;
if (page <= totalPage) {
resultList = resultList.subList((page - 1) * pageSize, Math.min(page * pageSize, resultList.size()));
@@ -228,17 +231,10 @@ private CommonServiceGrpc.CommonServiceFutureStub getCommonServiceFutureStub(Str
Preconditions.checkArgument(StringUtils.isNotBlank(host), "parameter host is blank");
Preconditions.checkArgument(port != null && port.intValue() != 0, "parameter port was wrong");

if (!NetUtils.isValidAddress(host + ":" + port)) {
throw new SysException("invalid address");
}

if (!componentService.isAllowAccess(host, port)) {
throw new RemoteRpcException("no allow access, target: " + host + ":" + port);
}
NetAddressChecker.check(host, port);

ManagedChannel managedChannel = grpcConnectionPool.getManagedChannel(host, port);
CommonServiceGrpc.CommonServiceFutureStub futureStub = CommonServiceGrpc.newFutureStub(managedChannel);
return futureStub;
return CommonServiceGrpc.newFutureStub(managedChannel);
}

}
Original file line number Diff line number Diff line change
@@ -29,7 +29,8 @@ public class SecurityFilter implements Filter {

@Override
public void doFilter(ServletRequest req, ServletResponse resp, FilterChain filterChain) throws IOException, ServletException {
((HttpServletResponse) resp).addHeader("X-Frame-Options","DENY");
((HttpServletResponse) resp).addHeader("X-Frame-Options", "DENY");
((HttpServletResponse) resp).addHeader("X-XSS-Protection", "1; mode=block");
filterChain.doFilter(req, resp);
}
}
Original file line number Diff line number Diff line change
@@ -43,7 +43,7 @@ public class LoginInterceptor implements HandlerInterceptor {
@Autowired
private Cache cache;

private static List<String> EXCLUDES = Arrays.asList("/api/component/list", "/api/monitor/queryJvm", "/api/monitor/query", "/api/monitor/queryModel");
private static final List<String> EXCLUDES = Arrays.asList("/api/component/list", "/api/monitor/queryJvm", "/api/monitor/query", "/api/monitor/queryModel");

@Override
public boolean preHandle(HttpServletRequest request, HttpServletResponse response, Object handler) throws Exception {
Original file line number Diff line number Diff line change
@@ -44,8 +44,7 @@ protected resp doService(Context context, InboundPackage<req> data, OutboundPack
} catch (Throwable e) {
e.printStackTrace();
if (e.getCause() != null && e.getCause() instanceof BaseException) {
BaseException baseException = (BaseException) e.getCause();
throw baseException;
throw (BaseException) e.getCause();
} else if (e instanceof InvocationTargetException) {
InvocationTargetException ex = (InvocationTargetException) e;
throw new SysException(ex.getTargetException().getMessage());
@@ -55,9 +54,4 @@ protected resp doService(Context context, InboundPackage<req> data, OutboundPack
}
return result;
}

@Override
protected void printFlowLog(Context context) {

}
}
Original file line number Diff line number Diff line change
@@ -34,9 +34,8 @@
@Service
public class ComponentService {

private final static String PATH_SEPARATOR = "/";
private final static String DEFAULT_COMPONENT_ROOT = "FATE-COMPONENTS";
private final static String PROVIDER = "providers";
private static final String PATH_SEPARATOR = "/";
private static final String DEFAULT_COMPONENT_ROOT = "FATE-COMPONENTS";
Logger logger = LoggerFactory.getLogger(ComponentService.class);
@Autowired
ZookeeperRegistry zookeeperRegistry;
@@ -88,18 +87,12 @@ public String getProject(String host, int port) {
}
return false;
}).map(Map.Entry::getKey).findFirst();
if(project.isPresent())
return project.get();
else
return "";
return project.orElse("");
}

public boolean isAllowAccess(String host, int port) {
Set<String> whitelist = getWhitelist();
if (whitelist != null && whitelist.contains(host + ":" + port)) {
return true;
}
return false;
return whitelist != null && whitelist.contains(host + ":" + port);
}

@Scheduled(cron = "0/5 * * * * ?")
Original file line number Diff line number Diff line change
@@ -37,7 +37,6 @@ public class FateServiceRegister implements ServiceRegister, ApplicationContextA
Logger logger = LoggerFactory.getLogger(FateServiceRegister.class);
Map<String, ServiceAdaptor> serviceAdaptorMap = new HashMap<String, ServiceAdaptor>();
ApplicationContext applicationContext;
GrpcConnectionPool grpcConnectionPool = GrpcConnectionPool.getPool();

@Override
public ServiceAdaptor getServiceAdaptor(String name) {
Original file line number Diff line number Diff line change
@@ -6,6 +6,7 @@
import com.webank.ai.fate.api.networking.common.CommonServiceProto;
import com.webank.ai.fate.register.url.URL;
import com.webank.ai.fate.register.zookeeper.ZookeeperRegistry;
import com.webank.ai.fate.serving.admin.utils.NetAddressChecker;
import com.webank.ai.fate.serving.common.health.HealthCheckRecord;
import com.webank.ai.fate.serving.common.health.HealthCheckResult;
import com.webank.ai.fate.serving.common.health.HealthCheckStatus;
@@ -47,18 +48,17 @@ public class HealthCheckService implements InitializingBean {
GrpcConnectionPool grpcConnectionPool = GrpcConnectionPool.getPool();
Map<String,Object> healthRecord = new ConcurrentHashMap<>();

private static ThreadPoolExecutor executor = ThreadPoolUtil.newThreadPoolExecutor();

public Map getHealthCheckInfo(){
return healthRecord;
}

private void checkRemoteHealth(Map<String,Map> componentMap, String address, String component) {
if(StringUtils.isBlank(address))
if (StringUtils.isBlank(address)) {
return ;
}
Map<String,Map> currentComponentMap = componentMap.get(component);
String host = address.substring(0,address.indexOf(":"));
int port = Integer.parseInt(address.substring(address.indexOf(":") + 1));
String host = address.substring(0,address.indexOf(':'));
int port = Integer.parseInt(address.substring(address.indexOf(':') + 1));
CommonServiceGrpc.CommonServiceBlockingStub blockingStub = getMonitorServiceBlockStub(host, port);
CommonServiceProto.HealthCheckRequest.Builder builder = CommonServiceProto.HealthCheckRequest.newBuilder();
CommonServiceProto.CommonResponse commonResponse = blockingStub.checkHealthService(builder.build());
@@ -141,17 +141,10 @@ private CommonServiceGrpc.CommonServiceBlockingStub getMonitorServiceBlockStub(S
Preconditions.checkArgument(StringUtils.isNotBlank(host), "parameter host is blank");
Preconditions.checkArgument(port != 0, "parameter port was wrong");

if (!NetUtils.isValidAddress(host + ":" + port)) {
throw new SysException("invalid address");
}

if (!componentService.isAllowAccess(host, port)) {
throw new RemoteRpcException("no allow access, target: " + host + ":" + port);
}
NetAddressChecker.check(host, port);

ManagedChannel managedChannel = grpcConnectionPool.getManagedChannel(host, port);
CommonServiceGrpc.CommonServiceBlockingStub blockingStub = CommonServiceGrpc.newBlockingStub(managedChannel);
return blockingStub;
return CommonServiceGrpc.newBlockingStub(managedChannel);
}
@Override
public void afterPropertiesSet() throws Exception {
Original file line number Diff line number Diff line change
@@ -26,6 +26,7 @@
import com.webank.ai.fate.serving.admin.controller.ValidateController;
import com.webank.ai.fate.serving.admin.services.AbstractAdminServiceProvider;
import com.webank.ai.fate.serving.admin.services.ComponentService;
import com.webank.ai.fate.serving.admin.utils.NetAddressChecker;
import com.webank.ai.fate.serving.common.rpc.core.FateService;
import com.webank.ai.fate.serving.common.rpc.core.FateServiceMethod;
import com.webank.ai.fate.serving.common.rpc.core.InboundPackage;
@@ -42,7 +43,10 @@
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;
import org.springframework.web.context.request.RequestContextHolder;
import org.springframework.web.context.request.ServletRequestAttributes;

import javax.servlet.http.HttpServletRequest;
import java.nio.charset.Charset;
import java.util.HashMap;
import java.util.List;
@@ -99,6 +103,10 @@ public Object publishBind(Context context, InboundPackage data) throws Exception

@FateServiceMethod(name = "inference")
public Object inference(Context context, InboundPackage data) throws Exception {
ServletRequestAttributes attributes = (ServletRequestAttributes) RequestContextHolder.getRequestAttributes();
HttpServletRequest request = attributes.getRequest();
String caseId = request.getHeader("caseId");

Map params = (Map) data.getBody();
String host = (String) params.get(Dict.HOST);
int port = (int) params.get(Dict.PORT);
@@ -123,6 +131,10 @@ public Object inference(Context context, InboundPackage data) throws Exception {
inferenceRequest.setApplyId(params.get("applyId").toString());
}

if(caseId != null && !caseId.isEmpty()) {
inferenceRequest.setCaseId(caseId);
}

for (Map.Entry<String, Object> entry : featureData.entrySet()) {
inferenceRequest.getFeatureData().put(entry.getKey(), entry.getValue());
}
@@ -138,8 +150,7 @@ public Object inference(Context context, InboundPackage data) throws Exception {
ListenableFuture<InferenceServiceProto.InferenceMessage> future = inferenceServiceFutureStub.inference(builder.build());
InferenceServiceProto.InferenceMessage response = future.get(MetaInfo.PROPERTY_GRPC_TIMEOUT * 2, TimeUnit.MILLISECONDS);

Map returnResult = JsonUtil.json2Object(response.getBody().toStringUtf8(), Map.class);
return returnResult;
return JsonUtil.json2Object(response.getBody().toStringUtf8(), Map.class);
}

@FateServiceMethod(name = "batchInference")
@@ -190,8 +201,7 @@ public Object batchInference(Context context, InboundPackage data) throws Except
ListenableFuture<InferenceServiceProto.InferenceMessage> future = inferenceServiceFutureStub.batchInference(builder.build());
InferenceServiceProto.InferenceMessage response = future.get(MetaInfo.PROPERTY_GRPC_TIMEOUT * 2, TimeUnit.MILLISECONDS);

Map returnResult = JsonUtil.json2Object(response.getBody().toStringUtf8(), Map.class);
return returnResult;
return JsonUtil.json2Object(response.getBody().toStringUtf8(), Map.class);
}

/*private ModelServiceProto.PublishRequest buildPublishRequest(Map params) {
@@ -254,7 +264,6 @@ public Object batchInference(Context context, InboundPackage data) throws Except

@Override
protected Object transformExceptionInfo(Context context, ExceptionInfo data) {
String actionType = context.getActionType();
Map returnResult = new HashMap();
if (data != null) {
int code = data.getCode();
@@ -267,13 +276,7 @@ protected Object transformExceptionInfo(Context context, ExceptionInfo data) {
}

private ModelServiceGrpc.ModelServiceBlockingStub getModelServiceBlockingStub(String host, Integer port) throws Exception {
if (!NetUtils.isValidAddress(host + ":" + port)) {
throw new SysException("invalid address");
}

if (!componentService.isAllowAccess(host, port)) {
throw new RemoteRpcException("no allow access, target: " + host + ":" + port);
}
NetAddressChecker.check(host, port);

ManagedChannel managedChannel = grpcConnectionPool.getManagedChannel(host, port);
ModelServiceGrpc.ModelServiceBlockingStub blockingStub = ModelServiceGrpc.newBlockingStub(managedChannel);
@@ -282,30 +285,16 @@ private ModelServiceGrpc.ModelServiceBlockingStub getModelServiceBlockingStub(St
}

private ModelServiceGrpc.ModelServiceFutureStub getModelServiceFutureStub(String host, Integer port) throws Exception {
if (!NetUtils.isValidAddress(host + ":" + port)) {
throw new SysException("invalid address");
}

if (!componentService.isAllowAccess(host, port)) {
throw new RemoteRpcException("no allow access, target: " + host + ":" + port);
}
NetAddressChecker.check(host, port);

ManagedChannel managedChannel = grpcConnectionPool.getManagedChannel(host, port);
ModelServiceGrpc.ModelServiceFutureStub futureStub = ModelServiceGrpc.newFutureStub(managedChannel);
return futureStub;
return ModelServiceGrpc.newFutureStub(managedChannel);
}

private InferenceServiceGrpc.InferenceServiceFutureStub getInferenceServiceFutureStub(String host, Integer port) throws Exception {
if (!NetUtils.isValidAddress(host + ":" + port)) {
throw new SysException("invalid address");
}

if (!componentService.isAllowAccess(host, port)) {
throw new RemoteRpcException("no allow access, target: " + host + ":" + port);
}
NetAddressChecker.check(host, port);

ManagedChannel managedChannel = grpcConnectionPool.getManagedChannel(host, port);
InferenceServiceGrpc.InferenceServiceFutureStub futureStub = InferenceServiceGrpc.newFutureStub(managedChannel);
return futureStub;
return InferenceServiceGrpc.newFutureStub(managedChannel);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
package com.webank.ai.fate.serving.admin.utils;

import com.webank.ai.fate.serving.admin.services.ComponentService;
import com.webank.ai.fate.serving.core.exceptions.RemoteRpcException;
import com.webank.ai.fate.serving.core.exceptions.SysException;
import com.webank.ai.fate.serving.core.utils.NetUtils;

/**
* @author hcy
*/
public class NetAddressChecker {

private static final ComponentService componentService = new ComponentService();

public static void check(String host, Integer port) {
if (!NetUtils.isValidAddress(host + ":" + port)) {
throw new SysException("invalid address");
}

if (!componentService.isAllowAccess(host, port)) {
throw new RemoteRpcException("no allow access, target: " + host + ":" + port);
}
}
}
2 changes: 2 additions & 0 deletions fate-serving-admin/src/main/resources/application.properties
Original file line number Diff line number Diff line change
@@ -28,5 +28,7 @@ zk.url=localhost:2181,localhost:2182,localhost:2183
# username & password
admin.username=admin
admin.password=admin
# 登录密码是否加密, 默认false关闭, 为true时请采用spring-security中BCryptPasswordEncoder进行提前加密处理
admin.isEncrypt=false

spring.mvc.pathmatch.matching-strategy=ANT_PATH_MATCHER
2 changes: 1 addition & 1 deletion fate-serving-common/pom.xml
Original file line number Diff line number Diff line change
@@ -57,7 +57,7 @@
<dependency>
<groupId>commons-net</groupId>
<artifactId>commons-net</artifactId>
<version>3.8.0</version>
<version>3.9.0</version>
</dependency>
<dependency>
<groupId>com.github.oshi</groupId>
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
package com.webank.ai.fate.serving.common.bean;

import java.io.Serializable;
import java.util.Objects;

/**
* @author hcy
*/
public class ThreadVO implements Serializable {
private static final long serialVersionUID = 0L;

private long id;
private String name;
private String group;
private int priority;
private Thread.State state;
private double cpu;
private long deltaTime;
private long time;
private boolean interrupted;
private boolean daemon;

public ThreadVO() {
}

public long getId() {
return id;
}

public void setId(long id) {
this.id = id;
}

public String getName() {
return name;
}

public void setName(String name) {
this.name = name;
}

public String getGroup() {
return group;
}

public void setGroup(String group) {
this.group = group;
}

public int getPriority() {
return priority;
}

public void setPriority(int priority) {
this.priority = priority;
}

public Thread.State getState() {
return state;
}

public void setState(Thread.State state) {
this.state = state;
}

public double getCpu() {
return cpu;
}

public void setCpu(double cpu) {
this.cpu = cpu;
}

public long getDeltaTime() {
return deltaTime;
}

public void setDeltaTime(long deltaTime) {
this.deltaTime = deltaTime;
}

public long getTime() {
return time;
}

public void setTime(long time) {
this.time = time;
}

public boolean isInterrupted() {
return interrupted;
}

public void setInterrupted(boolean interrupted) {
this.interrupted = interrupted;
}

public boolean isDaemon() {
return daemon;
}

public void setDaemon(boolean daemon) {
this.daemon = daemon;
}

@Override
public boolean equals(Object o) {
if (this == o) {
return true;
}

if (o == null || getClass() != o.getClass()) {
return false;
}

ThreadVO threadVO = (ThreadVO) o;

if (id != threadVO.id) {
return false;
}
return Objects.equals(name, threadVO.name);
}

@Override
public int hashCode() {
int result = (int) (id ^ (id >>> 32));
result = 31 * result + (name != null ? name.hashCode() : 0);
return result;
}

@Override
public String toString() {
return "ThreadVO{" +
"id=" + id +
", name='" + name + '\'' +
", group='" + group + '\'' +
", priority=" + priority +
", state=" + state +
", cpu=" + cpu +
", deltaTime=" + deltaTime +
", time=" + time +
", interrupted=" + interrupted +
", daemon=" + daemon +
'}';
}
}
Original file line number Diff line number Diff line change
@@ -32,7 +32,6 @@ public class ExpiringMap<K, V> extends LinkedHashMap<K, V> {
private static final float DEFAULT_LOAD_FACTOR = 0.75f;
private static final int DEFAULT_EXPIRATION_INTERVAL = 1;
private static AtomicInteger expireCount = new AtomicInteger(1);
private final Lock lock = new ReentrantLock();
private final ConcurrentHashMap<K, ExpiryObject> delegateMap;
private final ExpireThread expireThread;
private volatile int maxCapacity;
Original file line number Diff line number Diff line change
@@ -10,21 +10,20 @@ public enum HealthCheckItemEnum {
CHECK_FATEFLOW_IN_ZK("check fateflow in zookeeper",HealthCheckComponent.SERVINGSERVER),
CHECK_MODEL_LOADED("check model loaded",HealthCheckComponent.SERVINGSERVER);

// CHECK_MODEL_LOADED("check model loaded"),
// CHECK_MODEL_VALIDATE()

private String itemName;
private HealthCheckComponent component;

private String itemName;
private HealthCheckComponent component;
private HealthCheckItemEnum(String name,HealthCheckComponent healthCheckComponent ){
HealthCheckItemEnum(String name, HealthCheckComponent healthCheckComponent){
this.component = healthCheckComponent;
this.itemName= name;
}
public String getItemName(){
return itemName;
return itemName;
}

public HealthCheckComponent getComponent(){
return this.component;
return this.component;
}


Original file line number Diff line number Diff line change
@@ -40,6 +40,7 @@ public HealthCheckRecord(String checkItemName, String msg, HealthCheckStatus hea
this.healthCheckStatus = healthCheckStatus;
}

@Override
public String toString(){
return JsonUtil.object2Json(this);
}
Original file line number Diff line number Diff line change
@@ -15,6 +15,7 @@ public void setRecords(List<HealthCheckRecord> records) {

List<HealthCheckRecord> records = Lists.newArrayList();

@Override
public String toString(){
return JsonUtil.object2Json(this);
}
Original file line number Diff line number Diff line change
@@ -1,7 +1,28 @@
package com.webank.ai.fate.serving.common.health;

public enum HealthCheckStatus {
ok,
warn,
error
/**
* 健康
*/
ok("状态健康"),

/**
* 异常
*/
warn("状态异常"),

/**
* 错误
*/
error("状态错误");

private final String desc;

HealthCheckStatus(String desc) {
this.desc = desc;
}

public String getDesc() {
return desc;
}
}
Original file line number Diff line number Diff line change
@@ -59,9 +59,7 @@ public static String getPercentFormat(double d,int IntegerDigits,int FractionDig

nf.setMinimumFractionDigits(FractionDigits);// 小数点后保留几位

String str = nf.format(d);

return str;
return nf.format(d);

}

Original file line number Diff line number Diff line change
@@ -161,13 +161,14 @@ public int hashCode() {

@Override
public boolean equals(Object obj) {
if(obj!=null&&obj instanceof Model) {
if(obj instanceof Model) {
Model model = (Model) obj;
if(this.namespace!=null&&this.namespace!=null)
if(this.namespace != null && this.tableName != null) {
return this.namespace.equals(model.namespace) && this.tableName.equals(model.tableName);
else
} else {
return false;
}else {
}
} else {
return false;
}
}
Original file line number Diff line number Diff line change
@@ -104,7 +104,7 @@ private static HttpAdapterResponse getResponseByHeader(HttpRequestBase request)
int statusCode = response.getStatusLine().getStatusCode();
HttpAdapterResponse result = new HttpAdapterResponse();
result.setCode(statusCode);
result.setData(JsonUtil.json2Object(data,Map.class));
result.setData(JsonUtil.json2Object(data, Map.class));
return result;
} catch (IOException ex) {
logger.error("get http response failed:", ex);
Original file line number Diff line number Diff line change
@@ -62,9 +62,7 @@ private static void config(HttpRequestBase httpRequestBase, Map<String, String>
.setSocketTimeout(MetaInfo.HTTP_CLIENT_CONFIG_SOCK_TIME_OUT).build();
httpRequestBase.addHeader(Dict.CONTENT_TYPE, Dict.CONTENT_TYPE_JSON_UTF8);
if (headers != null) {
headers.forEach((key, value) -> {
httpRequestBase.addHeader(key, value);
});
headers.forEach(httpRequestBase::addHeader);
}
httpRequestBase.setConfig(requestConfig);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
package com.webank.ai.fate.serving.common.utils;

import com.webank.ai.fate.serving.common.bean.ThreadVO;

import java.util.*;

/**
* @author hcy
*/
public class JVMCPUUtils {

private static Set<String> states = null;

static {
states = new HashSet<>(Thread.State.values().length);
for (Thread.State state : Thread.State.values()) {
states.add(state.name());
}
}

public static List<ThreadVO> getThreadsState() {

List<ThreadVO> threads = ThreadUtils.getThreads();

Collection<ThreadVO> resultThreads = new ArrayList<>();
for (ThreadVO thread : threads) {
if (thread.getState() != null && states.contains(thread.getState().name())) {
resultThreads.add(thread);
}
}


ThreadSampler threadSampler = new ThreadSampler();
threadSampler.setIncludeInternalThreads(true);
threadSampler.sample(resultThreads);
threadSampler.pause(1000);
return threadSampler.sample(resultThreads);
}
}
Original file line number Diff line number Diff line change
@@ -13,9 +13,9 @@ public static boolean tryTelnet(String host ,int port){
isConnected = true;
telnetClient.disconnect();
} catch (Exception e) {
//e.printStackTrace();
throw new RuntimeException(e);
}
return isConnected;
return isConnected;
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,200 @@
package com.webank.ai.fate.serving.common.utils;


import com.webank.ai.fate.serving.common.bean.ThreadVO;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import sun.management.HotspotThreadMBean;
import sun.management.ManagementFactoryHelper;

import java.lang.management.ManagementFactory;
import java.lang.management.ThreadMXBean;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

/**
* @author hcy
*/

class ThreadSampler {
private static Logger logger = LoggerFactory.getLogger(ThreadSampler.class);
private static ThreadMXBean threadMXBean = ManagementFactory.getThreadMXBean();
private static HotspotThreadMBean hotspotThreadMBean;
private static boolean hotspotThreadMBeanEnable = true;

private Map<ThreadVO, Long> lastCpuTimes = new HashMap<ThreadVO, Long>();

private long lastSampleTimeNanos;
private boolean includeInternalThreads = true;


public List<ThreadVO> sample(Collection<ThreadVO> originThreads) {

List<ThreadVO> threads = new ArrayList<ThreadVO>(originThreads);

// Sample CPU
if (lastCpuTimes.isEmpty()) {
lastSampleTimeNanos = System.nanoTime();
for (ThreadVO thread : threads) {
if (thread.getId() > 0) {
long cpu = threadMXBean.getThreadCpuTime(thread.getId());
lastCpuTimes.put(thread, cpu);
thread.setTime(cpu / 1000000);
}
}

// add internal threads
Map<String, Long> internalThreadCpuTimes = getInternalThreadCpuTimes();
if (internalThreadCpuTimes != null) {
for (Map.Entry<String, Long> entry : internalThreadCpuTimes.entrySet()) {
String key = entry.getKey();
ThreadVO thread = createThreadVO(key);
thread.setTime(entry.getValue() / 1000000);
threads.add(thread);
lastCpuTimes.put(thread, entry.getValue());
}
}

//sort by time
Collections.sort(threads, new Comparator<ThreadVO>() {
@Override
public int compare(ThreadVO o1, ThreadVO o2) {
long l1 = o1.getTime();
long l2 = o2.getTime();
if (l1 < l2) {
return 1;
} else if (l1 > l2) {
return -1;
} else {
return 0;
}
}
});
return threads;
}

// Resample
long newSampleTimeNanos = System.nanoTime();
Map<ThreadVO, Long> newCpuTimes = new HashMap<ThreadVO, Long>(threads.size());
for (ThreadVO thread : threads) {
if (thread.getId() > 0) {
long cpu = threadMXBean.getThreadCpuTime(thread.getId());
newCpuTimes.put(thread, cpu);
}
}
// internal threads
Map<String, Long> newInternalThreadCpuTimes = getInternalThreadCpuTimes();
if (newInternalThreadCpuTimes != null) {
for (Map.Entry<String, Long> entry : newInternalThreadCpuTimes.entrySet()) {
ThreadVO threadVO = createThreadVO(entry.getKey());
threads.add(threadVO);
newCpuTimes.put(threadVO, entry.getValue());
}
}

// Compute delta time
final Map<ThreadVO, Long> deltas = new HashMap<ThreadVO, Long>(threads.size());
for (ThreadVO thread : newCpuTimes.keySet()) {
Long t = lastCpuTimes.get(thread);
if (t == null) {
t = 0L;
}
long time1 = t;
long time2 = newCpuTimes.get(thread);
if (time1 == -1) {
time1 = time2;
} else if (time2 == -1) {
time2 = time1;
}
long delta = time2 - time1;
deltas.put(thread, delta);
}

long sampleIntervalNanos = newSampleTimeNanos - lastSampleTimeNanos;

// Compute cpu usage
final HashMap<ThreadVO, Double> cpuUsages = new HashMap<ThreadVO, Double>(threads.size());
for (ThreadVO thread : threads) {
double cpu = sampleIntervalNanos == 0 ? 0 : (Math.rint(deltas.get(thread) * 10000.0 / sampleIntervalNanos) / 100.0);
cpuUsages.put(thread, cpu);
}

// Sort by CPU time : should be a rendering hint...
Collections.sort(threads, new Comparator<ThreadVO>() {
@Override
public int compare(ThreadVO o1, ThreadVO o2) {
long l1 = deltas.get(o1);
long l2 = deltas.get(o2);
if (l1 < l2) {
return 1;
} else if (l1 > l2) {
return -1;
} else {
return 0;
}
}
});

for (ThreadVO thread : threads) {
//nanos to mills
long timeMills = newCpuTimes.get(thread) / 1000000;
long deltaTime = deltas.get(thread) / 1000000;
double cpu = cpuUsages.get(thread);

thread.setCpu(cpu);
thread.setTime(timeMills);
thread.setDeltaTime(deltaTime);
}
lastCpuTimes = newCpuTimes;
lastSampleTimeNanos = newSampleTimeNanos;

return threads;
}

private Map<String, Long> getInternalThreadCpuTimes() {
if (hotspotThreadMBeanEnable && includeInternalThreads) {
try {
if (hotspotThreadMBean == null) {
hotspotThreadMBean = ManagementFactoryHelper.getHotspotThreadMBean();
}
return hotspotThreadMBean.getInternalThreadCpuTimes();
} catch (Exception ex) {
logger.error("getInternalThreadCpuTimes failed Cause : " + ex);
hotspotThreadMBeanEnable = false;
}
}
return null;
}

private ThreadVO createThreadVO(String name) {
ThreadVO threadVO = new ThreadVO();
threadVO.setId(-1);
threadVO.setName(name);
threadVO.setPriority(-1);
threadVO.setDaemon(true);
threadVO.setInterrupted(false);
return threadVO;
}

public void pause(long mills) {
try {
Thread.sleep(mills);
} catch (InterruptedException e) {
logger.error("pause failed Cause : " + e);
}
}

public boolean isIncludeInternalThreads() {
return includeInternalThreads;
}

public void setIncludeInternalThreads(boolean includeInternalThreads) {
this.includeInternalThreads = includeInternalThreads;
}
}

Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
package com.webank.ai.fate.serving.common.utils;

import com.webank.ai.fate.serving.common.bean.ThreadVO;

import java.util.ArrayList;
import java.util.List;

/**
* @author hcy
*/
public class ThreadUtils {

private static ThreadGroup getRoot() {
ThreadGroup group = Thread.currentThread().getThreadGroup();
ThreadGroup parent;
while ((parent = group.getParent()) != null) {
group = parent;
}
return group;
}

public static List<ThreadVO> getThreads() {
ThreadGroup root = getRoot();
Thread[] threads = new Thread[root.activeCount()];
while (root.enumerate(threads, true) == threads.length) {
threads = new Thread[threads.length * 2];
}
List<ThreadVO> list = new ArrayList<ThreadVO>(threads.length);
for (Thread thread : threads) {
if (thread != null) {
ThreadVO threadVO = createThreadVO(thread);
list.add(threadVO);
}
}
return list;
}

private static ThreadVO createThreadVO(Thread thread) {
ThreadGroup group = thread.getThreadGroup();
ThreadVO threadVO = new ThreadVO();
threadVO.setId(thread.getId());
threadVO.setName(thread.getName());
threadVO.setGroup(group == null ? "" : group.getName());
threadVO.setPriority(thread.getPriority());
threadVO.setState(thread.getState());
threadVO.setInterrupted(thread.isInterrupted());
threadVO.setDaemon(thread.isDaemon());
return threadVO;
}
}
Original file line number Diff line number Diff line change
@@ -117,6 +117,7 @@ public class Dict {
public static final String PROPERTY_MODEL_SYNC = "model.synchronize";
public static final String PROPERTY_SERVING_MAX_POOL_SIZE = "serving.max.pool.size";
public static final String PROPERTY_FEATURE_BATCH_ADAPTOR = "feature.batch.adaptor";
public static final String PROPERTY_FEATURE_BATCH_SINGLE_ADAPTOR = "feature.batch.single.adaptor";
public static final String PROPERTY_ACL_ENABLE = "acl.enable";
public static final String PROPERTY_ACL_USERNAME = "acl.username";
public static final String PROPERTY_ACL_PASSWORD = "acl.password";
Original file line number Diff line number Diff line change
@@ -1,12 +1,18 @@
package com.webank.ai.fate.serving.core.bean;

/**
* @author hcy
*/
public class HttpAdapterResponseCodeEnum{

// 通用http响应成功码
public static final int COMMON_HTTP_SUCCESS_CODE = 0;

//正常
public final static int SUCCESS_CODE = 200;
public static final int SUCCESS_CODE = 200;

//查询无果
public final static int ERROR_CODE = 404;
public static final int ERROR_CODE = 404;

}

Original file line number Diff line number Diff line change
@@ -38,6 +38,7 @@ public class MetaInfo {
public static Boolean PROPERTY_USE_REGISTER;
public static Boolean PROPERTY_USE_ZK_ROUTER;
public static String PROPERTY_FEATURE_BATCH_ADAPTOR;
public static String PROPERTY_FETTURE_BATCH_SINGLE_ADAPTOR;
public static Integer PROPERTY_BATCH_INFERENCE_MAX;
public static String PROPERTY_FEATURE_SINGLE_ADAPTOR;
public static Integer PROPERTY_SINGLE_INFERENCE_RPC_TIMEOUT;
Original file line number Diff line number Diff line change
@@ -16,6 +16,7 @@

package com.webank.ai.fate.serving.adaptor.dataaccess;

import com.fasterxml.jackson.databind.ObjectMapper;
import com.webank.ai.fate.serving.common.utils.HttpAdapterClientPool;
import com.webank.ai.fate.serving.core.bean.*;
import com.webank.ai.fate.serving.core.constant.StatusCode;
@@ -30,6 +31,7 @@ public class HttpAdapter extends AbstractSingleFeatureDataAdaptor {

private final static String HTTP_ADAPTER_URL = MetaInfo.PROPERTY_HTTP_ADAPTER_URL;

private static final ObjectMapper objectMapper = new ObjectMapper();
@Override
public void init() {
environment.getProperty("port");
@@ -44,13 +46,18 @@ public ReturnResult getData(Context context, Map<String, Object> featureIds) {
responseResult = HttpAdapterClientPool.doPost(HTTP_ADAPTER_URL, featureIds);
int responseCode = responseResult.getCode();
switch (responseCode) {
case HttpAdapterResponseCodeEnum.SUCCESS_CODE:
if (responseResult.getData() == null || responseResult.getData().size() == 0) {
case HttpAdapterResponseCodeEnum.COMMON_HTTP_SUCCESS_CODE:
Map<String, Object> responseResultData = responseResult.getData();
if (responseResultData == null || responseResultData.size() == 0) {
returnResult.setRetcode(StatusCode.FEATURE_DATA_ADAPTOR_ERROR);
returnResult.setRetmsg("responseData is null ");
} else if (!responseResultData.get("code").equals(HttpAdapterResponseCodeEnum.SUCCESS_CODE)) {
returnResult.setRetcode(StatusCode.FEATURE_DATA_ADAPTOR_ERROR);
returnResult.setRetmsg("responseData is : " + objectMapper.writeValueAsString(responseResultData.get("data")));
} else {
((Map<String, Object>)responseResultData.get("data")).remove("code");
returnResult.setRetcode(StatusCode.SUCCESS);
returnResult.setData(responseResult.getData());
returnResult.setData(responseResultData);
}
break;

Original file line number Diff line number Diff line change
@@ -19,30 +19,71 @@
import com.google.common.util.concurrent.ListeningExecutorService;
import com.google.common.util.concurrent.MoreExecutors;
import com.webank.ai.fate.serving.core.adaptor.SingleFeatureDataAdaptor;
import com.webank.ai.fate.serving.core.bean.BatchHostFeatureAdaptorResult;
import com.webank.ai.fate.serving.core.bean.BatchHostFederatedParams;
import com.webank.ai.fate.serving.core.bean.Context;
import com.webank.ai.fate.serving.core.bean.ReturnResult;
import com.webank.ai.fate.serving.core.bean.*;
import com.webank.ai.fate.serving.core.utils.InferenceUtils;
import org.apache.commons.lang3.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.List;
import java.util.concurrent.ArrayBlockingQueue;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.*;

/**
* 许多host方并未提供批量查询接口,这个类将批量请求拆分成单笔请求发送,再合结果
*/
public class ParallelBatchToSingleFeatureAdaptor extends AbstractBatchFeatureDataAdaptor {

private static final Logger logger = LoggerFactory.getLogger(HttpAdapter.class);

int timeout;

SingleFeatureDataAdaptor singleFeatureDataAdaptor;

ListeningExecutorService listeningExecutorService = MoreExecutors.listeningDecorator(null);
ListeningExecutorService listeningExecutorService;

// 自定义Adapter初始化
public ParallelBatchToSingleFeatureAdaptor(int core, int max, int timeout) {
initExecutor(core, max, timeout);
}

// 默认Adapter初始化
public ParallelBatchToSingleFeatureAdaptor() {

// 默认线程池核心线程10
int defaultCore = 10;

public ParallelBatchToSingleFeatureAdaptor(int core, int max) {
new ThreadPoolExecutor(core, max, 1000, TimeUnit.MILLISECONDS, new ArrayBlockingQueue<>(1000), new ThreadPoolExecutor.AbortPolicy());
// 默认线程池最大线程 100
int defaultMax = 100;

// 默认countDownLatch超时时间永远比grpc超时时间小
timeout = MetaInfo.PROPERTY_GRPC_TIMEOUT.intValue() - 1;

initExecutor(defaultCore, defaultMax, timeout);
}


private void initExecutor(int core, int max, int timeout) {
SingleFeatureDataAdaptor singleFeatureDataAdaptor = null;
String adaptorClass = MetaInfo.PROPERTY_FETTURE_BATCH_SINGLE_ADAPTOR;
if (StringUtils.isNotEmpty(adaptorClass)) {
logger.info("try to load single adaptor for ParallelBatchToSingleFeatureAdaptor {}", adaptorClass);
singleFeatureDataAdaptor = (SingleFeatureDataAdaptor) InferenceUtils.getClassByName(adaptorClass);
}

if (singleFeatureDataAdaptor != null) {
String implementationClass = singleFeatureDataAdaptor.getClass().getName();
logger.info("SingleFeatureDataAdaptor implementation class: " + implementationClass);
} else {
logger.warn("SingleFeatureDataAdaptor is null.");
}

this.singleFeatureDataAdaptor = singleFeatureDataAdaptor;

this.timeout = timeout;

ThreadPoolExecutor threadPoolExecutor = new ThreadPoolExecutor(core, max, 1000, TimeUnit.MILLISECONDS, new ArrayBlockingQueue<>(1000), new ThreadPoolExecutor.AbortPolicy());

listeningExecutorService = MoreExecutors.listeningDecorator(threadPoolExecutor);
}

@Override
@@ -56,37 +97,47 @@ public BatchHostFeatureAdaptorResult getFeatures(Context context, List<BatchHost
CountDownLatch countDownLatch = new CountDownLatch(featureIdList.size());
for (int i = 0; i < featureIdList.size(); i++) {
BatchHostFederatedParams.SingleInferenceData singleInferenceData = featureIdList.get(i);
// TODO: 2020/3/4 这里需要加 线程池满后的异常处理
this.listeningExecutorService.submit(new Runnable() {
@Override
public void run() {
try {
Integer index = singleInferenceData.getIndex();
ReturnResult returnResult = singleFeatureDataAdaptor.getData(context, singleInferenceData.getFeatureData());
BatchHostFeatureAdaptorResult.SingleBatchHostFeatureAdaptorResult adaptorResult = new BatchHostFeatureAdaptorResult.SingleBatchHostFeatureAdaptorResult();
adaptorResult.setFeatures(returnResult.getData());
result.getIndexResultMap().put(index, adaptorResult);
} finally {
countDownLatch.countDown();
try {
this.listeningExecutorService.submit(new Runnable() {
@Override
public void run() {
try {
Integer index = singleInferenceData.getIndex();
ReturnResult returnResult = singleFeatureDataAdaptor.getData(context, singleInferenceData.getSendToRemoteFeatureData());
BatchHostFeatureAdaptorResult.SingleBatchHostFeatureAdaptorResult adaptorResult = new BatchHostFeatureAdaptorResult.SingleBatchHostFeatureAdaptorResult();
adaptorResult.setFeatures(returnResult.getData());
result.getIndexResultMap().put(index, adaptorResult);
} finally {
countDownLatch.countDown();
}
}
});
} catch (RejectedExecutionException ree) {

// 处理线程池满后的异常, 等待2s后重新提交
logger.error("The thread pool has exceeded the maximum capacity, sleep 3s and submit again : " + ree.getMessage(), ree);
try {
Thread.sleep(2000);
} catch (InterruptedException e) {
logger.error("Interrupt during thread sleep");
}
});
this.listeningExecutorService.submit((Runnable) this);
}
}

/**
* 这里的超时时间需要设置比rpc超时时间短,否则没有意义
*/
try {
countDownLatch.await(timeout, TimeUnit.MILLISECONDS);
} catch (InterruptedException e) {
e.printStackTrace();
logger.error("Interrupt during countDownLatch thread await", e);
}

/**
* 如果等待超时也需要把已经返回的查询结果返回
*/
return result;

}

@Override
Original file line number Diff line number Diff line change
@@ -88,12 +88,10 @@ public BatchInferenceResult guestBatchInference(Context context, BatchInferenceR
throw new RemoteRpcException(transformRemoteErrorCode(remoteInferenceResult.getRetcode()), buildRemoteRpcErrorMsg(remoteInferenceResult.getRetcode(), remoteInferenceResult.getRetmsg()));
}
remoteResultMap.put(partyId, remoteInferenceResult);
} catch (RemoteRpcException e) {
throw e;
} catch (Exception e) {
if (!(e instanceof RemoteRpcException)) {
throw new RemoteRpcException("party id " + partyId + " remote error");
} else {
throw (RemoteRpcException) e;
}
throw new RemoteRpcException("party id " + partyId + " remote error");
} finally {
context.setDownstreamCost(System.currentTimeMillis() - context.getDownstreamBegin());
}
Original file line number Diff line number Diff line change
@@ -73,7 +73,7 @@ public int initModel(byte[] protoMeta, byte[] protoParam) {

protected String getSite(int treeId, int treeNodeId) {
String siteName = this.trees.get(treeId).getTree(treeNodeId).getSitename();
if(siteName!=null&&":".indexOf(siteName)!=0){
if(siteName != null && siteName.contains(":")){
return siteName.split(":")[1];
}else{
return siteName;
2 changes: 1 addition & 1 deletion fate-serving-proxy/bin/service.sh
Original file line number Diff line number Diff line change
@@ -24,7 +24,7 @@ basepath=$(cd `dirname $0`;pwd)
configpath=$(cd $basepath/conf;pwd)
module=serving-proxy
main_class=com.webank.ai.fate.serving.proxy.bootstrap.Bootstrap
module_version=2.1.6
module_version=2.1.7


case "$1" in
2 changes: 1 addition & 1 deletion fate-serving-proxy/pom.xml
Original file line number Diff line number Diff line change
@@ -88,7 +88,7 @@
<dependency>
<groupId>commons-net</groupId>
<artifactId>commons-net</artifactId>
<version>3.8.0</version>
<version>3.9.0</version>
</dependency>

</dependencies>
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
package com.webank.ai.fate.serving.proxy.controller;

import org.apache.logging.log4j.Level;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.core.LoggerContext;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.web.bind.annotation.*;

/**
* @author hcy
*/
@RequestMapping("/proxy")
@RestController
public class DynamicLogController {
private static final Logger logger = LoggerFactory.getLogger(DynamicLogController.class);

@GetMapping("/alterSysLogLevel/{level}")
public String alterSysLogLevel(@PathVariable String level){
try {
LoggerContext context = (LoggerContext) LogManager.getContext(false);
context.getLogger("ROOT").setLevel(Level.valueOf(level));
return "ok";
} catch (Exception ex) {
logger.error("proxy alterSysLogLevel failed : " + ex);
return "failed";
}

}

@GetMapping("/alterPkgLogLevel")
public String alterPkgLogLevel(@RequestParam String level, @RequestParam String pkgName){
try {
LoggerContext context = (LoggerContext) LogManager.getContext(false);
context.getLogger(pkgName).setLevel(Level.valueOf(level));
return "ok";
} catch (Exception ex) {
logger.error("proxy alterPkgLogLevel failed : " + ex);
return "failed";
}
}
}
Original file line number Diff line number Diff line change
@@ -105,21 +105,29 @@ public String call() throws Exception {
if (logger.isDebugEnabled()) {
logger.debug("receive : {} headers {}", data, headers.toSingleValueMap());
}

String caseId = headers.getFirst("caseId");
final ServiceAdaptor serviceAdaptor = proxyServiceRegister.getServiceAdaptor(callName);

Context context = new BaseContext();
context.setCallName(callName);
context.setVersion(version);
context.setCaseId(caseId);
if (null == context.getCaseId() || context.getCaseId().isEmpty()) {
context.setCaseId(UUID.randomUUID().toString().replaceAll("-", ""));
}

InboundPackage<Map> inboundPackage = buildInboundPackageFederation(context, data, httpServletRequest);
OutboundPackage<Map> result = serviceAdaptor.service(context, inboundPackage);

if (result != null && result.getData() != null) {
result.getData().remove("log");
result.getData().remove("warn");
result.getData().remove("caseid");
return JsonUtil.object2Json(result.getData());
}
return "";


return "";
}
};
}
@@ -134,10 +142,6 @@ private InboundPackage<Map> buildInboundPackageFederation(Context context, Strin
Map head = (Map) jsonObject.getOrDefault(Dict.HEAD, new HashMap<>());
Map body = (Map) jsonObject.getOrDefault(Dict.BODY, new HashMap<>());
context.setHostAppid((String) head.getOrDefault(Dict.APP_ID, ""));
context.setCaseId((String) head.getOrDefault(Dict.CASE_ID, ""));
if (null == context.getCaseId() || context.getCaseId().isEmpty()) {
context.setCaseId(UUID.randomUUID().toString().replaceAll("-", ""));
}

InboundPackage<Map> inboundPackage = new InboundPackage<Map>();
inboundPackage.setBody(body);
Original file line number Diff line number Diff line change
@@ -117,24 +117,22 @@ public ZookeeperClient getZkClient() {

@Override
public void doRegisterComponent(URL url) {
if(url==null) {
if(url == null) {
String hostAddress = NetUtils.getLocalIp();
String path = PATH_SEPARATOR + DEFAULT_COMPONENT_ROOT + PATH_SEPARATOR + project + PATH_SEPARATOR + hostAddress + ":" + port;
url = new URL(path, Maps.newHashMap());
//url=url.addParameter(Constants.INSTANCE_ID, AbstractRegistry.INSTANCE_ID);
}

if(url !=null) {
String path = url.getPath();
Map content = new HashMap();
content.put(Constants.INSTANCE_ID, AbstractRegistry.INSTANCE_ID);
content.put(Constants.TIMESTAMP_KEY, System.currentTimeMillis());
content.put(Dict.VERSION, MetaInfo.CURRENT_VERSION);
this.zkClient.create(path, JsonUtil.object2Json(content), true);
this.componentUrl = url;
String path = url.getPath();
Map content = new HashMap();
content.put(Constants.INSTANCE_ID, AbstractRegistry.INSTANCE_ID);
content.put(Constants.TIMESTAMP_KEY, System.currentTimeMillis());
content.put(Dict.VERSION, MetaInfo.CURRENT_VERSION);
this.zkClient.create(path, JsonUtil.object2Json(content), true);
this.componentUrl = url;

logger.info("register component {} ", path);
}
logger.info("register component {} ", path);
}


@@ -155,7 +153,6 @@ public void unRegisterComponent() {
}else{
System.err.println("componentUrl is null");
}

}


@@ -165,14 +162,15 @@ public boolean tryUnregister(URL url) {
boolean exists = client.checkExists(toUrlPath(url));
String urlPath = toUrlPath(url);
if (exists) {
System.err.println("delete zk path "+urlPath);
if (logger.isDebugEnabled()) {
logger.debug("delete zk path " + urlPath);
}
zkClient.delete(toUrlPath(url));
registedString.remove(url.getServiceInterface() + url.getEnvironment());
syncServiceCacheFile();
return true;
}
else{
System.err.println(urlPath +"is not exist");
} else {
logger.error(urlPath + " is not exist");
}
} catch (Throwable e) {
throw new RuntimeException("Failed to unregister " + url + " to zookeeper " + getUrl() + ", cause: " + e.getMessage(), e);
@@ -253,7 +251,6 @@ private String parseRegisterService(String serviceName, RegisterService register
param = param + "&";
param = param + Constants.TIMESTAMP_KEY + "=" + System.currentTimeMillis();
String key = serviceName;
boolean appendParam = false;
if (version != 0) {
param = param + "&" + Constants.VERSION + "=" + version;
}
@@ -271,14 +268,10 @@ private void loadCacheParams(URL url) {
if (serviceWrapper.getWeight() != null) {
parameters.put(Constants.WEIGHT_KEY, String.valueOf(serviceWrapper.getWeight()));
}
// if (serviceWrapper.getVersion() != null) {
// parameters.put(Constants.VERSION_KEY, String.valueOf(serviceWrapper.getVersion()));
// }
} else {
serviceWrapper = new ServiceWrapper();
serviceWrapper.setRouterMode(parameters.get(Constants.ROUTER_MODE));
serviceWrapper.setWeight(parameters.get(Constants.WEIGHT_KEY) != null ? Integer.valueOf(parameters.get(Constants.WEIGHT_KEY)) : null);
//serviceWrapper.setVersion(parameters.get(Constants.VERSION_KEY) != null ? Long.valueOf(parameters.get(Constants.VERSION_KEY)) : null);
this.getServiceCacheMap().put(url.getEnvironment() + "/" + url.getPath(), serviceWrapper);
}
}
@@ -391,8 +384,7 @@ private URL generateUrl(String hostAddress, RegisterService service) {
throw new RuntimeException("Failed to register service:"+ service +" ,acquire serviceName is blank");
}

URL url = URL.valueOf(protocol + "://" + hostAddress + ":" + hostPort + Constants.PATH_SEPARATOR + parseRegisterService(serviceName, service));
return url;
return URL.valueOf(protocol + "://" + hostAddress + ":" + hostPort + Constants.PATH_SEPARATOR + parseRegisterService(serviceName, service));
}

public void addDynamicEnvironment(String environment) {
@@ -580,8 +572,7 @@ private String toServicePath(URL url) {
return toRootPath();
}

String result = toRootDir() + project + Constants.PATH_SEPARATOR + environment + Constants.PATH_SEPARATOR + URL.encode(name);
return result;
return toRootDir() + project + Constants.PATH_SEPARATOR + environment + Constants.PATH_SEPARATOR + URL.encode(name);
}

private String[] toCategoriesPath(URL url) {
@@ -632,7 +623,7 @@ private List<URL> toUrlsWithoutEmpty(URL consumer, List<String> providers) {

private List<URL> toUrlsWithEmpty(URL consumer, String path, List<String> providers) {
List<URL> urls = toUrlsWithoutEmpty(consumer, providers);
if (urls == null || urls.isEmpty()) {
if (urls.isEmpty()) {
int i = path.lastIndexOf(PATH_SEPARATOR);
String category = i < 0 ? path : path.substring(i + 1);
URL empty = URLBuilder.from(consumer)
2 changes: 1 addition & 1 deletion fate-serving-server/bin/service.sh
Original file line number Diff line number Diff line change
@@ -24,7 +24,7 @@ basepath=$(cd `dirname $0`;pwd)
configpath=$(cd $basepath/conf;pwd)
module=serving-server
main_class=com.webank.ai.fate.serving.Bootstrap
module_version=2.1.6
module_version=2.1.7


case "$1" in
12 changes: 11 additions & 1 deletion fate-serving-server/pom.xml
Original file line number Diff line number Diff line change
@@ -94,14 +94,24 @@
<dependency>
<groupId>commons-net</groupId>
<artifactId>commons-net</artifactId>
<version>3.8.0</version>
<version>3.9.0</version>
</dependency>

<dependency>
<groupId>com.googlecode.protobuf-java-format</groupId>
<artifactId>protobuf-java-format</artifactId>
<version>1.2</version>
</dependency>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-web</artifactId>
<exclusions>
<exclusion>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-logging</artifactId>
</exclusion>
</exclusions>
</dependency>
</dependencies>

<build>
Original file line number Diff line number Diff line change
@@ -77,6 +77,7 @@ public static void parseConfig() {
MetaInfo.PROPERTY_SERVING_POOL_ALIVE_TIME = environment.getProperty(Dict.PROPERTY_SERVING_POOL_ALIVE_TIME) != null ? Integer.valueOf(environment.getProperty(Dict.PROPERTY_SERVING_POOL_ALIVE_TIME)) : 1000;
MetaInfo.PROPERTY_SERVING_POOL_QUEUE_SIZE = environment.getProperty(Dict.PROPERTY_SERVING_POOL_QUEUE_SIZE) != null ? Integer.valueOf(environment.getProperty(Dict.PROPERTY_SERVING_POOL_QUEUE_SIZE)) : 100;
MetaInfo.PROPERTY_FEATURE_BATCH_ADAPTOR = environment.getProperty(Dict.PROPERTY_FEATURE_BATCH_ADAPTOR);
MetaInfo.PROPERTY_FETTURE_BATCH_SINGLE_ADAPTOR = environment.getProperty(Dict.PROPERTY_FEATURE_BATCH_SINGLE_ADAPTOR);
MetaInfo.PROPERTY_BATCH_INFERENCE_MAX = environment.getProperty(Dict.PROPERTY_BATCH_INFERENCE_MAX) != null ? Integer.valueOf(environment.getProperty(Dict.PROPERTY_BATCH_INFERENCE_MAX)) : 300;
MetaInfo.PROPERTY_REMOTE_MODEL_INFERENCE_RESULT_CACHE_SWITCH = environment.getProperty(Dict.PROPERTY_REMOTE_MODEL_INFERENCE_RESULT_CACHE_SWITCH) != null ? Boolean.valueOf(environment.getProperty(Dict.PROPERTY_REMOTE_MODEL_INFERENCE_RESULT_CACHE_SWITCH)) : Boolean.FALSE;
MetaInfo.PROPERTY_SINGLE_INFERENCE_RPC_TIMEOUT = environment.getProperty(Dict.PROPERTY_SINGLE_INFERENCE_RPC_TIMEOUT) != null ? Integer.valueOf(environment.getProperty(Dict.PROPERTY_SINGLE_INFERENCE_RPC_TIMEOUT)) : 3000;
@@ -105,7 +106,6 @@ public static void parseConfig() {
MetaInfo.PROPERTY_MODEL_CACHE_PATH = StringUtils.isNotBlank(environment.getProperty(Dict.PROPERTY_MODEL_CACHE_PATH)) ? environment.getProperty(Dict.PROPERTY_MODEL_CACHE_PATH) : MetaInfo.PROPERTY_ROOT_PATH;
MetaInfo.PROPERTY_ACL_ENABLE = Boolean.valueOf(environment.getProperty(Dict.PROPERTY_ACL_ENABLE, "false"));
MetaInfo.PROPERTY_MODEL_SYNC = Boolean.valueOf(environment.getProperty(Dict.PROPERTY_MODEL_SYNC, "false"));
MetaInfo.PROPERTY_MODEL_SYNC = Boolean.valueOf(environment.getProperty(Dict.PROPERTY_MODEL_SYNC, "false"));
MetaInfo.PROPERTY_GRPC_TIMEOUT = Integer.valueOf(environment.getProperty(Dict.PROPERTY_GRPC_TIMEOUT, "5000"));
MetaInfo.PROPERTY_ACL_USERNAME = environment.getProperty(Dict.PROPERTY_ACL_USERNAME);
MetaInfo.PROPERTY_ACL_PASSWORD = environment.getProperty(Dict.PROPERTY_ACL_PASSWORD);
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
package com.webank.ai.fate.serving.controller;

import org.apache.logging.log4j.Level;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.core.LoggerContext;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.web.bind.annotation.*;

/**
* @author hcy
*/
@RequestMapping("/server")
@RestController
public class DynamicLogController {
private static final Logger logger = LoggerFactory.getLogger(DynamicLogController.class);

@GetMapping("/alterSysLogLevel/{level}")
public String alterSysLogLevel(@PathVariable String level){
try {
LoggerContext context = (LoggerContext) LogManager.getContext(false);
context.getLogger("ROOT").setLevel(Level.valueOf(level));
return "ok";
} catch (Exception ex) {
logger.error("server alterSysLogLevel failed : " + ex);
return "failed";
}

}

@GetMapping("/alterPkgLogLevel")
public String alterPkgLogLevel(@RequestParam String level, @RequestParam String pkgName){
try {
LoggerContext context = (LoggerContext) LogManager.getContext(false);
context.getLogger(pkgName).setLevel(Level.valueOf(level));
return "ok";
} catch (Exception ex) {
logger.error("server alterPkgLogLevel failed : " + ex);
return "failed";
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
package com.webank.ai.fate.serving.controller;

import com.google.common.base.Preconditions;
import com.google.common.util.concurrent.ListenableFuture;
import com.webank.ai.fate.api.mlmodel.manager.ModelServiceGrpc;
import com.webank.ai.fate.api.mlmodel.manager.ModelServiceProto;
import com.webank.ai.fate.serving.core.bean.GrpcConnectionPool;
import com.webank.ai.fate.serving.core.bean.MetaInfo;
import com.webank.ai.fate.serving.core.bean.RequestParamWrapper;
import com.webank.ai.fate.serving.core.bean.ReturnResult;
import com.webank.ai.fate.serving.core.exceptions.SysException;
import com.webank.ai.fate.serving.core.utils.NetUtils;
import io.grpc.ManagedChannel;
import org.apache.commons.lang3.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.web.bind.annotation.*;

import java.util.List;
import java.util.concurrent.Callable;
import java.util.concurrent.TimeUnit;

/**
* @author hcy
*/
@RestController
public class ServerModelController {

Logger logger = LoggerFactory.getLogger(ServerModelController.class);

GrpcConnectionPool grpcConnectionPool = GrpcConnectionPool.getPool();

@RequestMapping(value = "/server/model/unbind", method = RequestMethod.POST)
@ResponseBody
public Callable<ReturnResult> unbind(@RequestBody RequestParamWrapper requestParams) throws Exception {
return () -> {
String host = requestParams.getHost();
Integer port = requestParams.getPort();
String tableName = requestParams.getTableName();
String namespace = requestParams.getNamespace();
List<String> serviceIds = requestParams.getServiceIds();

Preconditions.checkArgument(StringUtils.isNotBlank(tableName), "parameter tableName is blank");
Preconditions.checkArgument(StringUtils.isNotBlank(namespace), "parameter namespace is blank");
Preconditions.checkArgument(serviceIds != null && serviceIds.size() != 0, "parameter serviceId is blank");

ReturnResult result = new ReturnResult();

logger.debug("unbind model by tableName and namespace, host: {}, port: {}, tableName: {}, namespace: {}", host, port, tableName, namespace);

ModelServiceGrpc.ModelServiceFutureStub futureStub = getModelServiceFutureStub(host, port);

ModelServiceProto.UnbindRequest unbindRequest = ModelServiceProto.UnbindRequest.newBuilder()
.setTableName(tableName)
.setNamespace(namespace)
.addAllServiceIds(serviceIds)
.build();

ListenableFuture<ModelServiceProto.UnbindResponse> future = futureStub.unbind(unbindRequest);

ModelServiceProto.UnbindResponse response = future.get(MetaInfo.PROPERTY_GRPC_TIMEOUT, TimeUnit.MILLISECONDS);

logger.debug("response: {}", response);

result.setRetcode(response.getStatusCode());
result.setRetmsg(response.getMessage());
return result;
};
}

@RequestMapping(value = "/server/model/transfer", method = RequestMethod.POST)
@ResponseBody
public Callable<ReturnResult> transfer(@RequestBody RequestParamWrapper requestParams) {
return () -> {
String host = requestParams.getHost();
Integer port = requestParams.getPort();
String tableName = requestParams.getTableName();
String namespace = requestParams.getNamespace();

String targetHost = requestParams.getTargetHost();
Integer targetPort = requestParams.getTargetPort();

Preconditions.checkArgument(StringUtils.isNotBlank(tableName), "parameter tableName is blank");
Preconditions.checkArgument(StringUtils.isNotBlank(namespace), "parameter namespace is blank");

ReturnResult result = new ReturnResult();

logger.debug("transfer model by tableName and namespace, host: {}, port: {}, tableName: {}, namespace: {}, targetHost: {}, targetPort: {}"
, host, port, tableName, namespace, targetHost, targetPort);

ModelServiceGrpc.ModelServiceFutureStub futureStub = getModelServiceFutureStub(targetHost, targetPort);
ModelServiceProto.FetchModelRequest fetchModelRequest = ModelServiceProto.FetchModelRequest.newBuilder()
.setNamespace(namespace).setTableName(tableName).setSourceIp(host).setSourcePort(port).build();

ListenableFuture<ModelServiceProto.FetchModelResponse> future = futureStub.fetchModel(fetchModelRequest);
ModelServiceProto.FetchModelResponse response = future.get(MetaInfo.PROPERTY_GRPC_TIMEOUT, TimeUnit.MILLISECONDS);

logger.debug("response: {}", response);

result.setRetcode(response.getStatusCode());
result.setRetmsg(response.getMessage());
return result;
};
}

private ModelServiceGrpc.ModelServiceFutureStub getModelServiceFutureStub(String host, Integer port) {
Preconditions.checkArgument(StringUtils.isNotBlank(host), "parameter host is blank");
Preconditions.checkArgument(port != null && port != 0, "parameter port was wrong");

if (!NetUtils.isValidAddress(host + ":" + port)) {
throw new SysException("invalid address");
}

ManagedChannel managedChannel = grpcConnectionPool.getManagedChannel(host, port);
return ModelServiceGrpc.newFutureStub(managedChannel);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
package com.webank.ai.fate.serving.controller;

import com.google.common.base.Preconditions;
import com.google.common.util.concurrent.ListenableFuture;
import com.webank.ai.fate.api.networking.common.CommonServiceGrpc;
import com.webank.ai.fate.api.networking.common.CommonServiceProto;
import com.webank.ai.fate.serving.core.bean.GrpcConnectionPool;
import com.webank.ai.fate.serving.core.bean.MetaInfo;
import com.webank.ai.fate.serving.core.bean.RequestParamWrapper;
import com.webank.ai.fate.serving.core.bean.ReturnResult;
import com.webank.ai.fate.serving.core.exceptions.SysException;
import com.webank.ai.fate.serving.core.utils.NetUtils;
import io.grpc.ManagedChannel;
import org.apache.commons.lang3.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.web.bind.annotation.*;

import java.util.concurrent.TimeUnit;

/**
* @author hcy
*/
@RestController
public class ServerServiceController {

Logger logger = LoggerFactory.getLogger(ServerServiceController.class);

GrpcConnectionPool grpcConnectionPool = GrpcConnectionPool.getPool();

@RequestMapping(value = "/server/service/weight/update", method = RequestMethod.POST)
@ResponseBody
public ReturnResult updateService(@RequestBody RequestParamWrapper requestParams) throws Exception {
String host = requestParams.getHost();
int port = requestParams.getPort();
String url = requestParams.getUrl();
String routerMode = requestParams.getRouterMode();
Integer weight = requestParams.getWeight();
Long version = requestParams.getVersion();

if (logger.isDebugEnabled()) {
logger.debug("try to update service");
}

Preconditions.checkArgument(StringUtils.isNotBlank(url), "parameter url is blank");

logger.info("update url: {}, routerMode: {}, weight: {}, version: {}", url, routerMode, weight, version);

CommonServiceGrpc.CommonServiceFutureStub commonServiceFutureStub = getCommonServiceFutureStub(host, port);
CommonServiceProto.UpdateServiceRequest.Builder builder = CommonServiceProto.UpdateServiceRequest.newBuilder();

builder.setUrl(url);
if (StringUtils.isNotBlank(routerMode)) {
builder.setRouterMode(routerMode);
}

if (weight != null) {
builder.setWeight(weight);
} else {
builder.setWeight(-1);
}

if (version != null) {
builder.setVersion(version);
} else {
builder.setVersion(-1);
}

ListenableFuture<CommonServiceProto.CommonResponse> future = commonServiceFutureStub.updateService(builder.build());

CommonServiceProto.CommonResponse response = future.get(MetaInfo.PROPERTY_GRPC_TIMEOUT, TimeUnit.MILLISECONDS);

ReturnResult result = new ReturnResult();
result.setRetcode(response.getStatusCode());
result.setRetmsg(response.getMessage());
return result;
}

private CommonServiceGrpc.CommonServiceFutureStub getCommonServiceFutureStub(String host, Integer port) {
Preconditions.checkArgument(StringUtils.isNotBlank(host), "parameter host is blank");
Preconditions.checkArgument(port != null && port != 0, "parameter port was wrong");

if (!NetUtils.isValidAddress(host + ":" + port)) {
throw new SysException("invalid address");
}

ManagedChannel managedChannel = grpcConnectionPool.getManagedChannel(host, port);
return CommonServiceGrpc.newFutureStub(managedChannel);
}
}
Original file line number Diff line number Diff line change
@@ -78,7 +78,7 @@ public synchronized ModelServiceProto.UnbindResponse unbind(Context context, Mod
String modelKey = this.getNameSpaceKey(req.getTableName(), req.getNamespace());
if (!this.namespaceMap.containsKey(modelKey)) {
logger.error("not found model info table name {} namespace {}, please check if the model is already loaded.", req.getTableName(), req.getNamespace());
throw new ModelNullException(" found model info, please check if the model is already loaded.");
throw new ModelNullException("not found model info, please check if the model is already loaded.");
}
Model model = this.namespaceMap.get(modelKey);
String tableNamekey = this.getNameSpaceKey(model.getTableName(), model.getNamespace());
Original file line number Diff line number Diff line change
@@ -13,7 +13,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
port=8000
#主机启动serving-server进程服务时,port端口定义为grpc端口
port=8000
#主机启动serving-server进程服务时,http端口在此定义;采用kubefate部署时,请注释关闭此选项,在k8s中serving-server对应svc资源文件上进行修改
server.port=8185
#serviceRoleName=serving
# cache
#remoteModelInferenceResultCacheSwitch=false
@@ -39,6 +42,7 @@ port=8000
# adapter
feature.single.adaptor=com.webank.ai.fate.serving.adaptor.dataaccess.MockAdapter
feature.batch.adaptor=com.webank.ai.fate.serving.adaptor.dataaccess.MockBatchAdapter
feature.batch.single.adaptor=com.webank.ai.fate.serving.adaptor.dataaccess.HttpAdapter
http.adapter.url=http://127.0.0.1:9380/v1/http/adapter/getFeature
# model transfer
model.transfer.url=http://127.0.0.1:9380/v1/model/transfer
@@ -54,4 +58,4 @@ zk.url=localhost:2181,localhost:2182,localhost:2183

# LR algorithm config
#lr.split.size=500
#lr.use.parallel=false
#lr.use.parallel=false
12 changes: 6 additions & 6 deletions pom.xml
Original file line number Diff line number Diff line change
@@ -30,14 +30,14 @@
<module>fate-serving-register</module>
<module>fate-serving-common</module>
<module>fate-serving-proxy</module>
<module>fate-serving-admin</module>
<module>fate-serving-admin-ui</module>
<!-- <module>fate-serving-admin</module>-->
<!-- <module>fate-serving-admin-ui</module>-->
<module>fate-serving-extension</module>
<module>fate-serving-sdk</module>
</modules>

<properties>
<fate.version>2.1.6</fate.version>
<fate.version>2.1.7</fate.version>
<java.version>1.8</java.version>
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
<project.reporting.outputEncoding>UTF-8</project.reporting.outputEncoding>
@@ -48,7 +48,7 @@
<protobuf-maven-plugin.version>0.6.1</protobuf-maven-plugin.version>
<os-maven-plugin.version>1.6.1</os-maven-plugin.version>
<spring.boot.version>2.7.0</spring.boot.version>
<jackson.version>2.13.3</jackson.version>
<jackson.version>2.15.3</jackson.version>
<jedis.version>2.9.0</jedis.version>
<log4j2.version>2.17.1</log4j2.version>
<skipTests>true</skipTests>
@@ -226,7 +226,7 @@
<dependency>
<groupId>org.yaml</groupId>
<artifactId>snakeyaml</artifactId>
<version>1.26</version>
<version>1.32</version>
<scope>compile</scope>
</dependency>

@@ -251,7 +251,7 @@
<dependency>
<groupId>commons-net</groupId>
<artifactId>commons-net</artifactId>
<version>3.8.0</version>
<version>3.9.0</version>
</dependency>

<dependency>

0 comments on commit ab7ba98

Please sign in to comment.