Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Performance Improvement #216

Merged
merged 14 commits into from
Jan 10, 2024
1 change: 1 addition & 0 deletions querent/common/types/querent_event.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
class EventType:
Graph = "Graph"
Vector = "Vector"
Terminate="Terminate"


class EventState:
Expand Down
8 changes: 6 additions & 2 deletions querent/core/base_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,9 +170,11 @@ async def set_state(self, new_state: EventState):
"""

async def _listen_for_state_changes(self):
while not self.state_queue.empty() and not self.termination_event.is_set():
new_state = await self.state_queue.get_nowait()
while not self.state_queue.empty() or not self.termination_event.is_set():
new_state = await self.state_queue.get()
if isinstance(new_state, EventState):
if new_state.payload == "Terminate":
break
await self._notify_subscribers(new_state.event_type, new_state)
else:
raise Exception(
Expand Down Expand Up @@ -218,6 +220,8 @@ async def _inner_worker():
await self.process_code(data)
elif data is None:
self.termination_event.set()
current_state = EventState(EventType.Terminate,1.0, "Terminate", "temp.txt")
await self.set_state(new_state=current_state)
else:
raise Exception(
f"Invalid data type {type(data)} for {self.__class__.__name__}. Supported type: {IngestedTokens, IngestedMessages}"
Expand Down
23 changes: 15 additions & 8 deletions querent/core/transformers/bert_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
from querent.kg.ner_helperfunctions.filter_triples import TripleFilter
from querent.config.core.bert_llm_config import BERTLLMConfig
from querent.kg.rel_helperfunctions.triple_to_json import TripleToJsonConverter
import time
import psutil
"""
BERTLLM is a class derived from BaseEngine designed for processing language models, particularly focusing on named entity recognition and relationship extraction in text. It integrates various components for handling different types of input data (messages, images, code, tokens), extracting entities, filtering relevant information, and constructing knowledge graphs.

Expand All @@ -40,6 +42,7 @@

The class also incorporates mechanisms for filtering and clustering entities and relationships, as well as extracting embeddings and generating output in different formats.
"""

class BERTLLM(BaseEngine):
def __init__(
self,
Expand All @@ -48,6 +51,8 @@ def __init__(
):
self.logger = setup_logger(__name__, "BERTLLM")
super().__init__(input_queue)
mock_config = RelationshipExtractorConfig()
self.semantic_extractor = RelationExtractor(mock_config)
self.graph_config = GraphConfig(identifier=config.name)
self.graph_config = GraphConfig(identifier=config.name)
self.contextual_graph = QuerentKG(self.graph_config)
Expand Down Expand Up @@ -124,8 +129,10 @@ def set_filter_params(self, **kwargs):
self.triple_filter.set_params(**kwargs)
else:
self.triple_filter = TripleFilter(**kwargs)

async def process_tokens(self, data: IngestedTokens):
start_time = time.time()
start_memory = psutil.Process().memory_info().rss / (1024 * 1024) # Memory in MB
doc_entity_pairs = []
number_sentences = 0
try:
Expand All @@ -135,7 +142,7 @@ async def process_tokens(self, data: IngestedTokens):
else:
clean_text = data.data
if not BERTLLM.validate_ingested_tokens(data):
self.set_termination_event()
self.set_termination_event()
return
file, content = self.file_buffer.add_chunk(
data.get_file_path(), clean_text
Expand All @@ -153,13 +160,14 @@ async def process_tokens(self, data: IngestedTokens):
number_sentences = number_sentences + 1
else:
return

if self.sample_entities:
doc_entity_pairs = self.entity_context_extractor.process_entity_types(doc_entities=doc_entity_pairs)
if doc_entity_pairs:
doc_entity_pairs = self.ner_llm_instance.remove_duplicates(doc_entity_pairs)
pairs_withattn = self.attn_scores_instance.extract_and_append_attention_weights(doc_entity_pairs)
if self.enable_filtering == True and not self.entity_context_extractor and self.count_entity_pairs(pairs_withattn)>1 and not self.predicate_context_extractor:
self.entity_embedding_extractor = EntityEmbeddingExtractor(self.ner_model, self.ner_tokenizer, self.count_entity_pairs(pairs_withattn), number_sentences=number_sentences)
self.entity_embedding_extractor = EntityEmbeddingExtractor(self.ner_model, self.ner_tokenizer)
pairs_withemb = self.entity_embedding_extractor.extract_and_append_entity_embeddings(pairs_withattn)
else:
pairs_withemb = pairs_withattn
Expand All @@ -169,18 +177,17 @@ async def process_tokens(self, data: IngestedTokens):
clustered_triples = cluster_output['filtered_triples']
cluster_labels = cluster_output['cluster_labels']
cluster_persistence = cluster_output['cluster_persistence']

final_clustered_triples = self.triple_filter.filter_by_cluster_persistence(pairs_with_predicates, cluster_persistence, cluster_labels)
if final_clustered_triples:
filtered_triples, _ = self.triple_filter.filter_triples(final_clustered_triples)
filtered_triples, reduction_count = self.triple_filter.filter_triples(final_clustered_triples)
else:
filtered_triples, _ = self.triple_filter.filter_triples(clustered_triples)
self.logger.log(f"Filtering in {self.__class__.__name__} producing 0 entity pairs. Filtering Disabled. ")
else:
filtered_triples = pairs_with_predicates
mock_config = RelationshipExtractorConfig()
semantic_extractor = RelationExtractor(mock_config)
relationships = semantic_extractor.process_tokens(filtered_triples)
embedding_triples = semantic_extractor.generate_embeddings(relationships)
relationships = self.semantic_extractor.process_tokens(filtered_triples[:1])
embedding_triples = self.semantic_extractor.generate_embeddings(relationships)
if self.sample_relationships:
embedding_triples = self.predicate_context_extractor.process_predicate_types(embedding_triples)
for triple in embedding_triples:
Expand Down
4 changes: 2 additions & 2 deletions querent/kg/contextual_predicate.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,8 @@ class ContextualPredicate(BaseModel):
@classmethod
def from_tuple(cls, data: Tuple[str, str, str, Dict[str, str], str]) -> 'ContextualPredicate':
try:
entity1_embedding = data[3].get('entity1_embedding', []).tolist() if 'entity1_embedding' in data[3] else []
entity2_embedding = data[3].get('entity2_embedding', []).tolist() if 'entity2_embedding' in data[3] else []
entity1_embedding = data[3].get('entity1_embedding', []) if 'entity1_embedding' in data[3] else []
entity2_embedding = data[3].get('entity2_embedding', []) if 'entity2_embedding' in data[3] else []

return cls(
context=data[1],
Expand Down
80 changes: 29 additions & 51 deletions querent/kg/ner_helperfunctions/contextual_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,12 @@
import torch
import umap
import numpy as np
import time
import psutil

def get_memory():
process = psutil.Process()
return process.memory_info().rss / (1024 * 1024) # Memory in MB

"""
EntityEmbeddingExtractor: A class designed to extract embeddings for entities within a given context.
Expand All @@ -29,18 +35,17 @@

