Skip to content

Commit

Permalink
initial infrastructure for chain-of-thought
Browse files Browse the repository at this point in the history
  • Loading branch information
WeetHet committed Sep 24, 2024
1 parent f9c0a9f commit b84f182
Show file tree
Hide file tree
Showing 6 changed files with 151 additions and 8 deletions.
52 changes: 52 additions & 0 deletions prompts/dafny_invariants_cot/001.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
===== QUESTION_EXAMPLE =====
Given the following program, for each function, determine if it has a specification. If it does not, is it itself a specification? If it is a specification, what does it specify?
```dafny
function sum(s: seq<int>) : int {
if |s| == 0 then 0 else s[0] + sum(s[1..])
}

lemma sum_prop(s: seq<int>)
requires |s| > 0
ensures sum(s) == sum(s[..|s| - 1]) + s[ |s| - 1 ]
{
if (|s| > 1) {
assert (s[1..][..|s[1..]| - 1]) == s[1..|s| - 1];
}
}
method sum_loop(numbers: seq<int>) returns (s: int)
ensures s == sum(numbers)
{
assert numbers[..|numbers|] == numbers;
s := 0;
var i := 0;
while (i < |numbers|)
invariant 0 <= i <= |numbers|
invariant s == sum(numbers[..i])
{
assert sum(numbers[..i + 1]) == sum(numbers[..i]) + numbers[i] by {
assert numbers[..i+1][..i] == numbers[..i];
sum_prop(numbers[..i + 1]);
}
s := s + numbers[i];

i := i + 1;
}
}
```
===== ANSWER =====
Let's go function by function. The first function, `sum` does not have a specification, as can be seen by the absence of either a `requires` or `ensures` clause. This means that this function is probably a specification itself. Let's consider, what it can be step by step. We can see that the aruguments to this function are a sequence of integers. We can see that the function returns an integer. We can conclude that the function does some accumulation of integers from the sequence. Looking at the body, we can see that the function takes the first element of the sequence and adds it to the accumulated result of evaluating the `sum` function on the rest of the sequence. This means the function is a specification for the sum of the elements of the sequence.

Next, we have the `sum_prop` lemma. Considering it's specification, we can see that it requires a non-empty sequence of integers, as seen by the `requires |s| > 0` clause. Looking at the ensures clause, we can see that the lemma ensures that the sum of the sequence (as computed by the `sum` function) is equal to the sum of the elements of the sequence without the last element, plus that last element. This is important to prove, because the sum function drops the first element and thus it is not obvious for the prover that the last element can be dropped in the same way.

Finally, we have the `sum_loop` method. This method has a specification, as indicated by the `ensures` clause. The specification states that the returned value `s` should be equal to the sum of the input sequence `numbers`, as computed by the `sum` function, as seen by the `ensures s == sum(numbers)` clause.

In summary:
1. The `sum` function is itself a specification for computing the sum of a sequence of integers.
2. The `sum_prop` lemma has a specification that proves a property about the `sum` function.
3. The `sum_loop` method has a specification that ensures it correctly computes the sum of a sequence of integers, matching the result of the `sum` function.
===== QUESTION =====
Given the following program, for each function, determine if it has a specification. If it does not, is it itself a specification? If it is a specification, what does it specify?

```dafny
{program}
```
4 changes: 2 additions & 2 deletions verified_cogen/experiments/use_houdini.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def collect_invariants(args: ProgramArgs, prg: str) -> list[str]:
llm.user_prompts.append(
INVARIANTS_JSON_PROMPT.replace("{program}", prg).replace("{function}", func)
)
response = llm._make_request() # type: ignore
response = llm.make_request() # type: ignore
try:
invariants = json.loads(response)
result_invariants.extend(invariants)
Expand All @@ -144,7 +144,7 @@ def remove_failed_invariants(
llm: LLM, invariants: list[str], err: str
) -> Optional[list[str]]:
llm.user_prompts.append(REMOVE_FAILED_INVARIANTS_PROMPT.format(error=err))
response = llm._make_request() # type: ignore
response = llm.make_request() # type: ignore
try:
new_invariants = json.loads(response)
log.debug("REMOVED: {}".format(set(invariants).difference(set(new_invariants))))
Expand Down
18 changes: 12 additions & 6 deletions verified_cogen/llm/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,12 @@ def __init__(
system_prompt if system_prompt else prompts.sys_prompt(self.prompt_dir)
)

def add_user_prompt(self, prompt: str):
self.user_prompts.append(prompt)

def add_response(self, response: str):
self.responses.append(response)

