Skip to content

Commit

Permalink
fixed and sample types
Browse files Browse the repository at this point in the history
  • Loading branch information
ngupta10 committed Dec 31, 2023
1 parent fc6bc00 commit e9fd403
Show file tree
Hide file tree
Showing 13 changed files with 332 additions and 146 deletions.
2 changes: 1 addition & 1 deletion querent/callback/event_callback_dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ async def dispatch_event(self, event_type: EventType, event_data: EventState):
event_data (Any): Data associated with the event.
"""
for callback in self.callbacks[event_type]:
callback.handle_event(event_type, event_data)
await callback.handle_event(event_type, event_data)

def register_webhook(self, event_type: EventType, webhook: str):
"""
Expand Down
1 change: 0 additions & 1 deletion querent/collectors/drive/google_drive_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,6 @@ async def walk_files(self, root: Path) -> AsyncGenerator[Path, None]:
item_split = set(str(item).split("/"))
item_split.remove("")
if item_split.intersection(self.items_to_ignore):
print(item_split, "\n\n", self.items_to_ignore)
continue
if item.is_file():
yield item
Expand Down
24 changes: 1 addition & 23 deletions querent/common/types/querent_event.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,12 @@
from typing import Any, Literal
from typing import Any


class EventType:
"""
Custom type for representing event types in the querent system.
Attributes:
TOKEN_PROCESSED (Literal["token_processed"]): Event type for token processing completion.
CHAT_COMPLETED (Literal["chat_completed"]): Event type for chat completion.
"""

ContextualTriples = "ContextualTriples"
RdfContextualTriples = "RdfContextualTriples"
RdfSemanticTriples = "RdfSemanticTriples"
ContextualEmbeddings = "ContextualEmbeddings"
Graph = "Graph"
Vector = "Vector"


class EventState:
"""
Custom type for base class implementors to tie into the event system.
EventState has a event_type, a timestamp, and a payload.
Attributes:
event_type (EventType): The type of event.
timestamp (float): The timestamp of the event.
payload (Any): The payload of the event.
"""

def __init__(self, event_type: EventType, timestamp: float, payload: Any, file: str):
self.event_type = event_type
self.timestamp = timestamp
Expand Down
4 changes: 2 additions & 2 deletions querent/config/core/bert_llm_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ class BERTLLMConfig(BaseModel):

sample_entities: List[str] = Field(default_factory=list, description="List of sample entities")
fixed_entities: List[str] = Field(default_factory=list, description="List of fixed entities")
fixed_relationships: List[Dict[str, Any]] = Field(default_factory=list, description="List of fixed relationships represented as dictionaries")
sample_relationships: List[Dict[str, Any]] = Field(default_factory=list, description="List of sample relationships represented as dictionaries")
fixed_relationships: List[str] = Field(default_factory=list, description="List of fixed relationships")
sample_relationships: List[str] = Field(default_factory=list, description="List of sample relationships")
user_context: Dict[str, Any] = Field(default_factory=dict, description="User-specific context information")

12 changes: 0 additions & 12 deletions querent/core/base_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,18 +127,6 @@ async def process_images(self, data: IngestedImages):
"""
raise NotImplementedError

@abstractmethod
async def process_images(self, data: IngestedImages):
"""
Process images asynchronously.
Args:
data (IngestedImages): The input data to process.
Returns:
EventState: The state of the event is set with the event type and the timestamp
of the event and set using `self.set_state(event_state)`.
"""
raise NotImplementedError

@abstractmethod
def validate(self) -> bool:
"""
Expand Down
95 changes: 39 additions & 56 deletions querent/core/transformers/bert_llm.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from transformers import AutoTokenizer
from querent.kg.ner_helperfunctions.fixed_predicate import FixedPredicateExtractor
from querent.common.types.ingested_images import IngestedImages
from querent.config.core.relation_config import RelationshipExtractorConfig
from querent.core.transformers.relationship_extraction_llm import RelationExtractor
Expand All @@ -25,44 +26,6 @@
from querent.config.core.bert_llm_config import BERTLLMConfig
from querent.kg.rel_helperfunctions.triple_to_json import TripleToJsonConverter




"""
BERT-based Named Entity Recognition (NER) and Linking Language Model (LLM) for extracting entities and relationships from text.
Inherits from:
BaseEngine: Base class for processing engines.
Attributes:
graph_config (GraphConfig): Configuration for the graph.
logger (Logger): Logger instance for logging errors and information.
file_buffer (FileBuffer): Buffer for storing files.
ner_tokenizer (AutoTokenizer): Tokenizer for the NER model.
ner_model (Model): Pre-trained NER model.
ner_llm_instance (NER_LLM): Instance of the NER_LLM class.
attn_scores_instance (EntityAttentionExtractor): Instance for extracting attention scores.
entity_embedding_extractor (EntityEmbeddingExtractor, optional): Instance for extracting entity embeddings.
triple_filter_instance (EntityTripleFilter
Methods:
validate() -> bool:
Validates if the NER model and tokenizer are initialized.
process_messages(data: IngestedMessages):
Processes the ingested messages.
process_code(data: IngestedCode):
Processes the ingested code.
validate_ingested_tokens(data: IngestedTokens) -> bool:
Validates the ingested tokens.
process_tokens(data: IngestedTokens):
Processes the ingested tokens, extracts entities, and builds the knowledge graph.
"""


