Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[py-tx] Implement a new cleaner PDQ index solution from scratch #1613

Open
Dcallies opened this issue Aug 21, 2024 · 5 comments
Open

[py-tx] Implement a new cleaner PDQ index solution from scratch #1613

Dcallies opened this issue Aug 21, 2024 · 5 comments
Labels
help wanted mlh Related to Major League Hacking Fellowship pdq Items related to the pdq libraries or reference implementations python-threatexchange Items related to the threatexchange python tool / library

Comments

@Dcallies
Copy link
Contributor

When we built the PDQ index, it was our first attempt, and we made a lot of strange/bad choices.

Namely:

  • The hash -> id mapping is extremely convoluted and needlessly complicated
  • The build and lookup implementations could be simplified

I think we could provide a second implementation that is a lot simpler, which we could then find a way to swap.

They key elements:

Pass in the index type as an argument during construction

    def __init__(
        self,
        threshold: int = DEFAULT_MATCH_DIST,
        faiss_index: t.Optional[faiss.Index] = None,
    ) -> None:

Simplify the stored state of the index implementation

# Body of __init__
       super().__init__()
        if faiss_index is None:
            # Brute force
            faiss_index = faiss.IndexFlatL2(DIMENSIONALITY)
        self.faiss_index = _PDQHashIndex(faiss_index)
        self.threshold = threshold
        self._deduper = {}
        self._idx_to_entries: t.List[t.List[T]] = []

Use a simpler inner wrapper to handle some of the PDQ details

class _PDQHashIndex:
"""
A wrapper around the faiss index for pickle serialization
"""

def __init__(self, faiss_index: faiss.Index) -> None:
    self.faiss_index = faiss_index

def search(
    self,
    queries: t.Sequence[str],
    threshhold: int,
) -> t.List[t.Tuple[int, float]]:
    """
    Search method that return a mapping from query_str =>  (id, distance)
    """
    qs = convert_pdq_strings_to_ndarray(queries)
    # in Python, the results are returned as a triplet of 1D arrays lims, D, I
    # where result for query i is in I[lims[i]:lims[i+1]] (indices of neighbors)
    # D[lims[i]:lims[i+1]] (distances).
    limits, D, I = self.faiss_index.range_search(qs, threshhold + 1)

    results = []
    for i in range(len(queries)):
        matches = [result.item() for result in I[limits[i] : limits[i + 1]]]
        distances = [dist for dist in D[limits[i] : limits[i + 1]]]
        results.append(list(zip(matches, distances)))
    return results

def __getstate__(self):
    data = faiss.serialize_index(self.faiss_index)
    return data

def __setstate__(self, data):
    self.faiss_index = faiss.deserialize_index(data)

Putting it together with search

    def query(self, query: str) -> t.List[IndexMatch[T]]:
        results = self.faiss_index.search([query], self.threshold)
        return [
            IndexMatch(int(distf), entry)
            for idx, distf in results[0]
            for entry in self._idx_to_entries[idx]
        ]

Dynamically selecting lookup type from build function

  @classmethod
    def build(cls: t.Type[Self], entries: t.Iterable[t.Tuple[str, T]]) -> Self:
        """
        Faiss has many potential options that we can choose based on the size of the index.
        """
        entry_list = list(entries)
        xn = len(entry_list)
        if xn < 1024:  # If small enough, just use brute force
            return super().build(entry_list)
        centroids = pick_n_centroids(xn)
        index = faiss.index_factory(DIMENSIONALITY, f"IVF{centroids}")  # TODO - use the same magic factory string as the old one does
        # Squelch warnings about not having enough points...
        index.cp.min_points_per_centroid = 1
        index.nprobe = 16  # 16-64 should be high enough accuracy for 1-10M
        ret = cls(faiss_index=index)
        for signal_str, entry in entry_list:
            ret._dedupe_and_add(signal_str, entry, add_to_faiss=False)
        xb = convert_pdq_strings_to_ndarray(tuple(s for s in ret._deduper))
        index.train(xb)
        index.add(xb)
        return ret

Test everything

Add a robust set of unittests for this functionality

  1. Test 0 entries
  2. Test sample set entries
  3. Test > brute force entries
  4. Serialization and deserialization
  5. Duplicate hashes return right thing
  6. Test the conditions from [pytx] No match results if creating a local_file with only 1 hash in it #1318

Rollout plan

After we confirm that everything is working as expected, we'll swap out the index class that the PDQ signal type uses by default. I think we can get away without a major version bump for this.

@Dcallies Dcallies added pdq Items related to the pdq libraries or reference implementations python-threatexchange Items related to the threatexchange python tool / library labels Aug 22, 2024
@zackjh3
Copy link

zackjh3 commented Oct 3, 2024

I will start a fix on this issue at the Hackathon

@haianhng31
Copy link
Contributor

@Dcallies I plan to take this issue after I finish with the pdq rotation. Could you help me divide this issue into smaller sub problems for me to work on when you have time?

@Dcallies
Copy link
Contributor Author

Dcallies commented Nov 8, 2024

Yes - do you need them as issues or just to write them out?

Steps:

  1. Create a new file under pdq called index2.py which contains PDQSignalTypeIndex2, It can initially be unimplemented but passing CI
  2. Add content to the body of PDQSignalTypeIndex2, specifically the constructor. Add an implementation for serialize/deserialize and test it
  3. Implement the most basic version of search, and add tests, including a case that shows that it can handle 1 PDQ hash (the issue this fixes)
  4. Implement the rest of the required functionality, including build, add, etc. Add unittests for everything
  5. Add the selection logic for flat index vs optimized based on number of input hashes
  6. Test the compatibility of switching over the default PDQ index class - you just need to verify that with a previously serialized index, no exceptions are thrown from the CLI when reading the old version.

@haianhng31
Copy link
Contributor

@Dcallies do you want me to start swapping out the index class that the PDQ signal type uses by default?

@Dcallies
Copy link
Contributor Author

Dcallies commented Dec 3, 2024

Not quite yet, we're missing the optimized solution - faiss IVF. IVF faiss indices should be used when the number of hashes are above some number (e.g. 1,000 hashes), and the selection should be implemented in build() based on the initial input size.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
help wanted mlh Related to Major League Hacking Fellowship pdq Items related to the pdq libraries or reference implementations python-threatexchange Items related to the threatexchange python tool / library
Projects
Status: No status
Status: No status
Development

No branches or pull requests

4 participants