Skip to content

Commit

Permalink
Merge PR coq#18241: Simplify the encoding of the recursive disjunctiv…
Browse files Browse the repository at this point in the history
…e-conjunctive structure of (co)inductive types in rtree.ml by taking nodes of arrays of arrays

Reviewed-by: ppedrot
Co-authored-by: ppedrot <[email protected]>
  • Loading branch information
coqbot-app[bot] and ppedrot authored Nov 16, 2023
2 parents f8aa0d5 + 5a2b4dc commit 29718b6
Show file tree
Hide file tree
Showing 8 changed files with 71 additions and 58 deletions.
2 changes: 1 addition & 1 deletion checker/values.ml
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,7 @@ let v_recarg = v_sum "recarg" 1 (* Norec *)

let rec v_wfp = Sum ("wf_paths",0,
[|[|Int;Int|]; (* Rtree.Param *)
[|v_recarg;Array v_wfp|]; (* Rtree.Node *)
[|v_recarg;Array (Array v_wfp)|]; (* Rtree.Node *)
[|Int;Array v_wfp|] (* Rtree.Rec *)
|])

Expand Down
1 change: 1 addition & 0 deletions dev/base_include
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#install_printer (* constant *) ppcon;;
#install_printer (* projection *) ppproj;;
#install_printer (* projection *) ppprojrepr;;
#install_printer (* recarg *) pprecarg;;
#install_printer (* recarg Rtree.t *) ppwf_paths;;
#install_printer (* constr *) print_pure_constr;;
#install_printer (* patch *) ppripos;;
Expand Down
4 changes: 2 additions & 2 deletions dev/top_printers.ml
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,8 @@ let ppsp sp = pp(pr_path sp)
let ppqualid qid = pp(pr_qualid qid)
let ppscheme k = pp (Ind_tables.pr_scheme_kind k)

let prrecarg = Declareops.pp_recarg
let ppwf_paths x = pp (Declareops.pp_wf_paths x)
let pprecarg = Declareops.pr_recarg
let ppwf_paths x = pp (Declareops.pr_wf_paths x)

let get_current_context () =
try Vernacstate.Declare.get_current_context ()
Expand Down
2 changes: 1 addition & 1 deletion dev/top_printers.mli
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ val ppqualid : Libnames.qualid -> unit

val ppscheme : 'a Ind_tables.scheme_kind -> unit

val prrecarg : Declarations.recarg -> Pp.t
val pprecarg : Declarations.recarg -> Pp.t
val ppwf_paths : Declarations.recarg Rtree.t -> unit

val pr_evar : Evar.t -> Pp.t
Expand Down
12 changes: 6 additions & 6 deletions kernel/declareops.ml
Original file line number Diff line number Diff line change
Expand Up @@ -171,16 +171,16 @@ let eq_recarg r1 r2 = match r1, r2 with
| Nested ty1, Nested ty2 -> eq_nested_type ty1 ty2
| Nested _, _ -> false

let pp_recarg = let open Pp in function
| Declarations.Norec -> str "Norec"
let pr_recarg = let open Pp in function
| Declarations.Norec -> Pp.str "Norec"
| Declarations.Mrec (mind,i) ->
str "Mrec[" ++ Names.MutInd.print mind ++ pr_comma () ++ int i ++ str "]"
| Declarations.(Nested (NestedInd (mind,i))) ->
str "Nested[" ++ Names.MutInd.print mind ++ pr_comma () ++ int i ++ str "]"
| Declarations.(Nested (NestedPrimitive c)) ->
str "Nested[" ++ Names.Constant.print c ++ str "]"

let pp_wf_paths x = Rtree.pp_tree pp_recarg x
let pr_wf_paths x = Rtree.pr_tree pr_recarg x

let subst_nested_type subst ty = match ty with
| NestedInd (kn,i) ->
Expand All @@ -203,7 +203,7 @@ let mk_norec = Rtree.mk_node Norec [||]

let mk_paths r recargs =
Rtree.mk_node r
(Array.map (fun l -> Rtree.mk_node Norec (Array.of_list l)) recargs)
(Array.map Array.of_list recargs)

let dest_recarg p = fst (Rtree.dest_node p)

