diff --git a/hestia/similarity.py b/hestia/similarity.py index fc17e92..970ff04 100644 --- a/hestia/similarity.py +++ b/hestia/similarity.py @@ -413,6 +413,7 @@ def _embedding_distance( target_embds: Optional[np.ndarray] = None, distance: Union[str, Callable] = 'cosine', threads: int = cpu_count(), + threshold: float = 0.0, save_alignment: bool = False, filename: str = None, to_df: bool = True, @@ -429,8 +430,11 @@ def _embedding_distance( data = [] for idx in tqdm(range(mtx.shape[0])): for idx2 in range(mtx.shape[1]): + value = mtx[idx, idx2] + if value < threshold: + continue data.append({'query': idx, 'target': idx2, - 'metric': mtx[idx, idx2]}) + 'metric': 1 - value}) df = pd.DataFrame(data) if save_alignment: if filename is None: