Skip to content

Commit 372bee2

Browse files
Datalog: add [not_equal] and [filter] predicates (#4018)
Co-authored-by: Basile Clément <[email protected]>
1 parent 185c5a2 commit 372bee2

File tree

6 files changed

+110
-12
lines changed

6 files changed

+110
-12
lines changed

middle_end/flambda2/datalog/cursor.ml

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,10 @@
1515

1616
open Datalog_imports
1717

18+
type _ value_repr = Int_repr : int value_repr
19+
20+
let int_repr = Int_repr
21+
1822
(* Note: we don't use [with_name] here to avoid the extra indirection during
1923
execution. *)
2024
type vm_action =
@@ -25,6 +29,12 @@ type vm_action =
2529
* string
2630
* string list
2731
-> vm_action
32+
| Unless_eq :
33+
'k option ref * 'k option ref * string * string * 'k value_repr
34+
-> vm_action
35+
| Filter :
36+
('k Constant.hlist -> bool) * 'k Option_ref.hlist * string list
37+
-> vm_action
2838

2939
type action =
3040
| Bind_iterator :
@@ -39,6 +49,11 @@ let unless id cell args =
3949
(Unless
4050
(Table.Id.is_trie id, cell, args.values, Table.Id.name id, args.names))
4151

52+
let unless_eq repr cell1 cell2 =
53+
VM_action (Unless_eq (cell1.value, cell2.value, cell1.name, cell2.name, repr))
54+
55+
let filter f args = VM_action (Filter (f, args.values, args.names))
56+
4257
type binder = Bind_table : ('t, 'k, 'v) Table.Id.t * 't ref -> binder
4358

4459
type actions = { mutable rev_actions : action list }
@@ -55,6 +70,14 @@ let pp_cursor_action ff = function
5570
~pp_sep:(fun ff () -> Format.fprintf ff ", ")
5671
Format.pp_print_string)
5772
l_names
73+
| Unless_eq (_x1, _x2, x1_name, x2_name, _repr) ->
74+
Format.fprintf ff "if %s == %s:@ continue" x1_name x2_name
75+
| Filter (_f, _args, args_names) ->
76+
Format.fprintf ff "<filter>(%a)"
77+
(Format.pp_print_list
78+
~pp_sep:(fun ff () -> Format.fprintf ff ", ")
79+
Format.pp_print_string)
80+
args_names
5881

5982
module Order : sig
6083
type t
@@ -323,6 +346,14 @@ let evaluate = function
323346
(Trie.find_opt is_trie (Option_ref.get args) cell.contents)
324347
then Virtual_machine.Skip
325348
else Virtual_machine.Accept
349+
| Unless_eq (cell1, cell2, _cell1_name, _cell2_name, Int_repr) ->
350+
if Int.equal (Option.get !cell1) (Option.get !cell2)
351+
then Virtual_machine.Skip
352+
else Virtual_machine.Accept
353+
| Filter (f, args, _args_names) ->
354+
if f (Option_ref.get args)
355+
then Virtual_machine.Accept
356+
else Virtual_machine.Skip
326357

327358
let naive_iter cursor db f =
328359
with_bound_cursor ~callback:f cursor db @@ fun () ->

middle_end/flambda2/datalog/cursor.mli

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,10 @@
1515

1616
open Datalog_imports
1717

18+
type 'a value_repr
19+
20+
val int_repr : int value_repr
21+
1822
type action
1923

2024
val bind_iterator :
@@ -23,6 +27,12 @@ val bind_iterator :
2327
val unless :
2428
('t, 'k, 'v) Table.Id.t -> 't ref -> 'k Option_ref.hlist with_names -> action
2529

30+
val unless_eq :
31+
'k value_repr -> 'k option ref with_name -> 'k option ref with_name -> action
32+
33+
val filter :
34+
('k Constant.hlist -> bool) -> 'k Option_ref.hlist with_names -> action
35+
2636
type actions
2737

2838
val add_action : actions -> action -> unit

middle_end/flambda2/datalog/datalog.ml

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -156,20 +156,19 @@ let rec find_last_binding0 : type a. order:_ -> _ -> a Term.hlist -> _ =
156156
let find_last_binding post_level args =
157157
find_last_binding0 ~order:Cursor.Order.parameters post_level args
158158

159+
let compile_term : 'a Term.t -> 'a option ref with_name = function
160+
| Constant cte -> { value = ref (Some cte); name = "<constant>" }
161+
| Parameter param -> { value = param.cell; name = param.name }
162+
| Variable var -> Cursor.Level.use_output var
163+
159164
let rec compile_terms : type a. a Term.hlist -> a Option_ref.hlist with_names =
160165
fun vars ->
161166
match vars with
162167
| [] -> { values = []; names = [] }
163-
| term :: terms -> (
168+
| term :: terms ->
169+
let { value; name } = compile_term term in
164170
let { values; names } = compile_terms terms in
165-
match term with
166-
| Constant cte ->
167-
{ values = ref (Some cte) :: values; names = "<constant>" :: names }
168-
| Parameter param ->
169-
{ values = param.cell :: values; names = param.name :: names }
170-
| Variable var ->
171-
let { value; name } = Cursor.Level.use_output var in
172-
{ values = value :: values; names = name :: names })
171+
{ values = value :: values; names = name :: names }
173172

174173
let unless_atom id args k info =
175174
let refs = compile_terms args in
@@ -180,6 +179,23 @@ let unless_atom id args k info =
180179
Cursor.add_action post_level (Cursor.unless id r refs);
181180
k info
182181

182+
let unless_eq repr arg1 arg2 k info =
183+
let ref1 = compile_term arg1 in
184+
let ref2 = compile_term arg2 in
185+
let post_level =
186+
find_last_binding (Cursor.initial_actions info.context) [arg1; arg2]
187+
in
188+
Cursor.add_action post_level (Cursor.unless_eq repr ref1 ref2);
189+
k info
190+
191+
let filter f args k info =
192+
let refs = compile_terms args in
193+
let post_level =
194+
find_last_binding (Cursor.initial_actions info.context) args
195+
in
196+
Cursor.add_action post_level (Cursor.filter f refs);
197+
k info
198+
183199
type callback =
184200
| Callback :
185201
{ func : 'a Constant.hlist -> unit;

middle_end/flambda2/datalog/datalog.mli

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,19 @@ val unless_atom :
6363
('p, 'a) program ->
6464
('p, 'a) program
6565

66+
val unless_eq :
67+
'k Cursor.value_repr ->
68+
'k Term.t ->
69+
'k Term.t ->
70+
('p, 'a) program ->
71+
('p, 'a) program
72+
73+
val filter :
74+
('k Constant.hlist -> bool) ->
75+
'k Term.hlist ->
76+
('p, 'a) program ->
77+
('p, 'a) program
78+
6679
type callback
6780

6881
val create_callback :

middle_end/flambda2/datalog/flambda2_datalog.ml

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,9 @@ module Datalog = struct
6060
| { repr = Patricia_tree_repr; _ } :: (_ :: _ as columns) ->
6161
Trie.patricia_tree_of_trie (is_trie columns)
6262

63+
let key_repr : type t k v. (t, k, v) id -> k Cursor.value_repr = function
64+
| { repr = Patricia_tree_repr; _ } -> Cursor.int_repr
65+
6366
module type S = sig
6467
type t
6568

@@ -245,20 +248,33 @@ module Datalog = struct
245248

246249
let deduce = Schedule.deduce
247250

251+
type equality =
252+
| Equality : 'k Cursor.value_repr * 'k Term.t * 'k Term.t -> equality
253+
254+
type filter = Filter : ('k Constant.hlist -> bool) * 'k Term.hlist -> filter
255+
248256
type hypothesis =
249257
[ `Atom of atom
250-
| `Not_atom of atom ]
258+
| `Not_atom of atom
259+
| `Distinct of equality
260+
| `Filter of filter ]
251261

252262
let atom id args = `Atom (Atom (id, args))
253263

254264
let not (`Atom atom) = `Not_atom atom
255265

266+
let distinct c x y = `Distinct (Equality (Column.key_repr c, x, y))
267+
268+
let filter f args = `Filter (Filter (f, args))
269+
256270
let where predicates f =
257271
List.fold_left
258272
(fun f predicate ->
259273
match predicate with
260274
| `Atom (Atom (id, args)) -> where_atom id args f
261-
| `Not_atom (Atom (id, args)) -> unless_atom id args f)
275+
| `Not_atom (Atom (id, args)) -> unless_atom id args f
276+
| `Distinct (Equality (repr, t1, t2)) -> unless_eq repr t1 t2 f
277+
| `Filter (Filter (p, args)) -> Datalog.filter p args f)
262278
f predicates
263279

264280
module Cursor = struct

middle_end/flambda2/datalog/flambda2_datalog.mli

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,9 +132,15 @@ module Datalog : sig
132132

133133
type atom
134134

135+
type equality
136+
137+
type filter
138+
135139
type hypothesis =
136140
[ `Atom of atom
137-
| `Not_atom of atom ]
141+
| `Not_atom of atom
142+
| `Distinct of equality
143+
| `Filter of filter ]
138144

139145
(** [atom rel args] represents the application of relation [rel] to the
140146
arguments [args].
@@ -153,6 +159,12 @@ module Datalog : sig
153159

154160
val not : [< `Atom of atom] -> [> `Not_atom of atom]
155161

162+
val distinct :
163+
(_, 'k, _) Column.id -> 'k Term.t -> 'k Term.t -> [> `Distinct of equality]
164+
165+
val filter :
166+
('k Constant.hlist -> bool) -> 'k Term.hlist -> [> `Filter of filter]
167+
156168
type database
157169

158170
val print : Format.formatter -> database -> unit

0 commit comments

Comments
 (0)