class EntityEmbeddingExtractor:

def __init__(self, model, tokenizer, number_entity_pairs, number_sentences):
def __init__(self, model, tokenizer):
self.logger = setup_logger(__name__, "EntityEmbeddingExtractor")
try:
self.model = model
self.tokenizer = tokenizer
self.reducer = umap.UMAP(init='random',n_neighbors=min(15, number_entity_pairs), min_dist=0.1, n_components=10, metric='cosine')
self.sentence_reducer = umap.UMAP(init='random', n_neighbors=min(15, number_sentences), min_dist=0.1, n_components=10, metric='cosine')
except Exception as e:
self.logger.error(f"Error Initializing Entity Embedding Extractor Class: {e}")

def extract_entity_embedding(self, entity, context):
try:

inputs = self.tokenizer(context, return_tensors="pt", truncation=True, padding=True, max_length=512)
with torch.no_grad():
outputs = self.model(**inputs, output_hidden_states=True)
Expand All @@ -51,23 +56,13 @@ def extract_entity_embedding(self, entity, context):
entity_embedding = last_hidden_state[entity_positions].mean(dim=0)
sentence_embedding = last_hidden_state.mean(dim=0)
combined_embedding = torch.cat((entity_embedding, sentence_embedding), dim=0)

return combined_embedding, sentence_embedding

except Exception as e:
self.logger.error(f"Error extracting entity embedding: {e}")
raise Exception("Error extracting entity embedding: {}".format(e))


def fit_umap(self, all_embeddings):
try:
if not all_embeddings:
raise ValueError("Embedding lists are empty or contain invalid data.")
self.reducer.fit(np.array(all_embeddings))
except Exception as e:
self.logger.error(f"Error fitting UMAP: {e}")
raise Exception(f"Error fitting UMAP: {e}")


def _get_relevant_context(self, entity1, entity2, full_context):
sentences = NER_LLM.split_into_sentences(full_context)
Expand All @@ -76,51 +71,34 @@ def _get_relevant_context(self, entity1, entity2, full_context):
return sentence
return full_context

def append_if_not_present(self,item, item_embedding, all_items, all_embeddings, sentence=None):
if item not in all_items and sentence == None:
all_items.append(item)
all_embeddings.append(item_embedding.tolist())
elif sentence is not None:
if (item + " " + sentence) not in all_items:
all_items.append(item + " " + sentence)
all_embeddings.append(item_embedding.tolist())


