diff --git a/rtfs/_types/results.py b/rtfs/_types/results.py index c01ce7c..d29ccff 100644 --- a/rtfs/_types/results.py +++ b/rtfs/_types/results.py @@ -9,6 +9,6 @@ class NodeResponse(TypedDict): class Response(TypedDict): - results: dict[str, NodeResponse] + results: dict[str, NodeResponse] | None query_time: float commit_sha: str | None diff --git a/rtfs/app.py b/rtfs/app.py index 9b4bbfe..03e3fc6 100644 --- a/rtfs/app.py +++ b/rtfs/app.py @@ -53,24 +53,16 @@ def current_rtfs(state: State) -> Indexes: @get(path="/", dependencies={"rtfs": Provide(current_rtfs, sync_to_thread=False)}) -async def get_rtfs(query: dict[str, str], rtfs: Indexes) -> Response[Mapping[str, Any]]: - query_search = query.get("search") - library = query.get("library", "").lower() - format_ = query.get("format", "url") - - if not query_search or not library: +async def get_rtfs(search: str, library: str, direct: bool | None, rtfs: Indexes) -> Response[Mapping[str, Any]]: + if not search or not library: return Response(content={"available_libraries": rtfs.libraries}, media_type=MediaType.JSON, status_code=200) - elif not query_search: + elif not search: return Response({"error": "Missing `search` query parameter."}, media_type=MediaType.JSON, status_code=400) elif not library: return Response({"error": "Missing `library` query parameter."}, media_type=MediaType.JSON, status_code=400) - if format_ not in ("url", "source"): - return Response( - {"error": "The `format` parameter must be `url` or `source`."}, media_type=MediaType.JSON, status_code=400 - ) + result = rtfs.get_direct(library, search) if direct else rtfs.get_query(library, search) - result = rtfs.get_query(library, query_search) if result is None: return Response( content={ diff --git a/rtfs/fuzzy.py b/rtfs/fuzzy.py new file mode 100644 index 0000000..4c2a774 --- /dev/null +++ b/rtfs/fuzzy.py @@ -0,0 +1,345 @@ +""" +This Source Code Form is subject to the terms of the Mozilla Public +License, v. 2.0. If a copy of the MPL was not distributed with this +file, You can obtain one at http://mozilla.org/MPL/2.0/. +""" + +# help with: http://chairnerd.seatgeek.com/fuzzywuzzy-fuzzy-string-matching-in-python/ + +from __future__ import annotations + +import heapq +import re +from difflib import SequenceMatcher +from typing import TYPE_CHECKING, Literal, TypeVar, overload + +if TYPE_CHECKING: + from collections.abc import Callable, Generator, Iterable, Sequence + +T = TypeVar("T") + + +def ratio(a: str, b: str) -> int: + m = SequenceMatcher(None, a, b) + return int(round(100 * m.ratio())) + + +def quick_ratio(a: str, b: str) -> int: + m = SequenceMatcher(None, a, b) + return int(round(100 * m.quick_ratio())) + + +def partial_ratio(a: str, b: str) -> int: + short, long = (a, b) if len(a) <= len(b) else (b, a) + m = SequenceMatcher(None, short, long) + + blocks = m.get_matching_blocks() + + scores: list[float] = [] + for i, j, _ in blocks: + start = max(j - i, 0) + end = start + len(short) + o = SequenceMatcher(None, short, long[start:end]) + r = o.ratio() + + if 100 * r > 99: + return 100 + scores.append(r) + + return int(round(100 * max(scores))) + + +_word_regex = re.compile(r"\W", re.IGNORECASE) + + +def _sort_tokens(a: str) -> str: + a = _word_regex.sub(" ", a).lower().strip() + return " ".join(sorted(a.split())) + + +def token_sort_ratio(a: str, b: str) -> int: + a = _sort_tokens(a) + b = _sort_tokens(b) + return ratio(a, b) + + +def quick_token_sort_ratio(a: str, b: str) -> int: + a = _sort_tokens(a) + b = _sort_tokens(b) + return quick_ratio(a, b) + + +def partial_token_sort_ratio(a: str, b: str) -> int: + a = _sort_tokens(a) + b = _sort_tokens(b) + return partial_ratio(a, b) + + +@overload +def _extraction_generator( + query: str, + choices: Sequence[str], + scorer: Callable[[str, str], int] = ..., + score_cutoff: int = ..., +) -> Generator[tuple[str, int], None, None]: ... + + +@overload +def _extraction_generator( + query: str, + choices: dict[str, T], + scorer: Callable[[str, str], int] = ..., + score_cutoff: int = ..., +) -> Generator[tuple[str, int, T], None, None]: ... + + +def _extraction_generator( + query: str, + choices: Sequence[str] | dict[str, T], + scorer: Callable[[str, str], int] = quick_ratio, + score_cutoff: int = 0, +) -> Generator[tuple[str, int, T] | tuple[str, int], None, None]: + if isinstance(choices, dict): + for key, value in choices.items(): + score = scorer(query, key) + if score >= score_cutoff: + yield (key, score, value) + else: + for choice in choices: + score = scorer(query, choice) + if score >= score_cutoff: + yield (choice, score) + + +@overload +def extract( + query: str, + choices: Sequence[str], + *, + scorer: Callable[[str, str], int] = ..., + score_cutoff: int = ..., + limit: int | None = ..., +) -> list[tuple[str, int]]: ... + + +@overload +def extract( + query: str, + choices: dict[str, T], + *, + scorer: Callable[[str, str], int] = ..., + score_cutoff: int = ..., + limit: int | None = ..., +) -> list[tuple[str, int, T]]: ... + + +def extract( + query: str, + choices: dict[str, T] | Sequence[str], + *, + scorer: Callable[[str, str], int] = quick_ratio, + score_cutoff: int = 0, + limit: int | None = 10, +) -> list[tuple[str, int]] | list[tuple[str, int, T]]: + it = _extraction_generator(query, choices, scorer, score_cutoff) + key = lambda t: t[1] + if limit is not None: + return heapq.nlargest(limit, it, key=key) # type: ignore + return sorted(it, key=key, reverse=True) # type: ignore + + +@overload +def extract_one( + query: str, + choices: Sequence[str], + *, + scorer: Callable[[str, str], int] = ..., + score_cutoff: int = ..., +) -> tuple[str, int] | None: ... + + +@overload +def extract_one( + query: str, + choices: dict[str, T], + *, + scorer: Callable[[str, str], int] = ..., + score_cutoff: int = ..., +) -> tuple[str, int, T] | None: ... + + +def extract_one( + query: str, + choices: dict[str, T] | Sequence[str], + *, + scorer: Callable[[str, str], int] = quick_ratio, + score_cutoff: int = 0, +) -> tuple[str, int] | None | tuple[str, int, T] | None: + it = _extraction_generator(query, choices, scorer, score_cutoff) + key = lambda t: t[1] + try: + return max(it, key=key) + except: + # iterator could return nothing + return None + + +@overload +def extract_or_exact( + query: str, + choices: Sequence[str], + *, + scorer: Callable[[str, str], int] = ..., + score_cutoff: int = ..., + limit: int | None = ..., +) -> list[tuple[str, int]]: ... + + +@overload +def extract_or_exact( + query: str, + choices: dict[str, T], + *, + scorer: Callable[[str, str], int] = ..., + score_cutoff: int = ..., + limit: int | None = ..., +) -> list[tuple[str, int, T]]: ... + + +def extract_or_exact( + query: str, + choices: dict[str, T] | Sequence[str], + *, + scorer: Callable[[str, str], int] = quick_ratio, + score_cutoff: int = 0, + limit: int | None = None, +) -> list[tuple[str, int]] | list[tuple[str, int, T]]: + matches = extract(query, choices, scorer=scorer, score_cutoff=score_cutoff, limit=limit) + if len(matches) == 0: + return [] + + if len(matches) == 1: + return matches + + top = matches[0][1] + second = matches[1][1] + + # check if the top one is exact or more than 30% more correct than the top + if top == 100 or top > (second + 30): + return [matches[0]] # type: ignore + + return matches + + +@overload +def extract_matches( + query: str, + choices: Sequence[str], + *, + scorer: Callable[[str, str], int] = ..., + score_cutoff: int = ..., +) -> list[tuple[str, int]]: ... + + +@overload +def extract_matches( + query: str, + choices: dict[str, T], + *, + scorer: Callable[[str, str], int] = ..., + score_cutoff: int = ..., +) -> list[tuple[str, int, T]]: ... + + +def extract_matches( + query: str, + choices: dict[str, T] | Sequence[str], + *, + scorer: Callable[[str, str], int] = quick_ratio, + score_cutoff: int = 0, +) -> list[tuple[str, int]] | list[tuple[str, int, T]]: + matches = extract(query, choices, scorer=scorer, score_cutoff=score_cutoff, limit=None) + if len(matches) == 0: + return [] + + top_score = matches[0][1] + to_return = [] + index = 0 + while True: + try: + match = matches[index] + except IndexError: + break + else: + index += 1 + + if match[1] != top_score: + break + + to_return.append(match) + return to_return + + +@overload +def finder( + text: str, + collection: Iterable[T], + *, + key: Callable[[T], str] | None = ..., + raw: Literal[True], +) -> list[tuple[int, int, T]]: ... + + +@overload +def finder( + text: str, + collection: Iterable[T], + *, + key: Callable[[T], str] | None = ..., + raw: Literal[False], +) -> list[T]: ... + + +@overload +def finder( + text: str, + collection: Iterable[T], + *, + key: Callable[[T], str] | None = ..., + raw: bool = ..., +) -> list[T]: ... + + +def finder( + text: str, + collection: Iterable[T], + *, + key: Callable[[T], str] | None = None, + raw: bool = False, +) -> list[tuple[int, int, T]] | list[T]: + suggestions: list[tuple[int, int, T]] = [] + text = str(text) + pat = ".*?".join(map(re.escape, text)) + regex = re.compile(pat, flags=re.IGNORECASE) + for item in collection: + to_search = key(item) if key else str(item) + r = regex.search(to_search) + if r: + suggestions.append((len(r.group()), r.start(), item)) + + def sort_key(tup: tuple[int, int, T]) -> tuple[int, int, str | T]: + if key: + return tup[0], tup[1], key(tup[2]) + return tup + + if raw: + return sorted(suggestions, key=sort_key) + else: + return [z for _, _, z in sorted(suggestions, key=sort_key)] + + +def find(text: str, collection: Iterable[str], *, key: Callable[[str], str] | None = None) -> str | None: + try: + return finder(text, collection, key=key)[0] + except IndexError: + return None diff --git a/rtfs/index.py b/rtfs/index.py index b171960..a717799 100644 --- a/rtfs/index.py +++ b/rtfs/index.py @@ -2,7 +2,6 @@ import ast import configparser -import difflib import logging import os import pathlib @@ -11,6 +10,8 @@ from yarl import URL +from .fuzzy import extract + LOGGER = logging.getLogger(__name__) LOGGER.setLevel(logging.DEBUG) VERSION_REGEX = re.compile(r"__version__\s*=\s*(?:'|\")((?:\w|\.)*)(?:'|\")") @@ -272,4 +273,4 @@ def index_lib(self) -> None: ) def find_matches(self, word: str) -> list[Node]: - return [self.nodes[v] for v in difflib.get_close_matches(word, self.keys, cutoff=0.55)] + return [self.nodes[v[0]] for v in extract(word, self.keys, score_cutoff=20, limit=3)] diff --git a/rtfs/indexer.py b/rtfs/indexer.py index 205a3d3..d224b7d 100644 --- a/rtfs/indexer.py +++ b/rtfs/indexer.py @@ -52,6 +52,23 @@ def get_query(self, lib: str, query: str) -> Response | None: "commit_sha": self.index[lib].commit, } + def get_direct(self, lib: str, query: str) -> Response | None: + if not self._is_indexed: + raise RuntimeError("Indexing is not complete.") + + if lib not in self.index: + return + + start = time.monotonic() + result = self.index[lib].nodes.get(query) + end = time.monotonic() - start + + return { + "results": {result.name: {"source": result.source, "url": result.url}} if result else None, + "query_time": end, + "commit_sha": self.index[lib].commit, + } + def _do_pull(self, index: Index) -> bool: try: subprocess.run(["/bin/bash", "-c", f"cd {index.repo_path} && git pull"])