diff --git a/src/num/theories/cv_compute/selftest.sml b/src/num/theories/cv_compute/selftest.sml index 66e3a29f66..2db5065e25 100644 --- a/src/num/theories/cv_compute/selftest.sml +++ b/src/num/theories/cv_compute/selftest.sml @@ -1,5 +1,5 @@ open testutils HolKernel Parse boolLib cv_computeLib cvSyntax cvTheory; -open arithmeticTheory +open arithmeticTheory tailrecLib fun simp ths = simpLib.ASM_SIMP_TAC (BasicProvers.srw_ss()) ths @@ -57,3 +57,21 @@ fun test n = val _ = List.app test [0, 1, 5, 10, 13, 74, 157, 180]; +(* tail-recursion *) +val base_t = “fib A N x = if x = 0 then A else fib (N + A) A (x -1) ” +val expected = “?fib. !A N x. ^base_t” +val _ = tprint "tailrecursive fibonacci" +val _ = require_msg (check_result (aconv expected o concl)) thm_to_string + prove_tailrec_exists + base_t + +val odd = “(odd n = if n = 0 then F else even (n - 1))” +val even = “(even i = if i = 0 then T else odd (i - 1))” +val base_t = mk_conj(odd,even) +val expected = “?odd even. (!n. ^odd) /\ (!i. ^even) ” +val _ = tprint "tailrecursive even/odd" +val _ = require_msg (check_result (aconv expected o concl)) thm_to_string + prove_tailrec_exists + base_t + +val _ = require_msg diff --git a/src/num/theories/cv_compute/tailrecLib.sml b/src/num/theories/cv_compute/tailrecLib.sml index 06dcf7b2ac..a3dfdce2c2 100644 --- a/src/num/theories/cv_compute/tailrecLib.sml +++ b/src/num/theories/cv_compute/tailrecLib.sml @@ -1,24 +1,15 @@ structure tailrecLib :> tailrecLib = struct -open HolKernel Parse boolLib bossLib; +open HolKernel Parse boolLib simpLib boolSimps + +val Cases = BasicProvers.Cases +val PairCases = pairLib.PairCases (*----------------------------------------------------------------------* Miscellaneous helper functions *----------------------------------------------------------------------*) -fun list_dest_conj tm = - if is_conj tm then let - val (x,y) = dest_conj tm - in list_dest_conj x @ list_dest_conj y end - else [tm]; - -fun list_dest_exists tm = let - val (v,y) = dest_exists tm - val (vs,t) = list_dest_exists y - in (v::vs,t) end - handle HOL_ERR _ => ([],tm); - fun list_mk_pair_case pat r = if not (pairSyntax.is_pair pat) then (pat,r) else let val v = genvar (type_of pat) @@ -27,7 +18,7 @@ fun list_mk_pair_case pat r = val new_pat = pairSyntax.mk_pair(x1,y1) in (v,TypeBase.mk_case(v,[(new_pat,r1)])) end -fun auto_prove goal_tm (tac:tactic) = snd (tac ([],goal_tm)) []; +fun auto_prove goal_tm (tac:tactic) = TAC_PROOF (([],goal_tm), tac) (*----------------------------------------------------------------------* Function for proving that non-mutually recursive tail-recursive @@ -89,7 +80,8 @@ fun prove_simple_tailrec_exists tm = let val vs = xs |> map (fn (x,y) => (y,genvar (type_of y))) val specs = foldl (fn (x,t) => SPEC_TAC x THEN t) ALL_TAC vs val gens = foldr (fn ((_,x),t) => - if can pairSyntax.dest_prod (type_of x) then PairCases THEN t + if can pairSyntax.dest_prod (type_of x) then + PairCases THEN t else gen_tac THEN t) ALL_TAC vs fun expand_lets 0 = ALL_CONV | expand_lets 1 = (REWR_CONV LET_THM THENC PairRules.PBETA_CONV) @@ -129,7 +121,7 @@ fun prove_simple_tailrec_exists tm = let *----------------------------------------------------------------------*) fun prove_tailrec_exists def_tm = let - val defs = list_dest_conj def_tm + val defs = strip_conj def_tm (* build the goal to prove *) fun extract_def def_tm = let val (l,r) = dest_eq def_tm @@ -190,7 +182,7 @@ fun prove_tailrec_exists def_tm = let fun tailrec_define name def_tm = let val lemma = prove_tailrec_exists def_tm - val names = lemma |> concl |> list_dest_exists |> fst |> map (fst o dest_var) + val names = lemma |> concl |> strip_exists |> fst |> map (fst o dest_var) in new_specification(name,names,lemma) end end