Skip to content

[8.19] [ML] Check for model deployment in inference endpoints before stopping (#129325) #129907

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 24, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions docs/changelog/129325.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
pr: 129325
summary: Check for model deployment in inference endpoints before stopping
area: Machine Learning
type: bug
issues:
- 128549
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ public void testAttachToDeployment() throws IOException {
var deploymentStats = stats.get(0).get("deployment_stats");
assertNotNull(stats.toString(), deploymentStats);

stopMlNodeDeployment(deploymentId);
forceStopMlNodeDeployment(deploymentId);
}

public void testAttachWithModelId() throws IOException {
Expand Down Expand Up @@ -146,7 +146,7 @@ public void testAttachWithModelId() throws IOException {
)
);

stopMlNodeDeployment(deploymentId);
forceStopMlNodeDeployment(deploymentId);
}

public void testModelIdDoesNotMatch() throws IOException {
Expand Down Expand Up @@ -229,6 +229,29 @@ public void testNumAllocationsIsUpdated() throws IOException {
);
}

public void testStoppingDeploymentAttachedToInferenceEndpoint() throws IOException {
var modelId = "try_stop_attach_to_deployment";
var deploymentId = "test_stop_attach_to_deployment";

CustomElandModelIT.createMlNodeTextExpansionModel(modelId, client());
var response = startMlNodeDeploymemnt(modelId, deploymentId);
assertStatusOkOrCreated(response);

var inferenceId = "test_stop_inference_on_existing_deployment";
putModel(inferenceId, endpointConfig(deploymentId), TaskType.SPARSE_EMBEDDING);

var stopShouldNotSucceed = expectThrows(ResponseException.class, () -> stopMlNodeDeployment(deploymentId));
assertThat(
stopShouldNotSucceed.getMessage(),
containsString(
Strings.format("Cannot stop deployment [%s] as it is used by inference endpoint [%s]", deploymentId, inferenceId)
)
);

// Force stop will stop the deployment
forceStopMlNodeDeployment(deploymentId);
}

