Skip to content

Commit

Permalink
Added persona (#12)
Browse files Browse the repository at this point in the history
* Updated requirements

* Added persona to llm

* Removed num_parallel_processes from LLM class
  • Loading branch information
NeonBohdan authored Nov 16, 2023
1 parent 5fb146d commit 73d0182
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 10 deletions.
18 changes: 9 additions & 9 deletions neon_llm_chatgpt/chatgpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@ def __init__(self, config):
self.context_depth = config["context_depth"]
self.max_tokens = config["max_tokens"]
self.api_key = config["key"]
self.num_parallel_processes = config["num_parallel_processes"]
self.warmup()

@property
Expand Down Expand Up @@ -74,7 +73,7 @@ def _system_prompt(self) -> str:
def warmup(self):
self.model

def get_sorted_answer_indexes(self, question: str, answers: List[str]) -> List[int]:
def get_sorted_answer_indexes(self, question: str, answers: List[str], persona: dict) -> List[int]:
"""
Creates sorted list of answer indexes with respect to order provided in :param answers based on PPL score
Answers are sorted from best to worst
Expand All @@ -84,7 +83,7 @@ def get_sorted_answer_indexes(self, question: str, answers: List[str]) -> List[i
"""
if not answers:
return []
scores = self._score(prompt=question, targets=answers)
scores = self._score(prompt=question, targets=answers, persona=persona)
sorted_items = sorted(zip(range(len(answers)), scores), key=lambda x: x[1])
sorted_items_indexes = [x[0] for x in sorted_items]
return sorted_items_indexes
Expand All @@ -106,7 +105,7 @@ def _call_model(self, prompt: List[Dict[str, str]]) -> str:

return text

def _assemble_prompt(self, message: str, chat_history: List[List[str]]) -> List[Dict[str, str]]:
def _assemble_prompt(self, message: str, chat_history: List[List[str]], persona: dict) -> List[Dict[str, str]]:
"""
Assembles prompt engineering logic
Setup Guidance:
Expand All @@ -116,8 +115,9 @@ def _assemble_prompt(self, message: str, chat_history: List[List[str]]) -> List[
:param chat_history: History of preceding conversation
:returns: assembled prompt
"""
system_prompt = persona.get("description", self._system_prompt)
messages = [
{"role": "system", "content": self._system_prompt},
{"role": "system", "content": system_prompt},
]
# Context N messages
for role, content in chat_history[-self.context_depth:]:
Expand All @@ -126,29 +126,29 @@ def _assemble_prompt(self, message: str, chat_history: List[List[str]]) -> List[
messages.append({"role": "user", "content": message})
return messages

def _score(self, prompt: str, targets: List[str]) -> List[float]:
def _score(self, prompt: str, targets: List[str], persona: dict) -> List[float]:
"""
Calculates logarithmic probabilities for the list of provided text sequences
:param prompt: Input text sequence
:param targets: Output text sequences
:returns: List of calculated logarithmic probabilities per output text sequence
"""

question_embeddings, answers_embeddings = self._embeddings(question=prompt, answers=targets)
question_embeddings, answers_embeddings = self._embeddings(question=prompt, answers=targets, persona=persona)
scores_list = distances_from_embeddings(question_embeddings, answers_embeddings)
return scores_list

def _tokenize(self, prompt: str) -> None:
pass

def _embeddings(self, question: str, answers: List[str]) -> (List[float], List[List[float]]):
def _embeddings(self, question: str, answers: List[str], persona: dict) -> (List[float], List[List[float]]):
"""
Computes embeddings for the list of provided answers
:param question: Question for LLM to response to
:param answers: List of provided answers
:returns ppl values for each answer
"""
response = self.ask(question, [])
response = self.ask(question, [], persona=persona)
texts = [response] + answers
embeddings = get_embeddings(texts, engine="text-embedding-ada-002")
question_embeddings = embeddings[0]
Expand Down
2 changes: 1 addition & 1 deletion requirements/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# model
openai[embeddings]~=0.27
# networking
neon_llm_core==0.0.6
neon_llm_core~=0.1.0

0 comments on commit 73d0182

Please sign in to comment.