Skip to content

Commit

Permalink
Initial commit
Browse files Browse the repository at this point in the history
  • Loading branch information
ajparsons committed Sep 11, 2024
1 parent 45257ea commit a09a179
Show file tree
Hide file tree
Showing 10 changed files with 2,328 additions and 12 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/auto_publish.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ on:
type: boolean
default: false
push:
branches: [main]
branches: [main-old]

jobs:

Expand Down
6 changes: 3 additions & 3 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@ name: Run tests

on:
push:
branches-ignore: [ main ]
branches-ignore: [ main-old ]
pull_request:
branches-ignore: [ main ]
branches-ignore: [ main-old ]
workflow_call:
workflow_dispatch:

Expand All @@ -14,7 +14,7 @@ jobs:
strategy:
fail-fast: false
matrix:
python-version: ["3.8", "3.9", "3.10"]
python-version: ["3.9", "3.10"]
poetry-version: ["1.8"]

runs-on: ubuntu-latest
Expand Down
6 changes: 2 additions & 4 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
{
"python.linting.pylintEnabled": true,
"python.defaultInterpreterPath": "/usr/local/bin/python",
"python.terminal.activateEnvironment": false,
"python.formatting.provider": "black",
"python.analysis.typeCheckingMode": "strict",
"python.analysis.typeCheckingMode": "basic",
"editor.formatOnSave": true,
"editor.codeActionsOnSave": {
"source.organizeImports": true
"source.organizeImports": "explicit"
},
"git.pullTags": false,
"ltex.language": "en-GB",
Expand Down
11 changes: 10 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,12 @@
# mini-transcript-search

Experiment in low dependency vector search
This is an experiment in low dependency vector search.

For just checking a few days - we don't need an index or a big database. We can just calculate cosine similarity directly.

See infer.ipynb for usage as a module.

```bash
python -m mini_transcript_search "register of members financial interests" --threshold 0.4 --n 5
```

45 changes: 45 additions & 0 deletions notebooks/infer.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"from mini_transcript_search import ModelHandler\n",
"import datetime\n",
"from pathlib import Path\n",
"\n",
"handler = ModelHandler(use_local_model=False, override_stored=True)\n",
"\n",
"search_query = \"register of members financial interests\"\n",
"\n",
"yesterday = datetime.date.today() - datetime.timedelta(days=1)\n",
"\n",
"# the last week starting yesterday\n",
"last_week = ModelHandler.DateRange(start_date=yesterday, end_date=yesterday)\n",
"results = handler.query(\n",
" search_query,\n",
" threshold=0.4,\n",
" n=10,\n",
" date_range=last_week,\n",
" chamber=ModelHandler.Chamber.COMMONS,\n",
" transcript_type=ModelHandler.TranscriptType.DEBATES,\n",
")\n",
"\n",
"# dump csv\n",
"results.df().to_csv(Path(\"last_week.csv\"), index=False)\n",
"\n",
"# dump json\n",
"results.to_path(Path(\"last_week.json\"))"
]
}
],
"metadata": {
"language_info": {
"name": "python"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
2,119 changes: 2,119 additions & 0 deletions poetry.lock

Large diffs are not rendered by default.

13 changes: 11 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,25 @@ include = [

[tool.poetry_bumpversion.file."src/mini_transcript_search/__init__.py"]

[tool.poetry.scripts]
transcript-search = "mini_transcript_search.__main__:app"

[tool.poetry.dependencies]
python = "^3.8"
python = "^3.9,<3.13"
mysoc-validator = "^0.3.0"
fastembed = "^0.3.6"
numpy = "1.26.4"
pandas = "^2.2.2"
requests = "^2.32.3"
pyarrow = "^17.0.0"

[tool.poetry.group.dev.dependencies]
pytest = "^7.1.2"
pytest-cov = "^3.0.0"
pylint = "^2.12.2"
ruff = "^0.4.4"
pyright = "^1.1"
toml = "^0.10.2"
ruff = "^0.6.4"

[build-system]
requires = ["poetry-core>=1.0.0"]
Expand Down
7 changes: 6 additions & 1 deletion src/mini_transcript_search/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
"""
Experiment in low dependency vector search
"""
__version__ = "0.1.0"

from .search import ModelHandler

__version__ = "0.1.0"

__all__ = ["ModelHandler", "__version__"]
64 changes: 64 additions & 0 deletions src/mini_transcript_search/__main__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import datetime
from typing import Annotated, Optional, Union

import typer
from mysoc_validator.models.consts import Chamber, TranscriptType

from .search import ModelHandler, default_model

app = typer.Typer()
yesterday = datetime.date.today() - datetime.timedelta(days=1)

handler = ModelHandler(use_local_model=False)

yesterday = datetime.date.today() - datetime.timedelta(days=1)
last_week = yesterday - datetime.timedelta(days=7)


def parse_date(date: Union[datetime.date, str]) -> datetime.date:
if isinstance(date, datetime.date):
return date
return datetime.date.fromisoformat(date)


DateField = Annotated[datetime.date, typer.Argument(parser=parse_date)]


@app.command()
def search(
query: str,
threshold: float = 0.2,
n: Optional[int] = None,
start_date: DateField = yesterday,
end_date: DateField = yesterday,
chamber: Chamber = Chamber.COMMONS,
transcript_type: TranscriptType = TranscriptType.DEBATES,
model_id: str = default_model,
use_local_model: bool = True,
override_stored: bool = False,
):
handler = ModelHandler(
model_id=model_id,
use_local_model=use_local_model,
override_stored=override_stored,
)

# if threshold is a float greater than one, convert to int
if threshold > 1:
threshold = int(threshold)

date_range = ModelHandler.DateRange(start_date=start_date, end_date=end_date)
results = handler.query(
query=query,
threshold=threshold,
n=n,
date_range=date_range,
chamber=chamber,
transcript_type=transcript_type,
)

typer.echo(results.json())


if __name__ == "__main__":
app()
67 changes: 67 additions & 0 deletions src/mini_transcript_search/inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import os
from contextlib import redirect_stderr
from typing import Optional

import numpy as np
import pandas as pd
import requests
from fastembed import TextEmbedding # type: ignore
from numpy.typing import NDArray


class Inference:
"""
Class to handle inference for text embeddings.
Uses hugging faces free api if not local, but setting local
will use fastembed's approach
"""

def __init__(
self, model_id: str, hf_token: Optional[str] = None, local: bool = False
):
self.model_id: str = model_id
self.hf_token = hf_token if hf_token else os.environ.get("HF_TOKEN", None)
self.local = local
self._model = None
if self.hf_token is None and self.local is False:
raise ValueError("Need to set hf_token for remote embedding generation.")

def query_local(self, texts: list[str]) -> list[NDArray[np.float64]]:
if self._model is None:
# suppress download progerss bar
with open(os.devnull, "w") as fnull:
with redirect_stderr(fnull):
self._model = TextEmbedding(model_name=self.model_id)
return list(self._model.embed(texts)) # type: ignore

def query_remote(self, texts: list[str]) -> list[NDArray[np.float64]]:
api_url = f"https://api-inference.huggingface.co/pipeline/feature-extraction/{self.model_id}"
headers = {"Authorization": f"Bearer {self.hf_token}"}
response = requests.post(
api_url,
headers=headers,
json={"inputs": texts, "options": {"wait_for_model": True}},
)
return response.json()

def query_id_and_text(self, id_and_text: dict[str, str]) -> pd.DataFrame:
id_values = list(id_and_text.keys())
text_values = list(id_and_text.values())
embeddings = self.query(text_values)
if isinstance(embeddings, dict):
raise ValueError(f"Error in embeddings: {embeddings}")
return pd.DataFrame(
{
"id": id_values,
"text": text_values,
"embedding": embeddings,
}
)

def query(self, texts: list[str]):
if len(texts) == 0:
return []
if self.local:
return self.query_local(texts)
else:
return self.query_remote(texts)

0 comments on commit a09a179

Please sign in to comment.