Skip to content

Commit

Permalink
support kill job
Browse files Browse the repository at this point in the history
  • Loading branch information
cyjseagull committed Nov 1, 2024
1 parent 6ee87c5 commit d1fd2f2
Show file tree
Hide file tree
Showing 18 changed files with 169 additions and 58 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,15 @@

package com.webank.wedpr.components.loadbalancer;

import java.util.List;

public interface LoadBalancer {
public static enum Policy {
ROUND_ROBIN,
HASH,
}

EntryPointInfo selectService(Policy policy, String serviceType);

List<EntryPointInfo> selectAllEndPoint(String serviceType);
}
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,11 @@ public LoadBalancerImpl(EntryPointFetcher entryPointFetcher) {
this.entryPointFetcher = entryPointFetcher;
}

@Override
public List<EntryPointInfo> selectAllEndPoint(String serviceType) {
return entryPointFetcher.getAliveEntryPoints(serviceType);
}

@Override
public EntryPointInfo selectService(Policy policy, String serviceType) {
List<EntryPointInfo> entryPointInfoList =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@ public static JobResult deserialize(String data) {
@JsonIgnore private transient List<FollowerDO> taskParties;

@JsonIgnore private transient Integer limitItems = -1;
@JsonIgnore private transient Boolean killed = false;

// shouldSync or not
private Boolean shouldSync;
Expand Down Expand Up @@ -304,6 +305,9 @@ public Boolean isJobParty(String agency) {
if (this.ownerAgency.compareToIgnoreCase(agency) == 0) {
return Boolean.TRUE;
}
if (taskParties == null) {
return Boolean.FALSE;
}
for (FollowerDO followerDO : taskParties) {
if (followerDO.getAgency().compareToIgnoreCase(agency) == 0) {
return Boolean.TRUE;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@
select
<choose>
<when test="onlyMeta == true">
`id`, `owner`, `owner_agency`, `job_type`
`id`, `owner`, `owner_agency`, `job_type`, `parties`
</when>
<otherwise>
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import com.webank.wedpr.components.http.client.HttpClientImpl;
import com.webank.wedpr.components.scheduler.dag.entity.JobWorker;
import com.webank.wedpr.components.scheduler.dag.utils.WorkerUtils;
import com.webank.wedpr.components.scheduler.dag.worker.WorkerStatus;
import com.webank.wedpr.components.scheduler.executor.impl.ml.MLExecutorConfig;
import com.webank.wedpr.components.scheduler.executor.impl.ml.request.ModelJobRequest;
import com.webank.wedpr.components.scheduler.executor.impl.ml.response.MLResponse;
Expand Down Expand Up @@ -71,7 +72,7 @@ public String submitTask(String params, JobWorker jobWorker) throws Exception {
}

@SneakyThrows
public void pollTask(String taskId) throws WeDPRException {
public WorkerStatus pollTask(String taskId) throws WeDPRException {

String requestUrl = MLExecutorConfig.getRunTaskApiUrl(url, taskId);

Expand Down Expand Up @@ -102,9 +103,13 @@ public void pollTask(String taskId) throws WeDPRException {
+ " ,response: "
+ response);
}
if (response.killed()) {
logger.info("The ml task {} has been killed, response: {}", taskId, response);
return response.getData().getWorkerStatus();
}

if (response.success()) {
return;
return response.getData().getWorkerStatus();
}

// task is running
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import static com.webank.wedpr.components.scheduler.SchedulerBuilder.WORKER_NAME;

import com.webank.wedpr.common.protocol.ExecutorType;
import com.webank.wedpr.common.protocol.JobStatus;
import com.webank.wedpr.common.utils.ThreadPoolService;
import com.webank.wedpr.components.api.credential.dao.ApiCredentialMapper;
import com.webank.wedpr.components.crypto.CryptoToolkitFactory;
Expand Down Expand Up @@ -132,7 +133,33 @@ protected void registerExecutors(
new ExecutiveContextBuilder(projectMapperWrapper),
threadPoolService);
executorManager.registerExecutor(ExecutorType.DAG.getType(), dagSchedulerExecutor);
// register the pir executor, TODO: implement the taskFinishHandler
// default
TaskFinishedHandler taskFinishedHandler =
new TaskFinishedHandler() {
@Override
public void onFinish(JobDO jobDO, ExecuteResult result) {
try {
if (result.getResultStatus() == null
|| result.getResultStatus().failed()) {
projectMapperWrapper.updateFinalJobResult(
jobDO, JobStatus.RunFailed, result.serialize());
} else {
projectMapperWrapper.updateFinalJobResult(
jobDO, JobStatus.RunSuccess, result.serialize());
}
} catch (Exception e) {
logger.error(
"update job status for job {} failed, result: {}, error: ",
jobDO.getId(),
result.toString(),
e);
}
}
};

executorManager.registerOnTaskFinished(ExecutorType.DAG.getType(), taskFinishedHandler);

// register the pir executor
executorManager.registerExecutor(
ExecutorType.PIR.getType(),
new PirExecutor(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ public void onReceiveRunAction(ResourceSyncer.CommitArgs commitArgs)
public void onReceiveKillAction(ResourceSyncer.CommitArgs commitArgs)
throws JsonProcessingException {
BatchJobList jobList =
BatchJobList.deserialize(commitArgs.getResourceActionRecord().getResourceAction());
BatchJobList.deserialize(commitArgs.getResourceActionRecord().getResourceContent());
logger.info("onReceiveKillAction, job size: {}", jobList.getJobs());
this.scheduler.batchKillJobs(jobList.getJobs());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -189,10 +189,13 @@ public void executeWorker(Worker worker) throws Exception {
workerId);
return;
}
worker.run(jobWorker.getStatus());

jobWorkerMapper.updateJobWorkerStatus(workerId, WorkerStatus.SUCCESS.name());
logger.info("worker executed successfully, jobId: {}, workId: {}", jobId, workerId);
WorkerStatus status = worker.run(jobWorker.getStatus());
if (status != WorkerStatus.KILLED) {
jobWorkerMapper.updateJobWorkerStatus(workerId, status.name());
logger.info("worker executed successfully, jobId: {}, workId: {}", jobId, workerId);
} else {
logger.info("worker has been killed, jobId: {}, workId: {}", jobId, workerId);
}
} catch (Exception e) {
logger.error("worker executed failed, jobId: {}, workId: {}, e: ", jobId, workerId, e);
jobWorkerMapper.updateJobWorkerStatus(workerId, WorkerStatus.FAILURE.name());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ public ModelWorker(
}

@Override
public void engineRun() throws Exception {
public WorkerStatus engineRun() throws Exception {

String jobId = getJobId();
String workerId = getWorkerId();
Expand Down Expand Up @@ -52,8 +52,7 @@ public void engineRun() throws Exception {
// submit task
String taskId = modelClient.submitTask(args, getJobWorker());
// poll until the task finished
modelClient.pollTask(getJobWorker().getWorkerId());

return modelClient.pollTask(getJobWorker().getWorkerId());
} finally {
long endTimeMillis = System.currentTimeMillis();
logger.info(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ public MpcWorker(
}

@Override
public void engineRun() throws WeDPRException {
public WorkerStatus engineRun() throws WeDPRException {

EntryPointInfo entryPoint =
getLoadBalancer()
Expand All @@ -33,5 +33,6 @@ public void engineRun() throws WeDPRException {
}

logger.info("## getting mpc client: {}", entryPoint);
return WorkerStatus.SUCCESS;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ public PsiWorker(
}

@Override
public void engineRun() throws Exception {
public WorkerStatus engineRun() throws Exception {

String jobId = getJobId();
String workerId = getWorkerId();
Expand Down Expand Up @@ -57,6 +57,7 @@ public void engineRun() throws Exception {
String taskId = psiClient.submitTask(workerArgs);
// poll until the task finished
psiClient.pollTask(taskId);
return WorkerStatus.SUCCESS;
} finally {
long endTimeMillis = System.currentTimeMillis();
logger.info(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,16 +81,16 @@ public void logWorker() {
*
* @return
*/
public abstract void engineRun() throws Exception;
public abstract WorkerStatus engineRun() throws Exception;

public boolean run(String workerStatus) throws Exception {
public WorkerStatus run(String workerStatus) throws Exception {

if (workerStatus.equals(WorkerStatus.SUCCESS.name())) {
logger.info(
"worker has been executed successfully, jobId: {}, workId: {}",
jobId,
workerId);
return false;
return WorkerStatus.SUCCESS;
}

logWorker();
Expand All @@ -101,9 +101,9 @@ public boolean run(String workerStatus) throws Exception {
while (attemptTimes++ < retryTimes) {
try {
logger.info(workerStartLog(workerId));
this.engineRun();
WorkerStatus status = this.engineRun();
logger.info(workerEndLog(workerId));
return true;
return status;
} catch (Exception e) {
if (attemptTimes >= retryTimes) {
logger.error(
Expand All @@ -125,7 +125,7 @@ public boolean run(String workerStatus) throws Exception {
}
}

return false;
return WorkerStatus.FAILURE;
}

String workerStartLog(String workId) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,11 @@ public String getStatus() {
}

public boolean isFailed() {
return ordinal() == WorkerStatus.FAILURE.ordinal()
|| ordinal() == WorkerStatus.KILLED.ordinal();
return ordinal() == WorkerStatus.FAILURE.ordinal();
}

public boolean isKilled() {
return ordinal() == WorkerStatus.KILLED.ordinal();
}

public boolean isSuccess() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,13 @@ public TaskFinishedHandler getTaskFinishedHandler() {
}

public void onTaskFinished(ExecuteResult result) {
// need to kill the job, no need to call the handler
if (job.getKilled()) {
logger.info(
"onTaskFinished return directly for the job is been killed, job: {}",
job.toString());
return;
}
JobDO.JobResultItem subJobResult =
new JobDO.JobResultItem(
taskID, result.getResultStatus().success(), result.serialize());
Expand Down
Loading

0 comments on commit d1fd2f2

Please sign in to comment.