diff --git a/wedpr-common/utils/src/main/java/com/webank/wedpr/common/utils/CSVFileParser.java b/wedpr-common/utils/src/main/java/com/webank/wedpr/common/utils/CSVFileParser.java index 23d19d70..7e730fdc 100644 --- a/wedpr-common/utils/src/main/java/com/webank/wedpr/common/utils/CSVFileParser.java +++ b/wedpr-common/utils/src/main/java/com/webank/wedpr/common/utils/CSVFileParser.java @@ -173,6 +173,8 @@ public Object call(CSVReaderHeaderAware reader) throws Exception { new FileWriter(extractConfig.getExtractFilePath()), extractConfig.getWriteChunkSize())) { // write the data(Note: here no need to write the header) + writer.write( + Constant.DEFAULT_ID_FIELD + Constant.DEFAULT_LINE_SPLITTER); while ((row = reader.readMap()) != null) { int column = 0; for (String field : extractConfig.getExtractFields()) { diff --git a/wedpr-common/utils/src/main/java/com/webank/wedpr/common/utils/Constant.java b/wedpr-common/utils/src/main/java/com/webank/wedpr/common/utils/Constant.java index 82a3f6ce..13e04f1c 100644 --- a/wedpr-common/utils/src/main/java/com/webank/wedpr/common/utils/Constant.java +++ b/wedpr-common/utils/src/main/java/com/webank/wedpr/common/utils/Constant.java @@ -25,6 +25,7 @@ public class Constant { public static String WEDPR_SUCCESS_MSG = "success"; public static Integer WEDPR_FAILED = -1; + public static Integer WEDPR_AUTH_FAILED = 401; public static String CHAIN_CONFIG_FILE = "config.toml"; diff --git a/wedpr-components/meta/project/src/main/java/com/webank/wedpr/components/project/dao/JobDO.java b/wedpr-components/meta/project/src/main/java/com/webank/wedpr/components/project/dao/JobDO.java index 63d8f148..1219932e 100644 --- a/wedpr-components/meta/project/src/main/java/com/webank/wedpr/components/project/dao/JobDO.java +++ b/wedpr-components/meta/project/src/main/java/com/webank/wedpr/components/project/dao/JobDO.java @@ -139,6 +139,8 @@ public static JobResult deserialize(String data) { @JsonIgnore private transient Object jobRequest; private String status; + private JobStatus jobStatus; + @JsonIgnore private String result; private List datasetList; @@ -180,6 +182,18 @@ public void setName(String name) { this.name = name; } + public void setStatus(String status) { + this.status = status; + this.jobStatus = JobStatus.deserialize(status); + } + + public void setJobStatus(JobStatus jobStatus) { + this.jobStatus = jobStatus; + if (this.jobStatus != null) { + this.status = this.jobStatus.getStatus(); + } + } + public void setResult(String result) { this.result = result; if (StringUtils.isBlank(result)) { diff --git a/wedpr-components/scheduler/src/main/java/com/webank/wedpr/components/scheduler/executor/hook/PSIExecutorHook.java b/wedpr-components/scheduler/src/main/java/com/webank/wedpr/components/scheduler/executor/hook/PSIExecutorHook.java index b77d4e0a..10811ef9 100644 --- a/wedpr-components/scheduler/src/main/java/com/webank/wedpr/components/scheduler/executor/hook/PSIExecutorHook.java +++ b/wedpr-components/scheduler/src/main/java/com/webank/wedpr/components/scheduler/executor/hook/PSIExecutorHook.java @@ -50,6 +50,6 @@ protected void preparePSIJob(JobDO jobDO, PSIJobParam psiJobParam) throws Except // download and prepare the psi file psiJobParam.prepare(this.fileMetaBuilder, storage); // convert to PSIRequest - jobDO.setJobRequest(psiJobParam.convert(jobDO.getOwnerAgency())); + jobDO.setJobRequest(psiJobParam.convert(jobDO.getType(), jobDO.getOwnerAgency())); } } diff --git a/wedpr-components/scheduler/src/main/java/com/webank/wedpr/components/scheduler/executor/impl/ExecutorConfig.java b/wedpr-components/scheduler/src/main/java/com/webank/wedpr/components/scheduler/executor/impl/ExecutorConfig.java index 04e0c23b..b0f1530a 100644 --- a/wedpr-components/scheduler/src/main/java/com/webank/wedpr/components/scheduler/executor/impl/ExecutorConfig.java +++ b/wedpr-components/scheduler/src/main/java/com/webank/wedpr/components/scheduler/executor/impl/ExecutorConfig.java @@ -95,7 +95,7 @@ public static String getPsiPrepareFileName() { public static String getDefaultPSIResultPath(String user, String jobID) { return WeDPRCommonConfig.getUserJobCachePath( - user, JobType.PIR.getType(), jobID, PSI_RESULT_FILE_NAME); + user, JobType.PSI.getType(), jobID, PSI_RESULT_FILE_NAME); } public static String getMpcResultFileName() { diff --git a/wedpr-components/scheduler/src/main/java/com/webank/wedpr/components/scheduler/executor/impl/ml/model/ModelJobParam.java b/wedpr-components/scheduler/src/main/java/com/webank/wedpr/components/scheduler/executor/impl/ml/model/ModelJobParam.java index dcad2974..1cfa2ba0 100644 --- a/wedpr-components/scheduler/src/main/java/com/webank/wedpr/components/scheduler/executor/impl/ml/model/ModelJobParam.java +++ b/wedpr-components/scheduler/src/main/java/com/webank/wedpr/components/scheduler/executor/impl/ml/model/ModelJobParam.java @@ -34,6 +34,7 @@ import java.util.ArrayList; import java.util.List; import lombok.Data; +import lombok.SneakyThrows; @Data @JsonIgnoreProperties(ignoreUnknown = true) @@ -135,13 +136,22 @@ public void parseLabelProviderInfo() throws Exception { } // set the participants information + @SneakyThrows(Exception.class) public void parseParticipants() { // set the active party this.modelRequest .getParticipantIDList() .add(this.labelProviderDataset.getDataset().getOwnerAgency()); + boolean selfParticipant = false; // set the passive parties for (DatasetInfo datasetInfo : dataSetList) { + if (datasetInfo + .getDataset() + .getOwnerAgency() + .compareToIgnoreCase(WeDPRCommonConfig.getAgency()) + == 0) { + selfParticipant = true; + } if (datasetInfo .getDataset() .getOwnerAgency() @@ -152,6 +162,12 @@ public void parseParticipants() { } this.modelRequest.getParticipantIDList().add(datasetInfo.getDataset().getOwnerAgency()); } + if (!selfParticipant) { + throw new WeDPRException( + "The agency " + + WeDPRCommonConfig.getAgency() + + " must participant the model job!"); + } } public PreprocessingRequest toPreprocessingRequest(FileMetaBuilder fileMetaBuilder) diff --git a/wedpr-components/scheduler/src/main/java/com/webank/wedpr/components/scheduler/executor/impl/psi/model/PSIJobParam.java b/wedpr-components/scheduler/src/main/java/com/webank/wedpr/components/scheduler/executor/impl/psi/model/PSIJobParam.java index a5ca70af..a46fce0c 100644 --- a/wedpr-components/scheduler/src/main/java/com/webank/wedpr/components/scheduler/executor/impl/psi/model/PSIJobParam.java +++ b/wedpr-components/scheduler/src/main/java/com/webank/wedpr/components/scheduler/executor/impl/psi/model/PSIJobParam.java @@ -31,11 +31,13 @@ import java.io.File; import java.util.ArrayList; import java.util.List; +import lombok.Data; import lombok.SneakyThrows; import org.apache.commons.lang3.StringUtils; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +@Data public class PSIJobParam { private static final Logger logger = LoggerFactory.getLogger(PSIJobParam.class); @@ -95,22 +97,6 @@ public void setOutput(FileMeta output) { @JsonIgnore private List datasetIDList; - public String getJobID() { - return jobID; - } - - public void setJobID(String jobID) { - this.jobID = jobID; - } - - public List getPartyResourceInfoList() { - return partyResourceInfoList; - } - - public void setPartyResourceInfoList(List partyResourceInfoList) { - this.partyResourceInfoList = partyResourceInfoList; - } - public static PSIJobParam deserialize(String data) throws Exception { if (StringUtils.isBlank(data)) { throw new WeDPRException("The PSIJobParam must be non-empty!"); @@ -154,7 +140,7 @@ public void check(FileMetaBuilder fileMetaBuilder) throws Exception { } } - public PSIRequest convert(String ownerAgency) throws Exception { + public PSIRequest convert(JobType jobType, String ownerAgency) throws Exception { PSIRequest psiRequest = new PSIRequest(); psiRequest.setTaskID(this.taskID); psiRequest.setParties(toPSIParam(ownerAgency)); @@ -162,7 +148,10 @@ public PSIRequest convert(String ownerAgency) throws Exception { List receivers = new ArrayList<>(); boolean syncResult = false; for (PartyResourceInfo partyInfo : partyResourceInfoList) { - if (partyInfo.getReceiveResult()) { + // Note: the ml-psi and mpc-psi case, all parties are the receivers + if (jobType == JobType.ML_PSI + || jobType == JobType.MPC_PSI + || partyInfo.getReceiveResult()) { receivers.add(partyInfo.getDataset().getOwnerAgency()); syncResult = true; } @@ -172,13 +161,16 @@ public PSIRequest convert(String ownerAgency) throws Exception { return psiRequest; } + @SneakyThrows(Exception.class) private List toPSIParam(String ownerAgency) { List partyInfoList = new ArrayList<>(); + boolean selfParticipant = false; for (PartyResourceInfo party : partyResourceInfoList) { String agency = party.getDataset().getOwnerAgency(); PartyInfo partyInfo = new PartyInfo(agency); if (agency.compareToIgnoreCase(ownerAgency) == 0) { partyInfo.setPartyIndex(PartyInfo.PartyType.SERVER.getType()); + selfParticipant = true; } else { partyInfo.setPartyIndex(PartyInfo.PartyType.CLIENT.getType()); } @@ -186,6 +178,12 @@ private List toPSIParam(String ownerAgency) { new PartyInfo.PartyData(jobID, party.getDataset(), party.getOutput())); partyInfoList.add(partyInfo); } + if (!selfParticipant) { + throw new WeDPRException( + "The agency " + + WeDPRCommonConfig.getAgency() + + " must participant the PSI job!"); + } return partyInfoList; } @@ -296,22 +294,6 @@ public void prepare( } } - public String getTaskID() { - return taskID; - } - - public void setTaskID(String taskID) { - this.taskID = taskID; - } - - public List getDatasetIDList() { - return datasetIDList; - } - - public void setDatasetIDList(List datasetIDList) { - this.datasetIDList = datasetIDList; - } - public String serialize() throws Exception { return ObjectMapperFactory.getObjectMapper().writeValueAsString(this); } diff --git a/wedpr-components/scheduler/src/main/java/com/webank/wedpr/components/scheduler/impl/SchedulerServiceImpl.java b/wedpr-components/scheduler/src/main/java/com/webank/wedpr/components/scheduler/impl/SchedulerServiceImpl.java index 30b14e64..d7e88371 100644 --- a/wedpr-components/scheduler/src/main/java/com/webank/wedpr/components/scheduler/impl/SchedulerServiceImpl.java +++ b/wedpr-components/scheduler/src/main/java/com/webank/wedpr/components/scheduler/impl/SchedulerServiceImpl.java @@ -55,13 +55,17 @@ public Object queryJobDetail(String user, String agency, String jobID) throws Ex JobDO jobDO = jobDOList.get(0); // run failed, no need to fetch the result, only fetch the log if (!JobStatus.success(jobDO.getStatus())) { - GetTaskResultRequest getTaskResultRequest = - new GetTaskResultRequest(user, jobDO.getId(), jobDO.getJobType()); - getTaskResultRequest.setOnlyFetchLog(Boolean.TRUE); - ModelJobResult.ModelJobData modelJobData = - (ModelJobResult.ModelJobData) - MLExecutorClient.getJobResult(loadBalancer, getTaskResultRequest); - return new JobDetailResponse(jobDO, null, null, modelJobData.getLogDetail()); + Object logDetail = null; + if (jobDO.getJobStatus().finished()) { + GetTaskResultRequest getTaskResultRequest = + new GetTaskResultRequest(user, jobDO.getId(), jobDO.getJobType()); + getTaskResultRequest.setOnlyFetchLog(Boolean.TRUE); + ModelJobResult.ModelJobData modelJobData = + (ModelJobResult.ModelJobData) + MLExecutorClient.getJobResult(loadBalancer, getTaskResultRequest); + logDetail = modelJobData.getLogDetail(); + } + return new JobDetailResponse(jobDO, null, null, logDetail); } // the ml job if (jobDO.getType().mlJob()) { diff --git a/wedpr-components/security/src/main/java/com/webank/wedpr/components/security/filter/APISignatureAuthFilter.java b/wedpr-components/security/src/main/java/com/webank/wedpr/components/security/filter/APISignatureAuthFilter.java index 6bf0b7d9..9740936c 100644 --- a/wedpr-components/security/src/main/java/com/webank/wedpr/components/security/filter/APISignatureAuthFilter.java +++ b/wedpr-components/security/src/main/java/com/webank/wedpr/components/security/filter/APISignatureAuthFilter.java @@ -73,7 +73,8 @@ protected void doFilterInternal( chain.doFilter(requestWrapper, response); } catch (Exception e) { logger.warn("APISignatureAuthFilter exception, error: ", e); - TokenUtils.responseToClient(response, e.getMessage(), HttpServletResponse.SC_FORBIDDEN); + TokenUtils.responseToClient( + response, e.getMessage(), HttpServletResponse.SC_UNAUTHORIZED); } } } diff --git a/wedpr-components/security/src/main/java/com/webank/wedpr/components/security/filter/JwtAuthenticationFilter.java b/wedpr-components/security/src/main/java/com/webank/wedpr/components/security/filter/JwtAuthenticationFilter.java index a0394c31..d7cc3f2a 100644 --- a/wedpr-components/security/src/main/java/com/webank/wedpr/components/security/filter/JwtAuthenticationFilter.java +++ b/wedpr-components/security/src/main/java/com/webank/wedpr/components/security/filter/JwtAuthenticationFilter.java @@ -76,9 +76,11 @@ protected void doFilterInternal( } catch (Exception e) { logger.info("jwt auth failed, error: ", e); String wedprResponse = - new WeDPRResponse(Constant.WEDPR_FAILED, "auth failed for " + e.getMessage()) + new WeDPRResponse( + Constant.WEDPR_AUTH_FAILED, "auth failed for " + e.getMessage()) .serialize(); - TokenUtils.responseToClient(response, wedprResponse, HttpServletResponse.SC_FORBIDDEN); + TokenUtils.responseToClient( + response, wedprResponse, HttpServletResponse.SC_UNAUTHORIZED); } } } diff --git a/wedpr-components/task-plugin/pir/src/main/java/com/webank/wedpr/components/task/plugin/pir/core/impl/PirDatasetConstructorImpl.java b/wedpr-components/task-plugin/pir/src/main/java/com/webank/wedpr/components/task/plugin/pir/core/impl/PirDatasetConstructorImpl.java index d74dea66..f2e48418 100644 --- a/wedpr-components/task-plugin/pir/src/main/java/com/webank/wedpr/components/task/plugin/pir/core/impl/PirDatasetConstructorImpl.java +++ b/wedpr-components/task-plugin/pir/src/main/java/com/webank/wedpr/components/task/plugin/pir/core/impl/PirDatasetConstructorImpl.java @@ -98,6 +98,7 @@ private void constructFromCSV(Dataset dataset, String idField) throws Exception if (datasetFieldsList.contains(Constant.PIR_ID_HASH_FIELD_NAME)) { throw new WeDPRException("Conflict with sys field " + Constant.PIR_ID_HASH_FIELD_NAME); } + Long startTime = System.currentTimeMillis(); List> sqlValues = CSVFileParser.processCsv2SqlMap(datasetFields, localFilePath); if (sqlValues.size() == 0) { @@ -107,6 +108,9 @@ private void constructFromCSV(Dataset dataset, String idField) throws Exception localFilePath); return; } + logger.info( + "processCsv2SqlMap success, timecost: {}ms", + System.currentTimeMillis() - startTime); String tableId = com.webank.wedpr.components.task.plugin.pir.utils.Constant.datasetId2tableId( dataset.getDatasetId());