From aa8de9170676be62e11875dec2b7ea3e0336ee71 Mon Sep 17 00:00:00 2001 From: zhourunlong Date: Thu, 28 Nov 2024 16:18:12 -0800 Subject: [PATCH] add caching --- .gitignore | 1 + .../ai-math-autoformalization/python/utils.py | 60 +++++++++++++++++++ 2 files changed, 61 insertions(+) create mode 100644 projects/ai-math-autoformalization/python/utils.py diff --git a/.gitignore b/.gitignore index 49275dbe2f..d6790d72b5 100644 --- a/.gitignore +++ b/.gitignore @@ -348,3 +348,4 @@ out data/mini-husky/basic/ experiments/transformer-vs-recursive/results/ experiments/transformer-vs-recursive/figures/ +projects/ai-math-autoformalization/python/.chat_completion_cache/cache.db diff --git a/projects/ai-math-autoformalization/python/utils.py b/projects/ai-math-autoformalization/python/utils.py new file mode 100644 index 0000000000..0b0ecb6d15 --- /dev/null +++ b/projects/ai-math-autoformalization/python/utils.py @@ -0,0 +1,60 @@ +from openai import OpenAI +import hashlib +from diskcache import Cache + +cache = Cache(".chat_completion_cache") + +class ChatCompletionAPI: + def __init__(self, type, model=None): + if type == "local": + self.client = OpenAI(base_url="http://172.28.6.60:8000/v1") # DO NOT CHANGE THIS LINE + self.model = "/local1/models/models--meta-llama--Llama-3.1-70B-Instruct/snapshots/945c8663693130f8be2ee66210e062158b2a9693" # DO NOT CHANGE THIS LINE + elif type == "openai": + self.client = OpenAI() + self.model = model + else: + raise ValueError("Invalid type") + + def _generate_cache_key(self, model, messages): + messages_str = "".join([msg["role"] + msg["content"] for msg in messages]) + key_base = f"{model}:{messages_str}" + cache_key = hashlib.md5(key_base.encode()).hexdigest() + return cache_key + + def _fetch_completion(self, model, messages, use_cache): + if use_cache: + cache_key = self._generate_cache_key(model, messages) + cached_result = cache.get(cache_key) + if cached_result: + return cached_result + + completion = self.client.chat.completions.create( + model=model, + messages=messages + ) + result = { + "role": completion.choices[0].message.role, + "content": completion.choices[0].message.content, + } + cache.set(cache_key, result) + return result + + def chat_completion(self, messages, use_cache=True): + return self._fetch_completion(self.model, messages, use_cache) + +if __name__ == "__main__": + api = ChatCompletionAPI("local") + + messages = [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "What is the capital of France?"}, + ] + for _ in range(3): + completion = api.chat_completion(messages) + print(completion["content"]) + + messages.append(completion) + messages.append({ + "role": "user", + "content": "Tell me more fun facts about this city.", + })