Skip to content

Commit

Permalink
Added Checks For Valid K Values
Browse files Browse the repository at this point in the history
  • Loading branch information
Johnny-Knighten committed Mar 26, 2019
1 parent 5954287 commit 2749169
Showing 1 changed file with 22 additions and 0 deletions.
22 changes: 22 additions & 0 deletions knn/mixins/nnsearch.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ class NNSearchMixin:
Attributes:
ball_tree (:obj:'BallTree'): The Ball Tree constructed and used by the Ball Tree methods
numb_of_train_vectors(int): The number of training vectors
"""

Expand All @@ -23,6 +24,7 @@ def __init__(self):
"""
self.ball_tee = None
self.number_of_train_vectors = 0

def _brute_force_nn_query(self, train_data, test_data, k=1, metric="euclidean"):
""" Finds the NNs of a set of query vectors using brute force.
Expand All @@ -41,6 +43,16 @@ def _brute_force_nn_query(self, train_data, test_data, k=1, metric="euclidean"):
k_smallest_distances (ndarray): A 2D array of distances of the NNs for each query vector.
"""
if k > train_data.shape[0]:
raise ValueError("k Must Be Smaller Than Or Equal To The Number Of Training Vectors")

if k < 0:
raise ValueError("k Must Be Greater Than 0")

if not isinstance(k, int):
raise ValueError("k Must Be An Integer")


# The User Is Allowed To Pass In Their Own Metric
if not callable(metric):
metrics = {"euclidean": dm.euclidean_pairwise,
Expand All @@ -67,6 +79,7 @@ def _ball_tree_build(self, train_data, leaf_size=1, metric="euclidean"):
metric(string): The metric used in the search.
"""
self.number_of_train_vectors = train_data.shape[0]
self.ball_tree = BallTree(train_data, leaf_size, metric)
self.ball_tree.build_tree()

Expand All @@ -85,6 +98,15 @@ def _ball_tree_nn_query(self, query_vectors, k=1):
distances (ndarray): A 2D array of distances of the NNs for each query vector.
"""
if k > self.number_of_train_vectors:
raise ValueError("k Must Be Smaller Than Or Equal To The Number Of Training Vectors")

if k < 0:
raise ValueError("k Must Be Greater Than 0")

if not isinstance(k, int):
raise ValueError("k Must Be An Integer")

self.ball_tree.query(query_vectors, k)

return self.ball_tree.heap_inds, self.ball_tree.heap

0 comments on commit 2749169

Please sign in to comment.