diff --git a/tensorboard/plugins/projector/vz_projector/data.ts b/tensorboard/plugins/projector/vz_projector/data.ts index ba89c6e997e..e0a64d15d8c 100644 --- a/tensorboard/plugins/projector/vz_projector/data.ts +++ b/tensorboard/plugins/projector/vz_projector/data.ts @@ -97,6 +97,8 @@ export const PCA_SAMPLE_DIM = 200; const NUM_PCA_COMPONENTS = 10; /** Id of message box used for umap optimization progress bar. */ const UMAP_MSG_ID = 'umap-optimization'; +/** Minimum KNN neighbors threshold */ +const MIN_NUM_KNN_NEIGHBORS = 300; /** * Reserved metadata attributes used for sequence information * NOTE: Use "__seq_next__" as "__next__" is deprecated. @@ -474,16 +476,26 @@ export class DataSet { ); } else { const knnGpuEnabled = (await util.hasWebGLSupport()) && !IS_FIREFOX; + const numKnnNeighborsToCompute = Math.max( + nNeighbors, + MIN_NUM_KNN_NEIGHBORS + ); const result = await (knnGpuEnabled - ? knn.findKNNGPUCosDistNorm(data, nNeighbors, (d) => d.vector) + ? knn.findKNNGPUCosDistNorm( + data, + numKnnNeighborsToCompute, + (d) => d.vector + ) : knn.findKNN( data, - nNeighbors, + numKnnNeighborsToCompute, (d) => d.vector, (a, b) => vector.cosDistNorm(a, b) )); this.nearest = result; - return Promise.resolve(result); + return Promise.resolve( + result.map((neighbors) => neighbors.slice(0, nNeighbors)) + ); } } /* Perturb TSNE and update dataset point coordinates. */