Skip to content

Commit

Permalink
Merge pull request #100 from tjmlabs/batches-for-embeddings
Browse files Browse the repository at this point in the history
Fixing update_embedding commands where memory can't fit all pages
  • Loading branch information
Jonathan-Adly authored Nov 24, 2024
2 parents b527301 + d09d321 commit 10bb4fe
Showing 1 changed file with 19 additions and 12 deletions.
31 changes: 19 additions & 12 deletions web/api/management/commands/update_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from api.models import Page, PageEmbedding
from django.conf import settings
from django.core.management.base import BaseCommand
from django.db import transaction
from tenacity import retry, stop_after_attempt, wait_fixed

# Constants
Expand All @@ -16,20 +17,26 @@ class Command(BaseCommand):
help = "Update embeddings for all documents, we run this whenever we upgrade the base model"

def handle(self, *args: Any, **options: Any) -> None:
batch_size = 100
pages = Page.objects.all()

for page in pages:
image: List[str] = [page.img_base64]
embeddings_obj: List[Dict[str, Any]] = send_batch(image)
embeddings: List[float] = embeddings_obj[0]["embedding"]
page.embeddings.all().delete()
bulk_create_embeddings = [
PageEmbedding(page=page, embedding=embedding)
for embedding in embeddings
]
PageEmbedding.objects.bulk_create(bulk_create_embeddings)
self.stdout.write(self.style.SUCCESS(f"Updated embedding for {page.id}"))
sleep(DELAY_BETWEEN_BATCHES)
for i in range(0, pages.count(), batch_size):
batch = pages[i : i + batch_size]
for page in batch:
image: List[str] = [page.img_base64]
embeddings_obj: List[Dict[str, Any]] = send_batch(image)
embeddings: List[float] = embeddings_obj[0]["embedding"]
with transaction.atomic():
page.embeddings.all().delete()
bulk_create_embeddings = [
PageEmbedding(page=page, embedding=embedding)
for embedding in embeddings
]
PageEmbedding.objects.bulk_create(bulk_create_embeddings)
self.stdout.write(
self.style.SUCCESS(f"Updated embedding for {page.id}")
)
sleep(DELAY_BETWEEN_BATCHES)


@retry(stop=stop_after_attempt(3), wait=wait_fixed(5))
Expand Down

0 comments on commit 10bb4fe

Please sign in to comment.