def _update_pairs_with_embeddings(self, doc_entity_pairs):
updated_pairs = []
for inner_list in doc_entity_pairs:
updated_inner_list = []
for pair in inner_list:
entity1, full_context, entity2, pair_dict = pair
context = self._get_relevant_context(entity1, entity2, full_context)
entity1_embedding, _ = self.extract_entity_embedding(entity1, context)
entity2_embedding, _ = self.extract_entity_embedding(entity2, context)
pair_dict['entity1_embedding'] = self.reducer.transform([entity1_embedding.tolist()])[0]
pair_dict['entity2_embedding'] = self.reducer.transform([entity2_embedding.tolist()])[0]
updated_inner_list.append((entity1, full_context, entity2, pair_dict))
updated_pairs.append(updated_inner_list)
return updated_pairs

def extract_and_append_entity_embeddings(self, doc_entity_pairs):
all_embeddings = []
all_entities = []
try:
for inner_list in doc_entity_pairs:
for pair in inner_list:
entity1, full_context, entity2, _ = pair
start_time = time.time()
start_memory = get_memory()
for inner_list_index, inner_list in enumerate(doc_entity_pairs):
to_remove = [] # List to hold indices of pairs to remove
for pair_index, pair in enumerate(inner_list):
entity1, full_context, entity2, pair_dict = pair
context = self._get_relevant_context(entity1, entity2, full_context)
entity1_embedding, _ = self.extract_entity_embedding(entity1, context)
entity2_embedding, _ = self.extract_entity_embedding(entity2, context)
self.append_if_not_present(entity1, entity1_embedding, all_entities, all_embeddings, sentence=context)
self.append_if_not_present(entity2, entity2_embedding, all_entities, all_embeddings, sentence=context)

self.fit_umap(all_embeddings=all_embeddings)
return self._update_pairs_with_embeddings(doc_entity_pairs)
is_nan_entity1 = np.isnan(entity1_embedding).any()
is_nan_entity2 = np.isnan(entity2_embedding).any()

if is_nan_entity1 or is_nan_entity2:
# Record the index of the pair that needs to be removed
to_remove.append(pair_index)
if not is_nan_entity1:
pair_dict['entity1_embedding'] = entity1_embedding.tolist()
if not is_nan_entity2:
pair_dict['entity2_embedding'] = entity2_embedding.tolist()
doc_entity_pairs[inner_list_index] = [pair for i, pair in enumerate(inner_list) if i not in to_remove]

return doc_entity_pairs

except Exception as e:
self.logger.error(f"Error extracting and appending entity embedding: {e}")
raise Exception(f"Error extracting and appending entity embedding: {e}")




raise Exception(f"Error extracting and appending entity embedding: {e}")
1 change: 0 additions & 1 deletion querent/kg/ner_helperfunctions/filter_triples.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,6 @@ def cluster_triples(self, triples: List[Tuple[str, str, str]]) -> Dict[str, any]
cluster_persistence = clusterer.cluster_persistence_

filtered_triples = [triples[index] for index, label in enumerate(cluster_labels) if label != -1]