def _request(
self, temperature: Optional[float] = None, tries: int = 5
) -> ChatResponse:
Expand Down Expand Up @@ -73,7 +79,7 @@ def _request(
logger.warning("Grazie API is down, retrying...")
return self._request(temperature, tries - 1)

def _make_request(self) -> str:
def make_request(self) -> str:
response = self._request().content
self.responses.append(response)
return extract_code_from_llm_output(response)
Expand All @@ -82,20 +88,20 @@ def produce(self, prg: str) -> str:
self.user_prompts.append(
prompts.produce_prompt(self.prompt_dir).format(program=prg)
)
return self._make_request()
return self.make_request()

def add(self, prg: str, checks: str, function: Optional[str] = None) -> str:
prompt = prompts.add_prompt(self.prompt_dir).format(program=prg, checks=checks)
if "{function}" in prompt and function is not None:
prompt = prompt.replace("{function}", function)
self.user_prompts.append(prompt)
return self._make_request()
return self.make_request()

def rewrite(self, prg: str) -> str:
self.user_prompts.append(
prompts.rewrite_prompt(self.prompt_dir).replace("{program}", prg)
)
return self._make_request()
return self.make_request()

def ask_for_fixed(self, err: str) -> str:
prompt = (
Expand All @@ -104,8 +110,8 @@ def ask_for_fixed(self, err: str) -> str:
else prompts.ask_for_fixed_prompt(self.prompt_dir)
)
self.user_prompts.append(prompt.format(error=err))
return self._make_request()
return self.make_request()

def ask_for_timeout(self) -> str:
self.user_prompts.append(prompts.ask_for_timeout_prompt(self.prompt_dir))
return self._make_request()
return self.make_request()
3 changes: 3 additions & 0 deletions verified_cogen/runners/chain_of_thought/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .step import Step as Step

__ALL__ = ["Step"]
39 changes: 39 additions & 0 deletions verified_cogen/runners/chain_of_thought/step.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
class Substep:
question: str
answer: str

def __init__(self, question: str, answer: str):
self.question = question
self.answer = answer


class Step:
substeps: list[Substep] = []
question: str

def __init__(self, data: str):
data_lines = data.split("\n")
current_question: list[str] = []
current_answer: list[str] = []
appending_to = current_question
for i, line in enumerate(data_lines):
if line.startswith("===== QUESTION_EXAMPLE ====="):
self.substeps.append(Substep("\n".join(current_question), "\n".join(current_answer)))
current_question, current_answer = [], []
appending_to = current_question
elif line.startswith("===== ANSWER =====")
appending_to = current_answer
elif line.startswith("===== QUESTION ====="):
self.question = "\n".join(data_lines[i + 1:])
break
else:
appending_to.append(line)

def __repr__(self) -> str:
return f"Step(question={self.question}, substeps={self.substeps})"

def __str__(self) -> str:
result = ""
for substep in self.substeps:
result += f"===== QUESTION_EXAMPLE =====\n{substep.question}\n===== ANSWER =====\n{substep.answer}\n"
return result + f"===== QUESTION =====\n{self.question}\n"
43 changes: 43 additions & 0 deletions verified_cogen/runners/step_by_step.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import pathlib

from verified_cogen.runners import Runner
from verified_cogen.tools.modes import Mode
from verified_cogen.runners.chain_of_thought import Step


class StepByStepRunner(Runner):
wrapped_runner: Runner

def __init__(self, wrapping: Runner):
super().__init__(wrapping.llm, wrapping.logger, wrapping.verifier)
self.wrapped_runner = wrapping

def preprocess(self, prg: str, mode: Mode) -> str:
return self.wrapped_runner.preprocess(prg, mode)

def rewrite(self, prg: str) -> str:
steps: list[Step] = []
for step in sorted(pathlib.Path("prompts/step_by_step").glob("*.txt")):
steps.append(Step(step.read_text()))
for step in steps:
for substep in step.substeps:
self.llm.add_user_prompt(substep.question)
self.llm.add_response(substep.answer)
self.llm.add_user_prompt(step.question)
_ = self.llm.make_request()
return self.wrapped_runner.rewrite(prg)

def produce(self, prg: str) -> str:
return self.wrapped_runner.produce(prg)

def insert(self, prg: str, checks: str, mode: Mode) -> str:
return self.wrapped_runner.insert(prg, checks, mode)

def ask_for_timeout(self) -> str:
return self.wrapped_runner.ask_for_timeout()

def ask_for_fixed(self, err: str) -> str:
return self.wrapped_runner.ask_for_fixed(err)

def precheck(self, prg: str, mode: Mode):
return self.wrapped_runner.precheck(prg, mode)

0 comments on commit b84f182

Please sign in to comment.