Expand All @@ -215,11 +215,11 @@ let dest_recarg p = fst (Rtree.dest_node p)
let dest_subterms p =
let (ra,cstrs) = Rtree.dest_node p in
assert (match ra with Norec -> false | _ -> true);
Array.map (fun t -> Array.to_list (snd (Rtree.dest_node t))) cstrs
Array.map Array.to_list cstrs

let recarg_length p j =
let (_,cstrs) = Rtree.dest_node p in
Array.length (snd (Rtree.dest_node cstrs.(j-1)))
Array.length cstrs.(j-1)

let subst_wf_paths subst p = Rtree.Smart.map (subst_recarg subst) p

Expand Down
4 changes: 2 additions & 2 deletions kernel/declareops.mli
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,8 @@ val is_opaque : 'a pconstant_body -> bool

val eq_recarg : recarg -> recarg -> bool

val pp_recarg : recarg -> Pp.t
val pp_wf_paths : wf_paths -> Pp.t
val pr_recarg : recarg -> Pp.t
val pr_wf_paths : wf_paths -> Pp.t

val subst_recarg : substitution -> recarg -> recarg

Expand Down
74 changes: 40 additions & 34 deletions lib/rtree.ml
Original file line number Diff line number Diff line change
Expand Up @@ -11,41 +11,42 @@
open Util

(* Type of regular trees:
- Param denotes tree variables (like de Bruijn indices)
- Var denotes tree variables (like de Bruijn indices)
the first int is the depth of the occurrence, and the second int
is the index in the array of trees introduced at that depth.
Warning: Param's indices both start at 0!
- Node denotes the usual tree node, labelled with 'a
Warning: Var's indices both start at 0!
- Node denotes the usual tree node, labelled with 'a, to the
exception that it takes an array of arrays as argument
- Rec(j,v1..vn) introduces infinite tree. It denotes
v(j+1) with parameters 0..n-1 replaced by
Rec(0,v1..vn)..Rec(n-1,v1..vn) respectively.
*)
type 'a t =
Param of int * int
| Node of 'a * 'a t array
Var of int * int
| Node of 'a * 'a t array array
| Rec of int * 'a t array

(* Building trees *)
let mk_rec_calls i = Array.init i (fun j -> Param(0,j))
let mk_rec_calls i = Array.init i (fun j -> Var(0,j))
let mk_node lab sons = Node (lab, sons)

(* The usual lift operation *)
let rec lift_rtree_rec depth n = function
Param (i,j) as t -> if i < depth then t else Param (i+n,j)
| Node (l,sons) -> Node (l,Array.map (lift_rtree_rec depth n) sons)
Var (i,j) as t -> if i < depth then t else Var (i+n,j)
| Node (l,sons) -> Node (l,Array.map (Array.map (lift_rtree_rec depth n)) sons)
| Rec(j,defs) ->
Rec(j, Array.map (lift_rtree_rec (depth+1) n) defs)

let lift n t = if Int.equal n 0 then t else lift_rtree_rec 0 n t

(* The usual subst operation *)
let rec subst_rtree_rec depth sub = function
Param (i,j) as t ->
Var (i,j) as t ->
if i < depth then t
else if i = depth then
lift depth (Rec (j, sub))
else Param (i - 1, j)
| Node (l,sons) -> Node (l,Array.map (subst_rtree_rec depth sub) sons)
else Var (i - 1, j)
| Node (l,sons) -> Node (l,Array.map (Array.map (subst_rtree_rec depth sub)) sons)
| Rec(j,defs) ->
Rec(j, Array.map (subst_rtree_rec (depth+1) sub) defs)

Expand All @@ -65,7 +66,7 @@ let rec expand = function
accept definitions like rec X=Y and Y=f(X,Y) *)
let mk_rec defs =
let rec check histo d = match expand d with
| Param (0, j) ->
| Var (0, j) ->
if Int.Set.mem j histo then failwith "invalid rec call"
else check (Int.Set.add j histo) defs.(j)
| _ -> ()
Expand All @@ -79,10 +80,10 @@ the last one should be accepted
*)

