diff --git a/hloc/pairs_from_retrieval.py b/hloc/pairs_from_retrieval.py index 32336801..4ccfbc10 100644 --- a/hloc/pairs_from_retrieval.py +++ b/hloc/pairs_from_retrieval.py @@ -81,6 +81,7 @@ def main( db_list=None, db_model=None, db_descriptors=None, + match_mask=None, ): logger.info("Extracting image pairs from a retrieval database.") @@ -108,8 +109,15 @@ def main( query_desc = get_descriptors(query_names, descriptors) sim = torch.einsum("id,jd->ij", query_desc.to(device), db_desc.to(device)) - # Avoid self-matching - self = np.array(query_names)[:, None] == np.array(db_names)[None] + if match_mask is None: + # Avoid self-matching + self = np.array(query_names)[:, None] == np.array(db_names)[None] + else: + assert match_mask.shape == ( + len(query_names), + len(db_names), + ), "mask shape must match size of query and database images!" + self = match_mask pairs = pairs_from_score_matrix(sim, self, num_matched, min_score=0) pairs = [(query_names[i], db_names[j]) for i, j in pairs]