From 3357142ac6304ed0120025aea2e75a01df8c61fd Mon Sep 17 00:00:00 2001 From: ngupta10 Date: Wed, 5 Jun 2024 07:09:01 +0530 Subject: [PATCH] fixes to attn mechanism --- .../transformers/bert_ner_opensourcellm.py | 55 ++- .../relationship_extraction_llm.py | 3 +- .../ner_llm_transformer.py | 76 ++-- .../attn_based_relationship_filter.py | 344 +++++++++++++----- .../attn_based_relationship_model_getter.py | 27 +- .../attn_based_relationship_seach_scope.py | 81 ++--- .../contextual_predicate.py | 4 +- .../kg/rel_helperfunctions/embedding_store.py | 4 +- .../kg/rel_helperfunctions/triple_to_json.py | 3 +- 9 files changed, 384 insertions(+), 213 deletions(-) diff --git a/querent/core/transformers/bert_ner_opensourcellm.py b/querent/core/transformers/bert_ner_opensourcellm.py index 7a7260bf..722b2b0d 100644 --- a/querent/core/transformers/bert_ner_opensourcellm.py +++ b/querent/core/transformers/bert_ner_opensourcellm.py @@ -1,9 +1,7 @@ import json -import re -from transformers import AutoTokenizer +from transformers import AutoConfig, AutoTokenizer +import transformers import time - -import unidecode from querent.common.types.ingested_table import IngestedTables from querent.kg.ner_helperfunctions.fixed_predicate import FixedPredicateExtractor from querent.common.types.ingested_images import IngestedImages @@ -30,6 +28,8 @@ from querent.kg.rel_helperfunctions.embedding_store import EmbeddingStore from querent.models.model_manager import ModelManager from querent.models.gguf_metadata_extractor import GGUFMetadataExtractor +from querent.kg.rel_helperfunctions.attn_based_relationship_model_getter import get_model +from querent.kg.rel_helperfunctions.attn_based_relationship_filter import process_tokens, trim_triples class BERTLLM(BaseEngine): def __init__( @@ -49,6 +49,7 @@ def __init__( self.sample_relationships = config.sample_relationships self.user_context = config.user_context self.isConfinedSearch = config.is_confined_search + self.attn_based_rel_extraction = True self.create_emb = EmbeddingStore() if not Embedding else Embedding try: @@ -76,12 +77,11 @@ def _initialize_components(self, config): def _initialize_models(self, config): self.ner_model_initialized = self.model_manager.get_model(config.ner_model_name) - if not self.skip_inferences: + if not self.skip_inferences and self.attn_based_rel_extraction == False: extractor = GGUFMetadataExtractor(config.rel_model_path) model_metadata = extractor.dump_metadata() rel_model_name = extractor.extract_general_name(model_metadata) self.rel_model_initialized = self.model_manager.get_model(rel_model_name, model_path=config.rel_model_path) - self.ner_llm_instance = NER_LLM(ner_model_name=self.ner_model_initialized) self.ner_tokenizer = self.ner_llm_instance.ner_tokenizer self.ner_model = self.ner_llm_instance.ner_model @@ -89,7 +89,7 @@ def _initialize_models(self, config): self.nlp_model = NER_LLM.get_class_variable() def _initialize_extractors(self, config): - if not self.skip_inferences: + if not self.skip_inferences and self.attn_based_rel_extraction == False: mock_config = Opensource_LLM_Config( qa_template=config.user_context, model_type=config.rel_model_type, @@ -100,7 +100,30 @@ def _initialize_extractors(self, config): nltk_path=config.nltk_path ) self.semantic_extractor = RelationExtractor(mock_config, self.create_emb) - + + elif not self.skip_inferences and self.attn_based_rel_extraction == True: + model_config = AutoConfig.from_pretrained(config.rel_model_path) + print("Model Config -------------", model_config) + if 'bert' in model_config.model_type.lower(): + self.ner_helper_instance = NER_LLM(ner_model_name=config.rel_model_path) + self.ner_helper_tokenizer = self.ner_helper_instance.ner_tokenizer + self.ner_helper_model = self.ner_helper_instance.ner_model + self.extractor = get_model("bert",model_tokenizer= self.ner_helper_tokenizer,model=self.ner_helper_model) + elif 'llama' in model_config.model_type.lower() or 'mpt' in model_config.model_type.lower(): + # model_id = "TheBloke/Llama-2-7B-GGUF" + # filename = "llama-2-7b.Q5_K_M.gguf" + # self.ner_tokenizer = AutoTokenizer.from_pretrained(model_id, gguf_file=filename) + # self.model = transformers.AutoModelForCausalLM.from_pretrained(model_id, gguf_file=filename) + # self.ner_helper_instance = NER_LLM(provided_tokenizer =self.ner_tokenizer, provided_model=self.model) + print("Loaded Model-------------11") + self.model = transformers.AutoModelForCausalLM.from_pretrained(config.rel_model_path,trust_remote_code=True) + print("Loaded Model-------------") + self.ner_helper_instance = NER_LLM(ner_model_name= config.rel_model_path, provided_model=self.model) + self.ner_helper_tokenizer = self.ner_helper_instance.ner_tokenizer + self.ner_helper_model = self.ner_helper_instance.ner_model + self.extractor = get_model("llama",model_tokenizer= self.ner_helper_tokenizer,model=self.ner_helper_model) + else: + raise ValueError("Selected Model not supported for Attnetion Based Graph Extraction") self.attn_scores_instance = EntityAttentionExtractor(model=self.ner_model, tokenizer=self.ner_tokenizer) def _initialize_entity_context_extractor(self): @@ -323,6 +346,7 @@ def _process_entity_types(self, doc_entity_pairs): doc_entity_pairs = self.entity_context_extractor.process_entity_types(doc_entities=doc_entity_pairs) if any(doc_entity_pairs): doc_entity_pairs = self.ner_llm_instance.remove_duplicates(doc_entity_pairs) + print("Binary Pairs -------------", doc_entity_pairs) return doc_entity_pairs def _process_pairs_with_embeddings(self, pairs_withattn, file): @@ -347,10 +371,17 @@ def _filter_triples(self, pairs_with_predicates, pairs_withattn): return filtered_triples async def _process_relationships(self, filtered_triples, file, doc_source): - relationships = self.semantic_extractor.process_tokens( - filtered_triples, - fixed_entities=(len(self.sample_entities) >= 1) - ) + if self.attn_based_rel_extraction == False: + relationships = self.semantic_extractor.process_tokens( + filtered_triples, + fixed_entities=(len(self.sample_entities) >= 1) + ) + else: + print("Trimming -----") + filtered_triples = trim_triples(filtered_triples) + print("Filtereddddddd Triples ------------", len(filtered_triples)) + relationships = process_tokens(filtered_triples=filtered_triples, ner_instance=self.ner_helper_instance, extractor=self.extractor, nlp_model=self.nlp_model) + print("Predicates Triples From Attn Method----", relationships) if not relationships: return diff --git a/querent/core/transformers/relationship_extraction_llm.py b/querent/core/transformers/relationship_extraction_llm.py index 6bb61da0..a51d89e2 100644 --- a/querent/core/transformers/relationship_extraction_llm.py +++ b/querent/core/transformers/relationship_extraction_llm.py @@ -138,7 +138,8 @@ def create_semantic_triple(self, input1, input2): "context": input2_data.get("context", ""), "file_path": input2_data.get("file_path", ""), "subject_type": input1.get("subject_type","Unlabeled"), - "object_type": input1.get("object_type","Unlabeled") + "object_type": input1.get("object_type","Unlabeled"), + "score":1 }), input1.get("object","") ) diff --git a/querent/kg/ner_helperfunctions/ner_llm_transformer.py b/querent/kg/ner_helperfunctions/ner_llm_transformer.py index 00865979..ae288c83 100644 --- a/querent/kg/ner_helperfunctions/ner_llm_transformer.py +++ b/querent/kg/ner_helperfunctions/ner_llm_transformer.py @@ -128,7 +128,6 @@ def _tokenize_and_chunk(self, data: str) -> List[Tuple[List[str], str, int]]: def _token_distance(self, tokens, start_idx1, nn_chunk_length_idx1, start_idx2, noun_chunk1, noun_chunk2): distance = 0 - print("Tokens-----", tokens) for idx in range(start_idx1 + nn_chunk_length_idx1, start_idx2): token = tokens[idx] if (token not in self.filler_tokens and @@ -140,32 +139,39 @@ def _token_distance(self, tokens, start_idx1, nn_chunk_length_idx1, start_idx2, def transform_entity_pairs(self, entity_pairs): - transformed_pairs = [] - sentence_group = {} - for pair, metadata in entity_pairs: - combined_sentence = ' '.join(filter(None, [ - metadata['previous_sentence'], - metadata['current_sentence'], - metadata['next_sentence'] - ])) - if combined_sentence not in sentence_group: - sentence_group[combined_sentence] = [] - sentence_group[combined_sentence].append(pair) - - for combined_sentence, pairs in sentence_group.items(): - for entity1, entity2 in pairs: - meta_dict = { - "entity1_score": entity1['score'], - "entity2_score": entity2['score'], - "entity1_label": entity1['label'], - "entity2_label": entity2['label'], - "entity1_nn_chunk":entity1['noun_chunk'], - "entity2_nn_chunk":entity2['noun_chunk'], - } - new_pair = (entity1['entity'], combined_sentence, entity2['entity'], meta_dict) - transformed_pairs.append(new_pair) - - return transformed_pairs + try: + transformed_pairs = [] + sentence_group = {} + for pair, metadata in entity_pairs: + combined_sentence = ' '.join(filter(None, [ + metadata['previous_sentence'], + metadata['current_sentence'], + metadata['next_sentence'] + ])) + current_sentence = metadata['current_sentence'] + if combined_sentence not in sentence_group: + sentence_group[combined_sentence] = [] + sentence_group[combined_sentence].append(pair + (current_sentence,)) + + for combined_sentence, pairs in sentence_group.items(): + for entity1, entity2, current_sentence in pairs: + meta_dict = { + "entity1_score": entity1['score'], + "entity2_score": entity2['score'], + "entity1_label": entity1['label'], + "entity2_label": entity2['label'], + "entity1_nn_chunk":entity1['noun_chunk'], + "entity2_nn_chunk":entity2['noun_chunk'], + "current_sentence":current_sentence + } + new_pair = (entity1['entity'], combined_sentence, entity2['entity'], meta_dict) + transformed_pairs.append(new_pair) + + return transformed_pairs + except Exception as e: + print("EEEEEEEEEEEEEE", e) + self.logger.error(f"Error trasnforming entity pairs: {e}") + raise Exception(f"Error trasnforming entity pairs: {e}") def get_chunks(self, tokens: List[str], max_chunk_size=510): chunks = [] @@ -210,13 +216,14 @@ def combine_entities_wordpiece(self, entities: List[dict], tokens: List[str]): i = 0 while i < len(entities): entity = entities[i] - while i + 1 < len(entities) and entities[i + 1]["entity"].startswith("##"): + while i + 1 < len(entities) and entities[i + 1]["entity"].startswith("##") and entities[i + 1]["start_idx"] - entities[i]["start_idx"] ==1: entity["entity"] += entities[i + 1]["entity"][2:] entity["score"] = (entity["score"] + entities[i + 1]["score"]) / 2 i += 1 combined_entities.append(entity) i += 1 final_entities = [] + print("Combined Entitiesssssss----------", combined_entities) for entity in combined_entities: entity_text = entity["entity"] start_idx = entity["start_idx"] @@ -261,7 +268,6 @@ def extract_binary_pairs(self, entities: List[dict], tokens: List[str], all_sent if entities[i]["start_idx"] + 1 == entities[j]["start_idx"]: continue distance = self._token_distance(tokens, entities[i]["start_idx"], entities[i]["noun_chunk_length"],entities[j]["start_idx"],entities[i]["noun_chunk"], entities[j]["noun_chunk"]) - print("Distance---------", distance) if distance <= 10: pair = (entities[i], entities[j]) if pair not in binary_pairs: @@ -338,13 +344,13 @@ def filter_matching_entities(self, tuples_nested_list, entities_nested_list): return matched_tuples def find_subword_indices(self, text, entity): + print("entity----", entity) subwords = self.ner_tokenizer.tokenize(entity) subword_ids = self.ner_tokenizer.convert_tokens_to_ids(subwords) token_ids = self.ner_tokenizer.convert_tokens_to_ids(self.ner_tokenizer.tokenize(text)) - print("Length of token idsssss", len(token_ids)) - print("Length of Subword IDs---",len(subword_ids), subword_ids) - subword_positions = [] + print("entity----", subwords) + print("entity -----------", self.ner_tokenizer.tokenize(text)) for i in range(len(token_ids) - len(subword_ids) + 1): if token_ids[i:i + len(subword_ids)] == subword_ids: subword_positions.append((i+1, i + len(subword_ids))) @@ -362,14 +368,17 @@ def extract_entities_from_sentence(self, sentence: str, sentence_idx: int, all_s tokens = self.tokenize_sentence(sentence) chunks = self.get_chunks(tokens) all_entities = [] + print("Tokenssssss-----", tokens) for chunk in chunks: if fixed_entities_flag == False: entities = self.extract_entities_from_chunk(chunk) else: entities = self.extract_fixed_entities_from_chunk(chunk,fixed_entities, entity_types) all_entities.extend(entities) + print("Before Wordpiece---------------", all_entities) final_entities = self.combine_entities_wordpiece(all_entities, tokens) if fixed_entities_flag == False: + print("Final Entities ----", final_entities) parsed_entities = Dependency_Parsing(entities=final_entities, sentence=sentence, model=NER_LLM.nlp) entities_withnnchunk = parsed_entities.entities else: @@ -377,9 +386,7 @@ def extract_entities_from_sentence(self, sentence: str, sentence_idx: int, all_s entity['noun_chunk'] = entity['entity'] entity['noun_chunk_length'] = len(entity['noun_chunk'].split()) entities_withnnchunk = final_entities - print("Entitiessssss------", entities_withnnchunk) binary_pairs = self.extract_binary_pairs(entities_withnnchunk, tokens, all_sentences, sentence_idx) - print("Binary Pairs------", binary_pairs) return entities_withnnchunk, binary_pairs except Exception as e: @@ -429,6 +436,7 @@ def remove_duplicates(self, data): if cleaned_sublist: new_data.append(cleaned_sublist) + return new_data diff --git a/querent/kg/rel_helperfunctions/attn_based_relationship_filter.py b/querent/kg/rel_helperfunctions/attn_based_relationship_filter.py index de9a2e2a..04ed3136 100644 --- a/querent/kg/rel_helperfunctions/attn_based_relationship_filter.py +++ b/querent/kg/rel_helperfunctions/attn_based_relationship_filter.py @@ -1,9 +1,15 @@ +import ast +import json +import torch from querent.kg.rel_helperfunctions.attn_based_relationship_seach_scope import SearchContextualRelationship as sc from querent.kg.rel_helperfunctions.attn_based_relationship_seach_scope import EntityPair as ep from querent.kg.rel_helperfunctions.attn_based_relationship_seach_scope import perform_search from dataclasses import dataclass +from querent.logging.logger import setup_logger from typing import Optional - +from collections import defaultdict +import numpy +from querent.kg.ner_helperfunctions.ner_llm_transformer import NER_LLM @dataclass class Entity: @@ -15,6 +21,7 @@ class SemanticPairs: head: Entity tail: Entity relations: list[str] + scores: list[float] class IndividualFilter: @@ -25,7 +32,6 @@ def __init__(self, forward_relations: bool, threshold: float, token_idx_with_wor self.doc = spacy_doc def filter(self, candidates: list[sc], e_pair: ep) -> SemanticPairs: - print("---------------", e_pair) response = SemanticPairs( head=Entity( text=e_pair.head_entity['noun_chunk'].lower() @@ -33,14 +39,12 @@ def filter(self, candidates: list[sc], e_pair: ep) -> SemanticPairs: tail=Entity( text=e_pair.tail_entity['noun_chunk'].lower() ), - relations=[] + relations=[], + scores = [0] ) - + counter = 0 for candidate in candidates: - print("-----------Relation Token ", candidate.relation_tokens) - print("-----------Relation Score ", candidate.mean_score()) if candidate.mean_score() < self.threshold: - print("Less Scorte") continue rel_txt = '' rel_idx = [] @@ -67,7 +71,35 @@ def filter(self, candidates: list[sc], e_pair: ep) -> SemanticPairs: continue response.relations.append(rel_txt) + response.scores.append((response.scores[counter] + candidate.mean_score())) + # print("Response---", response) + counter = counter +1 + del response.scores[0] return response + + def combine_entities(self, entity_list): + # This list will store the final entities after combining + combined_entities = [] + # Temporary storage for current entity being processed + current_entity = None + + for entity, index in entity_list: + if entity.startswith('##'): + # If the entity starts with ##, concatenate it with the last part of current_entity + if current_entity: + current_entity = (current_entity[0] + entity[2:], current_entity[1]) + else: + # If the current_entity is not None, it means we have completed processing an entity + if current_entity: + combined_entities.append(current_entity) + # Start a new entity + current_entity = (entity, len(combined_entities) + 1) + + # Append the last processed entity if any + if current_entity: + combined_entities.append(current_entity) + + return combined_entities def lemmatize(self, relation: str, indexes: list[int]) -> str: if relation.isnumeric(): @@ -75,22 +107,17 @@ def lemmatize(self, relation: str, indexes: list[int]) -> str: new_text = '' # Another option would be including 'AUX' - remove_morpho = {'SYM', 'OTHER', 'PUNCT', 'NUM', 'INTJ'} + remove_morpho = {'SYM', 'OTHER', 'PUNCT', 'NUM', 'INTJ', 'DET', 'ADP', 'PART'} last_word = ' ' - print("Relationshippppppppppp--------", relation) - print("Indexxxxxxxxxx", indexes) words = [] for idx in indexes: words.append(self.token_idx_with_word[idx -1]) - print("Wordssssssss", words) + words = self.combine_entities(words) for word, word_id in words: - token = next((token for token in self.doc if token.text.lower() == word), None) - print("Tokennnnnnnnnnnnnnnnn", token, token.pos_) + token = next((token for token in self.doc if word in token.text.lower()), None) if token and token.pos_ not in remove_morpho: - print("Finalyy adding -------------") new_word = token.lemma_.lower() if last_word != new_word: - print("Finalyy adding -------------") new_text += new_word new_text += ' ' last_word = new_word @@ -98,22 +125,14 @@ def lemmatize(self, relation: str, indexes: list[int]) -> str: new_text = new_text.strip() return new_text -def clean_relations(ht_pairs: list[SemanticPairs]) -> list[SemanticPairs]: - unique_relations = set() - for ht_pair in ht_pairs: - filtered_relations = [] - for relation in ht_pair.relations: - unique_key = ht_pair.head.text + "|" + relation + "|" + ht_pair.tail.text - reverse_key = ht_pair.tail.text + "|" + relation + "|" + ht_pair.head.text - if unique_key not in unique_relations and reverse_key not in unique_relations: - filtered_relations.append(relation) - unique_relations.add(unique_key) - ht_pair.relations = filtered_relations - - new_list = [pair for pair in ht_pairs if len(pair.relations) > 0 and - (pair.head.text != pair.tail.text)] +def get_best_relation(semantic_pair): + scores = [score.item() if isinstance(score, torch.Tensor) else score for score in semantic_pair.scores] + max_index = scores.index(max(scores)) + best_relation = semantic_pair.relations[max_index] + best_score = scores[max_index] + + return best_relation, best_score - return new_list def frequency_cutoff(ht_relations: list[SemanticPairs], frequency: int): if frequency == 1: @@ -129,6 +148,128 @@ def frequency_cutoff(ht_relations: list[SemanticPairs], frequency: int): for ht_item in ht_relations: ht_item.relations = [rel for rel in ht_item.relations if counter[rel] >= frequency] +def trim_triples(data): + try: + trimmed_data = [] + for entity1, predicate, entity2 in data: + predicate_dict = json.loads(predicate) + trimmed_predicate = { + 'context': predicate_dict.get('context', ''), + 'entity1_nn_chunk': predicate_dict.get('entity1_nn_chunk', ''), + 'entity2_nn_chunk': predicate_dict.get('entity2_nn_chunk', ''), + 'entity1_label': predicate_dict.get('entity1_label', ''), + 'entity2_label': predicate_dict.get('entity2_label', ''), + 'file_path': predicate_dict.get('file_path', ''), + 'current_sentence': predicate_dict.get('current_sentence', '') + } + trimmed_data.append((entity1, trimmed_predicate, entity2)) + + return trimmed_data + except Exception as e: + raise Exception(f'Error in trimming triples: {e}') + +def process_tokens(ner_instance : NER_LLM, extractor, filtered_triples, nlp_model): + try: + updated_triples = [] + for subject, predicate_metadata, object in filtered_triples: + try: + context = predicate_metadata['current_sentence'].replace("\n"," ") + print("This is the context---------", context) + head_positions = ner_instance.find_subword_indices(context, predicate_metadata['entity1_nn_chunk']) + tail_positions = ner_instance.find_subword_indices(context, predicate_metadata['entity2_nn_chunk']) + print("Head and Tail positions acquired----", head_positions, tail_positions) + if head_positions[0][0] > tail_positions[0][0]: + head_entity = {'entity': object, 'noun_chunk':predicate_metadata['entity2_nn_chunk'], 'entity_label':predicate_metadata['entity2_label'] } + tail_entity = {'entity': subject, 'noun_chunk':predicate_metadata['entity1_nn_chunk'], 'entity_label':predicate_metadata['entity1_label']} + entity_pair = ep(head_entity, tail_entity, context, tail_positions, head_positions) + else: + head_entity = {'entity': subject, 'noun_chunk':predicate_metadata['entity1_nn_chunk'], 'entity_label':predicate_metadata['entity1_label']} + tail_entity = {'entity': object, 'noun_chunk':predicate_metadata['entity2_nn_chunk'], 'entity_label':predicate_metadata['entity2_label']} + entity_pair = ep(head_entity, tail_entity, context, head_positions, tail_positions) + print("Entity Pair---------", ep) + tokenized_sentence = extractor.tokenize_sentence(context) + model_input = extractor.model_input(tokenized_sentence) + attention_matrix = extractor.inference_attention(model_input) + token_idx_with_word = ner_instance.tokenize_sentence_with_positions(context) + spacy_doc = nlp_model(context) + filter = IndividualFilter(True, 0.02, token_idx_with_word, spacy_doc) + + ## HEAD Entity Based Attention Search + print("Lets start Perform Search") + candidate_paths = perform_search(entity_pair.head_entity['start_idx'], attention_matrix, entity_pair, search_candidates=5, require_contiguous=True, max_relation_length=8, num_initial_tokens=extractor.num_start_tokens()) + candidate_paths = remove_duplicates(candidate_paths) + print("Search Finished------------") + filtered_results = filter.filter(candidates=candidate_paths,e_pair=entity_pair) + predicate_he, score_he = get_best_relation(filtered_results) + print("Context----", context) + print("------------", predicate_he,"-------------------", score_he) + + ##TAIL ENTITY Based Attention Search + candidate_paths = perform_search(entity_pair.tail_entity['start_idx'], attention_matrix, entity_pair, search_candidates=5, require_contiguous=True, max_relation_length=8, num_initial_tokens=extractor.num_start_tokens()) + candidate_paths = remove_duplicates(candidate_paths) + filtered_results = filter.filter(candidates=candidate_paths,e_pair=entity_pair) + predicate_te, score_te = get_best_relation(filtered_results) + print("------------222222", predicate_te,"-------------------", score_te) + + if score_he > score_te and (score_he >= 0.2 or score_te >= 0.2): + triple = create_semantic_triple(head_entity=head_entity['noun_chunk'], + tail_entity=tail_entity['noun_chunk'], + predicate=predicate_he, + score=score_he, + predicate_metadata=predicate_metadata, + subject_type=head_entity['entity_label'], + object_type=tail_entity['entity_label']) + updated_triples.append(triple) + elif score_he < score_te and (score_he >= 0.2 or score_te >= 0.2): + triple = create_semantic_triple(head_entity=tail_entity['noun_chunk'], + tail_entity=head_entity['noun_chunk'], + predicate=predicate_te, + score=score_te, + predicate_metadata=predicate_metadata, + subject_type=tail_entity['entity_label'], + object_type=head_entity['entity_label']) + updated_triples.append(triple) + except Exception as e: + print(f"Caught an exception: {e}") + continue + return updated_triples + except Exception as e: + print("Exception in process tokens -----", e) + raise Exception(f'Error in extracting Attention Based Relationships: {e}') + + +def remove_duplicates(candidate_paths): + seen_relations = set() + unique_paths = [] + + for path in candidate_paths: + # Convert the relation_tokens to a tuple to make it hashable + relation_tokens_tuple = tuple(path.relation_tokens) + if relation_tokens_tuple not in seen_relations: + seen_relations.add(relation_tokens_tuple) + unique_paths.append(path) + + return unique_paths + +def create_semantic_triple(head_entity, tail_entity, predicate, score, predicate_metadata, subject_type, object_type): + try: + triple = ( + head_entity, + json.dumps({ + "predicate": predicate, + "predicate_type": "", + "context": predicate_metadata["context"].replace('\n',' '), + "file_path": predicate_metadata["file_path"], + "subject_type": subject_type, + "object_type": object_type, + "score":score, + }), + tail_entity + ) + return triple + except Exception as e: + print(f"Error in creating semantic triple: {e}") + raise Exception(f"Error in creating semantic triple: {e}") # Example usage # head_entity = { @@ -144,25 +285,25 @@ def frequency_cutoff(ht_relations: list[SemanticPairs], frequency: int): # 'score': 1.0, # 'noun_chunk': 'Environmental Sciences Department', # } -context = { - 'current_sentence': 'Dr. Emily Stanton, a prominent figure in the Environmental Sciences Department, has been advocating for cleaner energy use on campus for years.', - 'previous_sentence': "This decision was influenced heavily by the growing concern among the student body and faculty about the city's escalating pollution levels.", - 'next_sentence': 'Her relentless efforts finally paid off when the university committed to a 40% reduction in carbon emissions over the next five years.' -} - -tail_entity = { - 'entity': 'introduction gas injection', - 'label': 'geo-method', - 'score': 1.0, - 'noun_chunk': 'introduction gas injection', -} - -head_entity = { - 'entity': 'oil production', - 'label': 'geo-method', - 'score': 1.0, - 'noun_chunk': 'oil production', -} +# context = { +# 'current_sentence': 'Dr. Emily Stanton, a prominent figure in the Environmental Sciences Department, has been advocating for cleaner energy use on campus for years.', +# 'previous_sentence': "This decision was influenced heavily by the growing concern among the student body and faculty about the city's escalating pollution levels.", +# 'next_sentence': 'Her relentless efforts finally paid off when the university committed to a 40% reduction in carbon emissions over the next five years.' +# } + +# head_entity = { +# 'entity': 'gas injection', +# 'label': 'geo-method', +# 'score': 1.0, +# 'noun_chunk': 'gas injection', +# } + +# tail_entity = { +# 'entity': 'oil production', +# 'label': 'geo-method', +# 'score': 1.0, +# 'noun_chunk': 'oil production', +# } # head_entity = { # 'entity': 'oil production', @@ -177,85 +318,92 @@ def frequency_cutoff(ht_relations: list[SemanticPairs], frequency: int): # 'score': 1.0, # 'noun_chunk': 'shale plays', +# } +# head_entity = { +# 'entity': 'nitrogen gas cyclic miscible and immiscible injection', +# 'label': 'organization', +# 'score': 1.0, +# 'noun_chunk': 'nitrogen gas cyclic miscible and immiscible injection', # } +# tail_entity = { +# 'entity': 'oil recovery', +# 'label': 'group', +# 'score': 1.0, +# 'noun_chunk': 'oil recovery', + +# } +# import transformers +# from querent.kg.ner_helperfunctions.ner_llm_transformer import NER_LLM +# from querent.kg.rel_helperfunctions.attn_based_relationship_model_getter import get_model +# import numpy -import transformers -from querent.kg.ner_helperfunctions.ner_llm_transformer import NER_LLM -from querent.kg.rel_helperfunctions.attn_based_relationship_model_getter import get_model -import numpy +# tokenizer = transformers.BertTokenizer.from_pretrained('bert-base-uncased',from_tf=True ) +# model = transformers.BertModel.from_pretrained('bert-base-uncased', from_tf=True) -tokenizer = transformers.BertTokenizer.from_pretrained('bert-base-uncased',from_tf=True ) -model = transformers.BertModel.from_pretrained('bert-base-uncased', from_tf=True) +# tokenizer = transformers.BertTokenizer.from_pretrained("botryan96/GeoBERT",from_tf=True ) +# model = transformers.BertModel.from_pretrained("botryan96/GeoBERT", from_tf=True) ##Get end indexes of noun chunks # combined_text = f"{context['previous_sentence']} {context['current_sentence']} {context['next_sentence']}" # combined_text = "Dr. Emily Stanton, a prominent figure in the Environmental Sciences Department, has been advocating for cleaner energy use on campus for years." # combined_text = "1. introduction gas injection has been a widely used technology for increasing oil production in unconventional shale plays in the united states, and it may be the most efficient approach for unlocking the remaining oil percentage. unconventional resources, like shale reservoirs, are widely recognized for their extremely low permeability and porosity.1 despite the fact that multistage hydraulic fracturing and horizontal well drilling techniques are used to extract the remaining oil from such reservoirs, only 4- 6% of the trapped oil can be extracted, and the oil production drops after a few months, attributing to the ultralow permeability.2-19 water injection is also one of the suitable strategies for increasing oil recovery from conventional reservoirs; nevertheless, due to weak injectivity, insuicient sweep potency, and clay swelling concerns, this approach is not the ideal solution for tight reservoirs.20,21 cyclic gas injection outperforms gas looding methods in terms of enhancing oil recovery, mainly in ultratight reservoirs.22,23 the total organic carbon (toc) is the most important inluencing parameter on gas injection in tight reservoirs because kerogen makes the surface of the pore oil-wet, making the oil inside challenging to extract.24 due to the combination of multiphase luids (i.e., gas, oil, condensate, and water) and scales, multiphase low production can create a number of challenges including wax and asphaltene deposition, hydrate formation, slugging, and emulsions.25 organic hydrocarbon particles settling in oil and gas reservoirs might create many low assurance problems throughout the extraction process." -combined_text = "1. introduction gas injection has been a widely used technology for increasing oil production in unconventional shale plays in the united states, and it may be the most efficient approach for unlocking the remaining oil percentage." -ner_helper = NER_LLM(ner_model_name="dummy",provided_model="dummy", provided_tokenizer= tokenizer) - -head_positions = ner_helper.find_subword_indices(combined_text, head_entity['noun_chunk']) -tail_positions = ner_helper.find_subword_indices(combined_text, tail_entity['noun_chunk']) +# combined_text = "1. introduction gas injection has been a widely used technology for increasing oil production in unconventional shale plays in the united states, and it may be the most efficient approach for unlocking the remaining oil percentage." +# combined_text = "asphaltene precipitation and deposition during nitrogen gas cyclic\nmiscible and immiscible injection in eagle ford shale and its impact\non oil recovery\nmukhtar elturki and abdulmohsin imqam*\n cite".replace("\n", " ") +# ner_helper = NER_LLM(ner_model_name="dummy",provided_model="dummy", provided_tokenizer= tokenizer) +# head_positions = ner_helper.find_subword_indices(combined_text, head_entity['noun_chunk']) +# tail_positions = ner_helper.find_subword_indices(combined_text, tail_entity['noun_chunk']) -##Initialize Entity Pair -entity_pair = ep(head_entity, tail_entity, context, head_positions, tail_positions) +# ##Initialize Entity Pair +# entity_pair = ep(head_entity, tail_entity, context, head_positions, tail_positions) -# # Compute Attention_matrix -extractor = get_model("bert",model_tokenizer= tokenizer,model=model) -tokenized_sentence = extractor.tokenize_sentence(combined_text) -model_input = extractor.model_input(tokenized_sentence) -attention_matrix = extractor.inference_attention(model_input) +# # # Compute Attention_matrix -print("Attention------- Done", numpy.shape(attention_matrix)) +# extractor = get_model("bert",model_tokenizer= tokenizer,model=model) +# tokenized_sentence = extractor.tokenize_sentence(combined_text) +# model_input = extractor.model_input(tokenized_sentence) +# attention_matrix = extractor.inference_attention(model_input) +# print("Attention------- Done", numpy.shape(attention_matrix)) -candidate_paths = perform_search(attention_matrix, entity_pair, search_candidates=5, require_contiguous=True, max_relation_length=8, num_initial_tokens=1) -def remove_duplicates(candidate_paths): - seen_relations = set() - unique_paths = [] +# candidate_paths = perform_search(attention_matrix, entity_pair, search_candidates=5, require_contiguous=True, max_relation_length=8, num_initial_tokens=1) - for path in candidate_paths: - # Convert the relation_tokens to a tuple to make it hashable - relation_tokens_tuple = tuple(path.relation_tokens) - if relation_tokens_tuple not in seen_relations: - seen_relations.add(relation_tokens_tuple) - unique_paths.append(path) - return unique_paths -candidate_paths = remove_duplicates(candidate_paths) +# candidate_paths = remove_duplicates(candidate_paths) -# Display candidate paths -for path in candidate_paths: - print(f"Path: {path.relation_tokens}, Mean Score: {path.mean_score()}") +# # Display candidate paths +# for path in candidate_paths: +# print(f"Path: {path.relation_tokens}, Mean Score: {path.mean_score()}") -token_idx_with_word = ner_helper.tokenize_sentence_with_positions(combined_text) -print("Token_idx_wi&&&&&&&&&&&&&&&&", token_idx_with_word) -nlp_model = NER_LLM.set_nlp_model('en_core_web_lg') -nlp_model = NER_LLM.get_class_variable() -nlp = nlp_model -spacy_doc = nlp(combined_text) -for index, token in enumerate(tokenizer.tokenize(combined_text)): +# token_idx_with_word = ner_helper.tokenize_sentence_with_positions(combined_text) +# print("Token_idx_wi&&&&&&&&&&&&&&&&", token_idx_with_word) +# nlp_model = NER_LLM.set_nlp_model('en_core_web_lg') +# nlp_model = NER_LLM.get_class_variable() +# nlp = nlp_model +# spacy_doc = nlp(combined_text) +# for index, token in enumerate(tokenizer.tokenize(combined_text)): - print(f"Index {index}: {token}") -filter = IndividualFilter(True, 0.02, token_idx_with_word, spacy_doc) -filtered_results = filter.filter(candidates=candidate_paths,e_pair=entity_pair) -print("Final Results --------------", filtered_results) -print(head_positions, tail_positions) +# print(f"Index {index}: {token}") +# filter = IndividualFilter(True, 0.02, token_idx_with_word, spacy_doc) +# filtered_results = filter.filter(candidates=candidate_paths,e_pair=entity_pair) +# print("Final Results --------------", filtered_results) +# print(head_positions, tail_positions) -print("Cleaned Relationships ------------", clean_relations([filtered_results])) +# # print("Cleaned Relationships ------------", clean_relations([filtered_results])) -# for token in spacy_doc: -# print("Token and Token Pos", token, "---------", token.pos_) \ No newline at end of file +# predicate, score = get_best_relation(filtered_results) +# print("------------", predicate,"-------------------", score) +# # for token in spacy_doc: +# # print("Token and Token Pos", token, "---------", token.pos_) \ No newline at end of file diff --git a/querent/kg/rel_helperfunctions/attn_based_relationship_model_getter.py b/querent/kg/rel_helperfunctions/attn_based_relationship_model_getter.py index c1184bc3..75afcc03 100644 --- a/querent/kg/rel_helperfunctions/attn_based_relationship_model_getter.py +++ b/querent/kg/rel_helperfunctions/attn_based_relationship_model_getter.py @@ -68,7 +68,6 @@ def tokenize_sentence(self, sentence: str): class BertBasedModel(AttnRelationshipExtractor): def __init__(self, model_tokenizer, model): - print("initiualzing--------") super().__init__(model_tokenizer, model) def init_token_idx_2_word_doc_idx(self) -> list[tuple[str, int]]: @@ -96,14 +95,12 @@ def inference_attention(self, model_input: dict[str, torch.Tensor]) -> torch.Ten output = self.model(**model_input, output_attentions=True) last_att_layer = output.attentions[-1] mean = torch.mean(last_att_layer, dim=1) - print("Mean ", mean[0]) return mean[0] def maximum_tokens(self) -> int: return 512 def tokenize_sentence(self, sentence: str): - print("hereeeeee") return self.tokenizer.encode(sentence, add_special_tokens=False) @@ -138,6 +135,9 @@ def inference_attention(self, model_input: dict[str, torch.Tensor]): def maximum_tokens(self) -> int: return 2048 + + def tokenize_sentence(self, sentence: str): + return self.tokenizer.encode(sentence, add_special_tokens=False) def get_model(model_name:str, model_tokenizer: str, model: str) -> AttnRelationshipExtractor: @@ -148,24 +148,3 @@ def get_model(model_name:str, model_tokenizer: str, model: str) -> AttnRelations raise Exception("Model not found") -# Usage example -if __name__ == '__main__': - sentence = """ABSTRACT: Cyclic gas injection methods have been shown to improve oil recovery -in conventional reservoirs. Even though similar technologies have been used in -unconventional reservoirs with some success stories in shale resources, cyclic gas -injection enhanced oil recovery (EOR) is still a little-understood subject in boosting -oil recovery from unconventional reservoirs.""" - tokenizer = transformers.BertTokenizer.from_pretrained('botryan96/GeoBERT',from_tf=True ) - model = transformers.BertModel.from_pretrained('botryan96/GeoBERT', from_tf=True) - extractor = get_model("bert",model_tokenizer= tokenizer,model=model) - tokenized_sentence = extractor.tokenize_sentence(sentence) - print("tokenized--", tokenized_sentence) - model_input = extractor.model_input(tokenized_sentence) - attention = extractor.inference_attention(model_input) - print("Attention-------", attention) - # Pairs-------- [HtPair(head=NounChunk(text='introduction gas injection', doc_start_idx=2, doc_end_idx=4, token_start_idx=3, token_end_idx=5, - # wikidata_id=None), tail=NounChunk(text='a widely used technology', doc_start_idx=7, doc_end_idx=10, token_start_idx=8, token_end_idx=11, - # wikidata_id=None)), -# model_metadata = extractor.dump_metadata() -# model_name = extractor.extract_general_name(model_metadata) -# print(model_name) \ No newline at end of file diff --git a/querent/kg/rel_helperfunctions/attn_based_relationship_seach_scope.py b/querent/kg/rel_helperfunctions/attn_based_relationship_seach_scope.py index 730298a2..bd7836ca 100644 --- a/querent/kg/rel_helperfunctions/attn_based_relationship_seach_scope.py +++ b/querent/kg/rel_helperfunctions/attn_based_relationship_seach_scope.py @@ -4,7 +4,7 @@ import numpy class EntityPair: - def __init__(self, head_entity: Dict, tail_entity: Dict, context: Dict, head_positions, tail_positions): + def __init__(self, head_entity: Dict, tail_entity: Dict, context: str, head_positions, tail_positions): self.head_entity = head_entity self.tail_entity = tail_entity self.context = context @@ -53,7 +53,7 @@ def is_valid_token(token_id, pair: EntityPair, candidate_paths: List[SearchConte -def perform_search(attention_matrix: torch.Tensor, entity_pair: EntityPair, search_candidates: int, require_contiguous: bool, max_relation_length: int, num_initial_tokens: int) -> List[SearchContextualRelationship]: +def perform_search(entity_start_index, attention_matrix: torch.Tensor, entity_pair: EntityPair, search_candidates: int, require_contiguous: bool, max_relation_length: int, num_initial_tokens: int) -> List[SearchContextualRelationship]: """ Initialize the perform search function with the following parameters: :param attention_matrix :Mean attention score, average attention each token pays to every other token showing which tokens are most related to each other in the context of the given sentence(s). @@ -63,44 +63,43 @@ def perform_search(attention_matrix: torch.Tensor, entity_pair: EntityPair, sear :patam num_initial_tokens: Different for different models. E.g. 'Bert' adds a '[CLS]' to the start of a sequence, so it is 1. """ - - queue = [ - SearchContextualRelationship(entity_pair.head_entity['start_idx']) - ] - print("Length of Queue", len(queue)) - candidate_paths = [] - visited_paths = set() - while len(queue) > 0: - current_path = queue.pop(0) - - if len(current_path.relation_tokens) > max_relation_length: - continue - - if require_contiguous and len(current_path.relation_tokens) > 1 and abs(current_path.relation_tokens[-2] - current_path.relation_tokens[-1]) != 1: - continue - - new_paths = [] - - # How all other tokens attend to an entity e.g. "Emily Stanton" - # These scores indicate how much importance the model places on each token when considering "Emily Stanton." - # The tokens which consider entity "Emily Stanton" important, highlight entity's relationships and relevance within the sentence. - attention_scores = attention_matrix[:, current_path.current_token] + try: + queue = [ + SearchContextualRelationship(entity_start_index) + ] + candidate_paths = [] + visited_paths = set() + while len(queue) > 0: + current_path = queue.pop(0) + + if len(current_path.relation_tokens) > max_relation_length: + continue + + if require_contiguous and len(current_path.relation_tokens) > 1 and abs(current_path.relation_tokens[-2] - current_path.relation_tokens[-1]) != 1: + continue + + new_paths = [] + + # How all other tokens attend to an entity e.g. "Emily Stanton" + # These scores indicate how much importance the model places on each token when considering "Emily Stanton." + # The tokens which consider entity "Emily Stanton" important, highlight entity's relationships and relevance within the sentence. + + attention_scores = attention_matrix[:, current_path.current_token] + for i in range(num_initial_tokens, len(attention_scores) - 1): + next_path = tuple(current_path.visited_tokens + [i]) + if is_valid_token(i, entity_pair, candidate_paths, current_path, attention_scores[i].detach()) and next_path not in visited_paths and current_path.current_token != i: + new_paths.append( + copy.deepcopy(current_path) + ) + new_paths[-1].add_token(i, attention_scores[i].detach()) + # print("New Paths ------ visited token & relation tokens-",new_paths[-1].visited_tokens, new_paths[-1].relation_tokens ) + # print("New Paths ------ Scoressss-",new_paths[-1].total_score) + visited_paths.add(next_path) + new_paths.sort(key=sort_by_mean_score, reverse=True) + queue += new_paths[:search_candidates] + + return candidate_paths + except Exception as e: + print("Exceptions while performing search ------",e) - print("Length of Attention Matrix....", len(attention_matrix), numpy.shape(attention_matrix)) - print("Length of Attention Scores....", len(attention_scores), numpy.shape(attention_scores)) - for i in range(num_initial_tokens, len(attention_scores) - 1): - next_path = tuple(current_path.visited_tokens + [i]) - if is_valid_token(i, entity_pair, candidate_paths, current_path, attention_scores[i].detach()) and next_path not in visited_paths and current_path.current_token != i: - new_paths.append( - copy.deepcopy(current_path) - ) - new_paths[-1].add_token(i, attention_scores[i].detach()) - print("New Paths ------ visited token & relation tokens-",new_paths[-1].visited_tokens, new_paths[-1].relation_tokens ) - print("New Paths ------ Scoressss-",new_paths[-1].total_score) - visited_paths.add(next_path) - new_paths.sort(key=sort_by_mean_score, reverse=True) - queue += new_paths[:search_candidates] - - return candidate_paths - diff --git a/querent/kg/rel_helperfunctions/contextual_predicate.py b/querent/kg/rel_helperfunctions/contextual_predicate.py index 908196d4..018e5146 100644 --- a/querent/kg/rel_helperfunctions/contextual_predicate.py +++ b/querent/kg/rel_helperfunctions/contextual_predicate.py @@ -42,6 +42,7 @@ class ContextualPredicate(BaseModel): pair_attnscore: float entity1_embedding: List[float] entity2_embedding: List[float] + current_sentence: str @classmethod @@ -63,7 +64,8 @@ def from_tuple(cls, data: Tuple[str, str, str, Dict[str, str], str]) -> 'Context pair_attnscore=data[3].get('pair_attnscore',1), entity1_embedding=entity1_embedding, entity2_embedding=entity2_embedding, - file_path=data[4] + file_path=data[4], + current_sentence = data[3].get('current_sentence'), ) except Exception as e: raise ValueError(f"Error creating ContextualPredicate from tuple: {e}") diff --git a/querent/kg/rel_helperfunctions/embedding_store.py b/querent/kg/rel_helperfunctions/embedding_store.py index 1faa4ef5..0c0dcb88 100644 --- a/querent/kg/rel_helperfunctions/embedding_store.py +++ b/querent/kg/rel_helperfunctions/embedding_store.py @@ -40,6 +40,7 @@ def generate_embeddings(self, payload, relationship_finder=False, generate_embed predicate_type = data.get("predicate_type","Unlabeled").replace('"', '\\"') subject_type = data.get("subject_type","Unlabeled").replace('"', '\\"') object_type = data.get("object_type","Unlabeled").replace('"', '\\"') + score = data.get("score") context_embeddings = None predicate_embedding = None context_embeddings = self.get_embeddings([context])[0] @@ -54,7 +55,8 @@ def generate_embeddings(self, payload, relationship_finder=False, generate_embed "predicate": predicate, "subject_type": subject_type, "object_type": object_type, - "predicate_emb": predicate_embedding if predicate_embedding is not None else "Not Implemented" + "predicate_emb": predicate_embedding if predicate_embedding is not None else "Not Implemented", + "score":score } updated_json_string = json.dumps(essential_data) processed_pairs.append( diff --git a/querent/kg/rel_helperfunctions/triple_to_json.py b/querent/kg/rel_helperfunctions/triple_to_json.py index 1e999f4a..434580be 100644 --- a/querent/kg/rel_helperfunctions/triple_to_json.py +++ b/querent/kg/rel_helperfunctions/triple_to_json.py @@ -44,7 +44,8 @@ def convert_graphjson(triple): "object_type": TripleToJsonConverter._normalize_text(predicate_info.get("object_type", "Unlabeled"), replace_space=True), "predicate": TripleToJsonConverter._normalize_text(predicate_info.get("predicate", ""), replace_space=True), "predicate_type": TripleToJsonConverter._normalize_text(predicate_info.get("predicate_type", "Unlabeled"), replace_space=True), - "sentence": predicate_info.get("context", "").lower() + "sentence": predicate_info.get("context", "").lower(), + "score": predicate_info.get("score", 1) } return json_object