Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix auth failed not redirect bug #120

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,8 @@ public Object call(CSVReaderHeaderAware reader) throws Exception {
new FileWriter(extractConfig.getExtractFilePath()),
extractConfig.getWriteChunkSize())) {
// write the data(Note: here no need to write the header)
writer.write(
Constant.DEFAULT_ID_FIELD + Constant.DEFAULT_LINE_SPLITTER);
while ((row = reader.readMap()) != null) {
int column = 0;
for (String field : extractConfig.getExtractFields()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ public class Constant {
public static String WEDPR_SUCCESS_MSG = "success";

public static Integer WEDPR_FAILED = -1;
public static Integer WEDPR_AUTH_FAILED = 401;

public static String CHAIN_CONFIG_FILE = "config.toml";

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,8 @@ public static JobResult deserialize(String data) {

@JsonIgnore private transient Object jobRequest;
private String status;
private JobStatus jobStatus;

@JsonIgnore private String result;
private List<String> datasetList;

Expand Down Expand Up @@ -180,6 +182,18 @@ public void setName(String name) {
this.name = name;
}

public void setStatus(String status) {
this.status = status;
this.jobStatus = JobStatus.deserialize(status);
}

public void setJobStatus(JobStatus jobStatus) {
this.jobStatus = jobStatus;
if (this.jobStatus != null) {
this.status = this.jobStatus.getStatus();
}
}

public void setResult(String result) {
this.result = result;
if (StringUtils.isBlank(result)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,6 @@ protected void preparePSIJob(JobDO jobDO, PSIJobParam psiJobParam) throws Except
// download and prepare the psi file
psiJobParam.prepare(this.fileMetaBuilder, storage);
// convert to PSIRequest
jobDO.setJobRequest(psiJobParam.convert(jobDO.getOwnerAgency()));
jobDO.setJobRequest(psiJobParam.convert(jobDO.getType(), jobDO.getOwnerAgency()));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ public static String getPsiPrepareFileName() {

public static String getDefaultPSIResultPath(String user, String jobID) {
return WeDPRCommonConfig.getUserJobCachePath(
user, JobType.PIR.getType(), jobID, PSI_RESULT_FILE_NAME);
user, JobType.PSI.getType(), jobID, PSI_RESULT_FILE_NAME);
}

public static String getMpcResultFileName() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
import java.util.ArrayList;
import java.util.List;
import lombok.Data;
import lombok.SneakyThrows;

@Data
@JsonIgnoreProperties(ignoreUnknown = true)
Expand Down Expand Up @@ -135,13 +136,22 @@ public void parseLabelProviderInfo() throws Exception {
}

// set the participants information
@SneakyThrows(Exception.class)
public void parseParticipants() {
// set the active party
this.modelRequest
.getParticipantIDList()
.add(this.labelProviderDataset.getDataset().getOwnerAgency());
boolean selfParticipant = false;
// set the passive parties
for (DatasetInfo datasetInfo : dataSetList) {
if (datasetInfo
.getDataset()
.getOwnerAgency()
.compareToIgnoreCase(WeDPRCommonConfig.getAgency())
== 0) {
selfParticipant = true;
}
if (datasetInfo
.getDataset()
.getOwnerAgency()
Expand All @@ -152,6 +162,12 @@ public void parseParticipants() {
}
this.modelRequest.getParticipantIDList().add(datasetInfo.getDataset().getOwnerAgency());
}
if (!selfParticipant) {
throw new WeDPRException(
"The agency "
+ WeDPRCommonConfig.getAgency()
+ " must participant the model job!");
}
}

public PreprocessingRequest toPreprocessingRequest(FileMetaBuilder fileMetaBuilder)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,13 @@
import java.io.File;
import java.util.ArrayList;
import java.util.List;
import lombok.Data;
import lombok.SneakyThrows;
import org.apache.commons.lang3.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

@Data
public class PSIJobParam {
private static final Logger logger = LoggerFactory.getLogger(PSIJobParam.class);

Expand Down Expand Up @@ -95,22 +97,6 @@ public void setOutput(FileMeta output) {

@JsonIgnore private List<String> datasetIDList;

public String getJobID() {
return jobID;
}

public void setJobID(String jobID) {
this.jobID = jobID;
}

public List<PartyResourceInfo> getPartyResourceInfoList() {
return partyResourceInfoList;
}

public void setPartyResourceInfoList(List<PartyResourceInfo> partyResourceInfoList) {
this.partyResourceInfoList = partyResourceInfoList;
}

public static PSIJobParam deserialize(String data) throws Exception {
if (StringUtils.isBlank(data)) {
throw new WeDPRException("The PSIJobParam must be non-empty!");
Expand Down Expand Up @@ -154,15 +140,18 @@ public void check(FileMetaBuilder fileMetaBuilder) throws Exception {
}
}

public PSIRequest convert(String ownerAgency) throws Exception {
public PSIRequest convert(JobType jobType, String ownerAgency) throws Exception {
PSIRequest psiRequest = new PSIRequest();
psiRequest.setTaskID(this.taskID);
psiRequest.setParties(toPSIParam(ownerAgency));
resetPartyIndex(psiRequest.getParties());
List<String> receivers = new ArrayList<>();
boolean syncResult = false;
for (PartyResourceInfo partyInfo : partyResourceInfoList) {
if (partyInfo.getReceiveResult()) {
// Note: the ml-psi and mpc-psi case, all parties are the receivers
if (jobType == JobType.ML_PSI
|| jobType == JobType.MPC_PSI
|| partyInfo.getReceiveResult()) {
receivers.add(partyInfo.getDataset().getOwnerAgency());
syncResult = true;
}
Expand All @@ -172,20 +161,29 @@ public PSIRequest convert(String ownerAgency) throws Exception {
return psiRequest;
}

@SneakyThrows(Exception.class)
private List<PartyInfo> toPSIParam(String ownerAgency) {
List<PartyInfo> partyInfoList = new ArrayList<>();
boolean selfParticipant = false;
for (PartyResourceInfo party : partyResourceInfoList) {
String agency = party.getDataset().getOwnerAgency();
PartyInfo partyInfo = new PartyInfo(agency);
if (agency.compareToIgnoreCase(ownerAgency) == 0) {
partyInfo.setPartyIndex(PartyInfo.PartyType.SERVER.getType());
selfParticipant = true;
} else {
partyInfo.setPartyIndex(PartyInfo.PartyType.CLIENT.getType());
}
partyInfo.setData(
new PartyInfo.PartyData(jobID, party.getDataset(), party.getOutput()));
partyInfoList.add(partyInfo);
}
if (!selfParticipant) {
throw new WeDPRException(
"The agency "
+ WeDPRCommonConfig.getAgency()
+ " must participant the PSI job!");
}
return partyInfoList;
}

Expand Down Expand Up @@ -296,22 +294,6 @@ public void prepare(
}
}

public String getTaskID() {
return taskID;
}

public void setTaskID(String taskID) {
this.taskID = taskID;
}

public List<String> getDatasetIDList() {
return datasetIDList;
}

public void setDatasetIDList(List<String> datasetIDList) {
this.datasetIDList = datasetIDList;
}

public String serialize() throws Exception {
return ObjectMapperFactory.getObjectMapper().writeValueAsString(this);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,13 +55,17 @@ public Object queryJobDetail(String user, String agency, String jobID) throws Ex
JobDO jobDO = jobDOList.get(0);
// run failed, no need to fetch the result, only fetch the log
if (!JobStatus.success(jobDO.getStatus())) {
GetTaskResultRequest getTaskResultRequest =
new GetTaskResultRequest(user, jobDO.getId(), jobDO.getJobType());
getTaskResultRequest.setOnlyFetchLog(Boolean.TRUE);
ModelJobResult.ModelJobData modelJobData =
(ModelJobResult.ModelJobData)
MLExecutorClient.getJobResult(loadBalancer, getTaskResultRequest);
return new JobDetailResponse(jobDO, null, null, modelJobData.getLogDetail());
Object logDetail = null;
if (jobDO.getJobStatus().finished()) {
GetTaskResultRequest getTaskResultRequest =
new GetTaskResultRequest(user, jobDO.getId(), jobDO.getJobType());
getTaskResultRequest.setOnlyFetchLog(Boolean.TRUE);
ModelJobResult.ModelJobData modelJobData =
(ModelJobResult.ModelJobData)
MLExecutorClient.getJobResult(loadBalancer, getTaskResultRequest);
logDetail = modelJobData.getLogDetail();
}
return new JobDetailResponse(jobDO, null, null, logDetail);
}
// the ml job
if (jobDO.getType().mlJob()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,8 @@ protected void doFilterInternal(
chain.doFilter(requestWrapper, response);
} catch (Exception e) {
logger.warn("APISignatureAuthFilter exception, error: ", e);
TokenUtils.responseToClient(response, e.getMessage(), HttpServletResponse.SC_FORBIDDEN);
TokenUtils.responseToClient(
response, e.getMessage(), HttpServletResponse.SC_UNAUTHORIZED);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,11 @@ protected void doFilterInternal(
} catch (Exception e) {
logger.info("jwt auth failed, error: ", e);
String wedprResponse =
new WeDPRResponse(Constant.WEDPR_FAILED, "auth failed for " + e.getMessage())
new WeDPRResponse(
Constant.WEDPR_AUTH_FAILED, "auth failed for " + e.getMessage())
.serialize();
TokenUtils.responseToClient(response, wedprResponse, HttpServletResponse.SC_FORBIDDEN);
TokenUtils.responseToClient(
response, wedprResponse, HttpServletResponse.SC_UNAUTHORIZED);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ private void constructFromCSV(Dataset dataset, String idField) throws Exception
if (datasetFieldsList.contains(Constant.PIR_ID_HASH_FIELD_NAME)) {
throw new WeDPRException("Conflict with sys field " + Constant.PIR_ID_HASH_FIELD_NAME);
}
Long startTime = System.currentTimeMillis();
List<List<String>> sqlValues =
CSVFileParser.processCsv2SqlMap(datasetFields, localFilePath);
if (sqlValues.size() == 0) {
Expand All @@ -107,6 +108,9 @@ private void constructFromCSV(Dataset dataset, String idField) throws Exception
localFilePath);
return;
}
logger.info(
"processCsv2SqlMap success, timecost: {}ms",
System.currentTimeMillis() - startTime);
String tableId =
com.webank.wedpr.components.task.plugin.pir.utils.Constant.datasetId2tableId(
dataset.getDatasetId());
Expand Down
Loading