Skip to content

Commit

Permalink
Add tailrecursive keyword/attribute to Definition handling
Browse files Browse the repository at this point in the history
With some simple tests.
  • Loading branch information
mn200 committed Nov 29, 2023
1 parent 679fd63 commit 6d57be5
Show file tree
Hide file tree
Showing 6 changed files with 101 additions and 8 deletions.
2 changes: 1 addition & 1 deletion src/1/ThmAttribute.sml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ struct
val funstore = ref (Map.empty : attrfuns Map.table)

val reserved_attrnames = ["local", "unlisted", "nocompute", "schematic",
"notuserdef", "allow_rebind"]
"notuserdef", "allow_rebind", "tailrecursive"]

fun okchar c = Char.isAlphaNum c orelse c = #"_" orelse c = #"'"
fun illegal_attrname s = Lib.mem s reserved_attrnames orelse
Expand Down
1 change: 1 addition & 0 deletions src/num/termination/TotalDefn.sig
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ sig

val defnDefine : tactic -> defn -> thm * thm option * thm option
val primDefine : defn -> thm * thm option * thm option
val tailrecDefine: string -> term quotation -> thm
val tDefine : string -> term quotation -> tactic -> thm * thm option
val xDefine : string -> term quotation -> thm * thm option
val Define : term quotation -> thm
Expand Down
25 changes: 21 additions & 4 deletions src/num/termination/TotalDefn.sml
Original file line number Diff line number Diff line change
Expand Up @@ -742,13 +742,23 @@ fun find_indoption sl =
set_diff sl [s]
)

fun tailrecDefine nm q =
let
val (t, _) = Defn.parse_absyn (Parse.Absyn q)
val th = tailrecLib.tailrec_define nm t
in
Defn.add_defs_to_EVAL [(nm,th)];
th
end

