Skip to content

Commit

Permalink
initial implementation of attn based predicate extraction
Browse files Browse the repository at this point in the history
  • Loading branch information
ngupta10 committed May 23, 2024
1 parent f33e70c commit 6f6a148
Show file tree
Hide file tree
Showing 12 changed files with 572 additions and 80 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from querent.core.transformers.bert_ner_opensourcellm import BERTLLM
from querent.common.types.ingested_images import IngestedImages
from querent.kg.ner_helperfunctions.ner_llm_transformer import NER_LLM
from querent.kg.rel_helperfunctions.openai_functions import FunctionRegistry
from querent.common.types.querent_event import EventState, EventType
from querent.core.base_engine import BaseEngine
from querent.common.types.ingested_tokens import IngestedTokens
Expand Down
2 changes: 0 additions & 2 deletions querent/core/transformers/gpt_llm_gpt_ner.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import spacy
from querent.config.core.gpt_llm_config import GPTConfig
from querent.common.types.ingested_images import IngestedImages
from querent.kg.rel_helperfunctions.openai_functions import FunctionRegistry
from querent.common.types.querent_event import EventState, EventType
from querent.core.base_engine import BaseEngine
from querent.common.types.ingested_tokens import IngestedTokens
Expand Down Expand Up @@ -46,7 +45,6 @@ def __init__(
self.gpt_llm = OpenAI(api_key=config.openai_api_key)
else:
self.gpt_llm = OpenAI()
self.function_registry = FunctionRegistry()
self.create_emb = EmbeddingStore()
self.user_context = config.user_context

Expand Down
38 changes: 2 additions & 36 deletions querent/kg/ner_helperfunctions/dependency_parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,13 @@
List of noun chunks identified in the sentence.
filtered_chunks : list
Filtered noun chunks based on certain criteria.
merged_entities : list
noun_chunks : list
Entities merged based on overlapping criteria.
Methods:
--------
load_spacy_model():
Loads the specified SpaCy model.
filter_chunks():
Filters the noun chunks based on length, stop words, and POS tagging.
merge_overlapping_entities():
Merges entities that overlap with each other.
compare_entities_with_chunks():
Compares the entities with the noun chunks and updates the entity details.
process_entities():
Expand All @@ -43,42 +39,12 @@ def __init__(self, entities=None, sentence=None, model=None):
self.nlp = model
self.doc = self.nlp(self.sentence)
self.noun_chunks = list(self.doc.noun_chunks)
self.filtered_chunks = self.filter_chunks()
self.merged_entities = self.merge_overlapping_entities()
self.noun_chunks = list(self.doc.noun_chunks)
self.compare_entities_with_chunks()
self.entities = self.process_entities()
except Exception as e:
raise Exception(f"Error Initializing Dependency Parsing Class: {e}")

def filter_chunks(self):
try:
filtered_chunks = []
relevant_pos_tags = {"NOUN", "PROPN", "ADJ"}
for chunk in self.noun_chunks:
# Filtering logic
if len(chunk) > 1 and not chunk.root.is_stop and chunk.root.pos_ in relevant_pos_tags:
filtered_chunks.append(chunk)
return filtered_chunks

except Exception as e:
raise Exception(f"Error filtering chunks: {e}")


def merge_overlapping_entities(self):
try:
merged_entities = []
i = 0
while i < len(self.filtered_chunks):
entity = self.filtered_chunks[i]
while i + 1 < len(self.filtered_chunks) and entity.end >= self.filtered_chunks[i + 1].start:
entity = self.doc[entity.start:self.filtered_chunks[i + 1].end]
i += 1
merged_entities.append(entity)
i += 1
return merged_entities
except Exception as e:
raise Exception(f"Error merging overlapping entities: {e}")

def compare_entities_with_chunks(self):
try:
for entity in self.entities:
Expand Down
39 changes: 32 additions & 7 deletions querent/kg/ner_helperfunctions/ner_llm_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ def __init__(
provided_model=None,
):
self.logger = setup_logger(__name__, "NER_LLM")
self.device = "cpu"
if provided_tokenizer:
self.ner_tokenizer = provided_tokenizer
else:
Expand All @@ -68,6 +69,7 @@ def __init__(
self.ner_model = provided_model
else:
self.ner_model = NER_LLM.load_model(ner_model_name, "NER")
self.ner_model.eval()
self.filler_tokens = filler_tokens or ["of", "a", "the", "in", "on", "at", "and", "or", "with","(",")","-"]


Expand Down Expand Up @@ -124,9 +126,10 @@ def _tokenize_and_chunk(self, data: str) -> List[Tuple[List[str], str, int]]:
raise Exception(f"An error occurred while tokenizing: {e}")
return tokenized_sentences

def _token_distance(self, tokens, start_idx1, start_idx2, noun_chunk1, noun_chunk2):
def _token_distance(self, tokens, start_idx1, nn_chunk_length_idx1, start_idx2, noun_chunk1, noun_chunk2):
distance = 0
for idx in range(start_idx1 + 1, start_idx2):
print("Tokens-----", tokens)
for idx in range(start_idx1 + nn_chunk_length_idx1, start_idx2):
token = tokens[idx]
if (token not in self.filler_tokens and
token not in noun_chunk1 and
Expand Down Expand Up @@ -179,9 +182,10 @@ def extract_entities_from_chunk(self, chunk: List[str]):
results = []
try:
input_ids = self.ner_tokenizer.convert_tokens_to_ids(chunk)
input_tensor = torch.tensor([input_ids])
input_tensor = torch.tensor([input_ids], device=self.device)
attention_mask = torch.ones(input_tensor.shape, device=self.device)
with torch.no_grad():
outputs = self.ner_model(input_tensor)
outputs = self.ner_model(input_tensor, attention_mask=attention_mask)
predictions = torch.argmax(outputs[0], dim=2)
scores = torch.nn.functional.softmax(outputs[0], dim=2)
label_ids = predictions[0].tolist()
Expand Down Expand Up @@ -256,8 +260,9 @@ def extract_binary_pairs(self, entities: List[dict], tokens: List[str], all_sent
for j in range(i + 1, len(entities)):
if entities[i]["start_idx"] + 1 == entities[j]["start_idx"]:
continue
distance = self._token_distance(tokens, entities[i]["start_idx"], entities[j]["start_idx"],entities[i]["noun_chunk"], entities[j]["noun_chunk"])
if distance <= 30:
distance = self._token_distance(tokens, entities[i]["start_idx"], entities[i]["noun_chunk_length"],entities[j]["start_idx"],entities[i]["noun_chunk"], entities[j]["noun_chunk"])
print("Distance---------", distance)
if distance <= 10:
pair = (entities[i], entities[j])
if pair not in binary_pairs:
metadata = {
Expand Down Expand Up @@ -332,7 +337,24 @@ def filter_matching_entities(self, tuples_nested_list, entities_nested_list):

return matched_tuples


def find_subword_indices(self, text, entity):
subwords = self.ner_tokenizer.tokenize(entity)
subword_ids = self.ner_tokenizer.convert_tokens_to_ids(subwords)
token_ids = self.ner_tokenizer.convert_tokens_to_ids(self.ner_tokenizer.tokenize(text))
print("Length of token idsssss", len(token_ids))
print("Length of Subword IDs---",len(subword_ids), subword_ids)

subword_positions = []
for i in range(len(token_ids) - len(subword_ids) + 1):
if token_ids[i:i + len(subword_ids)] == subword_ids:
subword_positions.append((i+1, i + len(subword_ids)))
return subword_positions

def tokenize_sentence_with_positions(self, sentence: str):
tokens = self.ner_tokenizer.tokenize(sentence)
token_positions = [(token, idx +1 ) for idx, token in enumerate(tokens)]

return token_positions


def extract_entities_from_sentence(self, sentence: str, sentence_idx: int, all_sentences: List[str], fixed_entities_flag: bool, fixed_entities: List[str],entity_types: List[str]):
Expand All @@ -355,7 +377,10 @@ def extract_entities_from_sentence(self, sentence: str, sentence_idx: int, all_s
entity['noun_chunk'] = entity['entity']
entity['noun_chunk_length'] = len(entity['noun_chunk'].split())
entities_withnnchunk = final_entities
print("Entitiessssss------", entities_withnnchunk)
binary_pairs = self.extract_binary_pairs(entities_withnnchunk, tokens, all_sentences, sentence_idx)
print("Binary Pairs------", binary_pairs)

return entities_withnnchunk, binary_pairs
except Exception as e:
self.logger.error(f"Error extracting entities from sentence: {e}")
Expand Down
Loading

0 comments on commit 6f6a148

Please sign in to comment.