Skip to content

Commit

Permalink
Check before delete (#3209)
Browse files Browse the repository at this point in the history
* add logic to detect agent before deleting

Signed-off-by: xinyual <[email protected]>

* add logic to detect agent before deleting

Signed-off-by: xinyual <[email protected]>

* add logic to detect pipelines before delete model

Signed-off-by: xinyual <[email protected]>

* check pipeline before deleting

Signed-off-by: xinyual <[email protected]>

* apply spotless

Signed-off-by: xinyual <[email protected]>

* remove useless file

Signed-off-by: xinyual <[email protected]>

* rename functions

Signed-off-by: xinyual <[email protected]>

* fix failure test

Signed-off-by: xinyual <[email protected]>

* add UT

Signed-off-by: xinyual <[email protected]>

* apply spotless

Signed-off-by: xinyual <[email protected]>

* renam

Signed-off-by: xinyual <[email protected]>

* refactor to parallel check

Signed-off-by: xinyual <[email protected]>

* concate error message

Signed-off-by: xinyual <[email protected]>

* move logic after user access check

Signed-off-by: xinyual <[email protected]>

* change agent model searcher map to set

Signed-off-by: xinyual <[email protected]>

* rename and remove useless method

Signed-off-by: xinyual <[email protected]>

* fix bug to fetch all pipelines

Signed-off-by: xinyual <[email protected]>

* apply spotless

Signed-off-by: xinyual <[email protected]>

* apply spotless

Signed-off-by: xinyual <[email protected]>

* remove and add comment

Signed-off-by: xinyual <[email protected]>

* rename and add more UTs

Signed-off-by: xinyual <[email protected]>

* use correct key

Signed-off-by: xinyual <[email protected]>

* simplify function

Signed-off-by: xinyual <[email protected]>

* change to a better class

Signed-off-by: xinyual <[email protected]>

* apply spotless

Signed-off-by: xinyual <[email protected]>

* change compareAndSet to set

Signed-off-by: xinyual <[email protected]>

* apply comment

Signed-off-by: xinyual <[email protected]>

* change name and reformat logic

Signed-off-by: xinyual <[email protected]>

* change name

Signed-off-by: xinyual <[email protected]>

* remove useless line

Signed-off-by: xinyual <[email protected]>

* change to a better method

Signed-off-by: xinyual <[email protected]>

* change name

Signed-off-by: xinyual <[email protected]>

* apply spotless

Signed-off-by: xinyual <[email protected]>

* add java doc for function

Signed-off-by: xinyual <[email protected]>

* add another interface

Signed-off-by: xinyual <[email protected]>

* apply java spotless

Signed-off-by: xinyual <[email protected]>

* change interface to with model

Signed-off-by: xinyual <[email protected]>

* apply spot less

Signed-off-by: xinyual <[email protected]>

* add settings

Signed-off-by: xinyual <[email protected]>

* apply spot less

Signed-off-by: xinyual <[email protected]>

* add test for cluster setting

Signed-off-by: xinyual <[email protected]>

* apply spotless

Signed-off-by: xinyual <[email protected]>

* recover useless change

Signed-off-by: xinyual <[email protected]>

* change default value of cluster setting

Signed-off-by: xinyual <[email protected]>

* rename setting and add comment

Signed-off-by: xinyual <[email protected]>

* apply spot

Signed-off-by: xinyual <[email protected]>

* remove logic for hidden model

Signed-off-by: xinyual <[email protected]>

* reorder code

Signed-off-by: xinyual <[email protected]>

* reorder code

Signed-off-by: xinyual <[email protected]>

* reorder code

Signed-off-by: xinyual <[email protected]>

* apply spot

Signed-off-by: xinyual <[email protected]>

* add UT

Signed-off-by: xinyual <[email protected]>

* add more UT

Signed-off-by: xinyual <[email protected]>

* remove search for hidden agent

Signed-off-by: xinyual <[email protected]>

* fix logic and apply spot

Signed-off-by: xinyual <[email protected]>

* add exist for UT

Signed-off-by: xinyual <[email protected]>

* change dsl to query index

Signed-off-by: xinyual <[email protected]>

* change query logic

Signed-off-by: xinyual <[email protected]>

* remove useless ut

Signed-off-by: xinyual <[email protected]>

* rebert

Signed-off-by: xinyual <[email protected]>

* apply spot

Signed-off-by: xinyual <[email protected]>

* rechange code

Signed-off-by: xinyual <[email protected]>

* apply spot

Signed-off-by: xinyual <[email protected]>

* remove useless should

Signed-off-by: xinyual <[email protected]>

* apply spot

Signed-off-by: xinyual <[email protected]>

* fix final dsl logic and ut

Signed-off-by: xinyual <[email protected]>

---------

Signed-off-by: xinyual <[email protected]>
  • Loading branch information
xinyual authored Jan 24, 2025
1 parent af96fe0 commit 570edaf
Show file tree
Hide file tree
Showing 10 changed files with 825 additions and 47 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ public class CommonValue {
public static final String ML_MEMORY_MESSAGE_INDEX = ".plugins-ml-memory-message";
public static final String ML_STOP_WORDS_INDEX = ".plugins-ml-stop-words";
public static final Set<String> stopWordsIndices = ImmutableSet.of(".plugins-ml-stop-words");
public static final String TOOL_PARAMETERS_PREFIX = "tools.parameters.";

// Index mapping paths
public static final String ML_MODEL_GROUP_INDEX_MAPPING_PATH = "index-mappings/ml_model_group.json";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@
import org.opensearch.ml.common.output.model.ModelTensorOutput;
import org.opensearch.ml.common.output.model.ModelTensors;
import org.opensearch.ml.common.spi.tools.Parser;
import org.opensearch.ml.common.spi.tools.Tool;
import org.opensearch.ml.common.spi.tools.ToolAnnotation;
import org.opensearch.ml.common.spi.tools.WithModelTool;
import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction;
import org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest;
import org.opensearch.ml.common.utils.StringUtils;
Expand All @@ -33,7 +33,7 @@
*/
@Log4j2
@ToolAnnotation(MLModelTool.TYPE)
public class MLModelTool implements Tool {
public class MLModelTool implements WithModelTool {
public static final String TYPE = "MLModelTool";
public static final String RESPONSE_FIELD = "response_field";
public static final String MODEL_ID_FIELD = "model_id";
Expand Down Expand Up @@ -127,7 +127,7 @@ public boolean validate(Map<String, String> parameters) {
return true;
}

public static class Factory implements Tool.Factory<MLModelTool> {
public static class Factory implements WithModelTool.Factory<MLModelTool> {
private Client client;

private static Factory INSTANCE;
Expand Down Expand Up @@ -172,5 +172,10 @@ public String getDefaultType() {
public String getDefaultVersion() {
return null;
}

@Override
public List<String> getAllModelKeys() {
return List.of(MODEL_ID_FIELD);
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
package org.opensearch.ml.engine.utils;

import static org.opensearch.ml.common.CommonValue.ML_AGENT_INDEX;
import static org.opensearch.ml.common.CommonValue.TOOL_PARAMETERS_PREFIX;

import java.util.HashSet;
import java.util.Map;
import java.util.Set;

import org.opensearch.action.search.SearchRequest;
import org.opensearch.index.query.BoolQueryBuilder;
import org.opensearch.index.query.QueryBuilders;
import org.opensearch.ml.common.agent.MLAgent;
import org.opensearch.ml.common.spi.tools.Tool;
import org.opensearch.ml.common.spi.tools.WithModelTool;
import org.opensearch.search.builder.SearchSourceBuilder;

public class AgentModelsSearcher {
private final Set<String> relatedModelIdSet;

public AgentModelsSearcher(Map<String, Tool.Factory> toolFactories) {
relatedModelIdSet = new HashSet<>();
for (Map.Entry<String, Tool.Factory> entry : toolFactories.entrySet()) {
Tool.Factory toolFactory = entry.getValue();
if (toolFactory instanceof WithModelTool.Factory) {
WithModelTool.Factory withModelTool = (WithModelTool.Factory) toolFactory;
relatedModelIdSet.addAll(withModelTool.getAllModelKeys());
}
}
}

/**
* Construct a should query to search all agent which containing candidate model Id
@param candidateModelId the candidate model Id
@return a should search request towards agent index.
*/
public SearchRequest constructQueryRequestToSearchModelIdInsideAgent(String candidateModelId) {
SearchRequest searchRequest = new SearchRequest(ML_AGENT_INDEX);
// Two conditions here
// 1. {[(exists hidden field) and (hidden field = false)] or (not exist hidden field)} and
// 2. Any model field contains candidate ID
BoolQueryBuilder searchAgentQuery = QueryBuilders.boolQuery();

BoolQueryBuilder hiddenFieldQuery = QueryBuilders.boolQuery();
// not exist hidden
hiddenFieldQuery.should(QueryBuilders.boolQuery().mustNot(QueryBuilders.existsQuery(MLAgent.IS_HIDDEN_FIELD)));
// exist but equal to false
BoolQueryBuilder existHiddenFieldQuery = QueryBuilders.boolQuery();
existHiddenFieldQuery.must(QueryBuilders.termsQuery(MLAgent.IS_HIDDEN_FIELD, false));
existHiddenFieldQuery.must(QueryBuilders.existsQuery(MLAgent.IS_HIDDEN_FIELD));
hiddenFieldQuery.should(existHiddenFieldQuery);

//
BoolQueryBuilder modelIdQuery = QueryBuilders.boolQuery();
for (String keyField : relatedModelIdSet) {
modelIdQuery.should(QueryBuilders.termsQuery(TOOL_PARAMETERS_PREFIX + keyField, candidateModelId));
}

searchAgentQuery.must(hiddenFieldQuery);
searchAgentQuery.must(modelIdQuery);
searchRequest.source(new SearchSourceBuilder().query(searchAgentQuery));
return searchRequest;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.verify;
import static org.opensearch.ml.engine.tools.MLModelTool.DEFAULT_DESCRIPTION;
import static org.opensearch.ml.engine.tools.MLModelTool.MODEL_ID_FIELD;

import java.util.Arrays;
import java.util.Collections;
Expand Down Expand Up @@ -218,5 +219,6 @@ public void testTool() {
assertTrue(tool.validate(otherParams));
assertFalse(tool.validate(emptyParams));
assertEquals(DEFAULT_DESCRIPTION, tool.getDescription());
assertEquals(List.of(MODEL_ID_FIELD), MLModelTool.Factory.getInstance().getAllModelKeys());
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.engine.utils;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;

import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;

import org.junit.Test;
import org.opensearch.action.search.SearchRequest;
import org.opensearch.index.query.BoolQueryBuilder;
import org.opensearch.index.query.ExistsQueryBuilder;
import org.opensearch.index.query.QueryBuilder;
import org.opensearch.index.query.TermsQueryBuilder;
import org.opensearch.ml.common.agent.MLAgent;
import org.opensearch.ml.common.spi.tools.Tool;
import org.opensearch.ml.common.spi.tools.WithModelTool;

public class AgentModelSearcherTests {

@Test
public void testConstructor_CollectsModelIds() {
// Arrange
WithModelTool.Factory withModelToolFactory1 = mock(WithModelTool.Factory.class);
when(withModelToolFactory1.getAllModelKeys()).thenReturn(Arrays.asList("modelKey1", "modelKey2"));

WithModelTool.Factory withModelToolFactory2 = mock(WithModelTool.Factory.class);
when(withModelToolFactory2.getAllModelKeys()).thenReturn(Collections.singletonList("anotherModelKey"));

// This tool factory does not implement WithModelTool.Factory
Tool.Factory regularToolFactory = mock(Tool.Factory.class);

Map<String, Tool.Factory> toolFactories = new HashMap<>();
toolFactories.put("withModelTool1", withModelToolFactory1);
toolFactories.put("withModelTool2", withModelToolFactory2);
toolFactories.put("regularTool", regularToolFactory);

// Act
AgentModelsSearcher searcher = new AgentModelsSearcher(toolFactories);

// (Optional) We can't directly access relatedModelIdSet,
// but we can test the behavior indirectly using the search call:
SearchRequest request = searcher.constructQueryRequestToSearchModelIdInsideAgent("candidateId");

// Assert
// Verify the searchRequest uses all keys from the WithModelTool factories
BoolQueryBuilder boolQueryBuilder = (BoolQueryBuilder) request.source().query();
// We expect modelKey1, modelKey2, anotherModelKey => total 3 "should" clauses
assertEquals(2, boolQueryBuilder.must().size());
for (QueryBuilder query : boolQueryBuilder.must()) {
BoolQueryBuilder subBoolQueryBuilder = (BoolQueryBuilder) query;
assertTrue(subBoolQueryBuilder.should().size() == 2 || subBoolQueryBuilder.should().size() == 3);
if (subBoolQueryBuilder.should().size() == 3) {
boolQueryBuilder.should().forEach(subQuery -> {
assertTrue(subQuery instanceof TermsQueryBuilder);
TermsQueryBuilder termsQuery = (TermsQueryBuilder) subQuery;
// Each TermsQueryBuilder should contain candidateModelId
assertTrue(termsQuery.values().contains("candidateId"));
});
} else {
boolQueryBuilder.should().forEach(subQuery -> {
assertTrue(subQuery instanceof BoolQueryBuilder);
BoolQueryBuilder boolQuery = (BoolQueryBuilder) subQuery;
assertTrue(boolQuery.must().size() == 2 || boolQuery.mustNot().size() == 1);
if (boolQuery.must().size() == 2) {
boolQuery.must().forEach(existSubQuery -> {
assertTrue(existSubQuery instanceof ExistsQueryBuilder || existSubQuery instanceof TermsQueryBuilder);
if (existSubQuery instanceof TermsQueryBuilder) {
TermsQueryBuilder termsQuery = (TermsQueryBuilder) existSubQuery;
assertTrue(termsQuery.fieldName().equals(MLAgent.IS_HIDDEN_FIELD));
assertTrue(termsQuery.values().contains(false));
} else {
ExistsQueryBuilder existsQuery = (ExistsQueryBuilder) existSubQuery;
assertTrue(existsQuery.fieldName().equals(MLAgent.IS_HIDDEN_FIELD));
}
});
} else {
QueryBuilder mustNotQuery = boolQuery.mustNot().get(0);
assertTrue(mustNotQuery instanceof ExistsQueryBuilder);
assertEquals(MLAgent.IS_HIDDEN_FIELD, ((ExistsQueryBuilder) mustNotQuery).fieldName());
}
});
}
}

}
}
Loading

0 comments on commit 570edaf

Please sign in to comment.