Skip to content

Commit

Permalink
resolved conflicts
Browse files Browse the repository at this point in the history
  • Loading branch information
ngupta10 committed Dec 31, 2023
2 parents e9fd403 + 0ae631a commit 26fb020
Show file tree
Hide file tree
Showing 13 changed files with 231 additions and 10 deletions.
7 changes: 2 additions & 5 deletions querent/common/types/ingested_images.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ def __init__(
image: str,
image_name: str,
page_num: int,
text: str,
text: [str],
coordinates: list = [],
ocr_text: list = [],
error: str = None,
Expand All @@ -28,14 +28,11 @@ def __init__(
def __str__(self):
if self.error:
return f"Error: {self.error}"
return f"Data: {self.file_path}"
return f"Data: {self.ocr_text}"

def is_error(self) -> bool:
return self.error is not None

def get_file_path(self) -> str:
return self.file_path

def get_extension(self) -> str:
return self.extension

Expand Down
1 change: 1 addition & 0 deletions querent/common/types/querent_event.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,4 @@ def __init__(self, event_type: EventType, timestamp: float, payload: Any, file:
self.timestamp = timestamp
self.payload = payload
self.file = file
self.file = file
1 change: 1 addition & 0 deletions querent/config/core/bert_llm_config.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from pydantic import BaseModel, Field
from typing import List, Dict, Any
from typing import List, Dict, Any

class BERTLLMConfig(BaseModel):
name: str = "BERTLLMEngine"
Expand Down
3 changes: 3 additions & 0 deletions querent/core/base_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from querent.config.engine.engine_config import EngineConfig
from querent.logging.logger import setup_logger
from querent.common.types.ingested_images import IngestedImages
from querent.common.types.ingested_images import IngestedImages

"""
BaseEngine is an abstract base class that provides the foundational structure and methods
Expand Down Expand Up @@ -214,6 +215,8 @@ async def _inner_worker():
await self.process_tokens(data)
elif isinstance(data, IngestedImages):
await self.process_images(data)
elif isinstance(data, IngestedImages):
await self.process_images(data)
elif isinstance(data, IngestedCode):
await self.process_code(data)
elif isinstance(data, IngestedImages):
Expand Down
11 changes: 11 additions & 0 deletions querent/core/transformers/bert_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from querent.kg.contextual_predicate import process_data
from querent.kg.ner_helperfunctions.contextual_embeddings import EntityEmbeddingExtractor
from querent.kg.ner_helperfunctions.fixed_entities import FixedEntityExtractor
from querent.kg.ner_helperfunctions.fixed_entities import FixedEntityExtractor
from querent.kg.ner_helperfunctions.ner_llm_transformer import NER_LLM
from querent.common.types.querent_event import EventState, EventType
from querent.core.base_engine import BaseEngine
Expand All @@ -32,9 +33,11 @@ def __init__(
input_queue:QuerentQueue,
config: BERTLLMConfig
):
self.logger = setup_logger(__name__, "BERTLLM")
self.logger = setup_logger(__name__, "BERTLLM")
super().__init__(input_queue)
self.graph_config = GraphConfig(identifier=config.name)
self.graph_config = GraphConfig(identifier=config.name)
self.contextual_graph = QuerentKG(self.graph_config)
self.semantic_graph = QuerentKG(self.graph_config)
self.file_buffer = FileBuffer()
Expand Down Expand Up @@ -78,6 +81,9 @@ def validate(self) -> bool:
def process_messages(self, data: IngestedMessages):
return super().process_messages(data)

def process_images(self, data: IngestedImages):
return super().process_messages(data)

def process_images(self, data: IngestedImages):
return super().process_messages(data)

Expand Down Expand Up @@ -115,6 +121,7 @@ async def process_tokens(self, data: IngestedTokens):

return

file, content = self.file_buffer.add_chunk(
file, content = self.file_buffer.add_chunk(
data.get_file_path(), data.data
)
Expand Down Expand Up @@ -162,8 +169,10 @@ async def process_tokens(self, data: IngestedTokens):
else:
filtered_triples, _ = self.triple_filter.filter_triples(clustered_triples)
self.logger.log(f"Filtering in {self.__class__.__name__} producing 0 entity pairs. Filtering Disabled. ")
self.logger.log(f"Filtering in {self.__class__.__name__} producing 0 entity pairs. Filtering Disabled. ")
else:
filtered_triples = pairs_with_predicates
filtered_triples = pairs_with_predicates
mock_config = RelationshipExtractorConfig()
semantic_extractor = RelationExtractor(mock_config)
relationships = semantic_extractor.process_tokens(filtered_triples)
Expand All @@ -184,3 +193,5 @@ async def process_tokens(self, data: IngestedTokens):
except Exception as e:
self.logger.error(f"Invalid {self.__class__.__name__} configuration. Unable to process tokens. {e}")
raise Exception(f"An unexpected error occurred while processing tokens: {e}")
self.logger.error(f"Invalid {self.__class__.__name__} configuration. Unable to process tokens. {e}")
raise Exception(f"An unexpected error occurred while processing tokens: {e}")
118 changes: 118 additions & 0 deletions querent/core/transformers/relationship_extraction_llm.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,29 @@
import json
from typing import Any, List, Tuple
from querent.kg.rel_helperfunctions.BSM_relationfilter import BSMBranch
from typing import Any, List, Tuple
from querent.kg.rel_helperfunctions.BSM_relationfilter import BSMBranch
from querent.kg.rel_helperfunctions.embedding_store import EmbeddingStore
from querent.kg.rel_helperfunctions.questionanswer_llama2 import QASystem
from querent.kg.rel_helperfunctions.rag_retriever import RAGRetriever
from querent.kg.rel_helperfunctions.rag_retriever import RAGRetriever
from querent.kg.rel_helperfunctions.rel_normalize import TextNormalizer
from querent.logging.logger import setup_logger
from querent.config.core.relation_config import RelationshipExtractorConfig
from langchain.docstore.document import Document
import ast

"""
A class for extracting relationships from triples and processing them for various representations.
from langchain.docstore.document import Document
import ast
"""
A class for extracting relationships from triples and processing them for various representations.

This class includes methods for validating triples, generating embeddings, processing tokens,
normalizing and building indices for triples, creating semantic triples, extracting relationships,
and trimming triples.
This class includes methods for validating triples, generating embeddings, processing tokens,
normalizing and building indices for triples, creating semantic triples, extracting relationships,
and trimming triples.
Expand All @@ -26,6 +37,13 @@
rag_retriever (RAGRetriever): A RAG retriever for document retrieval, used if rag_approach is True.
bsmbranch (BSMBranch): An instance of BSMBranch for handling BSM-related tasks.
sub_tasks (list): List of dynamic sub-tasks for processing.
logger (Logger): Logger for logging messages and errors.
create_emb (EmbeddingStore): An instance of EmbeddingStore for generating embeddings.
qa_system (QASystem): A question-answering system for extracting relationships.
rag_approach (bool): A flag indicating whether to use the RAG approach for retrieval.
rag_retriever (RAGRetriever): A RAG retriever for document retrieval, used if rag_approach is True.
bsmbranch (BSMBranch): An instance of BSMBranch for handling BSM-related tasks.
sub_tasks (list): List of dynamic sub-tasks for processing.

Methods:
validate(data) -> bool:
Expand All @@ -42,8 +60,24 @@
Extracts relationships from the given triples.
trim_triples(data):
Trims the given data to a more concise format.
validate(data) -> bool:
Validates the input data to ensure it's in the correct format for processing.
generate_embeddings(payload):
Generates embeddings for the given payload containing triples.
process_tokens(payload):
Processes tokens in the given payload and extracts relationships.
normalizetriples_buildindex(triples):
Normalizes the given triples and builds an index for them.
create_semantic_triple(input1, input2):
Creates a semantic triple from the given inputs.
extract_relationships(triples):
Extracts relationships from the given triples.
trim_triples(data):
Trims the given data to a more concise format.
"""

class RelationExtractor():
def __init__(self, config: RelationshipExtractorConfig):
class RelationExtractor():
def __init__(self, config: RelationshipExtractorConfig):
self.logger = setup_logger(config.logger, "RelationshipExtractor")
Expand All @@ -56,6 +90,20 @@ def __init__(self, config: RelationshipExtractorConfig):
rel_model_type=config.model_type,
)

# self.qa_system_bsm_validator = QASystem(
# rel_model_path=config.bsm_validator_model_path,
# rel_model_type=config.bsm_validator_model_type,
# emb_model_name=config.emb_model_name,
# faiss_index_path=config.get_faiss_index_path()
# )
self.rag_approach = config.rag_approach
if self.rag_approach == True:
self.rag_retriever = RAGRetriever(
faiss_index_path=config.get_faiss_index_path(),
rel_model_path=config.model_path,
rel_model_type=config.model_type,
)

# self.qa_system_bsm_validator = QASystem(
# rel_model_path=config.bsm_validator_model_path,
# rel_model_type=config.bsm_validator_model_type,
Expand All @@ -70,11 +118,16 @@ def __init__(self, config: RelationshipExtractorConfig):
embedding_store=self.create_emb,
logger=self.logger)
self.bsmbranch = BSMBranch()
self.sub_tasks = config.dynamic_sub_tasks
embedding_store=self.create_emb,
logger=self.logger)
self.bsmbranch = BSMBranch()
self.sub_tasks = config.dynamic_sub_tasks
except Exception as e:
self.logger.error(f"Initialization failed: {e}")
raise Exception(f"Initialization failed: {e}")