(* Tree destructors, expanding loops when necessary *)
let dest_param t =
let dest_var t =
match expand t with
Param (i,j) -> (i,j)
| _ -> failwith "Rtree.dest_param"
Var (i,j) -> (i,j)
| _ -> failwith "Rtree.dest_var"

let dest_node t =
match expand t with
Expand All @@ -95,17 +96,17 @@ let is_node t =
| _ -> false

let rec map f t = match t with
Param(i,j) -> Param(i,j)
| Node (a,sons) -> Node (f a, Array.map (map f) sons)
Var(i,j) -> Var(i,j)
| Node (a,sons) -> Node (f a, Array.map (Array.map (map f)) sons)
| Rec(j,defs) -> Rec (j, Array.map (map f) defs)

module Smart =
struct

let map f t = match t with
Param _ -> t
Var _ -> t
| Node (a,sons) ->
let a'=f a and sons' = Array.Smart.map (map f) sons in
let a'=f a and sons' = Array.Smart.map (Array.Smart.map (map f)) sons in
if a'==a && sons'==sons then t
else Node (a',sons')
| Rec(j,defs) ->
Expand All @@ -118,8 +119,8 @@ end
(** Structural equality test, parametrized by an equality on elements *)

let rec raw_eq cmp t t' = match t, t' with
| Param (i,j), Param (i',j') -> Int.equal i i' && Int.equal j j'
| Node (x, a), Node (x', a') -> cmp x x' && Array.equal (raw_eq cmp) a a'
| Var (i,j), Var (i',j') -> Int.equal i i' && Int.equal j j'
| Node (x, a), Node (x', a') -> cmp x x' && Array.equal (Array.equal (raw_eq cmp)) a a'
| Rec (i, a), Rec (i', a') -> Int.equal i i' && Array.equal (raw_eq cmp) a a'
| _ -> false

