Skip to content

Commit

Permalink
Reka LM fix, add example
Browse files Browse the repository at this point in the history
  • Loading branch information
Taurus3301 committed Jan 17, 2025
1 parent 68db0a6 commit bc720f5
Show file tree
Hide file tree
Showing 2 changed files with 212 additions and 7 deletions.
205 changes: 205 additions & 0 deletions examples/storm_examples/run_storm_wiki_reka.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,205 @@
"""
STORM Wiki pipeline powered by Reka API and search engine.
You need to set up the following environment variables to run this script:
- REKA_API_KEY: Reka API key
- YDC_API_KEY: You.com API key; BING_SEARCH_API_KEY: Bing Search API key, SERPER_API_KEY: Serper API key,
BRAVE_API_KEY: Brave API key, or TAVILY_API_KEY: Tavily API key
Output will be structured as below
args.output_dir/
topic_name/ # topic_name will follow convention of underscore-connected topic name w/o space and slash
conversation_log.json # Log of information-seeking conversation
raw_search_results.json # Raw search results from search engine
direct_gen_outline.txt # Outline directly generated with LLM's parametric knowledge
storm_gen_outline.txt # Outline refined with collected information
url_to_info.json # Sources that are used in the final article
storm_gen_article.txt # Final article generated
storm_gen_article_polished.txt # Polished final article (if args.do_polish_article is True)
"""

import os
from argparse import ArgumentParser

from knowledge_storm import (
STORMWikiRunnerArguments,
STORMWikiRunner,
STORMWikiLMConfigs,
)
from knowledge_storm.rm import (
YouRM,
BingSearch,
BraveRM,
SerperRM,
DuckDuckGoSearchRM,
TavilySearchRM,
SearXNG,
)
from knowledge_storm.utils import load_api_key
from knowledge_storm.lm import RekaModel


def main(args):
load_api_key(toml_file_path="secrets.toml")
lm_configs = STORMWikiLMConfigs()
reka_kwargs = {
"api_key": os.getenv("REKA_API_KEY"),
"temperature": 1.0,
"top_p": 0.9,
}

# STORM is a LM system so different components can be powered by different models.
# For a good balance between cost and quality, you can choose a cheaper/faster model for conv_simulator_lm
# which is used to split queries, synthesize answers in the conversation. We recommend using stronger models
# for outline_gen_lm which is responsible for organizing the collected information, and article_gen_lm
# which is responsible for generating sections with citations.
conv_simulator_lm = RekaModel(model="reka-flash", max_tokens=500, **reka_kwargs)
question_asker_lm = RekaModel(model="reka-core", max_tokens=500, **reka_kwargs)
outline_gen_lm = RekaModel(model="reka-core", max_tokens=400, **reka_kwargs)
article_gen_lm = RekaModel(model="reka-core", max_tokens=700, **reka_kwargs)
article_polish_lm = RekaModel(model="reka-core", max_tokens=4000, **reka_kwargs)

lm_configs.set_conv_simulator_lm(conv_simulator_lm)
lm_configs.set_question_asker_lm(question_asker_lm)
lm_configs.set_outline_gen_lm(outline_gen_lm)
lm_configs.set_article_gen_lm(article_gen_lm)
lm_configs.set_article_polish_lm(article_polish_lm)

engine_args = STORMWikiRunnerArguments(
output_dir=args.output_dir,
max_conv_turn=args.max_conv_turn,
max_perspective=args.max_perspective,
search_top_k=args.search_top_k,
max_thread_num=args.max_thread_num,
)

# STORM is a knowledge curation system which consumes information from the retrieval module.
# Currently, the information source is the Internet and we use search engine API as the retrieval module.
match args.retriever:
case "bing":
rm = BingSearch(
bing_search_api=os.getenv("BING_SEARCH_API_KEY"),
k=engine_args.search_top_k,
)
case "you":
rm = YouRM(ydc_api_key=os.getenv("YDC_API_KEY"), k=engine_args.search_top_k)
case "brave":
rm = BraveRM(
brave_search_api_key=os.getenv("BRAVE_API_KEY"),
k=engine_args.search_top_k,
)
case "duckduckgo":
rm = DuckDuckGoSearchRM(
k=engine_args.search_top_k, safe_search="On", region="us-en"
)
case "serper":
rm = SerperRM(
serper_search_api_key=os.getenv("SERPER_API_KEY"),
query_params={"autocorrect": True, "num": 10, "page": 1},
)
case "tavily":
rm = TavilySearchRM(
tavily_search_api_key=os.getenv("TAVILY_API_KEY"),
k=engine_args.search_top_k,
include_raw_content=True,
)
case "searxng":
rm = SearXNG(
searxng_api_key=os.getenv("SEARXNG_API_KEY"), k=engine_args.search_top_k
)
case _:
raise ValueError(
f'Invalid retriever: {args.retriever}. Choose either "bing", "you", "brave", "duckduckgo", "serper", "tavily", or "searxng"'
)

