Skip to content

Commit

Permalink
Merge pull request #38 from nextcloud/feat/adjust-model-size
Browse files Browse the repository at this point in the history
feat: set chunk size to model context size for LLM chains
  • Loading branch information
marcelklehr authored Oct 30, 2024
2 parents 10dc872 + fcec41b commit 644bbae
Show file tree
Hide file tree
Showing 5 changed files with 21 additions and 11 deletions.
13 changes: 8 additions & 5 deletions lib/chains.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
"""

import json
from math import ceil
import os

from free_prompt import FreePromptChain
Expand Down Expand Up @@ -82,15 +83,17 @@ def generate_chains():

def generate_chain_for_model(file_name, chains):
model_name = file_name.split('.gguf')[0]
n_ctx = get_model_config(file_name)["loader_config"]["n_ctx"]
# chunk_size = int(ceil(0.7 * n_ctx))

chain = [None]
llm_chain = lambda: chain[-1] if chain[-1] is not None else chain.append(generate_llm_chain(file_name)) or chain[-1]

chains[model_name + ":core:text2text:summary"] = lambda: SummarizeChain(llm_chain=llm_chain())
chains[model_name + ":core:text2text:summary"] = lambda: SummarizeChain(llm_chain=llm_chain(), n_ctx=n_ctx)
chains[model_name + ":core:text2text:headline"] = lambda: HeadlineChain(llm_chain=llm_chain())
chains[model_name + ":core:text2text:topics"] = lambda: TopicsChain(llm_chain=llm_chain())
# chains[model_name + ":core:text2text:simplification"] = lambda: SimplifyChain(llm_chain=llm_chain())
# chains[model_name + ":core:text2text:formalization"] = lambda: FormalizeChain(llm_chain=llm_chain())
# chains[model_name + ":core:text2text:reformulation"] = lambda: ReformulateChain(llm_chain=llm_chain())
# chains[model_name + ":core:text2text:simplification"] = lambda: SimplifyChain(llm_chain=llm_chain(), chunk_size=chunk_size)
# chains[model_name + ":core:text2text:formalization"] = lambda: FormalizeChain(llm_chain=llm_chain(), chunk_size=chunk_size)
# chains[model_name + ":core:text2text:reformulation"] = lambda: ReformulateChain(llm_chain=llm_chain(), chunk_size=chunk_size)
chains[model_name + ":core:text2text"] = lambda: FreePromptChain(llm_chain=llm_chain())
#chains[model_name + ":core:contextwrite"] = lambda: ContextWriteChain(llm_chain=llm_chain())
# chains[model_name + ":core:contextwrite"] = lambda: ContextWriteChain(llm_chain=llm_chain())
3 changes: 2 additions & 1 deletion lib/formalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ class FormalizeChain(Chain):

"""Prompt object to use."""
llm_chain: LLMChain
chunk_size: int = 8000
output_key: str = "text" #: :meta private:

class Config:
Expand Down Expand Up @@ -75,7 +76,7 @@ def _call(
raise ValueError(f"llm_chain must have output_keys [{self.output_key}]")

text_splitter = CharacterTextSplitter(
separator='\n\n|\\.|\\?|\\!', chunk_size=8000, chunk_overlap=0, keep_separator=True)
separator='\n\n|\\.|\\?|\\!', chunk_size=self.chunk_size, chunk_overlap=0, keep_separator=True)
texts = text_splitter.split_text(inputs['input'])
outputs = self.llm_chain.apply([{"user_prompt": self.user_prompt.format_prompt(text=t), "system_prompt": self.system_prompt} for t in texts])
texts = [output['text'] for output in outputs]
Expand Down
3 changes: 2 additions & 1 deletion lib/reformulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ class ReformulateChain(Chain):

"""Prompt object to use."""
llm_chain: LLMChain
chunk_size: int = 8000
output_key: str = "text" #: :meta private:

class Config:
Expand Down Expand Up @@ -76,7 +77,7 @@ def _call(
raise ValueError(f"llm_chain must have output_keys [{self.output_key}]")

text_splitter = CharacterTextSplitter(
separator='\n\n|\\.|\\?|\\!', chunk_size=8000, chunk_overlap=0, keep_separator=True)
separator='\n\n|\\.|\\?|\\!', chunk_size=self.chunk_size, chunk_overlap=0, keep_separator=True)
texts = text_splitter.split_text(inputs['input'])
outputs = self.llm_chain.apply([{"user_prompt": self.user_prompt.format_prompt(text=t), "system_prompt": self.system_prompt} for t in texts])
texts = [output['text'] for output in outputs]
Expand Down
3 changes: 2 additions & 1 deletion lib/simplify.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ class SimplifyChain(Chain):
# Multilingual output doesn't work with llama3.1

llm_chain: LLMChain
chunk_size: int = 8000
output_key: str = "text" #: :meta private:

class Config:
Expand Down Expand Up @@ -72,7 +73,7 @@ def _call(
raise ValueError(f"llm_chain must have output_keys [{self.output_key}]")

text_splitter = CharacterTextSplitter(
separator="\n\n|\\.|\\?|\\!", chunk_size=8000, chunk_overlap=0, keep_separator=True
separator="\n\n|\\.|\\?|\\!", chunk_size=self.chunk_size, chunk_overlap=0, keep_separator=True
)
texts = text_splitter.split_text(inputs["input"])
outputs = self.llm_chain.apply([{"user_prompt": self.user_prompt.format_prompt(text=t), "system_prompt": self.system_prompt} for t in texts])
Expand Down
10 changes: 7 additions & 3 deletions lib/summarize.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ class SummarizeChain(Chain):


llm_chain: LLMChain
n_ctx: int = 8000
output_key: str = "text" #: :meta private:

class Config:
Expand Down Expand Up @@ -70,11 +71,14 @@ def _call(
raise ValueError("llm_chain must have input_keys ['user_prompt', 'system_prompt']")
if not self.llm_chain.output_keys == [self.output_key]:
raise ValueError(f"llm_chain must have output_keys [{self.output_key}]")


summary_size = max(len(inputs['input']) * 0.2, 1000) # 2000 chars summary per 10.000 chars original text
chunk_size = max(self.n_ctx - summary_size, 2048)

text_splitter = CharacterTextSplitter(
separator='\n\n|\\.|\\?|\\!', chunk_size=8000, chunk_overlap=0, keep_separator=True)
separator='\n\n|\\.|\\?|\\!', chunk_size=chunk_size, chunk_overlap=0, keep_separator=True)
texts = text_splitter.split_text(inputs['input'])
while sum([len(text) for text in texts]) > max(len(inputs['input']) * 0.2, 1000): # 2000 chars summary per 10.000 chars original text
while sum([len(text) for text in texts]) > summary_size:
docs = [texts[i:i + 3] for i in range(0, len(texts), 3)]
outputs = self.llm_chain.apply([{"user_prompt": self.user_prompt.format_prompt(text=''.join(doc)), "system_prompt": self.system_prompt} for doc in docs])
texts = [output[self.output_key] for output in outputs]
Expand Down

0 comments on commit 644bbae

Please sign in to comment.