class BERTLLM(BaseEngine):
def __init__(
self,
Expand All @@ -88,13 +51,25 @@ def __init__(
self.triple_filter = TripleFilter(**self.filter_params)
self.sample_entities = config.sample_entities
self.fixed_entities = config.fixed_entities
if self.fixed_entities and not self.sample_entities:
raise ValueError("If specific entities are provided, their types should also be provided.")
if self.fixed_entities and self.sample_entities:
self.entity_context_extractor = FixedEntityExtractor(fixed_entities=self.fixed_entities, entity_types=self.sample_entities)
elif self.sample_entities:
self.entity_context_extractor = FixedEntityExtractor(entity_types=self.sample_entities)
else:
self.entity_context_extractor = None
self.fixed_relationships = config.fixed_relationships
self.sample_relationships = config.sample_relationships
self.user_context = config.user_context
if config.fixed_entities:
self.entity_context_extractor = FixedEntityExtractor(config.fixed_entities)
if self.fixed_relationships and not self.sample_relationships:
raise ValueError("If specific predicates are provided, their types should also be provided.")
if self.fixed_relationships and self.sample_relationships:
self.predicate_context_extractor = FixedPredicateExtractor(fixed_predicates=self.fixed_relationships, predicate_types=self.sample_relationships)
elif self.sample_relationships:
self.predicate_context_extractor = FixedPredicateExtractor(predicate_types=self.sample_relationships)
else:
self.entity_context_extractor = None
self.predicate_context_extractor = None
self.user_context = config.user_context


def validate(self) -> bool:
Expand Down Expand Up @@ -143,36 +118,40 @@ async def process_tokens(self, data: IngestedTokens):
file, content = self.file_buffer.add_chunk(
data.get_file_path(), data.data
)
print("--------------------------------", content)
if content:
if self.entity_context_extractor:
if self.fixed_entities:
content = self.entity_context_extractor.find_entity_sentences(content)
print("--------------------------------", content)
if self.fixed_relationships:
content = self.predicate_context_extractor.find_predicate_sentences(content)
print("--------------------------------", content)
tokens = self.ner_llm_instance._tokenize_and_chunk(content)
print("tokens: ", tokens)
for tokenized_sentence, original_sentence, sentence_idx in tokens:
(
entities,
entity_pairs,
) = self.ner_llm_instance.extract_entities_from_sentence(
original_sentence, sentence_idx, [s[1] for s in tokens]
)
doc_entity_pairs.append(
self.ner_llm_instance.transform_entity_pairs(entity_pairs)
)
(entities, entity_pairs,) = self.ner_llm_instance.extract_entities_from_sentence(original_sentence, sentence_idx, [s[1] for s in tokens],False, [''])
print("entity pairs", entity_pairs)
if entity_pairs:
doc_entity_pairs.append(self.ner_llm_instance.transform_entity_pairs(entity_pairs))
number_sentences = number_sentences + 1



else:
if not BERTLLM.validate_ingested_tokens(data):
self.set_termination_event()
print("doc entities-----------------------", doc_entity_pairs)
if self.sample_entities:
doc_entity_pairs = self.entity_context_extractor.process_entity_types(doc_entities=doc_entity_pairs)
print("doc entities---------------------", doc_entity_pairs)
if doc_entity_pairs:
pairs_withattn = self.attn_scores_instance.extract_and_append_attention_weights(doc_entity_pairs)
print("-----------",pairs_withattn)
if self.count_entity_pairs(pairs_withattn)>1:
self.entity_embedding_extractor = EntityEmbeddingExtractor(self.ner_model, self.ner_tokenizer, self.count_entity_pairs(pairs_withattn), number_sentences=number_sentences)
else :
self.entity_embedding_extractor = EntityEmbeddingExtractor(self.ner_model, self.ner_tokenizer, 2, number_sentences=number_sentences)
pairs_withemb = self.entity_embedding_extractor.extract_and_append_entity_embeddings(pairs_withattn)
print("-----------",pairs_withemb)
pairs_with_predicates = process_data(pairs_withemb, file)
if self.enable_filtering == True:
if self.enable_filtering == True and not self.entity_context_extractor:
cluster_output = self.triple_filter.cluster_triples(pairs_with_predicates)
clustered_triples = cluster_output['filtered_triples']
cluster_labels = cluster_output['cluster_labels']
Expand All @@ -189,8 +168,12 @@ async def process_tokens(self, data: IngestedTokens):
semantic_extractor = RelationExtractor(mock_config)
relationships = semantic_extractor.process_tokens(filtered_triples)
embedding_triples = semantic_extractor.generate_embeddings(relationships)
print("-------------------------------- embedding triples: {}".format(embedding_triples))
if self.sample_relationships:
embedding_triples = self.predicate_context_extractor.process_predicate_types(embedding_triples)
for triple in embedding_triples:
graph_json = TripleToJsonConverter.convert_graphjson(triple)
print("-------------------------------- Graph : {}".format(graph_json))
if graph_json:
current_state = EventState(EventType.Graph,1.0, graph_json, file)
await self.set_state(new_state=current_state)
Expand Down
1 change: 1 addition & 0 deletions querent/core/transformers/relationship_extraction_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,7 @@ def extract_relationships(self, triples):
sub_task_list_llm = self.bsmbranch.create_sub_tasks(llm = self.qa_system.llm, template=self.config.get_template("default"), tasks=all_tasks,model_type=self.qa_system.rel_model_type)
for task in sub_task_list_llm:
answer_relation = self.qa_system.ask_question(prompt=task[2], top_docs=documents, llm_chain=task[0])
print("answersssssss", answer_relation)
try:
updated_triple= self.create_semantic_triple(answer_relation, predicate_str)
updated_triples.append(updated_triple)
Expand Down
Loading

0 comments on commit e9fd403

Please sign in to comment.