Skip to content

Commit

Permalink
Merge pull request #99 from tjmlabs/update-model-command
Browse files Browse the repository at this point in the history
updating to make pooled embeddings
  • Loading branch information
Jonathan-Adly authored Nov 24, 2024
2 parents 665e86b + aaa1691 commit b527301
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 1 deletion.
1 change: 1 addition & 0 deletions web/.coveragerc
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
[run]
omit =
*/migrations/*
*/management/commands/*
[html]
directory = api/tests/htmlcov

Expand Down
44 changes: 44 additions & 0 deletions web/api/management/commands/update_embeddings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
from time import sleep
from typing import Any, Dict, List

import requests
from api.models import Page, PageEmbedding
from django.conf import settings
from django.core.management.base import BaseCommand
from tenacity import retry, stop_after_attempt, wait_fixed

# Constants
EMBEDDINGS_URL = settings.EMBEDDINGS_URL
DELAY_BETWEEN_BATCHES = 1 # seconds


class Command(BaseCommand):
help = "Update embeddings for all documents, we run this whenever we upgrade the base model"

def handle(self, *args: Any, **options: Any) -> None:
pages = Page.objects.all()

for page in pages:
image: List[str] = [page.img_base64]
embeddings_obj: List[Dict[str, Any]] = send_batch(image)
embeddings: List[float] = embeddings_obj[0]["embedding"]
page.embeddings.all().delete()
bulk_create_embeddings = [
PageEmbedding(page=page, embedding=embedding)
for embedding in embeddings
]
PageEmbedding.objects.bulk_create(bulk_create_embeddings)
self.stdout.write(self.style.SUCCESS(f"Updated embedding for {page.id}"))
sleep(DELAY_BETWEEN_BATCHES)


@retry(stop=stop_after_attempt(3), wait=wait_fixed(5))
def send_batch(images: List[str]) -> List[Dict[str, Any]]:
payload: Dict[str, Any] = {"input": {"task": "image", "input_data": images}}
headers: Dict[str, str] = {
"Authorization": f"Bearer {settings.EMBEDDINGS_URL_TOKEN}"
}
response = requests.post(settings.EMBEDDINGS_URL, json=payload, headers=headers)
response.raise_for_status()
data: List[Dict[str, Any]] = response.json()["output"]["data"]
return data
2 changes: 1 addition & 1 deletion web/api/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,7 @@ async def send_batch(
# each page = 1 embedding, an
# exampple all_embeddings = [
# {
# "embedding": [[0.1, 0.2, ..., 0.128], [0.1, 0.2, ...]], # List of 1030 members, each a list of 128 floats
# "embedding": [[0.1, 0.2, ..., 0.128], [0.1, 0.2, ...]], # List of n members, each a list of 128 floats
# "index": 0,
# "object": "embedding"
# },
Expand Down
2 changes: 2 additions & 0 deletions web/mypy.ini
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
[mypy]
python_version = 3.12
ignore_missing_imports = True
warn_unused_ignores = True
disable_error_code = import-untyped
plugins = mypy_django_plugin.main

[mypy.plugins.django-stubs]
Expand Down

0 comments on commit b527301

Please sign in to comment.