Skip to content

Commit

Permalink
Indexed ProofBank
Browse files Browse the repository at this point in the history
  • Loading branch information
XanderVertegaal committed Aug 26, 2024
1 parent 515f9f1 commit 83a417d
Show file tree
Hide file tree
Showing 4 changed files with 97 additions and 10 deletions.
81 changes: 78 additions & 3 deletions backend/aethel_db/models.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,87 @@
from typing import Optional
from collections import defaultdict
from dataclasses import dataclass
from parseport.logger import logger

from django.conf import settings

from aethel import ProofBank
from aethel.frontend import Sample

dataset: Optional[ProofBank] = None

@dataclass(frozen=True)
class IndexedProofBank(ProofBank):
type_index: dict[str, list[int]]
word_index: dict[str, list[int]]

def __post_init__(self):
logger.info("Dataset loaded.")

def by_type(self, type: str) -> list[Sample]:
return [self.samples[i] for i in self.type_index.get(type, [])]

def by_word(self, word: str) -> list[Sample]:
return [self.samples[i] for i in self.word_index.get(word, [])]

def by_words(self, words: list[str]) -> list[Sample]:
candidates = set(self.word_index.get(words[0], []))

for word in words[1:]:
candidates = candidates.intersection(set(self.word_index.get(word, [])))

return [self.samples[i] for i in candidates]


dataset: IndexedProofBank | None = None


def load_dataset():
global dataset
dataset = ProofBank.load_data(settings.DATASET_PATH)
proofbank = ProofBank.load_data(settings.DATASET_PATH)
type_index = defaultdict(list)
word_index = defaultdict(list)

if settings.DEBUG:
total_length = len(proofbank.samples)
n = 1
print("Indexing dataset...")

for sample_index, sample in enumerate(proofbank.samples):
for phrase in sample.lexical_phrases:
type_index[str(phrase.type)].append(sample_index)
for item in phrase.items:
word_index[item.word].append(sample_index)

if settings.DEBUG:
progress(n, total_length)
n += 1

dataset = IndexedProofBank(
samples=proofbank.samples,
type_index=dict(type_index),
word_index=dict(word_index),
version=proofbank.version,
)


def progress(iteration, total, width=80, start="\r", newline_on_complete=True):
"""
Prints a progress bar to the console in the form of: |█████████-----| 5/10.
Only if we're in DEBUG mode.
Parameters:
- iteration (int): The current iteration.
- total (int): The total number of iterations.
- width (int, optional): The width of the progress bar. Defaults to 80.
- start (str, optional): The character(s) to display at the start of the progress bar. Defaults to "\r".
- newline_on_complete (bool, optional): Whether to print a new line when the iteration is complete. Defaults to True.
"""
width = width - 2
tally = f" {iteration}/{total}"
width -= len(tally)
filled_length = int(width * iteration // total)
bar = "█" * filled_length + "-" * (width - filled_length)
print(f"{start}|{bar}|{tally}", end="")
# Print New Line on Complete
if newline_on_complete and iteration == total:
print()
6 changes: 3 additions & 3 deletions backend/aethel_db/search.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from __future__ import annotations
from aethel.frontend import LexicalPhrase, LexicalItem
from aethel.mill.types import type_repr
from aethel.mill.types import Type


def match_type_with_phrase(phrase: LexicalPhrase, type_input: str) -> bool:
return type_input == type_repr(phrase.type)
def match_type_with_phrase(phrase: LexicalPhrase, type_input: Type) -> bool:
return type_input == phrase.type


def match_word_with_phrase(phrase: LexicalPhrase, word_input: str) -> bool:
Expand Down
6 changes: 4 additions & 2 deletions backend/aethel_db/views/list.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from aethel_db.models import dataset

from aethel.frontend import LexicalPhrase
from aethel.mill.types import type_prefix


@dataclass
Expand Down Expand Up @@ -91,14 +92,15 @@ def get(self, request: HttpRequest) -> JsonResponse:
response_object = AethelListResponse()

for sample in dataset.samples:
for phrase_index, phrase in enumerate(sample.lexical_phrases):
for phrase in sample.lexical_phrases:
word_match = word_input and match_word_with_phrase(phrase, word_input)
type_match = type_input and match_type_with_phrase(phrase, type_input)
if not (word_match or type_match):
continue

result = response_object.get_or_create_result(
phrase=phrase, type=str(phrase.type)
# type_prefix returns a string representation of the type, with spaces between the elements.
phrase=phrase, type=type_prefix(phrase.type)
)

result._sample_names.add(sample.name)
Expand Down
14 changes: 12 additions & 2 deletions backend/aethel_db/views/sample_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from rest_framework.views import APIView
from rest_framework import status

from aethel.frontend import Sample
from aethel.frontend import Sample, Type

from aethel_db.models import dataset
from aethel_db.search import (
Expand Down Expand Up @@ -77,7 +77,17 @@ def get(self, request: HttpRequest) -> JsonResponse:

word_input = json.loads(word_input)

for sample in dataset.samples:
assert dataset is not None
# parse_prefix expects a type string with spaces.
type_input = Type.parse_prefix(type_input, debug=True)
by_type = dataset.by_type(str(type_input)) # re-serialize type to match index
by_word = dataset.by_words(word_input)
by_name = {sample.name: sample for sample in by_type + by_word}
# we have to do the intersection by name because Samples are not hashable
intersection = set(s.name for s in by_type).intersection(set(s.name for s in by_word))
samples = [by_name[name] for name in intersection]

for sample in samples:
for phrase_index, phrase in enumerate(sample.lexical_phrases):
word_match = word_input and match_word_with_phrase_exact(
phrase, word_input
Expand Down

0 comments on commit 83a417d

Please sign in to comment.