Skip to content

Commit

Permalink
rework postprocessing, add rewrite step
Browse files Browse the repository at this point in the history
  • Loading branch information
WeetHet committed Sep 30, 2024
1 parent f3e3593 commit de81a83
Show file tree
Hide file tree
Showing 10 changed files with 198 additions and 29 deletions.
83 changes: 83 additions & 0 deletions prompts/humaneval-dafny-cot/rewrite/examples/001/answer.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
During the previous steps, we've found that the `sum_loop` function has a loop that goes over the elements of the array `numbers` and adds them to the `s` variable, which is equal to the `sum(numbers[..i])` for all `i`, such that `0 <= i <= |numbers|`. This means that during the execution of the loop the following invariant holds:
1. `0 <= i <= |numbers|`
2. `s == sum(numbers[..i])`

Let's put them in:
```dafny
method sum_loop(numbers: seq<int>) returns (s: int)
ensures s == sum(numbers)
{
s := 0;
var i := 0;
while (i < |numbers|)
invariant 0 <= i <= |numbers|
invariant s == sum(numbers[..i])
{
s := s + numbers[i];

i := i + 1;
}
}
```

We've also worked out how to prove the second invariant. For that we need to prove to the verifier that `sum(numbers[..i + 1]) == sum(numbers[..i]) + numbers[i]`. For that we need to use the `sum_prop` function, as follows:
```dafny
assert numbers[..i + 1][..i] == numbers[..i]; // assert that taking a slice of a slice is the same as taking a slice of the original array
sum_prop(numbers[..i + 1]); // use a lemma to prove the invariant
assert sum(numbers[..i + 1]) == sum(numbers[..i + 1][..i]) + numbers[i];
```
Let's add these lines to the code:
```dafny
method sum_loop(numbers: seq<int>) returns (s: int)
ensures s == sum(numbers)
{
s := 0;
var i := 0;
while (i < |numbers|)
invariant 0 <= i <= |numbers|
invariant s == sum(numbers[..i])
{
s := s + numbers[i];

assert numbers[..i + 1][..i] == numbers[..i]; // assert that taking a slice of a slice is the same as taking a slice of the original array
sum_prop(numbers[..i + 1]); // use a lemma to prove the invariant
assert sum(numbers[..i + 1]) == sum(numbers[..i + 1][..i]) + numbers[i];

i := i + 1;
}
}
```

Other than that, we didn't find any other invariants or patterns that hold during the execution of the program. Therefore, the resulting program is (notice the lack of backticks):
<answer>
function sum(s: seq<int>) : int {
if |s| == 0 then 0 else s[0] + sum(s[1..])
}

lemma sum_prop(s: seq<int>)
requires |s| > 0
ensures sum(s) == sum(s[..|s| - 1]) + s[ |s| - 1 ]
{
if (|s| > 1) {
assert (s[1..][..|s[1..]| - 1]) == s[1..|s| - 1];
}
}
method sum_loop(numbers: seq<int>) returns (s: int)
ensures s == sum(numbers)
{
s := 0;
var i := 0;
while (i < |numbers|)
invariant 0 <= i <= |numbers|
invariant s == sum(numbers[..i])
{
s := s + numbers[i];

assert numbers[..i + 1][..i] == numbers[..i]; // assert that taking a slice of a slice is the same as taking a slice of the original array
sum_prop(numbers[..i + 1]); // use a lemma to prove the invariant
assert sum(numbers[..i + 1]) == sum(numbers[..i + 1][..i]) + numbers[i];

i := i + 1;
}
}
</answer>
28 changes: 28 additions & 0 deletions prompts/humaneval-dafny-cot/rewrite/examples/001/question.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
It's time to rewrite the program to contain the invariants and assertions that we've found. Work on the following program and go step by step through the reasoning process. After that, output the rewritten program in the `<answer>` tag without backticks.

