Skip to content

Commit

Permalink
fixes to attn mechanism
Browse files Browse the repository at this point in the history
  • Loading branch information
ngupta10 committed Jun 5, 2024
1 parent 6f6a148 commit 3357142
Show file tree
Hide file tree
Showing 9 changed files with 384 additions and 213 deletions.
55 changes: 43 additions & 12 deletions querent/core/transformers/bert_ner_opensourcellm.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
import json
import re
from transformers import AutoTokenizer
from transformers import AutoConfig, AutoTokenizer
import transformers
import time

import unidecode
from querent.common.types.ingested_table import IngestedTables
from querent.kg.ner_helperfunctions.fixed_predicate import FixedPredicateExtractor
from querent.common.types.ingested_images import IngestedImages
Expand All @@ -30,6 +28,8 @@
from querent.kg.rel_helperfunctions.embedding_store import EmbeddingStore
from querent.models.model_manager import ModelManager
from querent.models.gguf_metadata_extractor import GGUFMetadataExtractor
from querent.kg.rel_helperfunctions.attn_based_relationship_model_getter import get_model
from querent.kg.rel_helperfunctions.attn_based_relationship_filter import process_tokens, trim_triples

class BERTLLM(BaseEngine):
def __init__(
Expand All @@ -49,6 +49,7 @@ def __init__(
self.sample_relationships = config.sample_relationships
self.user_context = config.user_context
self.isConfinedSearch = config.is_confined_search
self.attn_based_rel_extraction = True
self.create_emb = EmbeddingStore() if not Embedding else Embedding

try:
Expand Down Expand Up @@ -76,20 +77,19 @@ def _initialize_components(self, config):

def _initialize_models(self, config):
self.ner_model_initialized = self.model_manager.get_model(config.ner_model_name)
if not self.skip_inferences:
if not self.skip_inferences and self.attn_based_rel_extraction == False:
extractor = GGUFMetadataExtractor(config.rel_model_path)
model_metadata = extractor.dump_metadata()
rel_model_name = extractor.extract_general_name(model_metadata)
self.rel_model_initialized = self.model_manager.get_model(rel_model_name, model_path=config.rel_model_path)

self.ner_llm_instance = NER_LLM(ner_model_name=self.ner_model_initialized)
self.ner_tokenizer = self.ner_llm_instance.ner_tokenizer
self.ner_model = self.ner_llm_instance.ner_model
self.nlp_model = NER_LLM.set_nlp_model(config.spacy_model_path)
self.nlp_model = NER_LLM.get_class_variable()

def _initialize_extractors(self, config):
if not self.skip_inferences:
if not self.skip_inferences and self.attn_based_rel_extraction == False:
mock_config = Opensource_LLM_Config(
qa_template=config.user_context,
model_type=config.rel_model_type,
Expand All @@ -100,7 +100,30 @@ def _initialize_extractors(self, config):
nltk_path=config.nltk_path
)
self.semantic_extractor = RelationExtractor(mock_config, self.create_emb)


elif not self.skip_inferences and self.attn_based_rel_extraction == True:
model_config = AutoConfig.from_pretrained(config.rel_model_path)
print("Model Config -------------", model_config)
if 'bert' in model_config.model_type.lower():
self.ner_helper_instance = NER_LLM(ner_model_name=config.rel_model_path)
self.ner_helper_tokenizer = self.ner_helper_instance.ner_tokenizer
self.ner_helper_model = self.ner_helper_instance.ner_model
self.extractor = get_model("bert",model_tokenizer= self.ner_helper_tokenizer,model=self.ner_helper_model)
elif 'llama' in model_config.model_type.lower() or 'mpt' in model_config.model_type.lower():
# model_id = "TheBloke/Llama-2-7B-GGUF"
# filename = "llama-2-7b.Q5_K_M.gguf"
# self.ner_tokenizer = AutoTokenizer.from_pretrained(model_id, gguf_file=filename)
# self.model = transformers.AutoModelForCausalLM.from_pretrained(model_id, gguf_file=filename)
# self.ner_helper_instance = NER_LLM(provided_tokenizer =self.ner_tokenizer, provided_model=self.model)
print("Loaded Model-------------11")
self.model = transformers.AutoModelForCausalLM.from_pretrained(config.rel_model_path,trust_remote_code=True)
print("Loaded Model-------------")
self.ner_helper_instance = NER_LLM(ner_model_name= config.rel_model_path, provided_model=self.model)
self.ner_helper_tokenizer = self.ner_helper_instance.ner_tokenizer
self.ner_helper_model = self.ner_helper_instance.ner_model
self.extractor = get_model("llama",model_tokenizer= self.ner_helper_tokenizer,model=self.ner_helper_model)
else:
raise ValueError("Selected Model not supported for Attnetion Based Graph Extraction")
self.attn_scores_instance = EntityAttentionExtractor(model=self.ner_model, tokenizer=self.ner_tokenizer)

def _initialize_entity_context_extractor(self):
Expand Down Expand Up @@ -323,6 +346,7 @@ def _process_entity_types(self, doc_entity_pairs):
doc_entity_pairs = self.entity_context_extractor.process_entity_types(doc_entities=doc_entity_pairs)
if any(doc_entity_pairs):
doc_entity_pairs = self.ner_llm_instance.remove_duplicates(doc_entity_pairs)
print("Binary Pairs -------------", doc_entity_pairs)
return doc_entity_pairs

def _process_pairs_with_embeddings(self, pairs_withattn, file):
Expand All @@ -347,10 +371,17 @@ def _filter_triples(self, pairs_with_predicates, pairs_withattn):
return filtered_triples

async def _process_relationships(self, filtered_triples, file, doc_source):
relationships = self.semantic_extractor.process_tokens(
filtered_triples,
fixed_entities=(len(self.sample_entities) >= 1)
)
if self.attn_based_rel_extraction == False:
relationships = self.semantic_extractor.process_tokens(
filtered_triples,
fixed_entities=(len(self.sample_entities) >= 1)
)
else:
print("Trimming -----")
filtered_triples = trim_triples(filtered_triples)
print("Filtereddddddd Triples ------------", len(filtered_triples))
relationships = process_tokens(filtered_triples=filtered_triples, ner_instance=self.ner_helper_instance, extractor=self.extractor, nlp_model=self.nlp_model)
print("Predicates Triples From Attn Method----", relationships)
if not relationships:
return

Expand Down
3 changes: 2 additions & 1 deletion querent/core/transformers/relationship_extraction_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,8 @@ def create_semantic_triple(self, input1, input2):
"context": input2_data.get("context", ""),
"file_path": input2_data.get("file_path", ""),
"subject_type": input1.get("subject_type","Unlabeled"),
"object_type": input1.get("object_type","Unlabeled")
"object_type": input1.get("object_type","Unlabeled"),
"score":1
}),
input1.get("object","")
)
Expand Down
76 changes: 42 additions & 34 deletions querent/kg/ner_helperfunctions/ner_llm_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,6 @@ def _tokenize_and_chunk(self, data: str) -> List[Tuple[List[str], str, int]]:

def _token_distance(self, tokens, start_idx1, nn_chunk_length_idx1, start_idx2, noun_chunk1, noun_chunk2):
distance = 0
print("Tokens-----", tokens)
for idx in range(start_idx1 + nn_chunk_length_idx1, start_idx2):
token = tokens[idx]
if (token not in self.filler_tokens and
Expand All @@ -140,32 +139,39 @@ def _token_distance(self, tokens, start_idx1, nn_chunk_length_idx1, start_idx2,


def transform_entity_pairs(self, entity_pairs):
transformed_pairs = []
sentence_group = {}
for pair, metadata in entity_pairs:
combined_sentence = ' '.join(filter(None, [
metadata['previous_sentence'],
metadata['current_sentence'],
metadata['next_sentence']
]))
if combined_sentence not in sentence_group:
sentence_group[combined_sentence] = []
sentence_group[combined_sentence].append(pair)

for combined_sentence, pairs in sentence_group.items():
for entity1, entity2 in pairs:
meta_dict = {
"entity1_score": entity1['score'],
"entity2_score": entity2['score'],
"entity1_label": entity1['label'],
"entity2_label": entity2['label'],
"entity1_nn_chunk":entity1['noun_chunk'],
"entity2_nn_chunk":entity2['noun_chunk'],
}
new_pair = (entity1['entity'], combined_sentence, entity2['entity'], meta_dict)
transformed_pairs.append(new_pair)

return transformed_pairs
try:
transformed_pairs = []
sentence_group = {}
for pair, metadata in entity_pairs:
combined_sentence = ' '.join(filter(None, [
metadata['previous_sentence'],
metadata['current_sentence'],
metadata['next_sentence']
]))
current_sentence = metadata['current_sentence']
if combined_sentence not in sentence_group:
sentence_group[combined_sentence] = []
sentence_group[combined_sentence].append(pair + (current_sentence,))

for combined_sentence, pairs in sentence_group.items():
for entity1, entity2, current_sentence in pairs:
meta_dict = {
"entity1_score": entity1['score'],
"entity2_score": entity2['score'],
"entity1_label": entity1['label'],
"entity2_label": entity2['label'],
"entity1_nn_chunk":entity1['noun_chunk'],
"entity2_nn_chunk":entity2['noun_chunk'],
"current_sentence":current_sentence
}
new_pair = (entity1['entity'], combined_sentence, entity2['entity'], meta_dict)
transformed_pairs.append(new_pair)

return transformed_pairs
except Exception as e:
print("EEEEEEEEEEEEEE", e)
self.logger.error(f"Error trasnforming entity pairs: {e}")
raise Exception(f"Error trasnforming entity pairs: {e}")

def get_chunks(self, tokens: List[str], max_chunk_size=510):
chunks = []
Expand Down Expand Up @@ -210,13 +216,14 @@ def combine_entities_wordpiece(self, entities: List[dict], tokens: List[str]):
i = 0
while i < len(entities):
entity = entities[i]
while i + 1 < len(entities) and entities[i + 1]["entity"].startswith("##"):
while i + 1 < len(entities) and entities[i + 1]["entity"].startswith("##") and entities[i + 1]["start_idx"] - entities[i]["start_idx"] ==1:
entity["entity"] += entities[i + 1]["entity"][2:]
entity["score"] = (entity["score"] + entities[i + 1]["score"]) / 2
i += 1
combined_entities.append(entity)
i += 1
final_entities = []
print("Combined Entitiesssssss----------", combined_entities)
for entity in combined_entities:
entity_text = entity["entity"]
start_idx = entity["start_idx"]
Expand Down Expand Up @@ -261,7 +268,6 @@ def extract_binary_pairs(self, entities: List[dict], tokens: List[str], all_sent
if entities[i]["start_idx"] + 1 == entities[j]["start_idx"]:
continue
distance = self._token_distance(tokens, entities[i]["start_idx"], entities[i]["noun_chunk_length"],entities[j]["start_idx"],entities[i]["noun_chunk"], entities[j]["noun_chunk"])
print("Distance---------", distance)
if distance <= 10:
pair = (entities[i], entities[j])
if pair not in binary_pairs:
Expand Down Expand Up @@ -338,13 +344,13 @@ def filter_matching_entities(self, tuples_nested_list, entities_nested_list):
return matched_tuples

def find_subword_indices(self, text, entity):
print("entity----", entity)
subwords = self.ner_tokenizer.tokenize(entity)
subword_ids = self.ner_tokenizer.convert_tokens_to_ids(subwords)
token_ids = self.ner_tokenizer.convert_tokens_to_ids(self.ner_tokenizer.tokenize(text))
print("Length of token idsssss", len(token_ids))
print("Length of Subword IDs---",len(subword_ids), subword_ids)

subword_positions = []
print("entity----", subwords)
print("entity -----------", self.ner_tokenizer.tokenize(text))
for i in range(len(token_ids) - len(subword_ids) + 1):
if token_ids[i:i + len(subword_ids)] == subword_ids:
subword_positions.append((i+1, i + len(subword_ids)))
Expand All @@ -362,24 +368,25 @@ def extract_entities_from_sentence(self, sentence: str, sentence_idx: int, all_s
tokens = self.tokenize_sentence(sentence)
chunks = self.get_chunks(tokens)
all_entities = []
print("Tokenssssss-----", tokens)
for chunk in chunks:
if fixed_entities_flag == False:
entities = self.extract_entities_from_chunk(chunk)
else:
entities = self.extract_fixed_entities_from_chunk(chunk,fixed_entities, entity_types)
all_entities.extend(entities)
print("Before Wordpiece---------------", all_entities)
final_entities = self.combine_entities_wordpiece(all_entities, tokens)
if fixed_entities_flag == False:
print("Final Entities ----", final_entities)
parsed_entities = Dependency_Parsing(entities=final_entities, sentence=sentence, model=NER_LLM.nlp)
entities_withnnchunk = parsed_entities.entities
else:
for entity in final_entities:
entity['noun_chunk'] = entity['entity']
entity['noun_chunk_length'] = len(entity['noun_chunk'].split())
entities_withnnchunk = final_entities
print("Entitiessssss------", entities_withnnchunk)
binary_pairs = self.extract_binary_pairs(entities_withnnchunk, tokens, all_sentences, sentence_idx)
print("Binary Pairs------", binary_pairs)

return entities_withnnchunk, binary_pairs
except Exception as e:
Expand Down Expand Up @@ -429,6 +436,7 @@ def remove_duplicates(self, data):

if cleaned_sublist:
new_data.append(cleaned_sublist)


return new_data

Expand Down
Loading

0 comments on commit 3357142

Please sign in to comment.