forked from opensearch-project/k-NN
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Validate zero vector when using cosine metric (opensearch-project#1501)
Ensure zero vector is not used when using functionality with cosine similarity metric. Signed-off-by: panguixin <[email protected]>
- Loading branch information
1 parent
089db16
commit b7bdda4
Showing
20 changed files
with
472 additions
and
155 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
83 changes: 83 additions & 0 deletions
83
src/main/java/org/opensearch/knn/common/KNNValidationUtil.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
/* | ||
* SPDX-License-Identifier: Apache-2.0 | ||
* | ||
* The OpenSearch Contributors require contributions made to | ||
* this file be licensed under the Apache-2.0 license or a | ||
* compatible open source license. | ||
* | ||
* Modifications Copyright OpenSearch Contributors. See | ||
* GitHub history for details. | ||
*/ | ||
|
||
package org.opensearch.knn.common; | ||
|
||
import java.util.Locale; | ||
import lombok.AccessLevel; | ||
import lombok.NoArgsConstructor; | ||
import org.opensearch.knn.index.VectorDataType; | ||
|
||
import static org.opensearch.knn.common.KNNConstants.VECTOR_DATA_TYPE_FIELD; | ||
|
||
@NoArgsConstructor(access = AccessLevel.PRIVATE) | ||
public class KNNValidationUtil { | ||
/** | ||
* Validate the float vector value and throw exception if it is not a number or not in the finite range. | ||
* | ||
* @param value float vector value | ||
*/ | ||
public static void validateFloatVectorValue(float value) { | ||
if (Float.isNaN(value)) { | ||
throw new IllegalArgumentException("KNN vector values cannot be NaN"); | ||
} | ||
|
||
if (Float.isInfinite(value)) { | ||
throw new IllegalArgumentException("KNN vector values cannot be infinity"); | ||
} | ||
} | ||
|
||
/** | ||
* Validate the float vector value in the byte range if it is a finite number, | ||
* with no decimal values and in the byte range of [-128 to 127]. If not throw IllegalArgumentException. | ||
* | ||
* @param value float value in byte range | ||
*/ | ||
public static void validateByteVectorValue(float value) { | ||
validateFloatVectorValue(value); | ||
if (value % 1 != 0) { | ||
throw new IllegalArgumentException( | ||
String.format( | ||
Locale.ROOT, | ||
"[%s] field was set as [%s] in index mapping. But, KNN vector values are floats instead of byte integers", | ||
VECTOR_DATA_TYPE_FIELD, | ||
VectorDataType.BYTE.getValue() | ||
) | ||
|
||
); | ||
} | ||
if ((int) value < Byte.MIN_VALUE || (int) value > Byte.MAX_VALUE) { | ||
throw new IllegalArgumentException( | ||
String.format( | ||
Locale.ROOT, | ||
"[%s] field was set as [%s] in index mapping. But, KNN vector values are not within in the byte range [%d, %d]", | ||
VECTOR_DATA_TYPE_FIELD, | ||
VectorDataType.BYTE.getValue(), | ||
Byte.MIN_VALUE, | ||
Byte.MAX_VALUE | ||
) | ||
); | ||
} | ||
} | ||
|
||
/** | ||
* Validate if the given vector size matches with the dimension provided in mapping. | ||
* | ||
* @param dimension dimension of vector | ||
* @param vectorSize size of the vector | ||
*/ | ||
public static void validateVectorDimension(int dimension, int vectorSize) { | ||
if (dimension != vectorSize) { | ||
String errorMessage = String.format(Locale.ROOT, "Vector dimension mismatch. Expected: %d, Given: %d", dimension, vectorSize); | ||
throw new IllegalArgumentException(errorMessage); | ||
} | ||
} | ||
} |
45 changes: 45 additions & 0 deletions
45
src/main/java/org/opensearch/knn/common/KNNVectorUtil.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
/* | ||
* Copyright OpenSearch Contributors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
|
||
package org.opensearch.knn.common; | ||
|
||
import java.util.Objects; | ||
import lombok.AccessLevel; | ||
import lombok.NoArgsConstructor; | ||
|
||
@NoArgsConstructor(access = AccessLevel.PRIVATE) | ||
public class KNNVectorUtil { | ||
/** | ||
* Check if all the elements of a given vector are zero | ||
* | ||
* @param vector the vector | ||
* @return true if yes; otherwise false | ||
*/ | ||
public static boolean isZeroVector(byte[] vector) { | ||
Objects.requireNonNull(vector, "vector must not be null"); | ||
for (byte e : vector) { | ||
if (e != 0) { | ||
return false; | ||
} | ||
} | ||
return true; | ||
} | ||
|
||
/** | ||
* Check if all the elements of a given vector are zero | ||
* | ||
* @param vector the vector | ||
* @return true if yes; otherwise false | ||
*/ | ||
public static boolean isZeroVector(float[] vector) { | ||
Objects.requireNonNull(vector, "vector must not be null"); | ||
for (float e : vector) { | ||
if (e != 0f) { | ||
return false; | ||
} | ||
} | ||
return true; | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.