Skip to content

Commit

Permalink
support kill job
Browse files Browse the repository at this point in the history
  • Loading branch information
cyjseagull committed Nov 1, 2024
1 parent 6ee87c5 commit f274edf
Show file tree
Hide file tree
Showing 17 changed files with 138 additions and 35 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,15 @@

package com.webank.wedpr.components.loadbalancer;

import java.util.List;

public interface LoadBalancer {
public static enum Policy {
ROUND_ROBIN,
HASH,
}

EntryPointInfo selectService(Policy policy, String serviceType);

List<EntryPointInfo> selectAllEndPoint(String serviceType);
}
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,11 @@ public LoadBalancerImpl(EntryPointFetcher entryPointFetcher) {
this.entryPointFetcher = entryPointFetcher;
}

@Override
public List<EntryPointInfo> selectAllEndPoint(String serviceType) {
return entryPointFetcher.getAliveEntryPoints(serviceType);
}

@Override
public EntryPointInfo selectService(Policy policy, String serviceType) {
List<EntryPointInfo> entryPointInfoList =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@ public static JobResult deserialize(String data) {
@JsonIgnore private transient List<FollowerDO> taskParties;

@JsonIgnore private transient Integer limitItems = -1;
@JsonIgnore private transient Boolean killed = false;

// shouldSync or not
private Boolean shouldSync;
Expand Down Expand Up @@ -304,6 +305,9 @@ public Boolean isJobParty(String agency) {
if (this.ownerAgency.compareToIgnoreCase(agency) == 0) {
return Boolean.TRUE;
}
if (taskParties == null) {
return Boolean.FALSE;
}
for (FollowerDO followerDO : taskParties) {
if (followerDO.getAgency().compareToIgnoreCase(agency) == 0) {
return Boolean.TRUE;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@
select
<choose>
<when test="onlyMeta == true">
`id`, `owner`, `owner_agency`, `job_type`
`id`, `owner`, `owner_agency`, `job_type`, `parties`
</when>
<otherwise>
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import com.webank.wedpr.components.http.client.HttpClientImpl;
import com.webank.wedpr.components.scheduler.dag.entity.JobWorker;
import com.webank.wedpr.components.scheduler.dag.utils.WorkerUtils;
import com.webank.wedpr.components.scheduler.dag.worker.WorkerStatus;
import com.webank.wedpr.components.scheduler.executor.impl.ml.MLExecutorConfig;
import com.webank.wedpr.components.scheduler.executor.impl.ml.request.ModelJobRequest;
import com.webank.wedpr.components.scheduler.executor.impl.ml.response.MLResponse;
Expand Down Expand Up @@ -71,7 +72,7 @@ public String submitTask(String params, JobWorker jobWorker) throws Exception {
}

@SneakyThrows
public void pollTask(String taskId) throws WeDPRException {
public WorkerStatus pollTask(String taskId) throws WeDPRException {

String requestUrl = MLExecutorConfig.getRunTaskApiUrl(url, taskId);

Expand Down Expand Up @@ -102,9 +103,13 @@ public void pollTask(String taskId) throws WeDPRException {
+ " ,response: "
+ response);
}
if (response.killed()) {
logger.info("The ml task {} has been killed, response: {}", taskId, response);
return response.getData().getWorkerStatus();
}

if (response.success()) {
return;
return response.getData().getWorkerStatus();
}

// task is running
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ protected void registerExecutors(
new ExecutiveContextBuilder(projectMapperWrapper),
threadPoolService);
executorManager.registerExecutor(ExecutorType.DAG.getType(), dagSchedulerExecutor);
// register the pir executor, TODO: implement the taskFinishHandler
// register the pir executor
executorManager.registerExecutor(
ExecutorType.PIR.getType(),
new PirExecutor(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ public void onReceiveRunAction(ResourceSyncer.CommitArgs commitArgs)
public void onReceiveKillAction(ResourceSyncer.CommitArgs commitArgs)
throws JsonProcessingException {
BatchJobList jobList =
BatchJobList.deserialize(commitArgs.getResourceActionRecord().getResourceAction());
BatchJobList.deserialize(commitArgs.getResourceActionRecord().getResourceContent());
logger.info("onReceiveKillAction, job size: {}", jobList.getJobs());
this.scheduler.batchKillJobs(jobList.getJobs());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -189,10 +189,13 @@ public void executeWorker(Worker worker) throws Exception {
workerId);
return;
}
worker.run(jobWorker.getStatus());

jobWorkerMapper.updateJobWorkerStatus(workerId, WorkerStatus.SUCCESS.name());
logger.info("worker executed successfully, jobId: {}, workId: {}", jobId, workerId);
WorkerStatus status = worker.run(jobWorker.getStatus());
if (status != WorkerStatus.KILLED) {
jobWorkerMapper.updateJobWorkerStatus(workerId, status.name());
logger.info("worker executed successfully, jobId: {}, workId: {}", jobId, workerId);
} else {
logger.info("worker has been killed, jobId: {}, workId: {}", jobId, workerId);
}
} catch (Exception e) {
logger.error("worker executed failed, jobId: {}, workId: {}, e: ", jobId, workerId, e);
jobWorkerMapper.updateJobWorkerStatus(workerId, WorkerStatus.FAILURE.name());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ public ModelWorker(
}

@Override
public void engineRun() throws Exception {
public WorkerStatus engineRun() throws Exception {

String jobId = getJobId();
String workerId = getWorkerId();
Expand Down Expand Up @@ -52,8 +52,7 @@ public void engineRun() throws Exception {
// submit task
String taskId = modelClient.submitTask(args, getJobWorker());
// poll until the task finished
modelClient.pollTask(getJobWorker().getWorkerId());

return modelClient.pollTask(getJobWorker().getWorkerId());
} finally {
long endTimeMillis = System.currentTimeMillis();
logger.info(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ public MpcWorker(
}

@Override
public void engineRun() throws WeDPRException {
public WorkerStatus engineRun() throws WeDPRException {

EntryPointInfo entryPoint =
getLoadBalancer()
Expand All @@ -33,5 +33,6 @@ public void engineRun() throws WeDPRException {
}

logger.info("## getting mpc client: {}", entryPoint);
return WorkerStatus.SUCCESS;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ public PsiWorker(
}

@Override
public void engineRun() throws Exception {
public WorkerStatus engineRun() throws Exception {

String jobId = getJobId();
String workerId = getWorkerId();
Expand Down Expand Up @@ -57,6 +57,7 @@ public void engineRun() throws Exception {
String taskId = psiClient.submitTask(workerArgs);
// poll until the task finished
psiClient.pollTask(taskId);
return WorkerStatus.SUCCESS;
} finally {
long endTimeMillis = System.currentTimeMillis();
logger.info(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,16 +81,16 @@ public void logWorker() {
*
* @return
*/
public abstract void engineRun() throws Exception;
public abstract WorkerStatus engineRun() throws Exception;

public boolean run(String workerStatus) throws Exception {
public WorkerStatus run(String workerStatus) throws Exception {

if (workerStatus.equals(WorkerStatus.SUCCESS.name())) {
logger.info(
"worker has been executed successfully, jobId: {}, workId: {}",
jobId,
workerId);
return false;
return WorkerStatus.SUCCESS;
}

logWorker();
Expand All @@ -101,9 +101,9 @@ public boolean run(String workerStatus) throws Exception {
while (attemptTimes++ < retryTimes) {
try {
logger.info(workerStartLog(workerId));
this.engineRun();
WorkerStatus status = this.engineRun();
logger.info(workerEndLog(workerId));
return true;
return status;
} catch (Exception e) {
if (attemptTimes >= retryTimes) {
logger.error(
Expand All @@ -125,7 +125,7 @@ public boolean run(String workerStatus) throws Exception {
}
}

return false;
return WorkerStatus.FAILURE;
}

String workerStartLog(String workId) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,11 @@ public String getStatus() {
}

public boolean isFailed() {
return ordinal() == WorkerStatus.FAILURE.ordinal()
|| ordinal() == WorkerStatus.KILLED.ordinal();
return ordinal() == WorkerStatus.FAILURE.ordinal();
}

public boolean isKilled() {
return ordinal() == WorkerStatus.KILLED.ordinal();
}

public boolean isSuccess() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,13 @@ public TaskFinishedHandler getTaskFinishedHandler() {
}

public void onTaskFinished(ExecuteResult result) {
// need to kill the job, no need to call the handler
if (!job.getKilled()) {
logger.info(
"onTaskFinished return directly for the job is been killed, job: {}",
job.toString());
return;
}
JobDO.JobResultItem subJobResult =
new JobDO.JobResultItem(
taskID, result.getResultStatus().success(), result.serialize());
Expand Down
Original file line number Diff line number Diff line change
@@ -1,25 +1,33 @@
package com.webank.wedpr.components.scheduler.executor.impl.dag;

import com.webank.wedpr.common.protocol.ExecutorType;
import com.webank.wedpr.common.utils.BaseResponse;
import com.webank.wedpr.common.utils.ThreadPoolService;
import com.webank.wedpr.common.utils.WeDPRException;
import com.webank.wedpr.components.http.client.HttpClientImpl;
import com.webank.wedpr.components.loadbalancer.EntryPointInfo;
import com.webank.wedpr.components.loadbalancer.LoadBalancer;
import com.webank.wedpr.components.project.JobChecker;
import com.webank.wedpr.components.project.dao.JobDO;
import com.webank.wedpr.components.scheduler.api.WorkFlowOrchestratorApi;
import com.webank.wedpr.components.scheduler.dag.DagWorkFlowSchedulerImpl;
import com.webank.wedpr.components.scheduler.dag.api.WorkFlowScheduler;
import com.webank.wedpr.components.scheduler.dag.utils.ServiceName;
import com.webank.wedpr.components.scheduler.executor.ExecuteResult;
import com.webank.wedpr.components.scheduler.executor.Executor;
import com.webank.wedpr.components.scheduler.executor.callback.TaskFinishedHandler;
import com.webank.wedpr.components.scheduler.executor.impl.ExecutiveContext;
import com.webank.wedpr.components.scheduler.executor.impl.ExecutiveContextBuilder;
import com.webank.wedpr.components.scheduler.executor.impl.ml.MLExecutorConfig;
import com.webank.wedpr.components.scheduler.executor.impl.ml.response.MLResponseFactory;
import com.webank.wedpr.components.scheduler.executor.impl.model.FileMetaBuilder;
import com.webank.wedpr.components.scheduler.executor.manager.ExecutorManager;
import com.webank.wedpr.components.scheduler.mapper.JobWorkerMapper;
import com.webank.wedpr.components.scheduler.workflow.WorkFlow;
import com.webank.wedpr.components.scheduler.workflow.WorkFlowOrchestrator;
import com.webank.wedpr.components.scheduler.workflow.builder.JobWorkFlowBuilderManager;
import com.webank.wedpr.components.storage.api.FileStorageInterface;
import java.util.List;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

Expand All @@ -30,10 +38,10 @@ public class DagSchedulerExecutor implements Executor {
private final WorkFlowScheduler workFlowScheduler;
private final WorkFlowOrchestratorApi workflowOrchestrator;
private final ExecutorManager executorManager;

private final ExecutiveContextBuilder executiveContextBuilder;

private final ThreadPoolService threadPoolService;
private final LoadBalancer loadBalancer;

public DagSchedulerExecutor(
LoadBalancer loadBalancer,
Expand All @@ -44,6 +52,7 @@ public DagSchedulerExecutor(
ExecutorManager executorManager,
ExecutiveContextBuilder executiveContextBuilder,
ThreadPoolService threadPoolService) {
this.loadBalancer = loadBalancer;
this.executiveContextBuilder = executiveContextBuilder;
this.threadPoolService = threadPoolService;
this.executorManager = executorManager;
Expand Down Expand Up @@ -89,7 +98,6 @@ public void innerExecute(JobDO jobDO) {
WorkFlow workflow = workflowOrchestrator.buildWorkFlow(jobDO);

this.workFlowScheduler.schedule(jobDO.getId(), workflow);

executiveContext.onTaskFinished(new ExecuteResult(ExecuteResult.ResultStatus.SUCCESS));

long endTimeMillis = System.currentTimeMillis();
Expand All @@ -100,10 +108,8 @@ public void innerExecute(JobDO jobDO) {
(endTimeMillis - startTimeMillis));

} catch (Exception e) {

executiveContext.onTaskFinished(
new ExecuteResult(e.getMessage(), ExecuteResult.ResultStatus.FAILED));

long endTimeMillis = System.currentTimeMillis();

logger.warn(
Expand All @@ -115,7 +121,56 @@ public void innerExecute(JobDO jobDO) {
}

@Override
public void kill(JobDO jobDO) throws Exception {}
public void kill(JobDO jobDO) throws Exception {
if (jobDO.getType().mlJob()) {
killModelJob(jobDO);
}
}

// Note: since the job may exist in any node, establish kill command to all nodes
public void killModelJob(JobDO jobDO) throws Exception {
logger.info("killModelJob: {}", jobDO.getId());
List<EntryPointInfo> aliveEntryPoint =
loadBalancer.selectAllEndPoint(ServiceName.MODEL.getValue());
if (aliveEntryPoint == null || aliveEntryPoint.isEmpty()) {
return;
}
boolean failed = false;
String reason = "";
for (EntryPointInfo entryPointInfo : aliveEntryPoint) {
try {
logger.info("kill job: {}, entrypoint: {}", jobDO.toString(), entryPointInfo);
HttpClientImpl httpClient =
new HttpClientImpl(
MLExecutorConfig.getRunTaskApiUrl(
entryPointInfo.getEntryPoint(), jobDO.getId()),
MLExecutorConfig.getMaxTotalConnection(),
MLExecutorConfig.buildConfig(),
new MLResponseFactory());
BaseResponse response = httpClient.execute(httpClient.getUrl(), true);
if (response.statusOk()) {
logger.info(
"kill job success: {}, entrypoint: {}",
jobDO.getJobRequest(),
entryPointInfo);
return;
}
logger.error(
"kill job {} failed, response: {}, entrypoint: {}",
jobDO.getId(),
response.serialize(),
entryPointInfo);
throw new WeDPRException("kill job failed, response: " + response.serialize());
} catch (Exception e) {
failed = true;
reason = e.getMessage();
}
}
if (failed) {
throw new WeDPRException(reason);
}
logger.info("killModelJob: {} success", jobDO.getId());
}

@Override
public ExecuteResult queryStatus(String jobID) throws Exception {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,13 @@ public Boolean failed() {
return data.getWorkerStatus().isFailed();
}

public Boolean killed() {
if (data == null) {
return Boolean.FALSE;
}
return data.getWorkerStatus().isKilled();
}

public static MLResponse deserialize(String data) throws JsonProcessingException {
return ObjectMapperFactory.getObjectMapper().readValue(data, MLResponse.class);
}
Expand Down
Loading

0 comments on commit f274edf

Please sign in to comment.