Skip to content

Commit

Permalink
Tweak tailrecLib to compile and write two simple regression tests
Browse files Browse the repository at this point in the history
(It can't include bossLib in the early context of src/num.)
  • Loading branch information
mn200 committed Nov 28, 2023
1 parent 84d2679 commit 679fd63
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 18 deletions.
20 changes: 19 additions & 1 deletion src/num/theories/cv_compute/selftest.sml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
open testutils HolKernel Parse boolLib cv_computeLib cvSyntax cvTheory;
open arithmeticTheory
open arithmeticTheory tailrecLib

fun simp ths = simpLib.ASM_SIMP_TAC (BasicProvers.srw_ss()) ths

Expand Down Expand Up @@ -57,3 +57,21 @@ fun test n =

val _ = List.app test [0, 1, 5, 10, 13, 74, 157, 180];

(* tail-recursion *)
val base_t = “fib A N x = if x = 0 then A else fib (N + A) A (x -1) ”
val expected = “?fib. !A N x. ^base_t”
val _ = tprint "tailrecursive fibonacci"
val _ = require_msg (check_result (aconv expected o concl)) thm_to_string
prove_tailrec_exists
base_t

val odd = “(odd n = if n = 0 then F else even (n - 1))”
val even = “(even i = if i = 0 then T else odd (i - 1))”
val base_t = mk_conj(odd,even)
val expected = “?odd even. (!n. ^odd) /\ (!i. ^even) ”
val _ = tprint "tailrecursive even/odd"
val _ = require_msg (check_result (aconv expected o concl)) thm_to_string
prove_tailrec_exists
base_t

val _ = require_msg
26 changes: 9 additions & 17 deletions src/num/theories/cv_compute/tailrecLib.sml
Original file line number Diff line number Diff line change
@@ -1,24 +1,15 @@
structure tailrecLib :> tailrecLib =
struct

open HolKernel Parse boolLib bossLib;
open HolKernel Parse boolLib simpLib boolSimps

val Cases = BasicProvers.Cases
val PairCases = pairLib.PairCases

(*----------------------------------------------------------------------*
Miscellaneous helper functions
*----------------------------------------------------------------------*)

fun list_dest_conj tm =
if is_conj tm then let
val (x,y) = dest_conj tm
in list_dest_conj x @ list_dest_conj y end
else [tm];

fun list_dest_exists tm = let
val (v,y) = dest_exists tm
val (vs,t) = list_dest_exists y
in (v::vs,t) end
handle HOL_ERR _ => ([],tm);

fun list_mk_pair_case pat r =
if not (pairSyntax.is_pair pat) then (pat,r) else let
val v = genvar (type_of pat)
Expand All @@ -27,7 +18,7 @@ fun list_mk_pair_case pat r =
val new_pat = pairSyntax.mk_pair(x1,y1)
in (v,TypeBase.mk_case(v,[(new_pat,r1)])) end

fun auto_prove goal_tm (tac:tactic) = snd (tac ([],goal_tm)) [];
fun auto_prove goal_tm (tac:tactic) = TAC_PROOF (([],goal_tm), tac)

(*----------------------------------------------------------------------*
Function for proving that non-mutually recursive tail-recursive
Expand Down Expand Up @@ -89,7 +80,8 @@ fun prove_simple_tailrec_exists tm = let
val vs = xs |> map (fn (x,y) => (y,genvar (type_of y)))
val specs = foldl (fn (x,t) => SPEC_TAC x THEN t) ALL_TAC vs
val gens = foldr (fn ((_,x),t) =>
if can pairSyntax.dest_prod (type_of x) then PairCases THEN t
if can pairSyntax.dest_prod (type_of x) then
PairCases THEN t
else gen_tac THEN t) ALL_TAC vs
fun expand_lets 0 = ALL_CONV
| expand_lets 1 = (REWR_CONV LET_THM THENC PairRules.PBETA_CONV)
Expand Down Expand Up @@ -129,7 +121,7 @@ fun prove_simple_tailrec_exists tm = let
*----------------------------------------------------------------------*)

fun prove_tailrec_exists def_tm = let
val defs = list_dest_conj def_tm
val defs = strip_conj def_tm
(* build the goal to prove *)
fun extract_def def_tm = let
val (l,r) = dest_eq def_tm
Expand Down Expand Up @@ -190,7 +182,7 @@ fun prove_tailrec_exists def_tm = let

fun tailrec_define name def_tm = let
val lemma = prove_tailrec_exists def_tm
val names = lemma |> concl |> list_dest_exists |> fst |> map (fst o dest_var)
val names = lemma |> concl |> strip_exists |> fst |> map (fst o dest_var)
in new_specification(name,names,lemma) end

end

0 comments on commit 679fd63

Please sign in to comment.