Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

make factscore use chatgpt for atomic fact generation #42

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 14 additions & 19 deletions factscore/atomic_facts.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,12 @@


class AtomicFactGenerator(object):
def __init__(self, key_path, demon_dir, gpt3_cache_file=None):
def __init__(self, key_path, demon_dir, cache_file=None):
self.nlp = spacy.load("en_core_web_sm")
self.is_bio = True
self.demon_path = os.path.join(demon_dir, "demons.json" if self.is_bio else "demons_complex.json")

self.openai_lm = OpenAIModel("InstructGPT", cache_file=gpt3_cache_file, key_path=key_path)
self.openai_lm = OpenAIModel("ChatGPT", cache_file=cache_file, key_path=key_path)

# get the demos
with open(self.demon_path, 'r') as f:
Expand Down Expand Up @@ -103,40 +103,35 @@ def get_init_atomic_facts_from_sentence(self, sentences, cost_estimate=None):
n = 7 if is_bio else 8

prompts = []
prompt_to_sent = {}
atoms = {}
for sentence in sentences:
if sentence in atoms:
continue
top_machings = best_demos(sentence, self.bm25, list(demons.keys()), k)
prompt = ""
prompt = []

for i in range(n):
prompt = prompt + "Please breakdown the following sentence into independent facts: {}\n".format(list(demons.keys())[i])
for fact in demons[list(demons.keys())[i]]:
prompt = prompt + "- {}\n".format(fact)
prompt = prompt + "\n"
prompt.append({"role": "user", "content": f"Please breakdown the following sentence into independent facts: {list(demons.keys())[i]}"})
prompt.append({"role": "assistant", "content": "\n".join(f"- {fact}" for fact in demons[list(demons.keys())[i]])})

for match in top_machings:
prompt = prompt + "Please breakdown the following sentence into independent facts: {}\n".format(match)
for fact in demons[match]:
prompt = prompt + "- {}\n".format(fact)
prompt = prompt + "\n"
prompt = prompt + "Please breakdown the following sentence into independent facts: {}\n".format(sentence)
prompt.append({"role": "user", "content": f"Please breakdown the following sentence into independent facts: {match}"})
prompt.append({"role": "assistant", "content": "\n".join(f"- {fact}" for fact in demons[match])})

prompt.append({"role": "user", "content": f"Please breakdown the following sentence into independent facts: {sentence}"})
prompts.append(prompt)
prompt_to_sent[prompt] = sentence

if cost_estimate:
total_words_estimate = 0
for prompt in prompts:
if cost_estimate == "consider_cache" and (prompt.strip() + "_0") in self.openai_lm.cache_dict:
if cost_estimate == "consider_cache" and self.openai_lm.cache_key(prompt) in self.openai_lm.cache_dict:
continue
total_words_estimate += len(prompt.split())
total_words_estimate += sum(len(m["content"].split()) for m in prompt)
return total_words_estimate
else:
for prompt in prompts:
for prompt, sentence in zip(prompts, sentences):
output, _ = self.openai_lm.generate(prompt)
atoms[prompt_to_sent[prompt]] = text_to_sentences(output)
atoms[sentence] = text_to_sentences(output)

for key, value in demons.items():
if key not in atoms:
Expand Down Expand Up @@ -342,4 +337,4 @@ def main():
print(para_breaks)

if __name__ == "__main__":
main()
main()
6 changes: 3 additions & 3 deletions factscore/factscorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def print_cost_estimates(self, total_words, task, model):
if model == "davinci-003":
rate = 0.02
elif model == "gpt-3.5-turbo":
rate = 0.002
rate = 0.0015

total_cost = total_tokens * rate / 1000

Expand Down Expand Up @@ -128,14 +128,14 @@ def get_score(self,
if self.af_generator is None:
self.af_generator = AtomicFactGenerator(key_path=self.openai_key,
demon_dir=os.path.join(self.data_dir, "demos"),
gpt3_cache_file=os.path.join(self.cache_dir, "InstructGPT.pkl"))
cache_file=os.path.join(self.cache_dir, "ChatGPT.pkl"))

# estimate the total cost of atomic fact generation
total_words = 0
for gen in generations:
total_words += self.af_generator.run(gen, cost_estimate=self.cost_estimate)

self.print_cost_estimates(total_words, task="atomic fact generation", model="davinci-003")
self.print_cost_estimates(total_words, task="atomic fact generation", model="gpt-3.5-turbo")

