Skip to content

Commit

Permalink
fixed error
Browse files Browse the repository at this point in the history
  • Loading branch information
ngupta10 committed Dec 31, 2023
1 parent 6129312 commit e754931
Showing 1 changed file with 0 additions and 118 deletions.
118 changes: 0 additions & 118 deletions querent/core/transformers/relationship_extraction_llm.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,18 @@
import json
from typing import Any, List, Tuple
from querent.kg.rel_helperfunctions.BSM_relationfilter import BSMBranch
from typing import Any, List, Tuple
from querent.kg.rel_helperfunctions.BSM_relationfilter import BSMBranch
from querent.kg.rel_helperfunctions.embedding_store import EmbeddingStore
from querent.kg.rel_helperfunctions.questionanswer_llama2 import QASystem
from querent.kg.rel_helperfunctions.rag_retriever import RAGRetriever
from querent.kg.rel_helperfunctions.rag_retriever import RAGRetriever
from querent.kg.rel_helperfunctions.rel_normalize import TextNormalizer
from querent.logging.logger import setup_logger
from querent.config.core.relation_config import RelationshipExtractorConfig
from langchain.docstore.document import Document
import ast

"""
A class for extracting relationships from triples and processing them for various representations.
from langchain.docstore.document import Document
import ast
"""
A class for extracting relationships from triples and processing them for various representations.
This class includes methods for validating triples, generating embeddings, processing tokens,
normalizing and building indices for triples, creating semantic triples, extracting relationships,
and trimming triples.
This class includes methods for validating triples, generating embeddings, processing tokens,
normalizing and building indices for triples, creating semantic triples, extracting relationships,
and trimming triples.
Expand All @@ -37,13 +26,6 @@
rag_retriever (RAGRetriever): A RAG retriever for document retrieval, used if rag_approach is True.
bsmbranch (BSMBranch): An instance of BSMBranch for handling BSM-related tasks.
sub_tasks (list): List of dynamic sub-tasks for processing.
logger (Logger): Logger for logging messages and errors.
create_emb (EmbeddingStore): An instance of EmbeddingStore for generating embeddings.
qa_system (QASystem): A question-answering system for extracting relationships.
rag_approach (bool): A flag indicating whether to use the RAG approach for retrieval.
rag_retriever (RAGRetriever): A RAG retriever for document retrieval, used if rag_approach is True.
bsmbranch (BSMBranch): An instance of BSMBranch for handling BSM-related tasks.
sub_tasks (list): List of dynamic sub-tasks for processing.
Methods:
validate(data) -> bool:
Expand All @@ -60,24 +42,8 @@
Extracts relationships from the given triples.
trim_triples(data):
Trims the given data to a more concise format.
validate(data) -> bool:
Validates the input data to ensure it's in the correct format for processing.
generate_embeddings(payload):
Generates embeddings for the given payload containing triples.
process_tokens(payload):
Processes tokens in the given payload and extracts relationships.
normalizetriples_buildindex(triples):
Normalizes the given triples and builds an index for them.
create_semantic_triple(input1, input2):
Creates a semantic triple from the given inputs.
extract_relationships(triples):
Extracts relationships from the given triples.
trim_triples(data):
Trims the given data to a more concise format.
"""

class RelationExtractor():
def __init__(self, config: RelationshipExtractorConfig):
class RelationExtractor():
def __init__(self, config: RelationshipExtractorConfig):
self.logger = setup_logger(config.logger, "RelationshipExtractor")
Expand All @@ -90,20 +56,6 @@ def __init__(self, config: RelationshipExtractorConfig):
rel_model_type=config.model_type,
)

# self.qa_system_bsm_validator = QASystem(
# rel_model_path=config.bsm_validator_model_path,
# rel_model_type=config.bsm_validator_model_type,
# emb_model_name=config.emb_model_name,
# faiss_index_path=config.get_faiss_index_path()
# )
self.rag_approach = config.rag_approach
if self.rag_approach == True:
self.rag_retriever = RAGRetriever(
faiss_index_path=config.get_faiss_index_path(),
rel_model_path=config.model_path,
rel_model_type=config.model_type,
)

# self.qa_system_bsm_validator = QASystem(
# rel_model_path=config.bsm_validator_model_path,
# rel_model_type=config.bsm_validator_model_type,
Expand All @@ -118,16 +70,11 @@ def __init__(self, config: RelationshipExtractorConfig):
embedding_store=self.create_emb,
logger=self.logger)
self.bsmbranch = BSMBranch()
self.sub_tasks = config.dynamic_sub_tasks
embedding_store=self.create_emb,
logger=self.logger)
self.bsmbranch = BSMBranch()
self.sub_tasks = config.dynamic_sub_tasks
except Exception as e:
self.logger.error(f"Initialization failed: {e}")
raise Exception(f"Initialization failed: {e}")


def validate(self, data) -> bool:
try:
if not data:
Expand Down Expand Up @@ -194,39 +141,6 @@ def generate_embeddings(self, payload):
self.logger.error(f"Error in extracting embeddings: {e}")
raise Exception(f"Error in extracting embeddings: {e}")

def process_tokens(self, payload):
try:
triples = payload
def generate_embeddings(self, payload):
try:
triples = payload
processed_pairs = []

for entity, json_string, related_entity in triples:
data = json.loads(json_string)
context = data.get("context", "")
predicate = data.get("predicate","")
predicate_type = data.get("predicate_type","")
subject_type = data.get("subject_type","")
object_type = data.get("object_type","")
context_embeddings = self.create_emb.get_embeddings([context])[0]
essential_data = {
"context": context,
"context_embeddings" : context_embeddings,
"predicate_type": predicate_type,
"predicate" : predicate,
"subject_type": subject_type,
"object_type": object_type
}
updated_json_string = json.dumps(essential_data)
processed_pairs.append((entity, updated_json_string, related_entity))

return processed_pairs

except Exception as e:
self.logger.error(f"Error in extracting embeddings: {e}")
raise Exception(f"Error in extracting embeddings: {e}")

def process_tokens(self, payload):
try:
triples = payload
Expand All @@ -237,12 +151,6 @@ def process_tokens(self, payload):

return relationships

if self.rag_approach == True:
self.rag_retriever.build_faiss_index(trimmed_triples)
relationships = self.extract_relationships(triples)

return relationships

except Exception as e:
self.logger.error(f"Error in processing event: {e}")
raise Exception(f"Invalid in processing event: {e}")
Expand All @@ -256,31 +164,12 @@ def normalizetriples_buildindex(self, triples):
normalized_triples = normalizer.normalize_triples(triples)
trimmed_triples = self.trim_triples(normalized_triples)


return trimmed_triples


except Exception as e:
self.logger.error(f"Error in normalizing/building index: {e}")
raise Exception(f"Error in normalizing/building index: {e}")

def create_semantic_triple(self, input1, input2):
input1_data = ast.literal_eval(input1)
input2_data = ast.literal_eval(input2)
triple = (
input1_data.get("subject",""),
json.dumps({
"predicate": input1_data.get("predicate",""),
"predicate_type": input1_data.get("predicate_type",""),
"context": input2_data.get("context", ""),
"file_path": input2_data.get("file_path", ""),
"subject_type": input1_data.get("subject_type",""),
"object_type": input1_data.get("object_type","")
}),
input1_data.get("object","")
)
return triple

def create_semantic_triple(self, input1, input2):
input1_data = ast.literal_eval(input1)
input2_data = ast.literal_eval(input2)
Expand Down Expand Up @@ -340,19 +229,12 @@ def trim_triples(self, data):
'entity1_nn_chunk': predicate_dict.get('entity1_nn_chunk', ''),
'entity2_nn_chunk': predicate_dict.get('entity2_nn_chunk', ''),
'file_path': predicate_dict.get('file_path', '')
'context': predicate_dict.get('context', ''),
'entity1_nn_chunk': predicate_dict.get('entity1_nn_chunk', ''),
'entity2_nn_chunk': predicate_dict.get('entity2_nn_chunk', ''),
'file_path': predicate_dict.get('file_path', '')
}
trimmed_data.append((entity1, trimmed_predicate, entity2))

return trimmed_data


except Exception as e:
self.logger.error(f"Error in trimming triples: {e}")
raise Exception(f'Error in trimming triples: {e}')

raise Exception(f'Error in trimming triples: {e}')

0 comments on commit e754931

Please sign in to comment.