This repository has been archived by the owner on Jan 3, 2023. It is now read-only.
-
-
Notifications
You must be signed in to change notification settings - Fork 0
/
manage.py
71 lines (54 loc) · 2.25 KB
/
manage.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
if __name__ == "__main__":
import pathlib
import click
@click.group()
def cli():
pass
@click.command()
@click.argument("output", type=pathlib.Path)
@click.option("--tree-count", type=int, default=100)
def generate_index(output: pathlib.Path, tree_count: int):
"""Create a new version of the index using all embeddings stored in
the EmbeddingStore.
:param output: Path of the output index
:param tree_count: Number of trees to use when building the Annoy index
"""
import shutil
import tempfile
import tqdm
from annoy import AnnoyIndex
import settings
from embeddings import EmbeddingStore
from utils import get_logger
logger = get_logger()
with tempfile.TemporaryDirectory() as tmp_dir:
embedding_path = pathlib.Path(tmp_dir) / "embeddings.hdf5"
logger.info(f"Copying embedding file to {embedding_path}...")
# Copy embedding files to a temporary location to avoid modification during
# index generation
shutil.copy(str(settings.EMBEDDINGS_HDF5_PATH), str(embedding_path))
logger.info(f"Loading {embedding_path}...")
embedding_store = EmbeddingStore(embedding_path)
index = None
offset: int = 0
keys = []
logger.info("Adding embeddings to index...")
for logo_id, embedding in tqdm.tqdm(embedding_store.iter_embeddings()):
if index is None:
output_dim = embedding.shape[-1]
index = AnnoyIndex(output_dim, "euclidean")
index.add_item(offset, embedding)
keys.append(int(logo_id))
offset += 1
logger.info("Building index...")
if index is not None:
index.build(tree_count)
index.save(str(output))
logger.info("Index built.")
logger.info("Saving keys...")
with output.with_suffix(".txt").open("w") as f:
for key in keys:
f.write(str(key) + "\n")
logger.info("Keys saved.")
cli.add_command(generate_index)
cli()