if verbose:
topics = tqdm(topics)
Expand Down
14 changes: 11 additions & 3 deletions factscore/lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,25 @@ def load_model(self):
# load the model and put it as self.model
raise NotImplementedError()

def cache_key(self, prompt, sample_idx=0):
if isinstance(prompt, str):
prompt = prompt.strip() # it's important not to end with a whitespace

return f"{prompt}_{sample_idx}"

def generate(self, prompt, sample_idx=0, max_sequence_length=2048, max_output_length=128):
prompt = prompt.strip() # it's important not to end with a whitespace
cache_key = f"{prompt}_{sample_idx}"
if isinstance(prompt, str):
prompt = prompt.strip() # it's important not to end with a whitespace

cache_key = self.cache_key(prompt, sample_idx=sample_idx)

if cache_key in self.cache_dict:
return self.cache_dict[cache_key]

if self.model is None:
self.load_model()

if prompt.endswith(" True or False?\nAnswer:"):
if isinstance(prompt, str) and prompt.endswith(" True or False?\nAnswer:"):
generated = self._generate(prompt, max_sequence_length=max_sequence_length, max_output_length=1)
else:
generated = self._generate(prompt, max_sequence_length=max_sequence_length, max_output_length=max_output_length)
Expand Down
58 changes: 53 additions & 5 deletions factscore/openai_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,16 +24,43 @@ def load_model(self):
openai.api_key = api_key.strip()
self.model = self.model_name

def _construct_messages(self, prompt):
if isinstance(prompt, str):
return [{"role": "user", "content": prompt}]

if isinstance(prompt, dict):
role = prompt.get("role", None)
if not _valid_role(role):
raise ValueError("Invalid role {role}!")

content = prompt.get("content", None)
if not isinstance(content, str):
raise ValueError("Invalid content! Must be a string.")

return [prompt]

if isinstance(prompt, list):
for i, m in enumerate(prompt, 1):
_validate_message(m)

if i > 1 and m["role"] == "system":
raise ValueError("Only the first message can be a system message")

if i == len(prompt) and m["role"] != "user":
raise ValueError("The last message must be a user message")

return prompt

def _generate(self, prompt, max_sequence_length=2048, max_output_length=128):
if self.add_n % self.save_interval == 0:
self.save_cache()
# return a tuple of string (generated text) and metadata (any format)
# This should be about generating a response from the prompt, no matter what the application is
if self.model_name == "ChatGPT":
# Construct the prompt send to ChatGPT
message = [{"role": "user", "content": prompt}]
messages = self._construct_messages(prompt)
# Call API
response = call_ChatGPT(message, temp=self.temp, max_len=max_sequence_length)
response = call_ChatGPT(messages, temp=self.temp, max_len=max_sequence_length)
# Get the output from the response
output = response["choices"][0]["message"]["content"]
return output, response
Expand All @@ -46,15 +73,36 @@ def _generate(self, prompt, max_sequence_length=2048, max_output_length=128):
else:
raise NotImplementedError()

def call_ChatGPT(message, model_name="gpt-3.5-turbo", max_len=1024, temp=0.7, verbose=False):

ROLES = ["system", "user", "assistant"]


def _valid_role(role):
return role in ROLES


def _validate_message(message):
if not isinstance(message, dict):
raise ValueError("message must be a dict")

role = message.get("role", None)
if not _valid_role(role):
raise ValueError("Invalid role {role}!")

content = message.get("content", None)
if not isinstance(content, str):
raise ValueError("Invalid content! Must be a string.")


def call_ChatGPT(messages, model_name="gpt-3.5-turbo", max_len=1024, temp=0.7, verbose=False):
# call GPT-3 API until result is provided and then return it
response = None
received = False
num_rate_errors = 0
while not received:
try:
response = openai.ChatCompletion.create(model=model_name,
messages=message,
messages=messages,
max_tokens=max_len,
temperature=temp)
received = True
Expand All @@ -64,7 +112,7 @@ def call_ChatGPT(message, model_name="gpt-3.5-turbo", max_len=1024, temp=0.7, ve
error = sys.exc_info()[0]
if error == openai.error.InvalidRequestError:
# something is wrong: e.g. prompt too long
logging.critical(f"InvalidRequestError\nPrompt passed in:\n\n{message}\n\n")
logging.critical(f"InvalidRequestError\nPrompt passed in:\n\n{messages}\n\n")
assert False

logging.error("API error: %s (%d). Waiting %dsec" % (error, num_rate_errors, np.power(2, num_rate_errors)))
Expand Down