fun qDefine stem q tacopt =
let
val (corename, attrs) = ThmAttribute.extract_attributes stem
val (nocomp, attrs) = test_remove "nocompute" attrs
val (svarsok, attrs) = test_remove "schematic" attrs
val (notuserdef, attrs) = test_remove "notuserdef" attrs
val (rebindok, attrs) = test_remove "allow_rebind" attrs
val (tailrecp, attrs) = test_remove "tailrecursive" attrs
val (indopt,attrs) = find_indoption attrs
fun fmod f =
f |> (if nocomp then trace ("computeLib.auto_import_definitions", 0)
Expand All @@ -761,14 +771,21 @@ fun qDefine stem q tacopt =
|> (if rebindok then trace ("Theory.allow_rebinds", 1)
else (fn f => f))
val (thm,indopt) =
case tacopt of
NONE => fmod (xDefine corename) q
| SOME tac => fmod (tDefine corename q) tac
case (tailrecp, tacopt) of
(true, NONE) => (fmod (tailrecDefine corename) q, NONE)
| (true, SOME _) =>
raise ERR "qDefine"
"Termination tactic for tail-recursive definition makes \
\no sense"
| (false, NONE) => fmod (xDefine corename) q
| (false, SOME tac) => fmod (tDefine corename q) tac
fun proc_attr a =
ThmAttribute.store_at_attribute{name = corename, attrname = a,
thm = thm}
val attrs = if notuserdef then attrs else "userdef" :: attrs
val gen_ind = Prim_rec.gen_indthm {lookup_ind = TypeBase.induction_of}
val gen_ind =
if tailrecp then (fn th => raise ERR "Unseen" "")
else Prim_rec.gen_indthm {lookup_ind = TypeBase.induction_of}
in
List.app proc_attr attrs;
if notuserdef then ()
Expand Down
74 changes: 74 additions & 0 deletions src/num/termination/selftest.sml
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
open HolKernel Parse boolLib
open testutils TotalDefn
val _ = Feedback.emit_MESG := false

fun badpp x = HOLPP.add_string "<Can't print this>"

fun EVAL t = computeLib.CBV_CONV computeLib.the_compset t
val _ = tprint "Testing mutually recursive function definition"

val f_def = require (check_result (K true)) Define`
Expand Down Expand Up @@ -110,3 +114,73 @@ val _ = require_msg (check_result lhs_has_two_args)
(TotalDefn.qDefine "foo2[schematic]" q)
(SOME (WF_REL_TAC ‘$<’)))
‘foo2 x = if x = 0 then y else foo2(x - 1)*2’;

val _ = tprint "tailrecDefine (simple recursion: fact2)"
val expected_pat = “!A n. ff A n = if n < 1 then A else ff (A * n) (n - 1)”
fun check1 th =
case match_term expected_pat (concl th) of
([{redex,residue}], []) =>
aconv redex “ff:num->num->num” andalso
#1 (dest_const residue) = "fact2"
| _ => false
fun check2 _ = convtest("fact2 evaluates OK", EVAL, “fact2 1 6”, “720”)
val _ = require_msgk (check_result check1) pp_thm
(fn q => TotalDefn.qDefine "fact2_def[tailrecursive]"
q NONE)
check2
‘fact2 A n = if n < 1 then A else fact2 (A * n) (n-1)’;

val _ = tprint "tailrecDefine (simple recursion + rebind: fact)"
val expected_pat = “!A n. ff A n = if n < 1 then A else ff (A * n) (n - 1)”
fun check1 th =
case match_term expected_pat (concl th) of
([{redex,residue}], []) =>
aconv redex “ff:num->num->num” andalso
#1 (dest_const residue) = "fact"
| _ => false
fun check2 _ = convtest("fact evaluates OK", EVAL, “fact 1 5”, “120”)
val _ = require_msgk (check_result check1) pp_thm
(allquiet
(fn q =>
TotalDefn.qDefine
"fact_def[tailrecursive,allow_rebind]" q NONE))
check2
‘fact A n = if n < 1 then A else fact (A * n) (n-1)’;

val _ = tprint "tailrecDefine (2-way mutual recursion)"
val expected_pat = “(!x. ff1 x = if x = 0 then F else ff2 (x + 1)) /\
(!n. ff2 n = if n = 0 then T
else if n = 1 then F
else if n = 2 then T
else ff1 (n - 3))”
fun check1 th =
case match_term expected_pat (concl th) of
(tms as [rr1,rr2], []) => List.all (is_const o #residue) tms
| _ => false
fun check2 _ =
(tprint "DefnBase has record for even";
require_msgk (check_result Option.isSome) badpp DefnBase.lookup_userdef
(fn _ => convtest ("even evaluates OK", EVAL, “even 11”, “F”))
“even”)
val _ = require_msgk (check_result check1) pp_thm
(allquiet
(fn q =>
TotalDefn.qDefine
"odd_def[tailrecursive]" q NONE))
check2
‘(odd x = if x = 0 then F else even (x + 1)) /\
(even n = if n = 0 then T
else if n = 1 then F
else if n = 2 then T
else odd (n - 3))’;

val _ = tprint "tailrecDefine (2-way + isolate & nocompute)"
val _ = require_msgk (check_result (K true)) pp_thm
(fn q =>
TotalDefn.qDefine
"urk_def[tailrecursive,nocompute]" q NONE)
(fn _ => convtest("urk unevals", EVAL,
“urk x + urk' m”, “urk x + urk' m”))
‘(urk n = urk2 (n + 1)) /\
(urk' n = n + 1) /\
(urk2 m = if m = 0 then 1 else urk (2 * m))’;
1 change: 1 addition & 0 deletions src/tfl/src/Defn.sig
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ sig
val ind_suffix : string ref
val def_suffix : string ref
val const_eq_ref : conv ref
val add_defs_to_EVAL : (string * thm) list -> unit

val wfrec_eqns : thry -> term ->
{SV : term list,
Expand Down
6 changes: 3 additions & 3 deletions src/tfl/src/Defn.sml
Original file line number Diff line number Diff line change
Expand Up @@ -410,7 +410,7 @@ local fun is_suc tm =
in
val SUC_TO_NUMERAL_DEFN_CONV_hook =
ref (fn _ => raise ERR "SUC_TO_NUMERAL_DEFN_CONV_hook" "not initialized")
fun add_persistent_funs l =
fun add_defs_to_EVAL l =
if not (!computeLib.auto_import_definitions) then () else
let val has_lhs_SUC = List.exists
(can (find_term is_suc) o lhs o #2 o strip_forall)
Expand Down Expand Up @@ -440,7 +440,7 @@ local
val _ = Feedback.register_btrace("Define.storage_message", chatting)
in
fun been_stored (s,thm) =
(add_persistent_funs [(s,thm)];
(add_defs_to_EVAL [(s,thm)];
if !chatting then
mesg (if !Globals.interactive then
"Definition has been stored under " ^ Lib.quote s ^ "\n"
Expand Down Expand Up @@ -478,7 +478,7 @@ fun store(stem,eqs,ind) =
fun save x = Feedback.trace ("Theory.save_thm_reporting", 0) save_thm x
val _ = save (ind_bind, ind)
val eqns = save (eqs_bind, eqs)
val _ = add_persistent_funs [(eqs_bind,eqs)]
val _ = add_defs_to_EVAL [(eqs_bind,eqs)]
handle e => HOL_MESG ("Unable to add "^eqs_bind^" to global compset")
in
if !chatting then
Expand Down

0 comments on commit 6d57be5

Please sign in to comment.