private String endpointConfig(String deploymentId) {
return Strings.format("""
{
Expand Down Expand Up @@ -292,6 +315,12 @@ private Response updateMlNodeDeploymemnt(String deploymentId, int numAllocations
}

protected void stopMlNodeDeployment(String deploymentId) throws IOException {
String endpoint = "/_ml/trained_models/" + deploymentId + "/deployment/_stop";
Request request = new Request("POST", endpoint);
client().performRequest(request);
}

protected void forceStopMlNodeDeployment(String deploymentId) throws IOException {
String endpoint = "/_ml/trained_models/" + deploymentId + "/deployment/_stop";
Request request = new Request("POST", endpoint);
request.addParameter("force", "true");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
package org.elasticsearch.xpack.inference;

import org.elasticsearch.client.Request;
import org.elasticsearch.client.ResponseException;
import org.elasticsearch.client.RestClient;
import org.elasticsearch.core.Strings;
import org.elasticsearch.inference.TaskType;
Expand All @@ -18,6 +19,8 @@
import java.util.List;
import java.util.stream.Collectors;

import static org.hamcrest.Matchers.containsString;

public class CustomElandModelIT extends InferenceBaseRestTest {

// The model definition is taken from org.elasticsearch.xpack.ml.integration.TextExpansionQueryIT
Expand Down Expand Up @@ -92,6 +95,47 @@ public void testSparse() throws IOException {
assertNotNull(results.get("sparse_embedding"));
}

public void testCannotStopDeployment() throws IOException {
String modelId = "custom-model-that-cannot-be-stopped";

createTextExpansionModel(modelId, client());
putModelDefinition(modelId, BASE_64_ENCODED_MODEL, RAW_MODEL_SIZE, client());
putVocabulary(
List.of("these", "are", "my", "words", "the", "washing", "machine", "is", "leaking", "octopus", "comforter", "smells"),
modelId,
client()
);

var inferenceConfig = """
{
"service": "elasticsearch",
"service_settings": {
"model_id": "custom-model-that-cannot-be-stopped",
"num_allocations": 1,
"num_threads": 1
}
}
""";

var inferenceId = "sparse-inf";
putModel(inferenceId, inferenceConfig, TaskType.SPARSE_EMBEDDING);
infer(inferenceId, List.of("washing", "machine"));

// Stopping the deployment using the ML trained models API should fail
// because the deployment was created by the inference endpoint API
String stopEndpoint = org.elasticsearch.common.Strings.format("_ml/trained_models/%s/deployment/_stop?error_trace", inferenceId);
Request stopRequest = new Request("POST", stopEndpoint);
var e = expectThrows(ResponseException.class, () -> client().performRequest(stopRequest));
assertThat(
e.getMessage(),
containsString("Cannot stop deployment [sparse-inf] as it was created by inference endpoint [sparse-inf]")
);

// Force stop works
String forceStopEndpoint = org.elasticsearch.common.Strings.format("_ml/trained_models/%s/deployment/_stop?force", inferenceId);
assertStatusOkOrCreated(client().performRequest(new Request("POST", forceStopEndpoint)));
}

static void createTextExpansionModel(String modelId, RestClient client) throws IOException {
// with_special_tokens: false for this test with limited vocab
Request request = new Request("PUT", "/_ml/trained_models/" + modelId);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,20 +17,25 @@
import org.elasticsearch.action.TaskOperationFailure;
import org.elasticsearch.action.support.ActionFilters;
import org.elasticsearch.action.support.tasks.TransportTasksAction;
import org.elasticsearch.client.internal.Client;
import org.elasticsearch.client.internal.OriginSettingClient;
import org.elasticsearch.cluster.ClusterState;
import org.elasticsearch.cluster.node.DiscoveryNode;
import org.elasticsearch.cluster.node.DiscoveryNodes;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.util.concurrent.EsExecutors;
import org.elasticsearch.common.xcontent.XContentHelper;
import org.elasticsearch.discovery.MasterNotDiscoveredException;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.ingest.IngestMetadata;
import org.elasticsearch.ingest.IngestService;
import org.elasticsearch.injection.guice.Inject;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.tasks.CancellableTask;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.transport.TransportResponseHandler;
import org.elasticsearch.transport.TransportService;
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.core.inference.action.GetInferenceModelAction;
import org.elasticsearch.xpack.core.ml.action.StopTrainedModelDeploymentAction;
import org.elasticsearch.xpack.core.ml.inference.assignment.TrainedModelAssignment;
import org.elasticsearch.xpack.core.ml.inference.assignment.TrainedModelAssignmentMetadata;
Expand All @@ -47,6 +52,7 @@
import java.util.Set;

import static org.elasticsearch.core.Strings.format;
import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN;
import static org.elasticsearch.xpack.ml.action.TransportDeleteTrainedModelAction.getModelAliases;

/**
Expand All @@ -63,7 +69,7 @@ public class TransportStopTrainedModelDeploymentAction extends TransportTasksAct

private static final Logger logger = LogManager.getLogger(TransportStopTrainedModelDeploymentAction.class);

private final IngestService ingestService;
private final OriginSettingClient client;
private final TrainedModelAssignmentClusterService trainedModelAssignmentClusterService;
private final InferenceAuditor auditor;

Expand All @@ -72,7 +78,7 @@ public TransportStopTrainedModelDeploymentAction(
ClusterService clusterService,
TransportService transportService,
ActionFilters actionFilters,
IngestService ingestService,
Client client,
TrainedModelAssignmentClusterService trainedModelAssignmentClusterService,
InferenceAuditor auditor
) {
Expand All @@ -85,7 +91,7 @@ public TransportStopTrainedModelDeploymentAction(
StopTrainedModelDeploymentAction.Response::new,
EsExecutors.DIRECT_EXECUTOR_SERVICE
);
this.ingestService = ingestService;
this.client = new OriginSettingClient(client, ML_ORIGIN);
this.trainedModelAssignmentClusterService = trainedModelAssignmentClusterService;
this.auditor = Objects.requireNonNull(auditor);
}
Expand Down Expand Up @@ -154,21 +160,84 @@ protected void doExecute(

// NOTE, should only run on Master node
assert clusterService.localNode().isMasterNode();

if (request.isForce() == false) {
checkIfUsedByInferenceEndpoint(
request.getId(),
ActionListener.wrap(canStop -> stopDeployment(task, request, maybeAssignment.get(), listener), listener::onFailure)
);
} else {
stopDeployment(task, request, maybeAssignment.get(), listener);
}
}

private void stopDeployment(
Task task,
StopTrainedModelDeploymentAction.Request request,
TrainedModelAssignment assignment,
ActionListener<StopTrainedModelDeploymentAction.Response> listener
) {
trainedModelAssignmentClusterService.setModelAssignmentToStopping(
request.getId(),
ActionListener.wrap(
setToStopping -> normalUndeploy(task, request.getId(), maybeAssignment.get(), request, listener),
failure -> {
if (ExceptionsHelper.unwrapCause(failure) instanceof ResourceNotFoundException) {
listener.onResponse(new StopTrainedModelDeploymentAction.Response(true));
return;
}
listener.onFailure(failure);
ActionListener.wrap(setToStopping -> normalUndeploy(task, request.getId(), assignment, request, listener), failure -> {
if (ExceptionsHelper.unwrapCause(failure) instanceof ResourceNotFoundException) {
listener.onResponse(new StopTrainedModelDeploymentAction.Response(true));
return;
}
)
listener.onFailure(failure);
})
);
}

private void checkIfUsedByInferenceEndpoint(String deploymentId, ActionListener<Boolean> listener) {

GetInferenceModelAction.Request getAllEndpoints = new GetInferenceModelAction.Request("*", TaskType.ANY);
client.execute(GetInferenceModelAction.INSTANCE, getAllEndpoints, listener.delegateFailureAndWrap((l, response) -> {
// filter by the ml node services
var mlNodeEndpoints = response.getEndpoints()
.stream()
.filter(model -> model.getService().equals("elasticsearch") || model.getService().equals("elser"))
.toList();

var endpointOwnsDeployment = mlNodeEndpoints.stream()
.filter(model -> model.getInferenceEntityId().equals(deploymentId))
.findFirst();
if (endpointOwnsDeployment.isPresent()) {
l.onFailure(
new ElasticsearchStatusException(
"Cannot stop deployment [{}] as it was created by inference endpoint [{}]",
RestStatus.CONFLICT,
deploymentId,
endpointOwnsDeployment.get().getInferenceEntityId()
)
);
return;
}

// The inference endpoint may have been created by attaching to an existing deployment.
for (var endpoint : mlNodeEndpoints) {
var serviceSettingsXContent = XContentHelper.toXContent(endpoint.getServiceSettings(), XContentType.JSON, false);
var settingsMap = XContentHelper.convertToMap(serviceSettingsXContent, false, XContentType.JSON).v2();
// Endpoints with the deployment_id setting are attached to an existing deployment.
var deploymentIdFromSettings = (String) settingsMap.get("deployment_id");
if (deploymentIdFromSettings != null && deploymentIdFromSettings.equals(deploymentId)) {
// The endpoint was created to use this deployment
l.onFailure(
new ElasticsearchStatusException(
"Cannot stop deployment [{}] as it is used by inference endpoint [{}]",
RestStatus.CONFLICT,
deploymentId,
endpoint.getInferenceEntityId()
)
);
return;
}
}

l.onResponse(true);
}));
}

private void redirectToMasterNode(
DiscoveryNode masterNode,
StopTrainedModelDeploymentAction.Request request,
Expand Down