runner = STORMWikiRunner(engine_args, lm_configs, rm)

topic = input("Topic: ")
runner.run(
topic=topic,
do_research=args.do_research,
do_generate_outline=args.do_generate_outline,
do_generate_article=args.do_generate_article,
do_polish_article=args.do_polish_article,
)
runner.post_run()
runner.summary()


if __name__ == "__main__":
parser = ArgumentParser()
# global arguments
parser.add_argument(
"--output-dir",
type=str,
default="./results/reka",
help="Directory to store the outputs.",
)
parser.add_argument(
"--max-thread-num",
type=int,
default=3,
help="Maximum number of threads to use. The information seeking part and the article generation"
"part can speed up by using multiple threads. Consider reducing it if keep getting "
'"Exceed rate limit" error when calling LM API.',
)
parser.add_argument(
"--retriever",
type=str,
choices=["bing", "you", "brave", "serper", "duckduckgo", "tavily", "searxng"],
help="The search engine API to use for retrieving information.",
)
# stage of the pipeline
parser.add_argument(
"--do-research",
action="store_true",
help="If True, simulate conversation to research the topic; otherwise, load the results.",
)
parser.add_argument(
"--do-generate-outline",
action="store_true",
help="If True, generate an outline for the topic; otherwise, load the results.",
)
parser.add_argument(
"--do-generate-article",
action="store_true",
help="If True, generate an article for the topic; otherwise, load the results.",
)
parser.add_argument(
"--do-polish-article",
action="store_true",
help="If True, polish the article by adding a summarization section and (optionally) removing "
"duplicate content.",
)
# hyperparameters for the pre-writing stage
parser.add_argument(
"--max-conv-turn",
type=int,
default=3,
help="Maximum number of questions in conversational question asking.",
)
parser.add_argument(
"--max-perspective",
type=int,
default=3,
help="Maximum number of perspectives to consider in perspective-guided question asking.",
)
parser.add_argument(
"--search-top-k",
type=int,
default=3,
help="Top k search results to consider for each search query.",
)
# hyperparameters for the writing stage
parser.add_argument(
"--retrieve-top-k",
type=int,
default=3,
help="Top k collected references for each section title.",
)
parser.add_argument(
"--remove-duplicate",
action="store_true",
help="If True, remove duplicate content from the article.",
)

main(parser.parse_args())
14 changes: 7 additions & 7 deletions knowledge_storm/lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -988,8 +988,8 @@ def log_usage(self, response):
usage_data = response.usage
if usage_data:
with self._token_usage_lock:
self.prompt_tokens += usage_data.prompt_tokens
self.completion_tokens += usage_data.completion_tokens
self.prompt_tokens += usage_data.input_tokens
self.completion_tokens += usage_data.output_tokens

def get_usage_and_reset(self):
"""Get the total tokens used and reset the token usage."""
Expand All @@ -1009,16 +1009,16 @@ def basic_request(self, prompt: str, **kwargs):
# caching mechanism requires hashable kwargs
kwargs["messages"] = [{"role": "user", "content": prompt}]
kwargs.pop("n")
response = self.client.completions.create(**kwargs)
response = self.client.chat.create(**kwargs)

json_serializable_history = {
"prompt": prompt,
"response": {
"text": response.choices[0].text,
"text": response.responses[0].message.content,
"model": response.model,
"usage": {
"prompt_tokens": response.usage.prompt_tokens,
"completion_tokens": response.usage.completion_tokens,
"prompt_tokens": response.usage.input_tokens,
"completion_tokens": response.usage.output_tokens,
},
},
"kwargs": kwargs,
Expand Down Expand Up @@ -1058,5 +1058,5 @@ def __call__(self, prompt: str, only_completed=True, return_sorted=False, **kwar
for _ in range(n):
response = self.request(prompt, **kwargs)
self.log_usage(response)
completions = [choice.text for choice in response.choices]
completions = [choice.message.content for choice in response.responses]
return completions

0 comments on commit bc720f5

Please sign in to comment.