diff --git a/build.gradle b/build.gradle index 853ee38d0..3cc242586 100644 --- a/build.gradle +++ b/build.gradle @@ -696,15 +696,11 @@ List jacocoExclusions = [ // TODO: add test coverage (kaituo) 'org.opensearch.forecast.*', - 'org.opensearch.ad.transport.ADHCImputeNodeResponse', - 'org.opensearch.ad.transport.GetAnomalyDetectorTransportAction', - 'org.opensearch.timeseries.transport.BooleanNodeResponse', 'org.opensearch.timeseries.ml.TimeSeriesSingleStreamCheckpointDao', 'org.opensearch.timeseries.transport.JobRequest', 'org.opensearch.timeseries.transport.handler.ResultBulkIndexingHandler', 'org.opensearch.timeseries.ml.Inferencer', 'org.opensearch.timeseries.transport.SingleStreamResultRequest', - 'org.opensearch.timeseries.transport.BooleanResponse', 'org.opensearch.timeseries.rest.handler.IndexJobActionHandler.1', 'org.opensearch.timeseries.transport.SuggestConfigParamResponse', 'org.opensearch.timeseries.transport.SuggestConfigParamRequest', @@ -732,6 +728,7 @@ List jacocoExclusions = [ 'org.opensearch.timeseries.util.TimeUtil', 'org.opensearch.ad.transport.ADHCImputeTransportAction', 'org.opensearch.timeseries.ml.RealTimeInferencer', + 'org.opensearch.timeseries.util.ExpiringValue', ] diff --git a/src/main/java/org/opensearch/ad/ml/ADRealTimeInferencer.java b/src/main/java/org/opensearch/ad/ml/ADRealTimeInferencer.java index 94d2223a3..2ec4f254d 100644 --- a/src/main/java/org/opensearch/ad/ml/ADRealTimeInferencer.java +++ b/src/main/java/org/opensearch/ad/ml/ADRealTimeInferencer.java @@ -7,6 +7,8 @@ import static org.opensearch.timeseries.TimeSeriesAnalyticsPlugin.AD_THREAD_POOL_NAME; +import java.time.Clock; + import org.opensearch.ad.caching.ADCacheProvider; import org.opensearch.ad.caching.ADPriorityCache; import org.opensearch.ad.indices.ADIndex; @@ -32,7 +34,8 @@ public ADRealTimeInferencer( ADColdStartWorker coldStartWorker, ADSaveResultStrategy resultWriteWorker, ADCacheProvider cache, - ThreadPool threadPool + ThreadPool threadPool, + Clock clock ) { super( modelManager, @@ -43,7 +46,8 @@ public ADRealTimeInferencer( resultWriteWorker, cache, threadPool, - AD_THREAD_POOL_NAME + AD_THREAD_POOL_NAME, + clock ); } diff --git a/src/main/java/org/opensearch/ad/model/ADTask.java b/src/main/java/org/opensearch/ad/model/ADTask.java index 19fc87682..dbfd42d34 100644 --- a/src/main/java/org/opensearch/ad/model/ADTask.java +++ b/src/main/java/org/opensearch/ad/model/ADTask.java @@ -345,7 +345,8 @@ public static ADTask parse(XContentParser parser, String taskId) throws IOExcept detector.getCustomResultIndexMinSize(), detector.getCustomResultIndexMinAge(), detector.getCustomResultIndexTTL(), - detector.getFlattenResultIndexMapping() + detector.getFlattenResultIndexMapping(), + detector.getLastBreakingUIChangeTime() ); return new Builder() .taskId(parsedTaskId) diff --git a/src/main/java/org/opensearch/ad/model/AnomalyDetector.java b/src/main/java/org/opensearch/ad/model/AnomalyDetector.java index d88ffa653..c8ba4a685 100644 --- a/src/main/java/org/opensearch/ad/model/AnomalyDetector.java +++ b/src/main/java/org/opensearch/ad/model/AnomalyDetector.java @@ -151,6 +151,8 @@ public Integer getShingleSize(Integer customShingleSize) { * @param customResultIndexMinAge custom result index lifecycle management min age condition * @param customResultIndexTTL custom result index lifecycle management ttl * @param flattenResultIndexMapping flag to indicate whether to flatten result index mapping or not + * @param lastBreakingUIChangeTime last update time to configuration that can break UI and we have + * to display updates from the changed time */ public AnomalyDetector( String detectorId, @@ -178,7 +180,8 @@ public AnomalyDetector( Integer customResultIndexMinSize, Integer customResultIndexMinAge, Integer customResultIndexTTL, - Boolean flattenResultIndexMapping + Boolean flattenResultIndexMapping, + Instant lastBreakingUIChangeTime ) { super( detectorId, @@ -206,7 +209,8 @@ public AnomalyDetector( customResultIndexMinSize, customResultIndexMinAge, customResultIndexTTL, - flattenResultIndexMapping + flattenResultIndexMapping, + lastBreakingUIChangeTime ); checkAndThrowValidationErrors(ValidationAspect.DETECTOR); @@ -284,6 +288,7 @@ public AnomalyDetector(StreamInput input) throws IOException { this.customResultIndexMinAge = input.readOptionalInt(); this.customResultIndexTTL = input.readOptionalInt(); this.flattenResultIndexMapping = input.readOptionalBoolean(); + this.lastUIBreakingChangeTime = input.readOptionalInstant(); } public XContentBuilder toXContent(XContentBuilder builder) throws IOException { @@ -350,6 +355,7 @@ public void writeTo(StreamOutput output) throws IOException { output.writeOptionalInt(customResultIndexMinAge); output.writeOptionalInt(customResultIndexTTL); output.writeOptionalBoolean(flattenResultIndexMapping); + output.writeOptionalInstant(lastUIBreakingChangeTime); } @Override @@ -447,6 +453,7 @@ public static AnomalyDetector parse( Integer customResultIndexMinAge = null; Integer customResultIndexTTL = null; Boolean flattenResultIndexMapping = null; + Instant lastBreakingUIChangeTime = null; ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); while (parser.nextToken() != XContentParser.Token.END_OBJECT) { @@ -584,6 +591,9 @@ public static AnomalyDetector parse( case FLATTEN_RESULT_INDEX_MAPPING: flattenResultIndexMapping = onlyParseBooleanValue(parser); break; + case BREAKING_UI_CHANGE_TIME: + lastBreakingUIChangeTime = ParseUtils.toInstant(parser); + break; default: parser.skipChildren(); break; @@ -615,7 +625,8 @@ public static AnomalyDetector parse( customResultIndexMinSize, customResultIndexMinAge, customResultIndexTTL, - flattenResultIndexMapping + flattenResultIndexMapping, + lastBreakingUIChangeTime ); detector.setDetectionDateRange(detectionDateRange); return detector; diff --git a/src/main/java/org/opensearch/ad/rest/RestIndexAnomalyDetectorAction.java b/src/main/java/org/opensearch/ad/rest/RestIndexAnomalyDetectorAction.java index 66981d54c..8f5b5645d 100644 --- a/src/main/java/org/opensearch/ad/rest/RestIndexAnomalyDetectorAction.java +++ b/src/main/java/org/opensearch/ad/rest/RestIndexAnomalyDetectorAction.java @@ -84,6 +84,11 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli : WriteRequest.RefreshPolicy.IMMEDIATE; RestRequest.Method method = request.getHttpRequest().method(); + if (method == RestRequest.Method.POST && detectorId != AnomalyDetector.NO_ID) { + // reset detector to empty string detectorId is only meant for updating detector + detectorId = AnomalyDetector.NO_ID; + } + IndexAnomalyDetectorRequest indexAnomalyDetectorRequest = new IndexAnomalyDetectorRequest( detectorId, seqNo, diff --git a/src/main/java/org/opensearch/ad/rest/handler/AbstractAnomalyDetectorActionHandler.java b/src/main/java/org/opensearch/ad/rest/handler/AbstractAnomalyDetectorActionHandler.java index 7c86610a4..13f70c840 100644 --- a/src/main/java/org/opensearch/ad/rest/handler/AbstractAnomalyDetectorActionHandler.java +++ b/src/main/java/org/opensearch/ad/rest/handler/AbstractAnomalyDetectorActionHandler.java @@ -245,7 +245,8 @@ protected AnomalyDetector copyConfig(User user, Config config) { config.getCustomResultIndexMinSize(), config.getCustomResultIndexMinAge(), config.getCustomResultIndexTTL(), - config.getFlattenResultIndexMapping() + config.getFlattenResultIndexMapping(), + breakingUIChange ? Instant.now() : config.getLastBreakingUIChangeTime() ); } diff --git a/src/main/java/org/opensearch/ad/transport/ADHCImputeTransportAction.java b/src/main/java/org/opensearch/ad/transport/ADHCImputeTransportAction.java index 08973296a..96ff6696e 100644 --- a/src/main/java/org/opensearch/ad/transport/ADHCImputeTransportAction.java +++ b/src/main/java/org/opensearch/ad/transport/ADHCImputeTransportAction.java @@ -31,7 +31,6 @@ import org.opensearch.timeseries.ml.ModelState; import org.opensearch.timeseries.ml.Sample; import org.opensearch.timeseries.model.Config; -import org.opensearch.timeseries.model.IntervalTimeConfiguration; import org.opensearch.timeseries.util.ActionListenerExecutor; import org.opensearch.transport.TransportService; @@ -129,14 +128,12 @@ protected ADHCImputeNodeResponse nodeOperation(ADHCImputeNodeRequest nodeRequest return; } Config config = configOptional.get(); - long windowDelayMillis = ((IntervalTimeConfiguration) config.getWindowDelay()).toDuration().toMillis(); int featureSize = config.getEnabledFeatureIds().size(); long dataEndMillis = nodeRequest.getRequest().getDataEndMillis(); long dataStartMillis = nodeRequest.getRequest().getDataStartMillis(); - long executionEndTime = dataEndMillis + windowDelayMillis; String taskId = nodeRequest.getRequest().getTaskId(); for (ModelState modelState : cache.get().getAllModels(configId)) { - if (shouldProcessModelState(modelState, executionEndTime, clusterService, hashRing)) { + if (shouldProcessModelState(modelState, dataEndMillis, clusterService, hashRing)) { double[] nanArray = new double[featureSize]; Arrays.fill(nanArray, Double.NaN); adInferencer @@ -163,8 +160,8 @@ protected ADHCImputeNodeResponse nodeOperation(ADHCImputeNodeRequest nodeRequest * Determines whether the model state should be processed based on various conditions. * * Conditions checked: - * - The model's last seen execution end time is not the minimum Instant value. - * - The current execution end time is greater than or equal to the model's last seen execution end time, + * - The model's last seen data end time is not the minimum Instant value. This means the model hasn't been initialized yet. + * - The current data end time is greater than the model's last seen data end time, * indicating that the model state was updated in previous intervals. * - The entity associated with the model state is present. * - The owning node for real-time processing of the entity, with the same local version, is present in the hash ring. @@ -175,14 +172,14 @@ protected ADHCImputeNodeResponse nodeOperation(ADHCImputeNodeRequest nodeRequest * concurrently (e.g., during tests when multiple threads may operate quickly). * * @param modelState The current state of the model. - * @param executionEndTime The end time of the current execution interval. + * @param dataEndTime The data end time of current interval. * @param clusterService The service providing information about the current cluster node. * @param hashRing The hash ring used to determine the owning node for real-time processing of entities. * @return true if the model state should be processed; otherwise, false. */ private boolean shouldProcessModelState( ModelState modelState, - long executionEndTime, + long dataEndTime, ClusterService clusterService, HashRing hashRing ) { @@ -194,8 +191,8 @@ private boolean shouldProcessModelState( // Check if the model state conditions are met for processing // We cannot use last used time as it will be updated whenever we update its priority in CacheBuffer.update when there is a // PriorityCache.get. - return modelState.getLastSeenExecutionEndTime() != Instant.MIN - && executionEndTime >= modelState.getLastSeenExecutionEndTime().toEpochMilli() + return modelState.getLastSeenDataEndTime() != Instant.MIN + && dataEndTime > modelState.getLastSeenDataEndTime().toEpochMilli() && modelState.getEntity().isPresent() && owningNode.isPresent() && owningNode.get().getId().equals(clusterService.localNode().getId()); diff --git a/src/main/java/org/opensearch/forecast/ForecastTaskProfileRunner.java b/src/main/java/org/opensearch/forecast/ForecastTaskProfileRunner.java index f7deb5578..3eb93fdc7 100644 --- a/src/main/java/org/opensearch/forecast/ForecastTaskProfileRunner.java +++ b/src/main/java/org/opensearch/forecast/ForecastTaskProfileRunner.java @@ -14,8 +14,18 @@ public class ForecastTaskProfileRunner implements TaskProfileRunner listener) { - // return null since forecasting have no in-memory task profiles as AD - listener.onResponse(null); + // return null in other fields since forecasting have no in-memory task profiles as AD + listener + .onResponse( + new ForecastTaskProfile( + configLevelTask, + null, + null, + null, + configLevelTask == null ? null : configLevelTask.getTaskId(), + null + ) + ); } } diff --git a/src/main/java/org/opensearch/forecast/ml/ForecastRealTimeInferencer.java b/src/main/java/org/opensearch/forecast/ml/ForecastRealTimeInferencer.java index d2373dfce..c56017552 100644 --- a/src/main/java/org/opensearch/forecast/ml/ForecastRealTimeInferencer.java +++ b/src/main/java/org/opensearch/forecast/ml/ForecastRealTimeInferencer.java @@ -7,6 +7,8 @@ import static org.opensearch.timeseries.TimeSeriesAnalyticsPlugin.FORECAST_THREAD_POOL_NAME; +import java.time.Clock; + import org.opensearch.forecast.caching.ForecastCacheProvider; import org.opensearch.forecast.caching.ForecastPriorityCache; import org.opensearch.forecast.indices.ForecastIndex; @@ -32,7 +34,8 @@ public ForecastRealTimeInferencer( ForecastColdStartWorker coldStartWorker, ForecastSaveResultStrategy resultWriteWorker, ForecastCacheProvider cache, - ThreadPool threadPool + ThreadPool threadPool, + Clock clock ) { super( modelManager, @@ -43,7 +46,8 @@ public ForecastRealTimeInferencer( resultWriteWorker, cache, threadPool, - FORECAST_THREAD_POOL_NAME + FORECAST_THREAD_POOL_NAME, + clock ); } diff --git a/src/main/java/org/opensearch/forecast/model/ForecastTask.java b/src/main/java/org/opensearch/forecast/model/ForecastTask.java index bb6a53d50..3fb2e515a 100644 --- a/src/main/java/org/opensearch/forecast/model/ForecastTask.java +++ b/src/main/java/org/opensearch/forecast/model/ForecastTask.java @@ -343,7 +343,8 @@ public static ForecastTask parse(XContentParser parser, String taskId) throws IO forecaster.getCustomResultIndexMinSize(), forecaster.getCustomResultIndexMinAge(), forecaster.getCustomResultIndexTTL(), - forecaster.getFlattenResultIndexMapping() + forecaster.getFlattenResultIndexMapping(), + forecaster.getLastBreakingUIChangeTime() ); return new Builder() .taskId(parsedTaskId) @@ -375,10 +376,12 @@ public static ForecastTask parse(XContentParser parser, String taskId) throws IO @Generated @Override public boolean equals(Object other) { - if (this == other) + if (this == other) { return true; - if (other == null || getClass() != other.getClass()) + } + if (other == null || getClass() != other.getClass()) { return false; + } ForecastTask that = (ForecastTask) other; return super.equals(that) && Objects.equal(getForecaster(), that.getForecaster()) diff --git a/src/main/java/org/opensearch/forecast/model/Forecaster.java b/src/main/java/org/opensearch/forecast/model/Forecaster.java index 0cac28d8b..756b5c4e0 100644 --- a/src/main/java/org/opensearch/forecast/model/Forecaster.java +++ b/src/main/java/org/opensearch/forecast/model/Forecaster.java @@ -135,7 +135,8 @@ public Forecaster( Integer customResultIndexMinSize, Integer customResultIndexMinAge, Integer customResultIndexTTL, - Boolean flattenResultIndexMapping + Boolean flattenResultIndexMapping, + Instant lastBreakingUIChangeTime ) { super( forecasterId, @@ -163,7 +164,8 @@ public Forecaster( customResultIndexMinSize, customResultIndexMinAge, customResultIndexTTL, - flattenResultIndexMapping + flattenResultIndexMapping, + lastBreakingUIChangeTime ); checkAndThrowValidationErrors(ValidationAspect.FORECASTER); @@ -306,6 +308,7 @@ public static Forecaster parse( Integer customResultIndexMinAge = null; Integer customResultIndexTTL = null; Boolean flattenResultIndexMapping = null; + Instant lastBreakingUIChangeTime = null; ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); while (parser.nextToken() != XContentParser.Token.END_OBJECT) { @@ -437,6 +440,9 @@ public static Forecaster parse( case FLATTEN_RESULT_INDEX_MAPPING: flattenResultIndexMapping = parser.booleanValue(); break; + case BREAKING_UI_CHANGE_TIME: + lastBreakingUIChangeTime = ParseUtils.toInstant(parser); + break; default: parser.skipChildren(); break; @@ -468,7 +474,8 @@ public static Forecaster parse( customResultIndexMinSize, customResultIndexMinAge, customResultIndexTTL, - flattenResultIndexMapping + flattenResultIndexMapping, + lastBreakingUIChangeTime ); return forecaster; } diff --git a/src/main/java/org/opensearch/forecast/rest/RestIndexForecasterAction.java b/src/main/java/org/opensearch/forecast/rest/RestIndexForecasterAction.java index 24a9ab037..acb25d5f6 100644 --- a/src/main/java/org/opensearch/forecast/rest/RestIndexForecasterAction.java +++ b/src/main/java/org/opensearch/forecast/rest/RestIndexForecasterAction.java @@ -87,6 +87,11 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli : WriteRequest.RefreshPolicy.IMMEDIATE; RestRequest.Method method = request.getHttpRequest().method(); + if (method == RestRequest.Method.POST && forecasterId != Config.NO_ID) { + // reset detector to empty string detectorId is only meant for updating detector + forecasterId = Config.NO_ID; + } + IndexForecasterRequest indexAnomalyDetectorRequest = new IndexForecasterRequest( forecasterId, seqNo, diff --git a/src/main/java/org/opensearch/forecast/rest/handler/AbstractForecasterActionHandler.java b/src/main/java/org/opensearch/forecast/rest/handler/AbstractForecasterActionHandler.java index 58033c199..01edc14ef 100644 --- a/src/main/java/org/opensearch/forecast/rest/handler/AbstractForecasterActionHandler.java +++ b/src/main/java/org/opensearch/forecast/rest/handler/AbstractForecasterActionHandler.java @@ -258,7 +258,8 @@ protected Config copyConfig(User user, Config config) { config.getCustomResultIndexMinSize(), config.getCustomResultIndexMinAge(), config.getCustomResultIndexTTL(), - config.getFlattenResultIndexMapping() + config.getFlattenResultIndexMapping(), + breakingUIChange ? Instant.now() : config.getLastBreakingUIChangeTime() ); } diff --git a/src/main/java/org/opensearch/timeseries/TimeSeriesAnalyticsPlugin.java b/src/main/java/org/opensearch/timeseries/TimeSeriesAnalyticsPlugin.java index 3898fa9f0..719ec88ac 100644 --- a/src/main/java/org/opensearch/timeseries/TimeSeriesAnalyticsPlugin.java +++ b/src/main/java/org/opensearch/timeseries/TimeSeriesAnalyticsPlugin.java @@ -840,7 +840,8 @@ public PooledObject wrap(LinkedBuffer obj) { adColdstartQueue, adSaveResultStrategy, adCacheProvider, - threadPool + threadPool, + getClock() ); ADCheckpointReadWorker adCheckpointReadQueue = new ADCheckpointReadWorker( @@ -1230,7 +1231,8 @@ public PooledObject wrap(LinkedBuffer obj) { forecastColdstartQueue, forecastSaveResultStrategy, forecastCacheProvider, - threadPool + threadPool, + getClock() ); ForecastCheckpointReadWorker forecastCheckpointReadQueue = new ForecastCheckpointReadWorker( diff --git a/src/main/java/org/opensearch/timeseries/caching/PriorityCache.java b/src/main/java/org/opensearch/timeseries/caching/PriorityCache.java index 1d984be46..c9dacaca7 100644 --- a/src/main/java/org/opensearch/timeseries/caching/PriorityCache.java +++ b/src/main/java/org/opensearch/timeseries/caching/PriorityCache.java @@ -174,7 +174,7 @@ public ModelState get(String modelId, Config config) { // reset every 60 intervals return new DoorKeeper( TimeSeriesSettings.DOOR_KEEPER_FOR_CACHE_MAX_INSERTION, - config.getIntervalDuration().multipliedBy(TimeSeriesSettings.DOOR_KEEPER_MAINTENANCE_FREQ), + config.getIntervalDuration().multipliedBy(TimeSeriesSettings.EXPIRING_VALUE_MAINTENANCE_FREQ), clock, TimeSeriesSettings.CACHE_DOOR_KEEPER_COUNT_THRESHOLD ); diff --git a/src/main/java/org/opensearch/timeseries/ml/ModelColdStart.java b/src/main/java/org/opensearch/timeseries/ml/ModelColdStart.java index 2cb0f0b17..67d9e92c9 100644 --- a/src/main/java/org/opensearch/timeseries/ml/ModelColdStart.java +++ b/src/main/java/org/opensearch/timeseries/ml/ModelColdStart.java @@ -241,7 +241,7 @@ private void coldStart( // reset every 60 intervals return new DoorKeeper( TimeSeriesSettings.DOOR_KEEPER_FOR_COLD_STARTER_MAX_INSERTION, - config.getIntervalDuration().multipliedBy(TimeSeriesSettings.DOOR_KEEPER_MAINTENANCE_FREQ), + config.getIntervalDuration().multipliedBy(TimeSeriesSettings.EXPIRING_VALUE_MAINTENANCE_FREQ), clock, TimeSeriesSettings.COLD_START_DOOR_KEEPER_COUNT_THRESHOLD ); @@ -251,7 +251,7 @@ private void coldStart( logger .info( "Won't retry real-time cold start within {} intervals for model {}", - TimeSeriesSettings.DOOR_KEEPER_MAINTENANCE_FREQ, + TimeSeriesSettings.EXPIRING_VALUE_MAINTENANCE_FREQ, modelId ); return; diff --git a/src/main/java/org/opensearch/timeseries/ml/ModelManager.java b/src/main/java/org/opensearch/timeseries/ml/ModelManager.java index efc774e02..495b6048a 100644 --- a/src/main/java/org/opensearch/timeseries/ml/ModelManager.java +++ b/src/main/java/org/opensearch/timeseries/ml/ModelManager.java @@ -169,7 +169,7 @@ public IntermediateResultType score( throw e; } finally { modelState.setLastUsedTime(clock.instant()); - modelState.setLastSeenExecutionEndTime(clock.instant()); + modelState.setLastSeenDataEndTime(sample.getDataEndTime()); } return createEmptyResult(); } diff --git a/src/main/java/org/opensearch/timeseries/ml/ModelState.java b/src/main/java/org/opensearch/timeseries/ml/ModelState.java index cf337cf3b..f4078220f 100644 --- a/src/main/java/org/opensearch/timeseries/ml/ModelState.java +++ b/src/main/java/org/opensearch/timeseries/ml/ModelState.java @@ -36,7 +36,7 @@ public class ModelState implements org.opensearch.timeseries.ExpiringState { // time when the ML model was used last time protected Instant lastUsedTime; protected Instant lastCheckpointTime; - protected Instant lastSeenExecutionEndTime; + protected Instant lastSeenDataEndTime; protected Clock clock; protected float priority; protected Deque samples; @@ -75,7 +75,7 @@ public ModelState( this.priority = priority; this.entity = entity; this.samples = samples; - this.lastSeenExecutionEndTime = Instant.MIN; + this.lastSeenDataEndTime = Instant.MIN; } /** @@ -252,11 +252,11 @@ public Map getModelStateAsMap() { }; } - public Instant getLastSeenExecutionEndTime() { - return lastSeenExecutionEndTime; + public Instant getLastSeenDataEndTime() { + return lastSeenDataEndTime; } - public void setLastSeenExecutionEndTime(Instant lastSeenExecutionEndTime) { - this.lastSeenExecutionEndTime = lastSeenExecutionEndTime; + public void setLastSeenDataEndTime(Instant lastSeenExecutionEndTime) { + this.lastSeenDataEndTime = lastSeenExecutionEndTime; } } diff --git a/src/main/java/org/opensearch/timeseries/ml/RealTimeInferencer.java b/src/main/java/org/opensearch/timeseries/ml/RealTimeInferencer.java index 7a7d11630..ecc6fb3a1 100644 --- a/src/main/java/org/opensearch/timeseries/ml/RealTimeInferencer.java +++ b/src/main/java/org/opensearch/timeseries/ml/RealTimeInferencer.java @@ -5,10 +5,12 @@ package org.opensearch.timeseries.ml; -import java.util.Collections; +import java.time.Clock; +import java.util.Comparator; import java.util.Locale; import java.util.Map; -import java.util.WeakHashMap; +import java.util.PriorityQueue; +import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.TimeUnit; import java.util.concurrent.locks.Lock; import java.util.concurrent.locks.ReentrantLock; @@ -19,8 +21,10 @@ import org.opensearch.common.unit.TimeValue; import org.opensearch.core.action.ActionListener; import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.MaintenanceState; import org.opensearch.timeseries.caching.CacheProvider; import org.opensearch.timeseries.caching.TimeSeriesCache; +import org.opensearch.timeseries.common.exception.TimeSeriesException; import org.opensearch.timeseries.indices.IndexManagement; import org.opensearch.timeseries.indices.TimeSeriesIndex; import org.opensearch.timeseries.model.Config; @@ -31,7 +35,9 @@ import org.opensearch.timeseries.ratelimit.FeatureRequest; import org.opensearch.timeseries.ratelimit.RequestPriority; import org.opensearch.timeseries.ratelimit.SaveResultStrategy; +import org.opensearch.timeseries.settings.TimeSeriesSettings; import org.opensearch.timeseries.stats.Stats; +import org.opensearch.timeseries.util.ExpiringValue; import com.amazon.randomcutforest.parkservices.ThresholdedRandomCutForest; @@ -39,7 +45,10 @@ * Since we assume model state's last access time is current time and compare it with incoming data's execution time, * this class is only meant to be used by real time analysis. */ -public abstract class RealTimeInferencer, IndexType extends Enum & TimeSeriesIndex, IndexManagementType extends IndexManagement, CheckpointDaoType extends CheckpointDao, CheckpointWriterType extends CheckpointWriteWorker, ColdStarterType extends ModelColdStart, ModelManagerType extends ModelManager, SaveResultStrategyType extends SaveResultStrategy, CacheType extends TimeSeriesCache, ColdStartWorkerType extends ColdStartWorker> { +public abstract class RealTimeInferencer, IndexType extends Enum & TimeSeriesIndex, IndexManagementType extends IndexManagement, CheckpointDaoType extends CheckpointDao, CheckpointWriterType extends CheckpointWriteWorker, ColdStarterType extends ModelColdStart, ModelManagerType extends ModelManager, SaveResultStrategyType extends SaveResultStrategy, CacheType extends TimeSeriesCache, ColdStartWorkerType extends ColdStartWorker> + implements + MaintenanceState { + private static final Logger LOG = LogManager.getLogger(RealTimeInferencer.class); protected ModelManagerType modelManager; protected Stats stats; @@ -48,9 +57,18 @@ public abstract class RealTimeInferencer cache; - private Map modelLocks = Collections.synchronizedMap(new WeakHashMap<>()); + // ensure no two threads can score samples at the same time which can happen in tests + // where we send a lot of requests in a fast pace and the run API returns immediately + // without waiting for the requests get finished processing. It can also happen in + // production as the impute request and actual data scoring in the next interval + // can happen at the same time. + private Map> modelLocks; private ThreadPool threadPool; private String threadPoolName; + // ensure we process samples in the ascending order of time in case race conditions. + private Map>> sampleQueues; + private Comparator sampleComparator; + private Clock clock; public RealTimeInferencer( ModelManagerType modelManager, @@ -61,7 +79,8 @@ public RealTimeInferencer( SaveResultStrategyType resultWriteWorker, CacheProvider cache, ThreadPool threadPool, - String threadPoolName + String threadPoolName, + Clock clock ) { this.modelManager = modelManager; this.stats = stats; @@ -72,10 +91,10 @@ public RealTimeInferencer( this.cache = cache; this.threadPool = threadPool; this.threadPoolName = threadPoolName; - // WeakHashMap allows for automatic removal of entries when the key is no longer referenced elsewhere. - // This helps prevent memory leaks as the garbage collector can reclaim memory when modelId is no - // longer in use. - this.modelLocks = Collections.synchronizedMap(new WeakHashMap<>()); + this.modelLocks = new ConcurrentHashMap<>(); + this.sampleQueues = new ConcurrentHashMap<>(); + this.sampleComparator = Comparator.comparing(Sample::getDataEndTime); + this.clock = clock; } /** @@ -87,43 +106,65 @@ public RealTimeInferencer( * @return whether process succeeds or not */ public boolean process(Sample sample, ModelState modelState, Config config, String taskId) { - long windowDelayMillis = config.getWindowDelay() == null - ? 0 - : ((IntervalTimeConfiguration) config.getWindowDelay()).toDuration().toMillis(); - long curExecutionEnd = sample.getDataEndTime().toEpochMilli() + windowDelayMillis; - long nextExecutionEnd = curExecutionEnd + config.getIntervalInMilliseconds(); - - return processWithTimeout(sample, modelState, config, taskId, curExecutionEnd, nextExecutionEnd); + String modelId = modelState.getModelId(); + ExpiringValue> expiringSampleQueue = sampleQueues + .computeIfAbsent( + modelId, + k -> new ExpiringValue<>( + new PriorityQueue<>(sampleComparator), + config.getIntervalDuration().multipliedBy(TimeSeriesSettings.EXPIRING_VALUE_MAINTENANCE_FREQ).toMillis(), + clock + ) + ); + expiringSampleQueue.getValue().add(sample); + return processWithTimeout(modelState, config, taskId, sample); } - private boolean processWithTimeout( - Sample sample, - ModelState modelState, - Config config, - String taskId, - long curExecutionEnd, - long nextExecutionEnd - ) { + private boolean processWithTimeout(ModelState modelState, Config config, String taskId, Sample sample) { String modelId = modelState.getModelId(); - ReentrantLock lock = (ReentrantLock) modelLocks.computeIfAbsent(modelId, k -> new ReentrantLock()); + ReentrantLock lock = (ReentrantLock) modelLocks + .computeIfAbsent( + modelId, + k -> new ExpiringValue<>( + new ReentrantLock(), + config.getIntervalDuration().multipliedBy(TimeSeriesSettings.EXPIRING_VALUE_MAINTENANCE_FREQ).toMillis(), + clock + ) + ) + .getValue(); + boolean success = false; if (lock.tryLock()) { try { - tryProcess(sample, modelState, config, taskId, curExecutionEnd); + PriorityQueue queue = sampleQueues.get(modelId).getValue(); + while (!queue.isEmpty()) { + Sample curSample = queue.poll(); + long windowDelayMillis = config.getWindowDelay() == null + ? 0 + : ((IntervalTimeConfiguration) config.getWindowDelay()).toDuration().toMillis(); + long curExecutionEnd = curSample.getDataEndTime().toEpochMilli() + windowDelayMillis; + + success = tryProcess(curSample, modelState, config, taskId, curExecutionEnd); + } } finally { if (lock.isHeldByCurrentThread()) { lock.unlock(); } } - return true; } else { - if (System.currentTimeMillis() >= nextExecutionEnd) { + long windowDelayMillis = config.getWindowDelay() == null + ? 0 + : ((IntervalTimeConfiguration) config.getWindowDelay()).toDuration().toMillis(); + long curExecutionEnd = sample.getDataEndTime().toEpochMilli() + windowDelayMillis; + long nextExecutionEnd = curExecutionEnd + config.getIntervalInMilliseconds(); + // schedule a retry if not already time out + if (clock.millis() >= nextExecutionEnd) { LOG.warn("Timeout reached, not retrying."); } else { // Schedule a retry in one second threadPool .schedule( - () -> processWithTimeout(sample, modelState, config, taskId, curExecutionEnd, nextExecutionEnd), + () -> processWithTimeout(modelState, config, taskId, sample), new TimeValue(1, TimeUnit.SECONDS), threadPoolName ); @@ -131,6 +172,7 @@ private boolean processWithTimeout( return false; } + return success; } private boolean tryProcess(Sample sample, ModelState modelState, Config config, String taskId, long curExecutionEnd) { @@ -193,7 +235,7 @@ private void reColdStart(Config config, String modelId, Exception e, Sample samp coldStartWorker .put( new FeatureRequest( - System.currentTimeMillis() + config.getIntervalInMilliseconds(), + clock.millis() + config.getIntervalInMilliseconds(), config.getId(), RequestPriority.MEDIUM, modelId, @@ -203,4 +245,16 @@ private void reColdStart(Config config, String modelId, Exception e, Sample samp ) ); } + + @Override + public void maintenance() { + try { + sampleQueues.entrySet().removeIf(entry -> entry.getValue().isExpired()); + modelLocks.entrySet().removeIf(entry -> entry.getValue().isExpired()); + } catch (Exception e) { + // will be thrown to transport broadcast handler + throw new TimeSeriesException("Fail to maintain RealTimeInferencer", e); + } + + } } diff --git a/src/main/java/org/opensearch/timeseries/model/Config.java b/src/main/java/org/opensearch/timeseries/model/Config.java index 8c0586cde..d61807528 100644 --- a/src/main/java/org/opensearch/timeseries/model/Config.java +++ b/src/main/java/org/opensearch/timeseries/model/Config.java @@ -80,6 +80,11 @@ public abstract class Config implements Writeable, ToXContentObject { public static final String RESULT_INDEX_FIELD_MIN_AGE = "result_index_min_age"; public static final String RESULT_INDEX_FIELD_TTL = "result_index_ttl"; public static final String FLATTEN_RESULT_INDEX_MAPPING = "flatten_result_index_mapping"; + // Changing categorical field, feature attributes, interval, windowDelay, time field, horizon, indices, + // result index would force us to display results only from the most recent update. Otherwise, + // the UI appear cluttered and unclear. + // We cannot use last update time as it would change whenever other fields like name changes. + public static final String BREAKING_UI_CHANGE_TIME = "last_ui_breaking_change_time"; protected String id; protected Long version; @@ -120,6 +125,7 @@ public abstract class Config implements Writeable, ToXContentObject { protected Integer customResultIndexMinAge; protected Integer customResultIndexTTL; protected Boolean flattenResultIndexMapping; + protected Instant lastUIBreakingChangeTime; public static String INVALID_RESULT_INDEX_NAME_SIZE = "Result index name size must contains less than " + MAX_RESULT_INDEX_NAME_SIZE @@ -151,7 +157,8 @@ protected Config( Integer customResultIndexMinSize, Integer customResultIndexMinAge, Integer customResultIndexTTL, - Boolean flattenResultIndexMapping + Boolean flattenResultIndexMapping, + Instant lastBreakingUIChangeTime ) { if (Strings.isBlank(name)) { errorMessage = CommonMessages.EMPTY_NAME; @@ -291,6 +298,7 @@ protected Config( this.customResultIndexMinAge = Strings.trimToNull(resultIndex) == null ? null : customResultIndexMinAge; this.customResultIndexTTL = Strings.trimToNull(resultIndex) == null ? null : customResultIndexTTL; this.flattenResultIndexMapping = Strings.trimToNull(resultIndex) == null ? null : flattenResultIndexMapping; + this.lastUIBreakingChangeTime = lastBreakingUIChangeTime; } public int suggestHistory() { @@ -335,6 +343,7 @@ public Config(StreamInput input) throws IOException { this.customResultIndexMinAge = input.readOptionalInt(); this.customResultIndexTTL = input.readOptionalInt(); this.flattenResultIndexMapping = input.readOptionalBoolean(); + this.lastUIBreakingChangeTime = input.readOptionalInstant(); } /* @@ -388,6 +397,7 @@ public void writeTo(StreamOutput output) throws IOException { output.writeOptionalInt(customResultIndexMinAge); output.writeOptionalInt(customResultIndexTTL); output.writeOptionalBoolean(flattenResultIndexMapping); + output.writeOptionalInstant(lastUIBreakingChangeTime); } public boolean invalidShingleSizeRange(Integer shingleSizeToTest) { @@ -525,6 +535,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws if (flattenResultIndexMapping != null) { builder.field(FLATTEN_RESULT_INDEX_MAPPING, flattenResultIndexMapping); } + if (lastUIBreakingChangeTime != null) { + builder.field(BREAKING_UI_CHANGE_TIME, lastUIBreakingChangeTime.toEpochMilli()); + } return builder; } @@ -737,6 +750,10 @@ public Boolean getFlattenResultIndexMapping() { return flattenResultIndexMapping; } + public Instant getLastBreakingUIChangeTime() { + return lastUIBreakingChangeTime; + } + /** * Identifies redundant feature names. * diff --git a/src/main/java/org/opensearch/timeseries/rest/handler/AbstractTimeSeriesActionHandler.java b/src/main/java/org/opensearch/timeseries/rest/handler/AbstractTimeSeriesActionHandler.java index 9ce014274..eca71c555 100644 --- a/src/main/java/org/opensearch/timeseries/rest/handler/AbstractTimeSeriesActionHandler.java +++ b/src/main/java/org/opensearch/timeseries/rest/handler/AbstractTimeSeriesActionHandler.java @@ -145,6 +145,7 @@ public abstract class AbstractTimeSeriesActionHandler(taskManager, transportService); this.configValidationAspect = configValidationAspect; + this.breakingUIChange = false; } /** @@ -456,6 +458,11 @@ private void onGetConfigResponse(GetResponse response, boolean indexingDryRun, S ); return; } + } else { + if (!ParseUtils.listEqualsWithoutConsideringOrder(existingConfig.getCategoryFields(), config.getCategoryFields()) + || !Objects.equals(existingConfig.getCustomResultIndexOrAlias(), config.getCustomResultIndexOrAlias())) { + breakingUIChange = true; + } } ActionListener confirmBatchRunningListener = ActionListener @@ -675,7 +682,6 @@ protected void validateCategoricalField( ); } - @SuppressWarnings("unchecked") protected void searchConfigInputIndices(String configId, boolean indexingDryRun, ActionListener listener) { SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder() .query(QueryBuilders.matchAllQuery()) diff --git a/src/main/java/org/opensearch/timeseries/settings/TimeSeriesSettings.java b/src/main/java/org/opensearch/timeseries/settings/TimeSeriesSettings.java index 86ee34e88..34bf7835f 100644 --- a/src/main/java/org/opensearch/timeseries/settings/TimeSeriesSettings.java +++ b/src/main/java/org/opensearch/timeseries/settings/TimeSeriesSettings.java @@ -57,7 +57,7 @@ public class TimeSeriesSettings { public static final int DOOR_KEEPER_FOR_COLD_STARTER_MAX_INSERTION = 100_000; // clean up door keeper every 60 intervals - public static final int DOOR_KEEPER_MAINTENANCE_FREQ = 60; + public static final int EXPIRING_VALUE_MAINTENANCE_FREQ = 60; // 1 million insertion costs roughly 1 MB. public static final int DOOR_KEEPER_FOR_CACHE_MAX_INSERTION = 1_000_000; diff --git a/src/main/java/org/opensearch/timeseries/transport/BooleanNodeResponse.java b/src/main/java/org/opensearch/timeseries/transport/BooleanNodeResponse.java index c6b4f1285..ebb38e7c3 100644 --- a/src/main/java/org/opensearch/timeseries/transport/BooleanNodeResponse.java +++ b/src/main/java/org/opensearch/timeseries/transport/BooleanNodeResponse.java @@ -31,6 +31,7 @@ public boolean isAnswerTrue() { @Override public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); out.writeBoolean(answer); } } diff --git a/src/main/java/org/opensearch/timeseries/transport/BooleanResponse.java b/src/main/java/org/opensearch/timeseries/transport/BooleanResponse.java index 8eb18475a..b5ef0af6b 100644 --- a/src/main/java/org/opensearch/timeseries/transport/BooleanResponse.java +++ b/src/main/java/org/opensearch/timeseries/transport/BooleanResponse.java @@ -37,6 +37,7 @@ public boolean isAnswerTrue() { @Override public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); out.writeBoolean(answer); } diff --git a/src/main/java/org/opensearch/timeseries/transport/CronTransportAction.java b/src/main/java/org/opensearch/timeseries/transport/CronTransportAction.java index 25b7a2170..c64ac504e 100644 --- a/src/main/java/org/opensearch/timeseries/transport/CronTransportAction.java +++ b/src/main/java/org/opensearch/timeseries/transport/CronTransportAction.java @@ -22,6 +22,7 @@ import org.opensearch.ad.caching.ADCacheProvider; import org.opensearch.ad.ml.ADColdStart; import org.opensearch.ad.ml.ADModelManager; +import org.opensearch.ad.ml.ADRealTimeInferencer; import org.opensearch.ad.task.ADTaskManager; import org.opensearch.ad.transport.CronAction; import org.opensearch.cluster.service.ClusterService; @@ -30,6 +31,7 @@ import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.forecast.caching.ForecastCacheProvider; import org.opensearch.forecast.ml.ForecastColdStart; +import org.opensearch.forecast.ml.ForecastRealTimeInferencer; import org.opensearch.forecast.task.ForecastTaskManager; import org.opensearch.threadpool.ThreadPool; import org.opensearch.timeseries.NodeStateManager; @@ -45,6 +47,8 @@ public class CronTransportAction extends TransportNodesAction { + private V value; + private long lastAccessTime; + private long expirationTimeInMillis; + private Clock clock; + + public ExpiringValue(V value, long expirationTimeInMillis, Clock clock) { + this.value = value; + this.expirationTimeInMillis = expirationTimeInMillis; + this.clock = clock; + updateLastAccessTime(); + } + + public V getValue() { + updateLastAccessTime(); + return value; + } + + public boolean isExpired() { + return isExpired(clock.millis()); + } + + public boolean isExpired(long currentTimeMillis) { + return (currentTimeMillis - lastAccessTime) >= expirationTimeInMillis; + } + + public void updateLastAccessTime() { + lastAccessTime = clock.millis(); + } +} diff --git a/src/main/resources/mappings/anomaly-detection-state.json b/src/main/resources/mappings/anomaly-detection-state.json index be37da1eb..898a12d8b 100644 --- a/src/main/resources/mappings/anomaly-detection-state.json +++ b/src/main/resources/mappings/anomaly-detection-state.json @@ -1,7 +1,7 @@ { "dynamic": false, "_meta": { - "schema_version": 4 + "schema_version": 5 }, "properties": { "schema_version": { diff --git a/src/main/resources/mappings/config.json b/src/main/resources/mappings/config.json index 89b334f90..ad679f183 100644 --- a/src/main/resources/mappings/config.json +++ b/src/main/resources/mappings/config.json @@ -1,7 +1,7 @@ { "dynamic": false, "_meta": { - "schema_version": 6 + "schema_version": 7 }, "properties": { "schema_version": { @@ -232,6 +232,10 @@ }, "flatten_result_index_mapping": { "type": "boolean" + }, + "last_ui_breaking_change_time" : { + "type": "date", + "format": "strict_date_time||epoch_millis" } } } \ No newline at end of file diff --git a/src/test/java/org/opensearch/action/admin/indices/mapping/get/IndexAnomalyDetectorActionHandlerTests.java b/src/test/java/org/opensearch/action/admin/indices/mapping/get/IndexAnomalyDetectorActionHandlerTests.java index 990dfd907..257b8d5d9 100644 --- a/src/test/java/org/opensearch/action/admin/indices/mapping/get/IndexAnomalyDetectorActionHandlerTests.java +++ b/src/test/java/org/opensearch/action/admin/indices/mapping/get/IndexAnomalyDetectorActionHandlerTests.java @@ -859,7 +859,8 @@ public void doE detector.getCustomResultIndexMinSize(), detector.getCustomResultIndexMinAge(), detector.getCustomResultIndexTTL(), - detector.getFlattenResultIndexMapping() + detector.getFlattenResultIndexMapping(), + Instant.now() ); try { listener.onResponse((Response) TestHelpers.createGetResponse(clone, clone.getId(), CommonName.CONFIG_INDEX)); diff --git a/src/test/java/org/opensearch/ad/AnomalyDetectorRestTestCase.java b/src/test/java/org/opensearch/ad/AnomalyDetectorRestTestCase.java index 627be0240..e07728fcb 100644 --- a/src/test/java/org/opensearch/ad/AnomalyDetectorRestTestCase.java +++ b/src/test/java/org/opensearch/ad/AnomalyDetectorRestTestCase.java @@ -315,7 +315,8 @@ public ToXContentObject[] getConfig(String detectorId, BasicHeader header, boole detector.getCustomResultIndexMinSize(), detector.getCustomResultIndexMinAge(), detector.getCustomResultIndexTTL(), - detector.getFlattenResultIndexMapping() + detector.getFlattenResultIndexMapping(), + detector.getLastBreakingUIChangeTime() ), detectorJob, historicalAdTask, @@ -642,7 +643,8 @@ protected AnomalyDetector cloneDetector(AnomalyDetector anomalyDetector, String anomalyDetector.getCustomResultIndexMinSize(), anomalyDetector.getCustomResultIndexMinAge(), anomalyDetector.getCustomResultIndexTTL(), - anomalyDetector.getFlattenResultIndexMapping() + anomalyDetector.getFlattenResultIndexMapping(), + Instant.now() ); return detector; } diff --git a/src/test/java/org/opensearch/ad/ml/EntityColdStarterTests.java b/src/test/java/org/opensearch/ad/ml/EntityColdStarterTests.java index d89e03128..46600ed5e 100644 --- a/src/test/java/org/opensearch/ad/ml/EntityColdStarterTests.java +++ b/src/test/java/org/opensearch/ad/ml/EntityColdStarterTests.java @@ -1144,7 +1144,7 @@ public void testCacheReleaseAfterMaintenance() throws IOException, InterruptedEx // make sure when the next maintenance coming, current door keeper gets reset // note our detector interval is 1 minute and the door keeper will expire in 60 intervals, which are 60 minutes - when(clock.instant()).thenReturn(Instant.now().plus(TimeSeriesSettings.DOOR_KEEPER_MAINTENANCE_FREQ + 1, ChronoUnit.MINUTES)); + when(clock.instant()).thenReturn(Instant.now().plus(TimeSeriesSettings.EXPIRING_VALUE_MAINTENANCE_FREQ + 1, ChronoUnit.MINUTES)); entityColdStarter.maintenance(); modelState = createStateForCacheRelease(); diff --git a/src/test/java/org/opensearch/ad/model/AnomalyDetectorTests.java b/src/test/java/org/opensearch/ad/model/AnomalyDetectorTests.java index 152772691..b10c1afa4 100644 --- a/src/test/java/org/opensearch/ad/model/AnomalyDetectorTests.java +++ b/src/test/java/org/opensearch/ad/model/AnomalyDetectorTests.java @@ -344,7 +344,8 @@ public void testInvalidShingleSize() throws Exception { null, null, null, - null + null, + Instant.now() ) ); } @@ -381,7 +382,8 @@ public void testNullDetectorName() throws Exception { null, null, null, - null + null, + Instant.now() ) ); } @@ -418,7 +420,8 @@ public void testBlankDetectorName() throws Exception { null, null, null, - null + null, + Instant.now() ) ); } @@ -455,7 +458,8 @@ public void testNullTimeField() throws Exception { null, null, null, - null + null, + Instant.now() ) ); } @@ -492,7 +496,8 @@ public void testNullIndices() throws Exception { null, null, null, - null + null, + Instant.now() ) ); } @@ -529,7 +534,8 @@ public void testEmptyIndices() throws Exception { null, null, null, - null + null, + Instant.now() ) ); } @@ -566,7 +572,8 @@ public void testNullDetectionInterval() throws Exception { null, null, null, - null + null, + Instant.now() ) ); } @@ -602,7 +609,8 @@ public void testInvalidRecency() { null, null, null, - null + null, + Instant.now() ) ); assertEquals("Recency emphasis must be an integer greater than 1.", exception.getMessage()); @@ -639,7 +647,8 @@ public void testInvalidDetectionInterval() { null, null, null, - null + null, + Instant.now() ) ); assertEquals("Detection interval must be a positive integer", exception.getMessage()); @@ -676,7 +685,8 @@ public void testInvalidWindowDelay() { null, null, null, - null + null, + Instant.now() ) ); assertEquals("Interval -1 should be non-negative", exception.getMessage()); @@ -726,7 +736,8 @@ public void testGetShingleSize() throws IOException { null, null, null, - null + null, + Instant.now() ); assertEquals((int) anomalyDetector.getShingleSize(), 5); } @@ -761,7 +772,8 @@ public void testGetShingleSizeReturnsDefaultValue() throws IOException { null, null, null, - null + null, + Instant.now() ); // seasonalityIntervals is not null and custom shingle size is null, use seasonalityIntervals to deterine shingle size assertEquals(seasonalityIntervals / TimeSeriesSettings.SEASONALITY_TO_SHINGLE_RATIO, (int) anomalyDetector.getShingleSize()); @@ -792,7 +804,8 @@ public void testGetShingleSizeReturnsDefaultValue() throws IOException { null, null, null, - null + null, + Instant.now() ); // seasonalityIntervals is null and custom shingle size is null, use default shingle size assertEquals(TimeSeriesSettings.DEFAULT_SHINGLE_SIZE, (int) anomalyDetector.getShingleSize()); @@ -825,7 +838,8 @@ public void testNullFeatureAttributes() throws IOException { null, null, null, - null + null, + Instant.now() ); assertNotNull(anomalyDetector.getFeatureAttributes()); assertEquals(0, anomalyDetector.getFeatureAttributes().size()); @@ -858,7 +872,8 @@ public void testValidateResultIndex() throws IOException { null, null, null, - null + null, + Instant.now() ); String errorMessage = anomalyDetector.validateCustomResultIndex("abc"); assertEquals(ADCommonMessages.INVALID_RESULT_INDEX_PREFIX, errorMessage); @@ -1025,7 +1040,8 @@ public void testNullFixedValue() throws IOException { null, null, null, - null + null, + Instant.now() ) ); assertEquals("Got: " + e.getMessage(), "Enabled features are present, but no default fill values are provided.", e.getMessage()); diff --git a/src/test/java/org/opensearch/ad/model/GetAnomalyDetectorTransportActionTests.java b/src/test/java/org/opensearch/ad/model/GetAnomalyDetectorTransportActionTests.java index 64295e4e2..bafed343c 100644 --- a/src/test/java/org/opensearch/ad/model/GetAnomalyDetectorTransportActionTests.java +++ b/src/test/java/org/opensearch/ad/model/GetAnomalyDetectorTransportActionTests.java @@ -92,4 +92,58 @@ public void testRealtimeTaskAssignedWithSingleStreamRealTimeTaskName() throws Ex // For this example, we'll verify that the correct task is passed to getConfigAndJob verify(getForecaster).getConfigAndJob(eq(configID), anyBoolean(), anyBoolean(), any(), eq(Optional.of(adTask)), eq(listener)); } + + @SuppressWarnings("unchecked") + public void testInvalidTaskName() throws Exception { + // Arrange + String configID = "test-config-id"; + + // Create a task with singleStreamRealTimeTaskName + Map tasks = new HashMap<>(); + String invalidTaskName = "blah"; + ADTask adTask = ADTask.builder().taskType(invalidTaskName).build(); + tasks.put(invalidTaskName, adTask); + + // Mock taskManager to return the tasks + ADTaskManager taskManager = mock(ADTaskManager.class); + doAnswer(invocation -> { + List taskList = new ArrayList<>(tasks.values()); + ((Consumer>) invocation.getArguments()[4]).accept(taskList); + return null; + }).when(taskManager).getAndExecuteOnLatestTasks(anyString(), any(), any(), any(), any(), any(), anyBoolean(), anyInt(), any()); + + // Mock listener + ActionListener listener = mock(ActionListener.class); + + ClusterService clusterService = mock(ClusterService.class); + ClusterSettings settings = new ClusterSettings( + Settings.EMPTY, + Collections.unmodifiableSet(new HashSet<>(Arrays.asList(AnomalyDetectorSettings.AD_FILTER_BY_BACKEND_ROLES))) + ); + when(clusterService.getClusterSettings()).thenReturn(settings); + GetAnomalyDetectorTransportAction getForecaster = spy( + new GetAnomalyDetectorTransportAction( + mock(TransportService.class), + null, + mock(ActionFilters.class), + clusterService, + null, + null, + Settings.EMPTY, + null, + taskManager, + null + ) + ); + + // Act + GetConfigRequest request = new GetConfigRequest(configID, 0L, true, true, "", "", true, null); + getForecaster.getExecute(request, listener); + + // Assert + // Verify that realtimeTask is assigned using singleStreamRealTimeTaskName + // This can be checked by verifying interactions or internal state + // For this example, we'll verify that the correct task is passed to getConfigAndJob + verify(getForecaster).getConfigAndJob(eq(configID), anyBoolean(), anyBoolean(), any(), eq(Optional.empty()), eq(listener)); + } } diff --git a/src/test/java/org/opensearch/ad/ratelimit/CheckpointReadWorkerTests.java b/src/test/java/org/opensearch/ad/ratelimit/CheckpointReadWorkerTests.java index 1f40fa1f9..72c50cbaa 100644 --- a/src/test/java/org/opensearch/ad/ratelimit/CheckpointReadWorkerTests.java +++ b/src/test/java/org/opensearch/ad/ratelimit/CheckpointReadWorkerTests.java @@ -24,6 +24,7 @@ import static org.mockito.Mockito.when; import java.io.IOException; +import java.time.Clock; import java.time.Instant; import java.util.ArrayList; import java.util.Arrays; @@ -159,7 +160,8 @@ public void setUp() throws Exception { coldstartQueue, resultWriteStrategy, cacheProvider, - threadPool + threadPool, + mock(Clock.class) ); // Integer.MAX_VALUE makes a huge heap diff --git a/src/test/java/org/opensearch/ad/rest/ADRestTestUtils.java b/src/test/java/org/opensearch/ad/rest/ADRestTestUtils.java index 0ffe7b5d7..90634347d 100644 --- a/src/test/java/org/opensearch/ad/rest/ADRestTestUtils.java +++ b/src/test/java/org/opensearch/ad/rest/ADRestTestUtils.java @@ -226,7 +226,8 @@ public static Response createAnomalyDetector( null, null, null, - null + null, + now ); if (historical) { @@ -316,7 +317,6 @@ public static int countADResultOfDetector(RestClient client, String detectorId, TestHelpers.LEGACY_OPENDISTRO_AD_BASE_DETECTORS_URI + "/results/_search", ImmutableMap.of(), TestHelpers.toHttpEntity(query), - null ); Map responseMap = entityAsMap(searchAdTaskResponse); @@ -342,7 +342,6 @@ public static int countDetectors(RestClient client, String detectorType) throws TestHelpers.LEGACY_OPENDISTRO_AD_BASE_DETECTORS_URI + "/_search", ImmutableMap.of(), TestHelpers.toHttpEntity(query), - null ); Map responseMap = entityAsMap(searchAdTaskResponse); diff --git a/src/test/java/org/opensearch/ad/rest/AnomalyDetectorRestApiIT.java b/src/test/java/org/opensearch/ad/rest/AnomalyDetectorRestApiIT.java index 5a95a2cc1..c47638325 100644 --- a/src/test/java/org/opensearch/ad/rest/AnomalyDetectorRestApiIT.java +++ b/src/test/java/org/opensearch/ad/rest/AnomalyDetectorRestApiIT.java @@ -54,6 +54,7 @@ import org.opensearch.timeseries.model.Job; import org.opensearch.timeseries.rest.handler.AbstractTimeSeriesActionHandler; import org.opensearch.timeseries.settings.TimeSeriesSettings; +import org.opensearch.timeseries.util.RestHandlerUtils; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; @@ -159,6 +160,7 @@ public void testCreateAnomalyDetectorWithDuplicateName() throws Exception { null, null, null, + null, null ); @@ -205,6 +207,23 @@ public void testCreateAnomalyDetector() throws Exception { int version = (int) responseMap.get("_version"); assertNotEquals("response is missing Id", AnomalyDetector.NO_ID, id); assertTrue("incorrect version", version > 0); + + // users cannot specify detector id when creating a detector + AnomalyDetector detector2 = createIndexAndGetAnomalyDetector(INDEX_NAME); + String blahId = "__blah__"; + response = TestHelpers + .makeRequest( + client(), + "POST", + TestHelpers.AD_BASE_DETECTORS_URI, + ImmutableMap.of(RestHandlerUtils.DETECTOR_ID, blahId), + TestHelpers.toHttpEntity(detector2), + null + ); + assertEquals("Create anomaly detector failed", RestStatus.CREATED, TestHelpers.restStatus(response)); + responseMap = entityAsMap(response); + id = (String) responseMap.get("_id"); + assertNotEquals("response is missing Id", blahId, id); } public void testCreateAnomalyDetectorWithDateNanos() throws Exception { @@ -271,7 +290,8 @@ public void testUpdateAnomalyDetectorCategoryField() throws Exception { null, null, null, - null + null, + detector.getLastBreakingUIChangeTime() ); Exception ex = expectThrows( ResponseException.class, @@ -338,7 +358,8 @@ public void testUpdateAnomalyDetector() throws Exception { null, null, null, - null + null, + detector.getLastBreakingUIChangeTime() ); updateClusterSettings(ADEnabledSetting.AD_ENABLED, false); @@ -410,7 +431,8 @@ public void testUpdateAnomalyDetectorNameToExisting() throws Exception { null, null, null, - null + null, + detector1.getLastBreakingUIChangeTime() ); TestHelpers @@ -459,7 +481,8 @@ public void testUpdateAnomalyDetectorNameToNew() throws Exception { null, null, null, - null + null, + Instant.now() ); TestHelpers @@ -514,7 +537,8 @@ public void testUpdateAnomalyDetectorWithNotExistingIndex() throws Exception { null, null, null, - null + null, + detector.getLastBreakingUIChangeTime() ); deleteIndexWithAdminClient(CommonName.CONFIG_INDEX); @@ -886,7 +910,8 @@ public void testUpdateAnomalyDetectorWithRunningAdJob() throws Exception { null, null, null, - null + null, + detector.getLastBreakingUIChangeTime() ); TestHelpers diff --git a/src/test/java/org/opensearch/ad/rest/HistoricalAnalysisRestApiIT.java b/src/test/java/org/opensearch/ad/rest/HistoricalAnalysisRestApiIT.java index 505ed36b0..3a437e02d 100644 --- a/src/test/java/org/opensearch/ad/rest/HistoricalAnalysisRestApiIT.java +++ b/src/test/java/org/opensearch/ad/rest/HistoricalAnalysisRestApiIT.java @@ -349,7 +349,8 @@ private AnomalyDetector randomAnomalyDetector(AnomalyDetector detector) { detector.getCustomResultIndexMinSize(), detector.getCustomResultIndexMinAge(), detector.getCustomResultIndexTTL(), - detector.getFlattenResultIndexMapping() + detector.getFlattenResultIndexMapping(), + detector.getLastBreakingUIChangeTime() ); } diff --git a/src/test/java/org/opensearch/ad/rest/SecureADRestIT.java b/src/test/java/org/opensearch/ad/rest/SecureADRestIT.java index d83514ef2..98bb89f64 100644 --- a/src/test/java/org/opensearch/ad/rest/SecureADRestIT.java +++ b/src/test/java/org/opensearch/ad/rest/SecureADRestIT.java @@ -17,6 +17,7 @@ import java.time.temporal.ChronoUnit; import java.util.ArrayList; import java.util.Arrays; +import java.util.Locale; import java.util.Map; import org.apache.http.HttpHeaders; @@ -73,34 +74,66 @@ public static String generatePassword(String username) { String lowerCase = "abcdefghijklmnopqrstuvwxyz"; String digits = "0123456789"; String special = "_"; - String characters = upperCase + lowerCase + digits + special; - - SecureRandom rng = new SecureRandom(); - // Ensure password includes at least one character from each set - char[] password = new char[15]; - password[0] = upperCase.charAt(rng.nextInt(upperCase.length())); - password[1] = lowerCase.charAt(rng.nextInt(lowerCase.length())); - password[2] = digits.charAt(rng.nextInt(digits.length())); - password[3] = special.charAt(rng.nextInt(special.length())); - - for (int i = 4; i < 15; i++) { - char nextChar; - do { - nextChar = characters.charAt(rng.nextInt(characters.length())); - } while (username.indexOf(nextChar) > -1); - password[i] = nextChar; + // Remove characters from username (case-insensitive) + String usernameLower = username.toLowerCase(Locale.ROOT); + for (char c : usernameLower.toCharArray()) { + upperCase = upperCase.replaceAll("(?i)" + c, ""); + lowerCase = lowerCase.replaceAll("(?i)" + c, ""); + digits = digits.replace(String.valueOf(c), ""); + special = special.replace(String.valueOf(c), ""); } - // Shuffle the array to ensure the first 4 characters are not always in the same position - for (int i = password.length - 1; i > 0; i--) { - int index = rng.nextInt(i + 1); - char temp = password[index]; - password[index] = password[i]; - password[i] = temp; + // Combine all remaining characters + String characters = upperCase + lowerCase + digits + special; + + // Check if we have enough characters to proceed + if (characters.length() < 4) { + throw new IllegalArgumentException("Not enough characters to generate password without using username characters."); } - return new String(password); + SecureRandom rng = new SecureRandom(); + String password; + + do { + // Ensure password includes at least one character from each set, if available + StringBuilder passwordBuilder = new StringBuilder(); + if (!upperCase.isEmpty()) { + passwordBuilder.append(upperCase.charAt(rng.nextInt(upperCase.length()))); + } + if (!lowerCase.isEmpty()) { + passwordBuilder.append(lowerCase.charAt(rng.nextInt(lowerCase.length()))); + } + if (!digits.isEmpty()) { + passwordBuilder.append(digits.charAt(rng.nextInt(digits.length()))); + } + if (!special.isEmpty()) { + passwordBuilder.append(special.charAt(rng.nextInt(special.length()))); + } + + // Fill the rest of the password length with random characters + int remainingLength = 15 - passwordBuilder.length(); + for (int i = 0; i < remainingLength; i++) { + passwordBuilder.append(characters.charAt(rng.nextInt(characters.length()))); + } + + // Convert to char array for shuffling + char[] passwordChars = passwordBuilder.toString().toCharArray(); + + // Shuffle the password characters + for (int i = passwordChars.length - 1; i > 0; i--) { + int index = rng.nextInt(i + 1); + char temp = passwordChars[index]; + passwordChars[index] = passwordChars[i]; + passwordChars[i] = temp; + } + + password = new String(passwordChars); + + // Repeat if password contains the username as a substring (case-insensitive) + } while (password.toLowerCase(Locale.ROOT).contains(usernameLower.toLowerCase(Locale.ROOT))); + + return password; } @Before @@ -304,7 +337,8 @@ public void testUpdateApiFilterByEnabledForAdmin() throws IOException { null, null, null, - null + null, + Instant.now() ); // User client has admin all access, and has "opensearch" backend role so client should be able to update detector // But the detector's backend role should not be replaced as client's backend roles (all_access). @@ -359,7 +393,8 @@ public void testUpdateApiFilterByEnabled() throws IOException { null, null, null, - null + null, + Instant.now() ); enableFilterBy(); // User Fish has AD full access, and has "odfe" backend role which is one of Alice's backend role, so diff --git a/src/test/java/org/opensearch/ad/transport/ADHCImputeNodesResponseTests.java b/src/test/java/org/opensearch/ad/transport/ADHCImputeNodesResponseTests.java index f2657f21d..73320c671 100644 --- a/src/test/java/org/opensearch/ad/transport/ADHCImputeNodesResponseTests.java +++ b/src/test/java/org/opensearch/ad/transport/ADHCImputeNodesResponseTests.java @@ -115,4 +115,29 @@ public void testADHCImputeNodeResponseSerialization() throws IOException { assertNotNull(deserializedNodeResponse.getPreviousException()); assertEquals("exception: " + previousException.getMessage(), deserializedNodeResponse.getPreviousException().getMessage()); } + + public void testNoExceptionSerialization() throws IOException { + // Arrange + DiscoveryNode node = new DiscoveryNode( + "nodeId", + buildNewFakeTransportAddress(), + Collections.emptyMap(), + Collections.emptySet(), + Version.CURRENT + ); + + ADHCImputeNodeResponse nodeResponse = new ADHCImputeNodeResponse(node, null); + + // Act: Serialize the node response + BytesStreamOutput output = new BytesStreamOutput(); + nodeResponse.writeTo(output); + + // Deserialize the node response + StreamInput input = output.bytes().streamInput(); + ADHCImputeNodeResponse deserializedNodeResponse = new ADHCImputeNodeResponse(input); + + // Assert + assertEquals(node, deserializedNodeResponse.getNode()); + assertNull(deserializedNodeResponse.getPreviousException()); + } } diff --git a/src/test/java/org/opensearch/ad/transport/AnomalyResultTests.java b/src/test/java/org/opensearch/ad/transport/AnomalyResultTests.java index c4556fe0a..9c3828443 100644 --- a/src/test/java/org/opensearch/ad/transport/AnomalyResultTests.java +++ b/src/test/java/org/opensearch/ad/transport/AnomalyResultTests.java @@ -37,6 +37,7 @@ import static org.opensearch.timeseries.TestHelpers.createIndexBlockedState; import java.io.IOException; +import java.time.Clock; import java.time.Instant; import java.util.ArrayList; import java.util.Arrays; @@ -365,7 +366,8 @@ public void setUp() throws Exception { coldStartWorker, mock(ADSaveResultStrategy.class), cacheProvider, - threadPool + threadPool, + mock(Clock.class) ); } @@ -625,7 +627,8 @@ public void testInsufficientCapacityExceptionDuringRestoringModel() throws Inter coldStartWorker, mock(ADSaveResultStrategy.class), cacheProvider, - threadPool + threadPool, + mock(Clock.class) ); ADPriorityCache adPriorityCache = mock(ADPriorityCache.class); diff --git a/src/test/java/org/opensearch/ad/transport/AnomalyResultTransportActionTests.java b/src/test/java/org/opensearch/ad/transport/AnomalyResultTransportActionTests.java index 7c8d4f3c2..b3d30d5cb 100644 --- a/src/test/java/org/opensearch/ad/transport/AnomalyResultTransportActionTests.java +++ b/src/test/java/org/opensearch/ad/transport/AnomalyResultTransportActionTests.java @@ -227,7 +227,8 @@ private AnomalyDetector randomDetector(List indices, List featu null, null, null, - null + null, + Instant.now() ); } @@ -258,7 +259,8 @@ private AnomalyDetector randomHCDetector(List indices, List fea null, null, null, - null + null, + Instant.now() ); } diff --git a/src/test/java/org/opensearch/ad/transport/EntityResultTransportActionTests.java b/src/test/java/org/opensearch/ad/transport/EntityResultTransportActionTests.java index 5bd044ea8..9f31d0719 100644 --- a/src/test/java/org/opensearch/ad/transport/EntityResultTransportActionTests.java +++ b/src/test/java/org/opensearch/ad/transport/EntityResultTransportActionTests.java @@ -264,7 +264,16 @@ public void setUp() throws Exception { adStats = new ADStats(statsMap); resultSaver = new ADSaveResultStrategy(1, resultWriteQueue); - inferencer = new ADRealTimeInferencer(manager, adStats, checkpointDao, entityColdStartQueue, resultSaver, provider, threadPool); + inferencer = new ADRealTimeInferencer( + manager, + adStats, + checkpointDao, + entityColdStartQueue, + resultSaver, + provider, + threadPool, + clock + ); entityResult = new EntityADResultTransportAction( actionFilters, @@ -397,7 +406,8 @@ public void testFailToScore() { entityColdStartQueue, resultSaver, provider, - threadPool + threadPool, + clock ); entityResult = new EntityADResultTransportAction( actionFilters, diff --git a/src/test/java/org/opensearch/ad/transport/ForwardADTaskRequestTests.java b/src/test/java/org/opensearch/ad/transport/ForwardADTaskRequestTests.java index 266b3b009..f5be130a1 100644 --- a/src/test/java/org/opensearch/ad/transport/ForwardADTaskRequestTests.java +++ b/src/test/java/org/opensearch/ad/transport/ForwardADTaskRequestTests.java @@ -86,7 +86,8 @@ public void testNullDetectorIdAndTaskAction() throws IOException { null, null, null, - null + null, + Instant.now() ); ForwardADTaskRequest request = new ForwardADTaskRequest(detector, null, null, null, null, Version.V_2_1_0); ActionRequestValidationException validate = request.validate(); diff --git a/src/test/java/org/opensearch/ad/transport/MultiEntityResultTests.java b/src/test/java/org/opensearch/ad/transport/MultiEntityResultTests.java index fb93d8d1c..cd1505ced 100644 --- a/src/test/java/org/opensearch/ad/transport/MultiEntityResultTests.java +++ b/src/test/java/org/opensearch/ad/transport/MultiEntityResultTests.java @@ -330,7 +330,8 @@ public void setUp() throws Exception { entityColdStartQueue, resultSaver, provider, - threadPool + threadPool, + clock ); } diff --git a/src/test/java/org/opensearch/ad/transport/ValidateAnomalyDetectorTransportActionTests.java b/src/test/java/org/opensearch/ad/transport/ValidateAnomalyDetectorTransportActionTests.java index 53f6f0ab5..9a57c6a5e 100644 --- a/src/test/java/org/opensearch/ad/transport/ValidateAnomalyDetectorTransportActionTests.java +++ b/src/test/java/org/opensearch/ad/transport/ValidateAnomalyDetectorTransportActionTests.java @@ -405,7 +405,8 @@ public void testValidateAnomalyDetectorWithInvalidDetectorName() throws IOExcept null, null, null, - null + null, + Instant.now() ); ingestTestDataValidate(anomalyDetector.getIndices().get(0), Instant.now().minus(1, ChronoUnit.DAYS), 1, "error"); ValidateConfigRequest request = new ValidateConfigRequest( @@ -454,7 +455,8 @@ public void testValidateAnomalyDetectorWithDetectorNameTooLong() throws IOExcept null, null, null, - null + null, + Instant.now() ); ingestTestDataValidate(anomalyDetector.getIndices().get(0), Instant.now().minus(1, ChronoUnit.DAYS), 1, "error"); ValidateConfigRequest request = new ValidateConfigRequest( diff --git a/src/test/java/org/opensearch/forecast/model/ForecasterTests.java b/src/test/java/org/opensearch/forecast/model/ForecasterTests.java index 380137345..b8719360d 100644 --- a/src/test/java/org/opensearch/forecast/model/ForecasterTests.java +++ b/src/test/java/org/opensearch/forecast/model/ForecasterTests.java @@ -90,7 +90,8 @@ public void testForecasterConstructor() { customResultIndexMinSize, customResultIndexMinAge, customResultIndexTTL, - flattenResultIndexMapping + flattenResultIndexMapping, + lastUpdateTime ); assertEquals(forecasterId, forecaster.getId()); @@ -144,7 +145,8 @@ public void testForecasterConstructorWithNullForecastInterval() { customResultIndexMinSize, customResultIndexMinAge, customResultIndexTTL, - flattenResultIndexMapping + flattenResultIndexMapping, + lastUpdateTime ); }); @@ -183,7 +185,8 @@ public void testNegativeInterval() { customResultIndexMinSize, customResultIndexMinAge, customResultIndexTTL, - flattenResultIndexMapping + flattenResultIndexMapping, + lastUpdateTime ); }); @@ -222,7 +225,8 @@ public void testMaxCategoryFieldsLimits() { customResultIndexMinSize, customResultIndexMinAge, customResultIndexTTL, - flattenResultIndexMapping + flattenResultIndexMapping, + lastUpdateTime ); }); @@ -261,7 +265,8 @@ public void testBlankName() { customResultIndexMinSize, customResultIndexMinAge, customResultIndexTTL, - flattenResultIndexMapping + flattenResultIndexMapping, + lastUpdateTime ); }); @@ -300,7 +305,8 @@ public void testInvalidCustomResultIndex() { customResultIndexMinSize, customResultIndexMinAge, customResultIndexTTL, - flattenResultIndexMapping + flattenResultIndexMapping, + lastUpdateTime ); }); @@ -338,7 +344,8 @@ public void testValidCustomResultIndex() { customResultIndexMinSize, customResultIndexMinAge, customResultIndexTTL, - flattenResultIndexMapping + flattenResultIndexMapping, + lastUpdateTime ); assertEquals(resultIndex, forecaster.getCustomResultIndexOrAlias()); @@ -374,7 +381,8 @@ public void testInvalidHorizon() { customResultIndexMinSize, customResultIndexMinAge, customResultIndexTTL, - flattenResultIndexMapping + flattenResultIndexMapping, + lastUpdateTime ); }); diff --git a/src/test/java/org/opensearch/forecast/rest/ForecastRestApiIT.java b/src/test/java/org/opensearch/forecast/rest/ForecastRestApiIT.java index aad6b2039..46d1bdacd 100644 --- a/src/test/java/org/opensearch/forecast/rest/ForecastRestApiIT.java +++ b/src/test/java/org/opensearch/forecast/rest/ForecastRestApiIT.java @@ -6,33 +6,49 @@ package org.opensearch.forecast.rest; import static org.hamcrest.Matchers.containsString; +import static org.opensearch.timeseries.util.RestHandlerUtils.RUN_ONCE; +import static org.opensearch.timeseries.util.RestHandlerUtils.START_JOB; +import static org.opensearch.timeseries.util.RestHandlerUtils.STOP_JOB; import static org.opensearch.timeseries.util.RestHandlerUtils.SUGGEST; import static org.opensearch.timeseries.util.RestHandlerUtils.VALIDATE; +import java.io.IOException; import java.time.Duration; import java.time.Instant; import java.time.temporal.ChronoUnit; +import java.util.ArrayList; import java.util.List; import java.util.Locale; import java.util.Map; +import java.util.Set; import java.util.regex.Matcher; import java.util.regex.Pattern; +import org.apache.http.HttpEntity; +import org.apache.http.ParseException; +import org.apache.http.util.EntityUtils; import org.hamcrest.MatcherAssert; import org.junit.Before; import org.opensearch.client.Response; import org.opensearch.client.ResponseException; import org.opensearch.client.RestClient; import org.opensearch.core.rest.RestStatus; +import org.opensearch.core.xcontent.XContentParser; import org.opensearch.forecast.AbstractForecastSyntheticDataTest; +import org.opensearch.forecast.constant.ForecastCommonName; +import org.opensearch.forecast.model.ForecastTaskProfile; import org.opensearch.forecast.model.Forecaster; import org.opensearch.forecast.settings.ForecastEnabledSetting; import org.opensearch.timeseries.TestHelpers; import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; import org.opensearch.timeseries.constant.CommonMessages; import org.opensearch.timeseries.model.Config; +import org.opensearch.timeseries.model.EntityTaskProfile; +import org.opensearch.timeseries.model.TaskState; +import org.opensearch.timeseries.util.RestHandlerUtils; import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; import com.google.gson.JsonObject; /** @@ -40,14 +56,22 @@ * - Suggest * - Validate * - Create - * + * - run once + * - start + * - stop + * - update */ public class ForecastRestApiIT extends AbstractForecastSyntheticDataTest { + public static final int MAX_RETRY_TIMES = 200; private static final String SUGGEST_INTERVAL_URI; private static final String SUGGEST_INTERVAL_HORIZON_HISTORY_URI; private static final String VALIDATE_FORECASTER; private static final String VALIDATE_FORECASTER_MODEL; private static final String CREATE_FORECASTER; + private static final String RUN_ONCE_FORECASTER; + private static final String START_FORECASTER; + private static final String STOP_FORECASTER; + private static final String UPDATE_FORECASTER; static { SUGGEST_INTERVAL_URI = String @@ -72,6 +96,10 @@ public class ForecastRestApiIT extends AbstractForecastSyntheticDataTest { VALIDATE_FORECASTER_MODEL = String .format(Locale.ROOT, "%s/%s/%s", TimeSeriesAnalyticsPlugin.FORECAST_FORECASTERS_URI, VALIDATE, "model"); CREATE_FORECASTER = TimeSeriesAnalyticsPlugin.FORECAST_FORECASTERS_URI; + RUN_ONCE_FORECASTER = String.format(Locale.ROOT, "%s/%s/%s", TimeSeriesAnalyticsPlugin.FORECAST_FORECASTERS_URI, "%s", RUN_ONCE); + START_FORECASTER = String.format(Locale.ROOT, "%s/%s/%s", TimeSeriesAnalyticsPlugin.FORECAST_FORECASTERS_URI, "%s", START_JOB); + STOP_FORECASTER = String.format(Locale.ROOT, "%s/%s/%s", TimeSeriesAnalyticsPlugin.FORECAST_FORECASTERS_URI, "%s", STOP_JOB); + UPDATE_FORECASTER = String.format(Locale.ROOT, "%s/%s", TimeSeriesAnalyticsPlugin.FORECAST_FORECASTERS_URI, "%s"); } @Override @@ -1947,7 +1975,7 @@ public void testCreate() throws Exception { .makeRequest( client(), "POST", - String.format(Locale.ROOT, CREATE_FORECASTER),// VALIDATE_FORECASTER_MODEL), + String.format(Locale.ROOT, CREATE_FORECASTER), ImmutableMap.of(), TestHelpers.toHttpEntity(formattedForecaster), null @@ -2034,4 +2062,497 @@ public void testCreate() throws Exception { Map responseMap = entityAsMap(response); assertEquals("opensearch-forecast-result-b", ((Map) responseMap.get("forecaster")).get("result_index")); } + + public void testRunOnce() throws Exception { + Instant trainTime = loadRuleData(200); + // case 1: happy case + String forecasterDef = "{\n" + + " \"name\": \"Second-Test-Forecaster-4\",\n" + + " \"description\": \"ok rate\",\n" + + " \"time_field\": \"timestamp\",\n" + + " \"indices\": [\n" + + " \"%s\"\n" + + " ],\n" + + " \"feature_attributes\": [\n" + + " {\n" + + " \"feature_id\": \"max1\",\n" + + " \"feature_name\": \"max1\",\n" + + " \"feature_enabled\": true,\n" + + " \"importance\": 1,\n" + + " \"aggregation_query\": {\n" + + " \"max1\": {\n" + + " \"max\": {\n" + + " \"field\": \"transform._doc_count\"\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + " ],\n" + + " \"window_delay\": {\n" + + " \"period\": {\n" + + " \"interval\": %d,\n" + + " \"unit\": \"MINUTES\"\n" + + " }\n" + + " },\n" + + " \"ui_metadata\": {\n" + + " \"aabb\": {\n" + + " \"ab\": \"bb\"\n" + + " }\n" + + " },\n" + + " \"schema_version\": 2,\n" + + " \"horizon\": 24,\n" + + " \"forecast_interval\": {\n" + + " \"period\": {\n" + + " \"interval\": 10,\n" + + " \"unit\": \"MINUTES\"\n" + + " }\n" + + " }\n" + + "}"; + + // +1 to make sure it is big enough + long windowDelayMinutes = Duration.between(trainTime, Instant.now()).toMinutes() + 1; + final String formattedForecaster = String.format(Locale.ROOT, forecasterDef, RULE_DATASET_NAME, windowDelayMinutes); + Response response = TestHelpers + .makeRequest( + client(), + "POST", + String.format(Locale.ROOT, CREATE_FORECASTER), + ImmutableMap.of(), + TestHelpers.toHttpEntity(formattedForecaster), + null + ); + Map responseMap = entityAsMap(response); + String forecasterId = (String) responseMap.get("_id"); + + // run once + response = TestHelpers + .makeRequest( + client(), + "POST", + String.format(Locale.ROOT, RUN_ONCE_FORECASTER, forecasterId), + ImmutableMap.of(), + (HttpEntity) null, + null + ); + + ForecastTaskProfile forecastTaskProfile = (ForecastTaskProfile) waitUntilTaskReachState( + forecasterId, + ImmutableSet.of(TaskState.TEST_COMPLETE.name()) + ).get(0); + assertTrue(forecastTaskProfile != null); + assertTrue(forecastTaskProfile.getTask().isLatest()); + + responseMap = entityAsMap(response); + String taskId = (String) responseMap.get(EntityTaskProfile.TASK_ID_FIELD); + assertEquals(taskId, forecastTaskProfile.getTaskId()); + + response = searchTaskResult(taskId); + responseMap = entityAsMap(response); + int total = (int) (((Map) ((Map) responseMap.get("hits")).get("total")).get("value")); + assertTrue("actual: " + total, total > 40); + + // case 2: cannot run once while forecaster is started + response = TestHelpers + .makeRequest( + client(), + "POST", + String.format(Locale.ROOT, START_FORECASTER, forecasterId), + ImmutableMap.of(), + (HttpEntity) null, + null + ); + responseMap = entityAsMap(response); + assertEquals(forecasterId, responseMap.get("_id")); + + // starting another run once before finishing causes error + Exception ex = expectThrows( + ResponseException.class, + () -> TestHelpers + .makeRequest( + client(), + "POST", + String.format(Locale.ROOT, RUN_ONCE_FORECASTER, forecasterId), + ImmutableMap.of(), + (HttpEntity) null, + null + ) + ); + + String reason = ex.getMessage(); + assertTrue("actual: " + reason, reason.contains("Cannot run once " + forecasterId + " when real time job is running.")); + + // case 3: stop forecaster + response = TestHelpers + .makeRequest( + client(), + "POST", + String.format(Locale.ROOT, STOP_FORECASTER, forecasterId), + ImmutableMap.of(), + (HttpEntity) null, + null + ); + responseMap = entityAsMap(response); + assertEquals(forecasterId, responseMap.get("_id")); + } + + public ForecastTaskProfile getForecastTaskProfile(String forecasterId) throws IOException, ParseException { + Response profileResponse = TestHelpers + .makeRequest( + client(), + "GET", + TimeSeriesAnalyticsPlugin.FORECAST_FORECASTERS_URI + "/" + forecasterId + "/_profile/" + ForecastCommonName.FORECAST_TASK, + ImmutableMap.of(), + "", + null + ); + return parseForecastTaskProfile(profileResponse); + } + + public Response searchTaskResult(String taskId) throws IOException { + Response response = TestHelpers + .makeRequest( + client(), + "GET", + "opensearch-forecast-result*/_search", + ImmutableMap.of(), + TestHelpers + .toHttpEntity( + "{\"query\":{\"bool\":{\"filter\":[{\"term\":{\"task_id\":\"" + taskId + "\"}}]}},\"track_total_hits\":true}" + ), + null + ); + return response; + } + + public ForecastTaskProfile parseForecastTaskProfile(Response profileResponse) throws IOException, ParseException { + String profileResult = EntityUtils.toString(profileResponse.getEntity()); + XContentParser parser = TestHelpers.parser(profileResult); + ForecastTaskProfile forecastTaskProfile = null; + while (parser.nextToken() != XContentParser.Token.END_OBJECT) { + String fieldName = parser.currentName(); + parser.nextToken(); + if ("forecast_task".equals(fieldName)) { + forecastTaskProfile = ForecastTaskProfile.parse(parser); + } else { + parser.skipChildren(); + } + } + return forecastTaskProfile; + } + + protected List waitUntilTaskReachState(String forecasterId, Set targetStates) throws InterruptedException { + List results = new ArrayList<>(); + int i = 0; + ForecastTaskProfile forecastTaskProfile = null; + // Increase retryTimes if some task can't reach done state + while ((forecastTaskProfile == null || !targetStates.contains(forecastTaskProfile.getTask().getState())) && i < MAX_RETRY_TIMES) { + try { + forecastTaskProfile = getForecastTaskProfile(forecasterId); + } catch (Exception e) { + logger.error("failed to get ForecastTaskProfile", e); + } finally { + Thread.sleep(1000); + } + i++; + } + assertNotNull(forecastTaskProfile); + results.add(forecastTaskProfile); + results.add(i); + return results; + } + + public void testCreateDetector() throws Exception { + // Case 1: users cannot specify forecaster id when creating a forecaster + Instant trainTime = loadRuleData(200); + String forecasterDef = "{\n" + + " \"name\": \"Second-Test-Forecaster-4\",\n" + + " \"description\": \"ok rate\",\n" + + " \"time_field\": \"timestamp\",\n" + + " \"indices\": [\n" + + " \"%s\"\n" + + " ],\n" + + " \"feature_attributes\": [\n" + + " {\n" + + " \"feature_id\": \"max1\",\n" + + " \"feature_name\": \"max1\",\n" + + " \"feature_enabled\": true,\n" + + " \"importance\": 1,\n" + + " \"aggregation_query\": {\n" + + " \"max1\": {\n" + + " \"max\": {\n" + + " \"field\": \"transform._doc_count\"\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + " ],\n" + + " \"window_delay\": {\n" + + " \"period\": {\n" + + " \"interval\": %d,\n" + + " \"unit\": \"MINUTES\"\n" + + " }\n" + + " },\n" + + " \"ui_metadata\": {\n" + + " \"aabb\": {\n" + + " \"ab\": \"bb\"\n" + + " }\n" + + " },\n" + + " \"schema_version\": 2,\n" + + " \"horizon\": 24,\n" + + " \"forecast_interval\": {\n" + + " \"period\": {\n" + + " \"interval\": 10,\n" + + " \"unit\": \"MINUTES\"\n" + + " }\n" + + " }\n" + + "}"; + + // +1 to make sure it is big enough + long windowDelayMinutes = Duration.between(trainTime, Instant.now()).toMinutes() + 1; + final String formattedForecaster = String.format(Locale.ROOT, forecasterDef, RULE_DATASET_NAME, windowDelayMinutes); + String blahId = "__blah__"; + Response response = TestHelpers + .makeRequest( + client(), + "POST", + String.format(Locale.ROOT, CREATE_FORECASTER), + ImmutableMap.of(RestHandlerUtils.FORECASTER_ID, blahId), + TestHelpers.toHttpEntity(formattedForecaster), + null + ); + Map responseMap = entityAsMap(response); + String forecasterId = (String) responseMap.get("_id"); + assertNotEquals("response is missing Id", blahId, forecasterId); + } + + public void testUpdateDetector() throws Exception { + // Case 1: update non-impactful fields like name or description won't change last breaking change UI time + Instant trainTime = loadRuleData(200); + String forecasterDef = "{\n" + + " \"name\": \"Second-Test-Forecaster-4\",\n" + + " \"description\": \"ok rate\",\n" + + " \"time_field\": \"timestamp\",\n" + + " \"indices\": [\n" + + " \"%s\"\n" + + " ],\n" + + " \"feature_attributes\": [\n" + + " {\n" + + " \"feature_id\": \"max1\",\n" + + " \"feature_name\": \"max1\",\n" + + " \"feature_enabled\": true,\n" + + " \"importance\": 1,\n" + + " \"aggregation_query\": {\n" + + " \"max1\": {\n" + + " \"max\": {\n" + + " \"field\": \"transform._doc_count\"\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + " ],\n" + + " \"window_delay\": {\n" + + " \"period\": {\n" + + " \"interval\": %d,\n" + + " \"unit\": \"MINUTES\"\n" + + " }\n" + + " },\n" + + " \"ui_metadata\": {\n" + + " \"aabb\": {\n" + + " \"ab\": \"bb\"\n" + + " }\n" + + " },\n" + + " \"schema_version\": 2,\n" + + " \"horizon\": 24,\n" + + " \"forecast_interval\": {\n" + + " \"period\": {\n" + + " \"interval\": 10,\n" + + " \"unit\": \"MINUTES\"\n" + + " }\n" + + " }\n" + + "}"; + + // +1 to make sure it is big enough + long windowDelayMinutes = Duration.between(trainTime, Instant.now()).toMinutes() + 1; + final String formattedForecaster = String.format(Locale.ROOT, forecasterDef, RULE_DATASET_NAME, windowDelayMinutes); + Response response = TestHelpers + .makeRequest( + client(), + "POST", + String.format(Locale.ROOT, CREATE_FORECASTER), + ImmutableMap.of(), + TestHelpers.toHttpEntity(formattedForecaster), + null + ); + Map responseMap = entityAsMap(response); + String forecasterId = (String) responseMap.get("_id"); + assertEquals(null, responseMap.get("last_ui_breaking_change_time")); + + // changing description won't change last_breaking_change_ui_time + forecasterDef = "{\n" + + " \"name\": \"Second-Test-Forecaster-4\",\n" + + " \"description\": \"ok rate1\",\n" + + " \"time_field\": \"timestamp\",\n" + + " \"indices\": [\n" + + " \"%s\"\n" + + " ],\n" + + " \"feature_attributes\": [\n" + + " {\n" + + " \"feature_id\": \"max1\",\n" + + " \"feature_name\": \"max1\",\n" + + " \"feature_enabled\": true,\n" + + " \"importance\": 1,\n" + + " \"aggregation_query\": {\n" + + " \"max1\": {\n" + + " \"max\": {\n" + + " \"field\": \"transform._doc_count\"\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + " ],\n" + + " \"window_delay\": {\n" + + " \"period\": {\n" + + " \"interval\": %d,\n" + + " \"unit\": \"MINUTES\"\n" + + " }\n" + + " },\n" + + " \"ui_metadata\": {\n" + + " \"aabb\": {\n" + + " \"ab\": \"bb\"\n" + + " }\n" + + " },\n" + + " \"schema_version\": 2,\n" + + " \"horizon\": 24,\n" + + " \"forecast_interval\": {\n" + + " \"period\": {\n" + + " \"interval\": 10,\n" + + " \"unit\": \"MINUTES\"\n" + + " }\n" + + " }\n" + + "}"; + response = TestHelpers + .makeRequest( + client(), + "PUT", + String.format(Locale.ROOT, UPDATE_FORECASTER, forecasterId), + ImmutableMap.of(), + TestHelpers.toHttpEntity(formattedForecaster), + null + ); + responseMap = entityAsMap(response); + assertEquals(null, responseMap.get("last_ui_breaking_change_time")); + + // changing categorical fields changes last_ui_breaking_change_time + forecasterDef = "{\n" + + " \"name\": \"Second-Test-Forecaster-4\",\n" + + " \"description\": \"ok rate1\",\n" + + " \"time_field\": \"timestamp\",\n" + + " \"indices\": [\n" + + " \"%s\"\n" + + " ],\n" + + " \"feature_attributes\": [\n" + + " {\n" + + " \"feature_id\": \"max1\",\n" + + " \"feature_name\": \"max1\",\n" + + " \"feature_enabled\": true,\n" + + " \"importance\": 1,\n" + + " \"aggregation_query\": {\n" + + " \"max1\": {\n" + + " \"max\": {\n" + + " \"field\": \"transform._doc_count\"\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + " ],\n" + + " \"window_delay\": {\n" + + " \"period\": {\n" + + " \"interval\": %d,\n" + + " \"unit\": \"MINUTES\"\n" + + " }\n" + + " },\n" + + " \"ui_metadata\": {\n" + + " \"aabb\": {\n" + + " \"ab\": \"bb\"\n" + + " }\n" + + " },\n" + + " \"schema_version\": 2,\n" + + " \"horizon\": 24,\n" + + " \"forecast_interval\": {\n" + + " \"period\": {\n" + + " \"interval\": 10,\n" + + " \"unit\": \"MINUTES\"\n" + + " }\n" + + " },\n" + + " \"category_field\": [\"componentName\"]" + + "}"; + response = TestHelpers + .makeRequest( + client(), + "PUT", + String.format(Locale.ROOT, UPDATE_FORECASTER, forecasterId), + ImmutableMap.of(), + TestHelpers.toHttpEntity(formattedForecaster), + null + ); + responseMap = entityAsMap(response); + assertEquals(responseMap.get("last_update_time"), responseMap.get("last_ui_breaking_change_time")); + + // changing custom result index changes last_ui_breaking_change_time + forecasterDef = "{\n" + + " \"name\": \"Second-Test-Forecaster-4\",\n" + + " \"description\": \"ok rate1\",\n" + + " \"time_field\": \"timestamp\",\n" + + " \"indices\": [\n" + + " \"%s\"\n" + + " ],\n" + + " \"feature_attributes\": [\n" + + " {\n" + + " \"feature_id\": \"max1\",\n" + + " \"feature_name\": \"max1\",\n" + + " \"feature_enabled\": true,\n" + + " \"importance\": 1,\n" + + " \"aggregation_query\": {\n" + + " \"max1\": {\n" + + " \"max\": {\n" + + " \"field\": \"transform._doc_count\"\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + " ],\n" + + " \"window_delay\": {\n" + + " \"period\": {\n" + + " \"interval\": %d,\n" + + " \"unit\": \"MINUTES\"\n" + + " }\n" + + " },\n" + + " \"ui_metadata\": {\n" + + " \"aabb\": {\n" + + " \"ab\": \"bb\"\n" + + " }\n" + + " },\n" + + " \"schema_version\": 2,\n" + + " \"horizon\": 24,\n" + + " \"forecast_interval\": {\n" + + " \"period\": {\n" + + " \"interval\": 10,\n" + + " \"unit\": \"MINUTES\"\n" + + " }\n" + + " },\n" + + " \"category_field\": [\"componentName\"]," + + " \"result_index\": \"opensearch-forecast-result-b\"" + + "}"; + response = TestHelpers + .makeRequest( + client(), + "PUT", + String.format(Locale.ROOT, UPDATE_FORECASTER, forecasterId), + ImmutableMap.of(), + TestHelpers.toHttpEntity(formattedForecaster), + null + ); + responseMap = entityAsMap(response); + assertEquals(responseMap.get("last_update_time"), responseMap.get("last_ui_breaking_change_time")); + } } diff --git a/src/test/java/org/opensearch/timeseries/TestHelpers.java b/src/test/java/org/opensearch/timeseries/TestHelpers.java index 2c75febc0..22dcf64bd 100644 --- a/src/test/java/org/opensearch/timeseries/TestHelpers.java +++ b/src/test/java/org/opensearch/timeseries/TestHelpers.java @@ -340,7 +340,8 @@ public static AnomalyDetector randomAnomalyDetector( null, null, null, - null + null, + lastUpdateTime ); } @@ -395,7 +396,8 @@ public static AnomalyDetector randomDetector( null, null, null, - null + null, + Instant.now() ); } @@ -461,7 +463,8 @@ public static AnomalyDetector randomAnomalyDetectorUsingCategoryFields( null, null, null, - null + null, + Instant.now() ); } @@ -502,7 +505,8 @@ public static AnomalyDetector randomAnomalyDetector(String timefield, String ind null, null, null, - null + null, + Instant.now() ); } @@ -535,7 +539,8 @@ public static AnomalyDetector randomAnomalyDetectorWithEmptyFeature() throws IOE null, null, null, - null + null, + Instant.now().truncatedTo(ChronoUnit.SECONDS) ); } @@ -575,7 +580,8 @@ public static AnomalyDetector randomAnomalyDetectorWithInterval(TimeConfiguratio null, null, null, - null + null, + Instant.now().truncatedTo(ChronoUnit.SECONDS) ); } @@ -753,7 +759,8 @@ public AnomalyDetector build() { null, null, null, - null + null, + lastUpdateTime ); } } @@ -790,7 +797,8 @@ public static AnomalyDetector randomAnomalyDetectorWithInterval(TimeConfiguratio null, null, null, - null + null, + Instant.now().truncatedTo(ChronoUnit.SECONDS) ); } @@ -1940,13 +1948,15 @@ public Forecaster build() { resultIndex, horizon, imputationOption, - randomIntBetween(1, 1000), + // Recency emphasis must be an integer greater than 1 + randomIntBetween(2, 1000), randomIntBetween(1, 128), randomIntBetween(1, 1000), customResultIndexMinSize, customResultIndexMinAge, customResultIndexTTL, - flattenResultIndexMapping + flattenResultIndexMapping, + lastUpdateTime ); } } @@ -1974,13 +1984,15 @@ public static Forecaster randomForecaster() throws IOException { null, randomIntBetween(1, 20), randomImputationOption(featureList), - randomIntBetween(1, 1000), + // Recency emphasis must be an integer greater than 1 + randomIntBetween(2, 1000), randomIntBetween(1, 128), randomIntBetween(1, 1000), null, null, null, - null + null, + Instant.now().truncatedTo(ChronoUnit.SECONDS) ); } diff --git a/src/test/java/org/opensearch/timeseries/transport/BooleanResponseTests.java b/src/test/java/org/opensearch/timeseries/transport/BooleanResponseTests.java new file mode 100644 index 000000000..8d181bf3d --- /dev/null +++ b/src/test/java/org/opensearch/timeseries/transport/BooleanResponseTests.java @@ -0,0 +1,204 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.timeseries.transport; + +import java.io.IOException; +import java.util.Collections; +import java.util.List; + +import org.opensearch.Version; +import org.opensearch.action.FailedNodeException; +import org.opensearch.cluster.ClusterName; +import org.opensearch.cluster.node.DiscoveryNode; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.test.OpenSearchTestCase; + +public class BooleanResponseTests extends OpenSearchTestCase { + + public void testBooleanResponseSerialization() throws IOException { + // Arrange + DiscoveryNode node = new DiscoveryNode( + "nodeId", + buildNewFakeTransportAddress(), + Collections.emptyMap(), + Collections.emptySet(), + Version.CURRENT + ); + + BooleanNodeResponse nodeResponseTrue = new BooleanNodeResponse(node, true); + BooleanNodeResponse nodeResponseFalse = new BooleanNodeResponse(node, false); + List nodes = List.of(nodeResponseTrue, nodeResponseFalse); + List failures = Collections.emptyList(); + ClusterName clusterName = new ClusterName("test-cluster"); + + BooleanResponse response = new BooleanResponse(clusterName, nodes, failures); + + // Act: Serialize the response + BytesStreamOutput output = new BytesStreamOutput(); + response.writeTo(output); + + // Deserialize the response + StreamInput input = output.bytes().streamInput(); + BooleanResponse deserializedResponse = new BooleanResponse(input); + + // Assert + assertEquals(clusterName, deserializedResponse.getClusterName()); + assertEquals(response.getNodes().size(), deserializedResponse.getNodes().size()); + assertEquals(response.failures().size(), deserializedResponse.failures().size()); + assertEquals(response.isAnswerTrue(), deserializedResponse.isAnswerTrue()); + } + + public void testBooleanResponseReadNodesFromAndWriteNodesTo() throws IOException { + // Arrange + DiscoveryNode node1 = new DiscoveryNode( + "nodeId1", + buildNewFakeTransportAddress(), + Collections.emptyMap(), + Collections.emptySet(), + Version.CURRENT + ); + DiscoveryNode node2 = new DiscoveryNode( + "nodeId2", + buildNewFakeTransportAddress(), + Collections.emptyMap(), + Collections.emptySet(), + Version.CURRENT + ); + + BooleanNodeResponse nodeResponse1 = new BooleanNodeResponse(node1, true); + BooleanNodeResponse nodeResponse2 = new BooleanNodeResponse(node2, false); + List nodes = List.of(nodeResponse1, nodeResponse2); + ClusterName clusterName = new ClusterName("test-cluster"); + BooleanResponse response = new BooleanResponse(clusterName, nodes, Collections.emptyList()); + + // Act: Write nodes to output + BytesStreamOutput output = new BytesStreamOutput(); + response.writeNodesTo(output, nodes); + + // Read nodes from input + StreamInput input = output.bytes().streamInput(); + List readNodes = response.readNodesFrom(input); + + // Assert + assertEquals(nodes.size(), readNodes.size()); + assertEquals(nodes.get(0).isAnswerTrue(), readNodes.get(0).isAnswerTrue()); + assertEquals(nodes.get(1).isAnswerTrue(), readNodes.get(1).isAnswerTrue()); + } + + public void testBooleanNodeResponseSerialization() throws IOException { + // Arrange + DiscoveryNode node = new DiscoveryNode( + "nodeId", + buildNewFakeTransportAddress(), + Collections.emptyMap(), + Collections.emptySet(), + Version.CURRENT + ); + + BooleanNodeResponse nodeResponse = new BooleanNodeResponse(node, true); + + // Act: Serialize the node response + BytesStreamOutput output = new BytesStreamOutput(); + nodeResponse.writeTo(output); + + // Deserialize the node response + StreamInput input = output.bytes().streamInput(); + BooleanNodeResponse deserializedNodeResponse = new BooleanNodeResponse(input); + + // Assert + assertEquals(node, deserializedNodeResponse.getNode()); + assertEquals(nodeResponse.isAnswerTrue(), deserializedNodeResponse.isAnswerTrue()); + } + + public void testBooleanResponseAnswerAggregation() { + // Arrange + DiscoveryNode node1 = new DiscoveryNode( + "nodeId1", + buildNewFakeTransportAddress(), + Collections.emptyMap(), + Collections.emptySet(), + Version.CURRENT + ); + DiscoveryNode node2 = new DiscoveryNode( + "nodeId2", + buildNewFakeTransportAddress(), + Collections.emptyMap(), + Collections.emptySet(), + Version.CURRENT + ); + + BooleanNodeResponse nodeResponseTrue = new BooleanNodeResponse(node1, true); + BooleanNodeResponse nodeResponseFalse = new BooleanNodeResponse(node2, false); + List nodes = List.of(nodeResponseTrue, nodeResponseFalse); + ClusterName clusterName = new ClusterName("test-cluster"); + + // Act + BooleanResponse response = new BooleanResponse(clusterName, nodes, Collections.emptyList()); + + // Assert + assertTrue(response.isAnswerTrue()); // Since at least one node responded true + } + + public void testBooleanResponseAllFalse() { + // Arrange + DiscoveryNode node1 = new DiscoveryNode( + "nodeId1", + buildNewFakeTransportAddress(), + Collections.emptyMap(), + Collections.emptySet(), + Version.CURRENT + ); + DiscoveryNode node2 = new DiscoveryNode( + "nodeId2", + buildNewFakeTransportAddress(), + Collections.emptyMap(), + Collections.emptySet(), + Version.CURRENT + ); + + BooleanNodeResponse nodeResponse1 = new BooleanNodeResponse(node1, false); + BooleanNodeResponse nodeResponse2 = new BooleanNodeResponse(node2, false); + List nodes = List.of(nodeResponse1, nodeResponse2); + ClusterName clusterName = new ClusterName("test-cluster"); + + // Act + BooleanResponse response = new BooleanResponse(clusterName, nodes, Collections.emptyList()); + + // Assert + assertFalse(response.isAnswerTrue()); // Since all nodes responded false + } + + public void testToXContent() throws IOException { + // Arrange + DiscoveryNode node = new DiscoveryNode( + "nodeId", + buildNewFakeTransportAddress(), + Collections.emptyMap(), + Collections.emptySet(), + Version.CURRENT + ); + + BooleanNodeResponse nodeResponse = new BooleanNodeResponse(node, true); + List nodes = Collections.singletonList(nodeResponse); + ClusterName clusterName = new ClusterName("test-cluster"); + BooleanResponse response = new BooleanResponse(clusterName, nodes, Collections.emptyList()); + + // Act + XContentBuilder builder = XContentFactory.jsonBuilder(); + builder.startObject(); + response.toXContent(builder, ToXContent.EMPTY_PARAMS); + builder.endObject(); + + String jsonString = builder.toString(); + + // Assert + assertTrue(jsonString.contains("\"answer\":true")); + } +} diff --git a/src/test/java/org/opensearch/timeseries/transport/CronTransportActionTests.java b/src/test/java/org/opensearch/timeseries/transport/CronTransportActionTests.java index 7939e522c..03c1a2a01 100644 --- a/src/test/java/org/opensearch/timeseries/transport/CronTransportActionTests.java +++ b/src/test/java/org/opensearch/timeseries/transport/CronTransportActionTests.java @@ -28,6 +28,7 @@ import org.opensearch.ad.common.exception.JsonPathNotFoundException; import org.opensearch.ad.ml.ADColdStart; import org.opensearch.ad.ml.ADModelManager; +import org.opensearch.ad.ml.ADRealTimeInferencer; import org.opensearch.ad.task.ADTaskManager; import org.opensearch.cluster.ClusterName; import org.opensearch.cluster.node.DiscoveryNode; @@ -40,6 +41,7 @@ import org.opensearch.forecast.caching.ForecastCacheProvider; import org.opensearch.forecast.caching.ForecastPriorityCache; import org.opensearch.forecast.ml.ForecastColdStart; +import org.opensearch.forecast.ml.ForecastRealTimeInferencer; import org.opensearch.forecast.task.ForecastTaskManager; import org.opensearch.threadpool.ThreadPool; import org.opensearch.timeseries.AbstractTimeSeriesTest; @@ -83,6 +85,9 @@ public void setUp() throws Exception { when(forecastCacheProvider.get()).thenReturn(forecastCache); ForecastTaskManager forecastTaskManager = mock(ForecastTaskManager.class); + ADRealTimeInferencer adRealTimeInferencer = mock(ADRealTimeInferencer.class); + ForecastRealTimeInferencer forecastRealTimeInferencer = mock(ForecastRealTimeInferencer.class); + action = new CronTransportAction( threadPool, clusterService, @@ -95,7 +100,9 @@ public void setUp() throws Exception { entityColdStarter, forecastColdStarter, adTaskManager, - forecastTaskManager + forecastTaskManager, + adRealTimeInferencer, + forecastRealTimeInferencer ); }