Expand All @@ -137,7 +138,7 @@ let equiv cmp cmp' =
| Node(x,v), Node(x',v') ->
cmp' x x' &&
Int.equal (Array.length v) (Array.length v') &&
Array.for_all2 (compare ((t,t')::histo)) v v'
Array.for_all2 (Array.for_all2 (compare ((t,t')::histo))) v v'
| _ -> false
in compare []

Expand All @@ -151,15 +152,15 @@ let equal cmp t t' =
let rec inter cmp interlbl def n histo t t' =
try
let (i,j) = List.assoc_f (raw_eq2 cmp) (t,t') histo in
Param (n-i-1,j)
Var (n-i-1,j)
with Not_found ->
match t, t' with
| Param (i,j), Param (i',j') ->
| Var (i,j), Var (i',j') ->
assert (Int.equal i i' && Int.equal j j'); t
| Node (x, a), Node (x', a') ->
(match interlbl x x' with
| None -> mk_node def [||]
| Some x'' -> Node (x'', Array.map2 (inter cmp interlbl def n histo) a a'))
| Some x'' -> Node (x'', Array.map2 (Array.map2 (inter cmp interlbl def n histo)) a a'))
| Rec (i,v), Rec (i',v') ->
(* If possible, we preserve the shape of input trees *)
if Int.equal i i' && Int.equal (Array.length v) (Array.length v') then
Expand Down Expand Up @@ -187,25 +188,30 @@ let is_infinite cmp t =
let rec is_inf histo t =
List.mem_f (raw_eq cmp) t histo ||
match expand t with
| Node (_,v) -> Array.exists (is_inf (t::histo)) v
| Node (_,v) -> Array.exists (Array.exists (is_inf (t::histo))) v
| _ -> false
in
is_inf [] t

(* Pretty-print a tree (not so pretty) *)
open Pp

let rec pp_tree prl t =
let rec pr_tree prl t =
match t with
Param (i,j) -> str"#"++int i++str","++int j
| Node(lab,[||]) -> hov 2 (str"("++prl lab++str")")
| Var (i,j) -> str"#"++int i++str":"++int j
| Node(lab,[||]) -> prl lab
| Node(lab,v) ->
hov 2 (str"("++prl lab++str","++brk(1,0)++
prvect_with_sep pr_comma (pp_tree prl) v++str")")
hov 0 (prl lab++str","++spc()++
str"["++
hv 0 (prvect_with_sep pr_comma (fun a ->
str"("++
hv 0 (prvect_with_sep pr_comma (pr_tree prl) a)++
str")") v)++
str"]")
| Rec(i,v) ->
if Int.equal (Array.length v) 0 then str"Rec{}"
else if Int.equal (Array.length v) 1 then
hov 2 (str"Rec{"++pp_tree prl v.(0)++str"}")
hv 2 (str"Rec{"++pr_tree prl v.(0)++str"}")
else
hov 2 (str"Rec{"++int i++str","++brk(1,0)++
prvect_with_sep pr_comma (pp_tree prl) v++str"}")
hv 2 (str"Rec{"++int i++str","++brk(1,0)++
prvect_with_sep pr_comma (pr_tree prl) v++str"}")
30 changes: 18 additions & 12 deletions lib/rtree.mli
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,19 @@

(** Type of regular tree with nodes labelled by values of type 'a
The implementation uses de Bruijn indices, so binding capture
is avoided by the lift operator (see example below) *)
is avoided by the lift operator (see example below).
Note that it differs from standard regular trees by accepting
vectors of vectors in nodes, which is useful for encoding
disjunctive-conjunctive recursive trees such as inductive types.
Standard regular trees can however easily be simulated by using
singletons of vectors *)
type 'a t

(** Building trees *)

(** build a node given a label and the vector of sons *)
val mk_node : 'a -> 'a t array -> 'a t
(** Build a node given a label and a vector of vectors of sons *)
val mk_node : 'a -> 'a t array array -> 'a t

(** Build mutually recursive trees:
X_1 = f_1(X_1,..,X_n) ... X_n = f_n(X_1,..,X_n)
Expand All @@ -27,14 +33,14 @@ val mk_node : 'a -> 'a t array -> 'a t
First example: build rec X = a(X,Y) and Y = b(X,Y,Y)
let [|vx;vy|] = mk_rec_calls 2 in
let [|x;y|] = mk_rec [|mk_node a [|vx;vy|]; mk_node b [|vx;vy;vy|]|]
let [|x;y|] = mk_rec [|mk_node a [|[|vx;vy|]|]; mk_node b [|[|vx;vy;vy|]|]|]
Another example: nested recursive trees rec Y = b(rec X = a(X,Y),Y,Y)
let [|vy|] = mk_rec_calls 1 in
let [|vx|] = mk_rec_calls 1 in
let [|x|] = mk_rec[|mk_node a vx;lift 1 vy|]
let [|y|] = mk_rec[|mk_node b x;vy;vy|]
(note the lift to avoid
let [|x|] = mk_rec[|mk_node a [|[|vx;lift 1 vy|]|]|]
let [|y|] = mk_rec[|mk_node b [|[|x;vy;vy|]|]|]
(note the lift so that Y links to the "rec Y" skipping the "rec X")
*)
val mk_rec_calls : int -> 'a t array
val mk_rec : 'a t array -> 'a t array
Expand All @@ -46,10 +52,10 @@ val lift : int -> 'a t -> 'a t
val is_node : 'a t -> bool

(** Destructors (recursive calls are expanded) *)
val dest_node : 'a t -> 'a * 'a t array
val dest_node : 'a t -> 'a * 'a t array array

(** dest_param is not needed for closed trees (i.e. with no free variable) *)
val dest_param : 'a t -> int * int
(** dest_var is not needed for closed trees (i.e. with no free variable) *)
val dest_var : 'a t -> int * int

(** Tells if a tree has an infinite branch. The first arg is a comparison
used to detect already seen elements, hence loops *)
Expand Down Expand Up @@ -77,8 +83,8 @@ val incl : ('a -> 'a -> bool) -> ('a -> 'a -> 'a option) -> 'a -> 'a t -> 'a t -
(** See also [Smart.map] *)
val map : ('a -> 'b) -> 'a t -> 'b t

(** A rather simple minded pretty-printer *)
val pp_tree : ('a -> Pp.t) -> 'a t -> Pp.t
(** Pretty-printer *)
val pr_tree : ('a -> Pp.t) -> 'a t -> Pp.t

module Smart :
sig
Expand Down

0 comments on commit 29718b6

Please sign in to comment.