def validate(self, data) -> bool:
try:
if not data:
Expand Down Expand Up @@ -141,6 +194,39 @@ def generate_embeddings(self, payload):
self.logger.error(f"Error in extracting embeddings: {e}")
raise Exception(f"Error in extracting embeddings: {e}")

def process_tokens(self, payload):
try:
triples = payload
def generate_embeddings(self, payload):
try:
triples = payload
processed_pairs = []

for entity, json_string, related_entity in triples:
data = json.loads(json_string)
context = data.get("context", "")
predicate = data.get("predicate","")
predicate_type = data.get("predicate_type","")
subject_type = data.get("subject_type","")
object_type = data.get("object_type","")
context_embeddings = self.create_emb.get_embeddings([context])[0]
essential_data = {
"context": context,
"context_embeddings" : context_embeddings,
"predicate_type": predicate_type,
"predicate" : predicate,
"subject_type": subject_type,
"object_type": object_type
}
updated_json_string = json.dumps(essential_data)
processed_pairs.append((entity, updated_json_string, related_entity))

return processed_pairs

except Exception as e:
self.logger.error(f"Error in extracting embeddings: {e}")
raise Exception(f"Error in extracting embeddings: {e}")

def process_tokens(self, payload):
try:
triples = payload
Expand All @@ -151,6 +237,12 @@ def process_tokens(self, payload):