cluster_output = {
'filtered_triples': filtered_triples,
'reduction_count': len(triples) - len(filtered_triples),
Expand Down
File renamed without changes.
25 changes: 10 additions & 15 deletions tests/kg_tests/contextual_predicate_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,20 +17,15 @@
"""

def test_contextual_predicate():
sample_data = [[('eocene', 'In this study, we present evidence of a PaleoceneEocene Thermal Maximum (PETM) record within a 543-m-thick (1780 ft) deep-marine section in the Gulf of Mexico (GoM) using organic carbon stable isotopes and biostratigraphic constraints.', 'ft', {'entity1_score': 1.0, 'entity2_score': 0.69, 'entity1_label': 'B-GeoMeth, B-GeoTime', 'entity2_label': 'B-GeoMeth', 'entity1_nn_chunk': 'a PaleoceneEocene Thermal Maximum (PETM) record', 'entity2_nn_chunk': 'a 543-m-thick (1780 ft) deep-marine section', 'entity1_attnscore': 0.46, 'entity2_attnscore': 0.21, 'pair_attnscore': 0.13,'entity1_embedding': np.array([ 5.513507 , 6.14687 , 0.56821245, 3.7250893 , 8.519092 ,2.1298776 , 6.7030797 , 8.760443 , -2.4095411 , 14.959248 ],dtype=np.float32), 'entity2_embedding': np.array([ 4.3211513, 5.3283153, 1.2105073, 5.3618913, 8.23375 ,2.951651 , 7.3403625, 10.785665 , -2.5593305, 14.518231 ],
dtype=np.float32), 'sentence_embedding': np.array([ 8.136787 , 12.801951 , 3.1658218 , 7.360018 , 9.823584 ,
0.28562617, 12.840015 , 0.40643066, 9.059556 , 12.759513 ],
dtype=np.float32)}),
('eocene', 'In this study, we present evidence of a PaleoceneEocene Thermal Maximum (PETM) record within a 543-m-thick (1780 ft) deep-marine section in the Gulf of Mexico (GoM) using organic carbon stable isotopes and biostratigraphic constraints.', 'mexico', {'entity1_score': 1.0, 'entity2_score': 0.92, 'entity1_label': 'B-GeoMeth, B-GeoTime', 'entity2_label': 'B-GeoLoc', 'entity1_nn_chunk': 'a PaleoceneEocene Thermal Maximum (PETM) record', 'entity2_nn_chunk': 'Mexico', 'entity1_attnscore': 0.46, 'entity2_attnscore': 0.17, 'pair_attnscore': 0.13,'entity1_embedding': np.array([ 5.355203 , 6.1266084 , 0.60222036, 3.7390788 , 8.5242195 ,
2.1033056 , 6.6313214 , 8.70998 , -2.432465 , 15.200483 ],
dtype=np.float32), 'entity2_embedding': np.array([ 5.601423 , 6.058842 , 0.33065754, 6.1470265 , 8.568694 ,
3.922125 , 7.0688643 , 11.551212 , -2.5106885 , 14.04761 ],
dtype=np.float32),'sentence_embedding': np.array([ 8.136787 , 12.801951 , 3.1658218 , 7.360018 , 9.823584 ,
0.28562617, 12.840015 , 0.40643066, 9.059556 , 12.759513 ],
dtype=np.float32)})]]
sample_data = [[('temperatures', 'The Paleocene–Eocene Thermal Maximum (PETM) (ca. 56 Ma) was a rapid global warming event characterized by the rise of temperatures to5–9 °C (Kennett and Stott, 1991), which caused substantial environmental changes around the globe.', 'kenn', {'entity1_score': 0.91, 'entity2_score': 0.97, 'entity1_label': 'B-GeoMeth', 'entity2_label': 'B-GeoPetro', 'entity1_nn_chunk': 'temperatures', 'entity2_nn_chunk': 'Kennett', 'entity1_attnscore': 0.86, 'entity2_attnscore': 0.09, 'pair_attnscore': 0.16, 'entity1_embedding': [0.0975915715098381], 'entity2_embedding': [-1.2051959037780762]})]]

result_list = process_data(sample_data, "dummy1.pdf")
result_string = result_list[0][1] if result_list else ""
print(result_string)
expected_string = '{"context": "In this study, we present evidence of a PaleoceneEocene Thermal Maximum (PETM) record within a 543-m-thick (1780 ft) deep-marine section in the Gulf of Mexico (GoM) using organic carbon stable isotopes and biostratigraphic constraints.", "entity1_score": 1.0, "entity2_score": 0.69, "entity1_label": "B-GeoMeth, B-GeoTime", "entity2_label": "B-GeoMeth", "entity1_nn_chunk": "a PaleoceneEocene Thermal Maximum (PETM) record", "entity2_nn_chunk": "a 543-m-thick (1780 ft) deep-marine section", "file_path": "dummy1.pdf", "entity1_attnscore": 0.46, "entity2_attnscore": 0.21, "pair_attnscore": 0.13, "entity1_embedding": [5.513506889343262, 6.146870136260986, 0.5682124495506287, 3.7250893115997314, 8.519091606140137, 2.1298775672912598, 6.703079700469971, 8.760442733764648, -2.409541130065918, 14.959247589111328], "entity2_embedding": [4.321151256561279, 5.328315258026123, 1.2105072736740112, 5.361891269683838, 8.233750343322754, 2.951651096343994, 7.340362548828125, 10.785664558410645, -2.559330463409424, 14.518231391906738]}'
assert result_string == expected_string, f"Expected {expected_string}, but got {result_string}"

# Check if result_list is a list of tuples with three strings each
if result_list and all(isinstance(item, tuple) and len(item) == 3 and all(isinstance(element, str) for element in item) for item in result_list):
# If the condition is True, the assertion passes
assert True
else:
# If the condition is False, the assertion fails
assert False, "result_list is not a list of tuples like List[Tuple[str, str, str]]"

1 change: 1 addition & 0 deletions tests/llm_tests/bert_llm_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ class StateChangeCallback(EventCallbackInterface):
async def handle_event(self, event_type: EventType, event_state: EventState):
assert event_state.event_type == EventType.Graph
triple = json.loads(event_state.payload)
print("triple: {}".format(triple))
assert isinstance(triple['subject'], str) and triple['subject']
llm_instance.subscribe(EventType.Graph, StateChangeCallback())
querent = Querent(
Expand Down
Loading
Loading