Skip to content

Commit

Permalink
Lift, tweak and publish some tailrecLib code
Browse files Browse the repository at this point in the history
The new entry-point is independently useful, at least while I am
working at the prototype sort of level.
  • Loading branch information
mn200 committed Jan 8, 2024
1 parent 85c8ef4 commit c4c8009
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 58 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -454,49 +454,12 @@ Theorem kunifywl_thm =
REWRITE_RULE [GSYM kunifywl_def] (CONJ unifywl0_NIL $ cj 3 unifywl0)

(* now to do guard-elimination *)

fun findin f t =
if aconv f t then SOME []
else
case total dest_comb t of
NONE => NONE
| SOME (t1,t2) =>
case findin f t1 of
NONE => NONE
| SOME pfx => SOME (pfx @ [t2])


fun tcallify fn_t inty t =
if TypeBase.is_case t then
let val (f, ts) = strip_comb t
val {Thy,Name,Ty} = dest_thy_const f
val f0 = prim_mk_const{Name=Name,Thy=Thy}
val basety = type_of f0
val (argtys, rngty) = strip_fun basety
val rng_th = match_type rngty (sumSyntax.mk_sum(inty,type_of t))
val argty_th = match_type (hd argtys) (type_of (hd ts))
val ft = Term.inst (rng_th @ argty_th) f0
val ft1 = mk_comb(ft, hd ts)
val ts' = map (tcallify fn_t inty) (tl ts)
in
list_mk_comb(ft1, ts')
end
else
case dest_term t of
CONST _ => sumSyntax.mk_inr(t,inty)
| VAR _ => sumSyntax.mk_inr(t,inty)
| LAMB(vt,bt) => mk_abs(vt, tcallify fn_t inty bt)
| COMB _ =>
case findin fn_t t of
NONE => sumSyntax.mk_inr(t,inty)
| SOME args => sumSyntax.mk_inl(pairSyntax.list_mk_pair args, type_of t)

fun tcallify_th th =
let val (l,r) = dest_eq (concl th)
val (lf, args) = strip_comb l
val atup = pairSyntax.list_mk_pair args
val inty = type_of atup
val body_t = tcallify lf inty r
val body_t = tailrecLib.mk_sum_term lf inty r
in
pairSyntax.mk_pabs(atup, body_t)
end
Expand Down
10 changes: 10 additions & 0 deletions src/num/theories/cv_compute/tailrecLib.sig
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,17 @@ sig

include Abbrev

val mk_sum_term : term -> hol_type -> term -> term

val tailrec_define : string -> term -> thm
val prove_tailrec_exists : term -> thm

end

(* [mk_sum_term fnt inty t] generates an abstraction term c that can be an
argument to TAILREC (or TAILCALL) such that ("roughly")
TAILCALL c fnt x = fnt x
The argument inty is type of the argument to fnt (x above)
*)
54 changes: 34 additions & 20 deletions src/num/theories/cv_compute/tailrecLib.sml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@ struct

open HolKernel Parse boolLib simpLib boolSimps

fun mk_HOL_ERR f msg = HOL_ERR {origin_structure = "tailrecLib",
origin_function = f, message = msg}

val Cases = BasicProvers.Cases
val PairCases = pairLib.PairCases

Expand All @@ -29,6 +32,36 @@ val TAILREC_def = whileTheory.TAILREC
|> CONV_RULE (DEPTH_CONV ETA_CONV)
|> REWRITE_RULE [GSYM combinTheory.I_EQ_IDABS];

fun mk_sum_term fn_t inty tm =
let
fun build_sum t =
if TypeBase.is_case t then
let val (a,b,xs) = TypeBase.dest_case t
val ys = map (apsnd build_sum) xs
in
TypeBase.mk_case (b,ys)
end
else if can pairSyntax.dest_anylet t then
let val (xs,x) = pairSyntax.dest_anylet t
in pairSyntax.mk_anylet(xs,build_sum x) end
else if cvSyntax.is_cv_if tm then
let val (b,x,y) = cvSyntax.dest_cv_if tm
in mk_cond(cvSyntax.mk_c2b b,build_sum x,build_sum y) end
else
let val (f, xs) = strip_comb t
in
if aconv f fn_t then
if null xs then raise mk_HOL_ERR "mk_sum_term" "malformed term"
else
sumSyntax.mk_inl (pairSyntax.list_mk_pair xs, type_of t)
else if is_abs t then
mk_abs (apsnd build_sum (dest_abs t))
else sumSyntax.mk_inr(t,inty)
end
in
build_sum tm
end

fun prove_simple_tailrec_exists tm = let
val (l,r) = dest_eq tm
val (f_tm,arg_tm) = dest_comb l
Expand All @@ -39,26 +72,7 @@ fun prove_simple_tailrec_exists tm = let
fun mk_inl x = sumSyntax.mk_inl(x,output_ty)
fun mk_inr x = sumSyntax.mk_inr(x,input_ty)
(* building the witness *)
fun build_sum tm =
if is_comb tm andalso aconv (rator tm) f_tm then
mk_inl (rand tm)
else if List.all (not o aconv f_tm) (free_vars tm) then
mk_inr tm
else if is_cond tm then let
val (b,x,y) = dest_cond tm
in mk_cond(b,build_sum x,build_sum y) end
else if cvSyntax.is_cv_if tm then let
val (b,x,y) = cvSyntax.dest_cv_if tm
in mk_cond(cvSyntax.mk_c2b b,build_sum x,build_sum y) end
else if can pairSyntax.dest_anylet tm then let
val (xs,x) = pairSyntax.dest_anylet tm
in pairSyntax.mk_anylet(xs,build_sum x) end
else if TypeBase.is_case tm then let
val (a,b,xs) = TypeBase.dest_case tm
val ys = map (fn (x,tm) => (x,build_sum tm)) xs
in TypeBase.mk_case(b,ys) end
else failwith ("Unsupported: " ^ term_to_string tm)
val sum_tm = build_sum r
val sum_tm = mk_sum_term f_tm input_ty r
val abs_sum_tm = pairSyntax.mk_pabs(arg_tm,sum_tm)
val witness = ISPEC abs_sum_tm whileTheory.TAILREC |> SPEC_ALL
|> concl |> dest_eq |> fst |> rator
Expand Down

0 comments on commit c4c8009

Please sign in to comment.