```dafny
function sum(s: seq<int>) : int {
if |s| == 0 then 0 else s[0] + sum(s[1..])
}

lemma sum_prop(s: seq<int>)
requires |s| > 0
ensures sum(s) == sum(s[..|s| - 1]) + s[ |s| - 1 ]
{
if (|s| > 1) {
assert (s[1..][..|s[1..]| - 1]) == s[1..|s| - 1];
}
}
method sum_loop(numbers: seq<int>) returns (s: int)
ensures s == sum(numbers)
{
s := 0;
var i := 0;
while (i < |numbers|)
{
s := s + numbers[i];

i := i + 1;
}
}
```
Empty file.
4 changes: 2 additions & 2 deletions prompts/humaneval-dafny-cot/steps/003/001/answer.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ Let's work out how to use it. If applied to the `numbers[..i + 1]`, we get:
`sum(numbers[..i + 1]) == sum(numbers[..i + 1][..i]) + numbers[i]`. It might be difficult for dafny to prove that `sum(numbers[..i + 1][..i]) == sum(numbers[..i])`, so I'd suggest adding an assertion before calling `sum_prop` to make sure that this is the case, namely `assert numbers[..i + 1][..i] == numbers[..i]`.
So, to use the `sum_prop`, I'd suggest the following code:
```
assert numbers[..i + 1][..i] == numbers[..i];
sum_prop(numbers[..i + 1][..i]);
assert numbers[..i + 1][..i] == numbers[..i]; // assert that taking a slice of a slice is the same as taking a slice of the original array
sum_prop(numbers[..i + 1]); // use a lemma to prove the invariant
assert sum(numbers[..i + 1]) == sum(numbers[..i + 1][..i]) + numbers[i];
```

Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ src_paths = ["verified_cogen"]

[tool.pyright]
typeCheckingMode = "strict"
pythonVersion = "3.9"

[[tool.poetry.source]]
name = "PyPI"
Expand Down
4 changes: 2 additions & 2 deletions verified_cogen/experiments/use_houdini.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def collect_invariants(args: ProgramArgs, prg: str) -> list[str]:
temperature=temperature,
)

