generated from ajparsons/python-poetry-auto-template
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
10 changed files
with
2,328 additions
and
12 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -21,7 +21,7 @@ on: | |
type: boolean | ||
default: false | ||
push: | ||
branches: [main] | ||
branches: [main-old] | ||
|
||
jobs: | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,12 @@ | ||
# mini-transcript-search | ||
|
||
Experiment in low dependency vector search | ||
This is an experiment in low dependency vector search. | ||
|
||
For just checking a few days - we don't need an index or a big database. We can just calculate cosine similarity directly. | ||
|
||
See infer.ipynb for usage as a module. | ||
|
||
```bash | ||
python -m mini_transcript_search "register of members financial interests" --threshold 0.4 --n 5 | ||
``` | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 3, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from mini_transcript_search import ModelHandler\n", | ||
"import datetime\n", | ||
"from pathlib import Path\n", | ||
"\n", | ||
"handler = ModelHandler(use_local_model=False, override_stored=True)\n", | ||
"\n", | ||
"search_query = \"register of members financial interests\"\n", | ||
"\n", | ||
"yesterday = datetime.date.today() - datetime.timedelta(days=1)\n", | ||
"\n", | ||
"# the last week starting yesterday\n", | ||
"last_week = ModelHandler.DateRange(start_date=yesterday, end_date=yesterday)\n", | ||
"results = handler.query(\n", | ||
" search_query,\n", | ||
" threshold=0.4,\n", | ||
" n=10,\n", | ||
" date_range=last_week,\n", | ||
" chamber=ModelHandler.Chamber.COMMONS,\n", | ||
" transcript_type=ModelHandler.TranscriptType.DEBATES,\n", | ||
")\n", | ||
"\n", | ||
"# dump csv\n", | ||
"results.df().to_csv(Path(\"last_week.csv\"), index=False)\n", | ||
"\n", | ||
"# dump json\n", | ||
"results.to_path(Path(\"last_week.json\"))" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"language_info": { | ||
"name": "python" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2 | ||
} |
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,9 @@ | ||
""" | ||
Experiment in low dependency vector search | ||
""" | ||
__version__ = "0.1.0" | ||
|
||
from .search import ModelHandler | ||
|
||
__version__ = "0.1.0" | ||
|
||
__all__ = ["ModelHandler", "__version__"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
import datetime | ||
from typing import Annotated, Optional, Union | ||
|
||
import typer | ||
from mysoc_validator.models.consts import Chamber, TranscriptType | ||
|
||
from .search import ModelHandler, default_model | ||
|
||
app = typer.Typer() | ||
yesterday = datetime.date.today() - datetime.timedelta(days=1) | ||
|
||
handler = ModelHandler(use_local_model=False) | ||
|
||
yesterday = datetime.date.today() - datetime.timedelta(days=1) | ||
last_week = yesterday - datetime.timedelta(days=7) | ||
|
||
|
||
def parse_date(date: Union[datetime.date, str]) -> datetime.date: | ||
if isinstance(date, datetime.date): | ||
return date | ||
return datetime.date.fromisoformat(date) | ||
|
||
|
||
DateField = Annotated[datetime.date, typer.Argument(parser=parse_date)] | ||
|
||
|
||
@app.command() | ||
def search( | ||
query: str, | ||
threshold: float = 0.2, | ||
n: Optional[int] = None, | ||
start_date: DateField = yesterday, | ||
end_date: DateField = yesterday, | ||
chamber: Chamber = Chamber.COMMONS, | ||
transcript_type: TranscriptType = TranscriptType.DEBATES, | ||
model_id: str = default_model, | ||
use_local_model: bool = True, | ||
override_stored: bool = False, | ||
): | ||
handler = ModelHandler( | ||
model_id=model_id, | ||
use_local_model=use_local_model, | ||
override_stored=override_stored, | ||
) | ||
|
||
# if threshold is a float greater than one, convert to int | ||
if threshold > 1: | ||
threshold = int(threshold) | ||
|
||
date_range = ModelHandler.DateRange(start_date=start_date, end_date=end_date) | ||
results = handler.query( | ||
query=query, | ||
threshold=threshold, | ||
n=n, | ||
date_range=date_range, | ||
chamber=chamber, | ||
transcript_type=transcript_type, | ||
) | ||
|
||
typer.echo(results.json()) | ||
|
||
|
||
if __name__ == "__main__": | ||
app() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
import os | ||
from contextlib import redirect_stderr | ||
from typing import Optional | ||
|
||
import numpy as np | ||
import pandas as pd | ||
import requests | ||
from fastembed import TextEmbedding # type: ignore | ||
from numpy.typing import NDArray | ||
|
||
|
||
class Inference: | ||
""" | ||
Class to handle inference for text embeddings. | ||
Uses hugging faces free api if not local, but setting local | ||
will use fastembed's approach | ||
""" | ||
|
||
def __init__( | ||
self, model_id: str, hf_token: Optional[str] = None, local: bool = False | ||
): | ||
self.model_id: str = model_id | ||
self.hf_token = hf_token if hf_token else os.environ.get("HF_TOKEN", None) | ||
self.local = local | ||
self._model = None | ||
if self.hf_token is None and self.local is False: | ||
raise ValueError("Need to set hf_token for remote embedding generation.") | ||
|
||
def query_local(self, texts: list[str]) -> list[NDArray[np.float64]]: | ||
if self._model is None: | ||
# suppress download progerss bar | ||
with open(os.devnull, "w") as fnull: | ||
with redirect_stderr(fnull): | ||
self._model = TextEmbedding(model_name=self.model_id) | ||
return list(self._model.embed(texts)) # type: ignore | ||
|
||
def query_remote(self, texts: list[str]) -> list[NDArray[np.float64]]: | ||
api_url = f"https://api-inference.huggingface.co/pipeline/feature-extraction/{self.model_id}" | ||
headers = {"Authorization": f"Bearer {self.hf_token}"} | ||
response = requests.post( | ||
api_url, | ||
headers=headers, | ||
json={"inputs": texts, "options": {"wait_for_model": True}}, | ||
) | ||
return response.json() | ||
|
||
def query_id_and_text(self, id_and_text: dict[str, str]) -> pd.DataFrame: | ||
id_values = list(id_and_text.keys()) | ||
text_values = list(id_and_text.values()) | ||
embeddings = self.query(text_values) | ||
if isinstance(embeddings, dict): | ||
raise ValueError(f"Error in embeddings: {embeddings}") | ||
return pd.DataFrame( | ||
{ | ||
"id": id_values, | ||
"text": text_values, | ||
"embedding": embeddings, | ||
} | ||
) | ||
|
||
def query(self, texts: list[str]): | ||
if len(texts) == 0: | ||
return [] | ||
if self.local: | ||
return self.query_local(texts) | ||
else: | ||
return self.query_remote(texts) |