return relationships

if self.rag_approach == True:
self.rag_retriever.build_faiss_index(trimmed_triples)
relationships = self.extract_relationships(triples)

return relationships

except Exception as e:
self.logger.error(f"Error in processing event: {e}")
raise Exception(f"Invalid in processing event: {e}")
Expand All @@ -164,12 +256,31 @@ def normalizetriples_buildindex(self, triples):
normalized_triples = normalizer.normalize_triples(triples)
trimmed_triples = self.trim_triples(normalized_triples)


return trimmed_triples


except Exception as e:
self.logger.error(f"Error in normalizing/building index: {e}")
raise Exception(f"Error in normalizing/building index: {e}")

def create_semantic_triple(self, input1, input2):
input1_data = ast.literal_eval(input1)
input2_data = ast.literal_eval(input2)
triple = (
input1_data.get("subject",""),
json.dumps({
"predicate": input1_data.get("predicate",""),
"predicate_type": input1_data.get("predicate_type",""),
"context": input2_data.get("context", ""),
"file_path": input2_data.get("file_path", ""),
"subject_type": input1_data.get("subject_type",""),
"object_type": input1_data.get("object_type","")
}),
input1_data.get("object","")
)
return triple

def create_semantic_triple(self, input1, input2):
input1_data = ast.literal_eval(input1)
input2_data = ast.literal_eval(input2)
Expand Down Expand Up @@ -229,12 +340,19 @@ def trim_triples(self, data):
'entity1_nn_chunk': predicate_dict.get('entity1_nn_chunk', ''),
'entity2_nn_chunk': predicate_dict.get('entity2_nn_chunk', ''),
'file_path': predicate_dict.get('file_path', '')
'context': predicate_dict.get('context', ''),
'entity1_nn_chunk': predicate_dict.get('entity1_nn_chunk', ''),
'entity2_nn_chunk': predicate_dict.get('entity2_nn_chunk', ''),
'file_path': predicate_dict.get('file_path', '')
}
trimmed_data.append((entity1, trimmed_predicate, entity2))

return trimmed_data


except Exception as e:
self.logger.error(f"Error in trimming triples: {e}")
raise Exception(f'Error in trimming triples: {e}')

raise Exception(f'Error in trimming triples: {e}')

2 changes: 2 additions & 0 deletions querent/ingestors/pdfs/pdf_ingestor_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ async def extract_and_process_pdf(
text = page.extract_text()
if not text:
continue

processed_text = await self.process_data(text)

# Yield processed text as IngestedTokens
Expand Down Expand Up @@ -148,6 +149,7 @@ async def extract_images_and_ocr(self, page, page_num, text, data, file_path):
text=text,
coordinates=None,
ocr_text=None,
error=f"Exception:{e}",
)

async def get_ocr_from_image(self, image):
Expand Down
Loading

0 comments on commit 26fb020

Please sign in to comment.