llm.user_prompts.append(
llm.add_user_prompt(
INVARIANTS_JSON_PROMPT.replace("{program}", prg).replace("{function}", func)
)
response = llm.make_request() # type: ignore
Expand All @@ -143,7 +143,7 @@ def collect_invariants(args: ProgramArgs, prg: str) -> list[str]:
def remove_failed_invariants(
llm: LLM, invariants: list[str], err: str
) -> Optional[list[str]]:
llm.user_prompts.append(REMOVE_FAILED_INVARIANTS_PROMPT.format(error=err))
llm.add_user_prompt(REMOVE_FAILED_INVARIANTS_PROMPT.format(error=err))
response = llm.make_request() # type: ignore
try:
new_invariants = json.loads(response)
Expand Down
34 changes: 26 additions & 8 deletions verified_cogen/llm/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,18 +34,36 @@ def __init__(
self.prompt_dir = prompt_dir
self.is_gpt = "gpt" in self.profile.name
self.user_prompts: list[str] = []
self.is_user_prompt_temporary: list[bool] = []
self.responses: list[str] = []
self.is_response_temporary: list[bool] = []
self.had_errors = False
self.temperature = temperature
self.system_prompt = (
system_prompt if system_prompt else prompts.sys_prompt(self.prompt_dir)
)

def add_user_prompt(self, prompt: str):
def add_user_prompt(self, prompt: str, temporary: bool = False):
self.user_prompts.append(prompt)
self.is_user_prompt_temporary.append(temporary)

def add_response(self, response: str):
def add_response(self, response: str, temporary: bool = False):
self.responses.append(response)
self.is_response_temporary.append(temporary)

def wipe_temporary(self):
self.user_prompts = [
prompt
for prompt, temporary in zip(
self.user_prompts, self.is_user_prompt_temporary
)
if not temporary
]
self.responses = [
response
for response, temporary in zip(self.responses, self.is_response_temporary)
if not temporary
]

def _request(
self, temperature: Optional[float] = None, tries: int = 5
Expand Down Expand Up @@ -81,11 +99,11 @@ def _request(

def make_request(self) -> str:
response = self._request().content
self.responses.append(response)
self.add_response(response)
return extract_code_from_llm_output(response)

def produce(self, prg: str) -> str:
self.user_prompts.append(
self.add_user_prompt(
prompts.produce_prompt(self.prompt_dir).format(program=prg)
)
return self.make_request()
Expand All @@ -94,11 +112,11 @@ def add(self, prg: str, checks: str, function: Optional[str] = None) -> str:
prompt = prompts.add_prompt(self.prompt_dir).format(program=prg, checks=checks)
if "{function}" in prompt and function is not None:
prompt = prompt.replace("{function}", function)
self.user_prompts.append(prompt)
self.add_user_prompt(prompt, False)
return self.make_request()

def rewrite(self, prg: str) -> str:
self.user_prompts.append(
self.add_user_prompt(
prompts.rewrite_prompt(self.prompt_dir).replace("{program}", prg)
)
return self.make_request()
Expand All @@ -109,9 +127,9 @@ def ask_for_fixed(self, err: str) -> str:
if self.had_errors
else prompts.ask_for_fixed_prompt(self.prompt_dir)
)
self.user_prompts.append(prompt.format(error=err))
self.add_user_prompt(prompt.format(error=err))
return self.make_request()

def ask_for_timeout(self) -> str:
self.user_prompts.append(prompts.ask_for_timeout_prompt(self.prompt_dir))
self.add_user_prompt(prompts.ask_for_timeout_prompt(self.prompt_dir))
return self.make_request()
14 changes: 11 additions & 3 deletions verified_cogen/runners/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ class Runner:
logger: Logger
verifier: Verifier
log_tries: Optional[pathlib.Path]
starting_prg: Optional[str] = None

def __init__(
self,
Expand Down Expand Up @@ -56,6 +57,9 @@ def precheck(self, prg: str, mode: Mode):
def preprocess(self, prg: str, mode: Mode) -> str:
return prg

def postprocess(self, inv_prg: str) -> str:
return inv_prg

def invoke(self, prg: str, mode: Mode) -> str:
self.logger.info("Invoking LLM")
if mode == Mode.LLM_SINGLE_STEP:
Expand Down Expand Up @@ -97,7 +101,7 @@ def try_fixing(
self.logger.info("Verification timed out")
tries -= 1
if tries > 0:
inv_prg = self.ask_for_timeout()
inv_prg = self.postprocess(self.ask_for_timeout())
else:
verified_inv, out_inv, err_inv = verification_result
if verified_inv:
Expand All @@ -109,7 +113,9 @@ def try_fixing(
tries -= 1
if tries > 0:
self.logger.info(f"Retrying with {tries} tries left...")
inv_prg = self.ask_for_fixed(out_inv + err_inv)
inv_prg = self.postprocess(
self.ask_for_fixed(out_inv + err_inv)
)
return None

def run_on_file(
Expand All @@ -124,11 +130,13 @@ def run_on_file(
with open(file, "r") as f:
prg = self.preprocess(f.read(), mode)

self.starting_prg = prg

verification_result = self.verify_program(name, 0, prg)
if verification_result is not None and verification_result[0]:
return 0
elif verification_result is None:
self.logger.info("Verification timed out")
self.precheck(prg, mode)
inv_prg = self.invoke(prg, mode)
inv_prg = self.postprocess(self.invoke(prg, mode))
return self.try_fixing(total_tries, inv_prg, name)
37 changes: 33 additions & 4 deletions verified_cogen/runners/step_by_step.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,28 @@
import pathlib
from typing import Optional

from verified_cogen.runners import Runner
from verified_cogen.tools import extract_code_from_llm_output
from verified_cogen.tools.modes import Mode
from verified_cogen.runners.chain_of_thought import Step


class StepByStepConfig:
remove_old_examples: bool = False

@classmethod
def default(cls) -> "StepByStepConfig":
return cls()


class StepByStepRunner(Runner):
wrapped_runner: Runner
config: StepByStepConfig

def __init__(self, wrapping: Runner):
def __init__(self, wrapping: Runner, config: Optional[StepByStepConfig] = None):
super().__init__(wrapping.llm, wrapping.logger, wrapping.verifier)
self.wrapped_runner = wrapping
self.config = StepByStepConfig.default() if config is None else config

def preprocess(self, prg: str, mode: Mode) -> str:
return self.wrapped_runner.preprocess(prg, mode)
Expand All @@ -22,11 +34,28 @@ def rewrite(self, prg: str) -> str:
steps.append(Step(step))
for step in steps:
for substep in step.substeps:
self.llm.add_user_prompt(substep.question)
self.llm.add_response(substep.answer)
self.llm.add_user_prompt(
substep.question, self.config.remove_old_examples
)
self.llm.add_response(substep.answer, self.config.remove_old_examples)
self.llm.add_user_prompt(step.question.replace("{program}", prg))
_ = self.llm.make_request()
return self.wrapped_runner.rewrite(prg)
if self.config.remove_old_examples:
self.llm.wipe_temporary()

rewrite_step = Step(pathlib.Path(self.llm.prompt_dir) / "rewrite")
for substep in rewrite_step.substeps:
self.llm.add_user_prompt(substep.question, self.config.remove_old_examples)
self.llm.add_response(substep.answer, self.config.remove_old_examples)
self.llm.add_user_prompt(rewrite_step.question.replace("{program}", prg))
response = self.llm.make_request()
if self.config.remove_old_examples:
self.llm.wipe_temporary()

return extract_code_from_llm_output(response)

def postprocess(self, inv_prg: str) -> str:
return self.wrapped_runner.postprocess(inv_prg)

def produce(self, prg: str) -> str:
return self.wrapped_runner.produce(prg)
Expand Down
22 changes: 12 additions & 10 deletions verified_cogen/runners/validating.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
class ValidatingRunner(Runner):
wrapped_runner: Runner
language: Language
prg: Optional[str] = None

def __init__(
self,
Expand All @@ -29,29 +28,32 @@ def _add_validators(self, prg: str, inv_prg: str):
def preprocess(self, prg: str, mode: Mode) -> str:
return self.language.remove_asserts_and_invariants(prg)

def postprocess(self, inv_prg: str) -> str:
assert self.starting_prg is not None
return self._add_validators(
self.starting_prg, self.wrapped_runner.postprocess(inv_prg)
)

def rewrite(self, prg: str) -> str:
self.prg = prg
return self._add_validators(prg, self.wrapped_runner.rewrite(prg))
return self.wrapped_runner.rewrite(prg)

def produce(self, prg: str) -> str:
self.prg = prg
return self.wrapped_runner.produce(prg)

def insert(self, prg: str, checks: str, mode: Mode) -> str:
self.prg = prg
return self._add_validators(prg, self.wrapped_runner.insert(prg, checks, mode))
return self.wrapped_runner.insert(prg, checks, mode)

def ask_for_timeout(self) -> str:
assert (
self.prg is not None
self.starting_prg is not None
), "one of: rewrite, produce, insert must be called before ask_for_timeout"
return self._add_validators(self.prg, self.wrapped_runner.ask_for_timeout())
return self.wrapped_runner.ask_for_timeout()

def ask_for_fixed(self, err: str) -> str:
assert (
self.prg is not None
self.starting_prg is not None
), "one of: rewrite, produce, insert must be called before ask_for_fixed"
return self._add_validators(self.prg, self.wrapped_runner.ask_for_fixed(err))
return self.wrapped_runner.ask_for_fixed(err)

def precheck(self, prg: str, mode: Mode):
return self.wrapped_runner.precheck(prg, mode)

0 comments on commit de81a83

Please sign in to comment.