-
Notifications
You must be signed in to change notification settings - Fork 147
/
IR.v
1811 lines (1704 loc) · 96.8 KB
/
IR.v
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
From Coq Require Import ZArith.
From Coq Require Import MSetPositive.
From Coq Require Import FMapPositive.
From Coq Require Import String.
From Coq Require Import Ascii.
From Coq Require Import Bool.
From Coq Require Import HexString.
Require Import Crypto.Util.ListUtil Coq.Lists.List.
Require Crypto.Util.Strings.String.
Require Import Crypto.Util.Strings.Decimal.
Require Import Crypto.Util.Strings.Show.
Require Import Crypto.Util.ZRange.
Require Import Crypto.Util.ZRange.Operations.
Require Import Crypto.Util.ZRange.Show.
Require Import Crypto.Util.Option.
Require Import Crypto.Util.OptionList.
Require Import Rewriter.Language.Language.
Require Import Crypto.Language.API.
Require Import Crypto.Stringification.Language.
Require Import Crypto.AbstractInterpretation.ZRange.
Require Import Crypto.AbstractInterpretation.AbstractInterpretation.
Require Import Crypto.Util.Bool.Equality.
Require Import Crypto.Util.Notations.
Import Coq.Lists.List ListNotations. Local Open Scope zrange_scope. Local Open Scope Z_scope.
Module Compilers.
Local Set Boolean Equality Schemes.
Local Set Decidable Equality Schemes.
Export Language.Compilers.
Export Language.API.Compilers.
Export AbstractInterpretation.Compilers.
Export AbstractInterpretation.ZRange.Compilers.
Export Stringification.Language.Compilers.
Import invert_expr.
Import Compilers.API.
Local Notation tZ := (base.type.type_base base.type.Z).
Module ToString.
Import Stringification.Language.Compilers.ToString.
Import Stringification.Language.Compilers.ToString.ZRange.
Local Open Scope string_scope.
Local Open Scope Z_scope.
Module IR.
Module type.
Inductive primitive := Z | Zptr.
Inductive type := type_primitive (t : primitive) | prod (A B : type) | unit.
Module Export Notations.
Global Coercion type_primitive : primitive >-> type.
Declare Scope Ctype_scope.
Delimit Scope Ctype_scope with Ctype.
Bind Scope Ctype_scope with type.
Notation "()" := unit : Ctype_scope.
Notation "A * B" := (prod A B) : Ctype_scope.
Notation type := type.
End Notations.
End type.
Import type.Notations.
Import int.Notations.
Section ident.
Import type.
Inductive Z_binop : Set :=
| Z_land
| Z_lor
| Z_lxor
| Z_add
| Z_mul
| Z_sub
.
Inductive Z_unop : Set :=
| Z_shiftr (offset : BinInt.Z)
| Z_shiftl (offset : BinInt.Z)
(*| Z_opp*)
| Z_lnot (ty:int.type)
| Z_bneg
| Z_value_barrier (ty:int.type)
.
Inductive ident : type -> type -> Set :=
| literal (v : BinInt.Z) : ident unit Z
| List_nth (n : Datatypes.nat) : ident Zptr Z
| Addr : ident Z Zptr
| Dereference : ident Zptr Z
| iunop (op : Z_unop) : ident Z Z
| ibinop (op : Z_binop) : ident (Z * Z) Z
| Z_mul_split (lgs:BinInt.Z) : ident ((Zptr * Zptr) * (Z * Z)) unit
| Z_add_with_get_carry (lgs:BinInt.Z) : ident ((Zptr * Zptr) * (Z * Z * Z)) unit
| Z_sub_with_get_borrow (lgs:BinInt.Z) : ident ((Zptr * Zptr) * (Z * Z * Z)) unit
| Z_zselect (ty:int.type) : ident (Zptr * (Z * Z * Z)) unit
| Z_add_modulo : ident (Z * Z * Z) Z
| Z_static_cast (ty : int.type) : ident Z Z
.
End ident.
Inductive arith_expr : type -> Set :=
| AppIdent {s d} (idc : ident s d) (arg : arith_expr s) : arith_expr d
| Var (t : type.primitive) (v : string) : arith_expr t
| Pair {A B} (a : arith_expr A) (b : arith_expr B) : arith_expr (A * B)
| TT : arith_expr type.unit.
Inductive stmt :=
| Call (val : arith_expr type.unit)
| Comment (lines : list string) (mentioned_variables : list { t : _ & OfPHOAS.var_data t})
| Assign (declare : bool) (t : type.primitive) (sz : option int.type) (name : string) (val : arith_expr t)
| AssignZPtr (name : string) (sz : option int.type) (val : arith_expr type.Z)
| DeclareVar (t : type.primitive) (sz : option int.type) (name : string)
| AssignNth (name : string) (n : nat) (val : arith_expr type.Z).
Definition expr := list stmt.
Module Export Notations.
Export int.Notations.
Export type.Notations.
Declare Scope Cexpr_scope.
Delimit Scope Cexpr_scope with Cexpr.
Bind Scope Cexpr_scope with expr.
Bind Scope Cexpr_scope with stmt.
Bind Scope Cexpr_scope with arith_expr.
Infix "@@@" := AppIdent : Cexpr_scope.
Notation "( x , y , .. , z )" := (Pair .. (Pair x%Cexpr y%Cexpr) .. z%Cexpr) : Cexpr_scope.
Notation "( )" := TT : Cexpr_scope.
Notation "()" := TT : Cexpr_scope.
Notation "x ;; y" := (@cons stmt x%Cexpr y%Cexpr) (at level 70, right associativity, format "'[v' x ;; '/' y ']'") : Cexpr_scope.
Global Coercion iunop : Z_unop >-> ident.
Global Coercion ibinop : Z_binop >-> ident.
End Notations.
Definition invert_literal {t} (e : arith_expr t) : option (BinInt.Z)
:= match e with
| AppIdent s d (literal v) arg => Some v
| _ => None
end.
Module ident_infos.
Definition collect_infos_of_ident {s d} (idc : ident s d) : ident_infos
:= match idc with
| Z_static_cast ty => ident_info_of_bitwidths_used (IntSet.singleton ty)
| Z_mul_split lg2s
=> ident_info_of_mulx (PositiveSet.add (Z.to_pos lg2s) PositiveSet.empty)
| Z_add_with_get_carry lg2s
| Z_sub_with_get_borrow lg2s
=> ident_info_of_addcarryx (PositiveSet.add (Z.to_pos lg2s) PositiveSet.empty)
| Z_zselect ty
=> ident_info_of_cmovznz (IntSet.singleton ty)
| iunop (Z_value_barrier ty)
=>ident_info_of_value_barrier (IntSet.singleton ty)
| literal _
| List_nth _
| Addr
| Dereference
| ibinop _
| iunop _
| Z_add_modulo
=> ident_info_empty
end.
Fixpoint collect_infos_of_arith_expr {t} (e : arith_expr t) : ident_infos
:= match e with
| AppIdent s d idc arg => ident_info_union (collect_infos_of_ident idc) (@collect_infos_of_arith_expr _ arg)
| Var t v => ident_info_empty
| Pair A B a b => ident_info_union (@collect_infos_of_arith_expr _ a) (@collect_infos_of_arith_expr _ b)
| TT => ident_info_empty
end.
Definition collect_infos_of_stmt (e : stmt) : ident_infos
:= match e with
| Assign _ _ (Some sz) _ val
| AssignZPtr _ (Some sz) val
=> ident_info_union (ident_info_of_bitwidths_used (IntSet.singleton sz)) (collect_infos_of_arith_expr val)
| Call val
| Assign _ _ None _ val
| AssignZPtr _ None val
| AssignNth _ _ val
=> collect_infos_of_arith_expr val
| DeclareVar _ (Some sz) _
=> ident_info_of_bitwidths_used (IntSet.singleton sz)
| DeclareVar _ None _
| Comment _ _
=> ident_info_empty
end.
Fixpoint collect_infos_of_base_typedef_data {t} : OfPHOAS.base_var_typedef_data t -> ident_infos
:= match t return OfPHOAS.base_var_typedef_data t -> ident_infos with
| tZ
=> fun n => match n with
| Some n => ident_info_of_typedefs [n]
| None => ident_info_empty
end
| base.type.prod A B
=> fun '(tda, tdb)
=> ident_info_union (@collect_infos_of_base_typedef_data A tda)
(@collect_infos_of_base_typedef_data B tdb)
| base.type.list tZ
=> fun n => match n with
| Some n => ident_info_of_typedefs [n]
| None => ident_info_empty
end
| base.type.list _
| base.type.option _
| base.type.type_base _
| base.type.unit
=> fun _ => ident_info_empty
end.
Definition collect_infos_of_typedef_data {t} : OfPHOAS.var_typedef_data t -> ident_infos
:= match t return OfPHOAS.var_typedef_data t -> ident_infos with
| type.base t => collect_infos_of_base_typedef_data
| type.arrow _ _ => fun 'tt => ident_info_empty
end.
Fixpoint collect_infos_of_arg_typedef_data {t} : type.for_each_lhs_of_arrow OfPHOAS.var_typedef_data t -> ident_infos
:= match t with
| type.base _ => fun 'tt => ident_info_empty
| type.arrow s d
=> fun '(td, tds)
=> ident_info_union (collect_infos_of_typedef_data td)
(@collect_infos_of_arg_typedef_data d tds)
end.
Definition collect_infos (e : expr) : ident_infos
:= fold_right
ident_info_union
ident_info_empty
(List.map
collect_infos_of_stmt
e).
Definition collect_all_infos {skip_typedefs : skip_typedefs_opt}
(e : expr)
{t}
(intypedefs : type.for_each_lhs_of_arrow OfPHOAS.var_typedef_data t)
(outtypedefs : OfPHOAS.base_var_typedef_data (type.final_codomain t))
:= ident_info_union
(collect_infos e)
(if skip_typedefs
then ident_info_empty
else ident_info_union
(ident_infos.collect_infos_of_arg_typedef_data intypedefs)
(ident_infos.collect_infos_of_typedef_data (t:=type.base _) outtypedefs)).
End ident_infos.
Module name_infos.
Notation t := (list string) (only parsing).
Definition mem (v : string) (m : t) : bool := existsb (fun s => v =? s)%string m.
Definition add (v : string) (m : t) : t
:= if mem v m then m else v :: m.
Definition union (m1 m2 : t) : t
:= let '(l1, l2) := (List.length m1, List.length m2) in
let '(m1, m2) := if (l1 <=? l2)%nat then (m1, m2) else (m2, m1) in
(* now m1 is shorter, so we recurse over m1 *)
List.fold_right add m2 m1.
Definition empty : t := nil.
Definition singleton (v : string) : t := add v empty.
Definition of_list (v : list string) : t := v.
Section __.
Context (consider_retargs_live : forall s d, ident s d -> bool).
Let consider_addr_dead s d (idc : ident s d) : bool
:= match idc with
| Addr => false
| _ => true
end.
Fixpoint collect_live_of_arith_expr
(consider_args_live : forall s d, ident s d -> bool)
{t'} (e : arith_expr t') : t
:= match e with
| AppIdent s d idc arg
=> if consider_args_live s d idc
then @collect_live_of_arith_expr
(if consider_retargs_live s d idc
then consider_args_live
else consider_addr_dead)
_ arg
else empty
| Var t v => singleton v
| Pair A B a b => union (@collect_live_of_arith_expr consider_args_live _ a) (@collect_live_of_arith_expr consider_args_live _ b)
| TT => empty
end.
Definition collect_live_of_stmt (e : stmt) : t
:= match e with
| Assign _ _ _ _ val
| AssignZPtr _ _ val
| Call val
| AssignNth _ _ val
=> collect_live_of_arith_expr (fun _ _ _ => true) val
| Comment _ live_vars
=> of_list (List.flat_map (fun v => OfPHOAS.names_list_of_var_data (projT2 v)) live_vars)
| DeclareVar _ _ _ => empty
end.
Definition collect_live (e : expr) : t
:= fold_right
union
empty
(List.map
collect_live_of_stmt
e).
Section adjust_dead.
Context (rename_dead : string -> string)
(live : t).
Definition rename_if_dead (v : string) : string
:= if mem v live then v else rename_dead v.
Fixpoint adjust_dead_of_arith_expr {t'} (e : arith_expr t') : arith_expr t'
:= match e with
| AppIdent s d idc arg => AppIdent idc (@adjust_dead_of_arith_expr _ arg)
| Var t v => Var t (rename_if_dead v)
| Pair A B a b => Pair (@adjust_dead_of_arith_expr _ a) (@adjust_dead_of_arith_expr _ b)
| TT => TT
end.
Definition adjust_dead_of_stmt (e : stmt) : list stmt
:= match e with
| Call val
=> [Call (adjust_dead_of_arith_expr val)]
| Assign declare t sz name val
=> [Assign declare t sz name (adjust_dead_of_arith_expr val)]
| AssignZPtr name sz val
=> [AssignZPtr name sz (adjust_dead_of_arith_expr val)]
| AssignNth name n val
=> [AssignNth name n (adjust_dead_of_arith_expr val)]
| DeclareVar t sz name
=> if mem name live
then [DeclareVar t sz name]
else []
| Comment _ _ as e
=> [e]
end.
Definition adjust_dead_of_expr (e : expr) : expr
:= flat_map adjust_dead_of_stmt e.
End adjust_dead.
Definition adjust_dead (rename_dead : string -> string) (e : expr) : expr
:= adjust_dead_of_expr rename_dead (collect_live e) e.
End __.
End name_infos.
Module LiftDeclare.
Local Open Scope list_scope.
Definition split_declare (e : stmt) : list stmt (* decls *) * list stmt (* non-decls *)
:= match e with
| Assign true t sz name val
=> ([DeclareVar t sz name], [Assign false t sz name val])
| DeclareVar _ _ _ as e
=> ([e], [])
| e => ([], [e])
end.
Definition split_declarations (e : expr) : expr (* decls *) * expr (* non-decls *)
:= let ls := List.map split_declare e in
(List.flat_map (@fst _ _) ls,
List.flat_map (@snd _ _) ls).
Definition lift_declarations (e : expr) : expr
:= let '(decls, rest) := split_declarations e in
decls ++ rest.
End LiftDeclare.
Module OfPHOAS.
Export Stringification.Language.Compilers.ToString.OfPHOAS.
Fixpoint arith_expr_for_base (t : base.type) : Type
:= match t with
| tZ
=> arith_expr type.Z * option int.type
| base.type.prod A B
=> arith_expr_for_base A * arith_expr_for_base B
| base.type.list A => list (arith_expr_for_base A)
| base.type.option A => option (arith_expr_for_base A)
| base.type.unit as t
| base.type.type_base _ as t
=> base.interp t
end.
Definition arith_expr_for (t : Compilers.type.type base.type) : Type
:= match t with
| type.base t => arith_expr_for_base t
| type.arrow s d => Empty_set
end.
(* Parametrizes the PHOAS -> IR translation over language specific
numeric conversions *)
Class LanguageCasts :=
{
(* [bin_op_natural_output] takes in a binary operation and
the known types that the input fits into, and returns
the type that the output will land in if no casts are
present. *)
bin_op_natural_output
: Z_binop -> int.type * int.type -> int.type;
(* [bin_op_casts] takes in a binary operation, the known
type that the output fits in, and the pair of known
types that the inputs fit in. It returns the triple of
(output, (input1, intput2)) of casts that are necessary
for running the operation. No-op casts on the inputs
will later be discarded; the cast on the output, if
given, will always be used. *)
bin_op_casts
: Z_binop -> option int.type -> int.type * int.type -> option int.type * (option int.type * option int.type);
(* [un_op_casts] takes in a unary operation, the known
type that the output fits in, and the known types that
the input fits in. It returns the tuple of (output,
input) of casts that are necessary for running the
operation. No-op casts on the inputs will later be
discarded; the cast on the output, if given, will
always be used. *)
un_op_casts
: Z_unop -> option int.type -> int.type -> option int.type * option int.type;
(* Are upcasts necessary on assignments? *)
upcast_on_assignment : bool;
(* Are upcasts necessary for arguments to function calls? *)
upcast_on_funcall : bool;
(* Should we declare pointer variables explicitly (rather than declaring non-pointer variables and taking references to them *)
explicit_pointer_variables : bool;
}.
Class consider_retargs_live_opt := consider_retargs_live : forall {s d}, ident s d -> bool.
Class rename_dead_opt := rename_dead : string -> string.
Class lift_declarations_opt := lift_declarations : bool.
Section __.
Context {lang_casts : LanguageCasts}
{relax_zrange : relax_zrange_opt}
{consider_retargs_live : consider_retargs_live_opt}
{rename_dead : rename_dead_opt}
{lift_declarations : lift_declarations_opt}
{language_specific_cast_adjustment : language_specific_cast_adjustment_opt}.
(* None means unconstrained *)
Definition bin_op_natural_output_opt
: Z_binop -> option int.type * option int.type -> option int.type
:= fun idc '(t1, t2)
=> match t1, t2 with
| Some t1, Some t2 => Some (bin_op_natural_output idc (t1, t2))
| _, _ => None
end.
Definition bin_op_casts_opt
: Z_binop -> option int.type -> option int.type * option int.type -> option int.type * (option int.type * option int.type)
:= fun idc tout '(t1, t2)
=> match t1, t2 with
| Some t1, Some t2
=> bin_op_casts idc tout (t1, t2)
| _, _ => (tout, (None, None))
end.
Definition un_op_casts_opt
: Z_unop -> option int.type -> option int.type -> option int.type * option int.type
:= fun idc tout t1
=> match t1 with
| Some t1
=> un_op_casts idc tout t1
| None => (tout, None)
end.
Definition Zcast {always : bool}
: option int.type -> arith_expr_for_base tZ -> arith_expr_for_base tZ
:= fun desired_type '(e, known_type)
=> let eq_known_type_desired_type
:= (known_type <- known_type;
desired_type <- desired_type;
Some (int.type_beq known_type desired_type))%option in
match always, language_specific_cast_adjustment, desired_type, eq_known_type_desired_type with
| true, _, Some desired_type, _ (* if we always insert the cast, and there's a cast, then insert it *)
| _, true, Some desired_type, Some false (* if we are doing language-specific casts and we know the desired type and it's not the same as the known type, insert it *)
| _, true, Some desired_type, None (* if we are doing language-specific casts and we know the desired type and there was no known type, insert it *)
=> (Z_static_cast desired_type @@@ e, Some desired_type)
| false, false, _, _ (* if we're not doing cast insertion (neither by forcing nor by language-specific casts), don't insert a cast *)
| _, true, Some _, Some true (* if we're doing language-specific casts but the new type is the same as the known type, we don't need to insert a cast *)
| _, _, None, _ (* if we don't know the type to insert, we can't insert a cast *)
=> (e, known_type)
end%core%Cexpr%bool.
Definition get_Zcast_down_if_needed
: option int.type -> option int.type -> option int.type
:= fun desired_type known_type
=> match desired_type, known_type with
| None, _ => None
| Some desired_type, Some known_type
=> if int.is_tighter_than known_type desired_type
then None
else Some desired_type
| Some desired_type, None
=> Some desired_type
end%core%Cexpr.
Definition Zcast_down_if_needed
: option int.type -> arith_expr_for_base tZ -> arith_expr_for_base tZ
:= fun desired_type '(e, known_type)
=> Zcast (always:=false) (get_Zcast_down_if_needed desired_type known_type) (e, known_type).
Fixpoint cast_down_if_needed {t}
: int.option.interp t -> arith_expr_for_base t -> arith_expr_for_base t
:= match t with
| tZ => Zcast_down_if_needed
| base.type.type_base _
| base.type.unit
=> fun _ x => x
| base.type.prod A B
=> fun '(r1, r2) '(e1, e2) => (@cast_down_if_needed A r1 e1,
@cast_down_if_needed B r2 e2)
| base.type.list A
=> fun r1 ls
=> match r1 with
| Some r1 => List.map (fun '(r, e) => @cast_down_if_needed A r e)
(List.combine r1 ls)
| None => ls
end
| base.type.option A
=> fun r1 ls
=> match r1 with
| Some r1 => Option.map (fun '(r, e) => @cast_down_if_needed A r e)
(Option.combine r1 ls)
| None => ls
end
end.
Definition get_Zcast_up_if_needed
: option int.type -> option int.type -> option int.type
:= fun desired_type known_type
=> match desired_type, known_type with
| None, _ | _, None => None
| Some desired_type, Some known_type
=> if int.is_tighter_than desired_type known_type
then None
else Some desired_type
end.
Definition Zcast_up_if_needed
: option int.type -> arith_expr_for_base tZ -> arith_expr_for_base tZ
:= fun desired_type '(e, known_type)
=> Zcast (always:=false) (get_Zcast_up_if_needed desired_type known_type) (e, known_type).
Fixpoint cast_up_if_needed {t}
: int.option.interp t -> arith_expr_for_base t -> arith_expr_for_base t
:= match t with
| tZ => Zcast_up_if_needed
| base.type.type_base _
| base.type.unit
=> fun _ x => x
| base.type.prod A B
=> fun '(r1, r2) '(e1, e2) => (@cast_up_if_needed A r1 e1,
@cast_up_if_needed B r2 e2)
| base.type.list A
=> fun r1 ls
=> match r1 with
| Some r1 => List.map (fun '(r, e) => @cast_up_if_needed A r e)
(List.combine r1 ls)
| None => ls
end
| base.type.option A
=> fun r1 ls
=> match r1 with
| Some r1 => Option.map (fun '(r, e) => @cast_up_if_needed A r e)
(Option.combine r1 ls)
| None => ls
end
end.
Fixpoint cast {always:bool} {t}
: int.option.interp t -> arith_expr_for_base t -> arith_expr_for_base t
:= match t with
| tZ => Zcast (always:=always)
| base.type.type_base _
| base.type.unit
=> fun _ x => x
| base.type.prod A B
=> fun '(r1, r2) '(e1, e2) => (@cast always A r1 e1,
@cast always B r2 e2)
| base.type.list A
=> fun r1 ls
=> match r1 with
| Some r1 => List.map (fun '(r, e) => @cast always A r e)
(List.combine r1 ls)
| None => ls
end
| base.type.option A
=> fun r1 ls
=> match r1 with
| Some r1 => Option.map (fun '(r, e) => @cast always A r e)
(Option.combine r1 ls)
| None => ls
end
end.
Definition arith_bin_arith_expr_of_PHOAS_ident
(s:=(tZ * tZ)%etype)
(d:=tZ)
(idc : Z_binop)
: option int.type -> arith_expr_for (type.base s) -> arith_expr_for (type.base d)
:= fun desired_type '((e1, t1), (e2, t2)) =>
let '(cstout, (cst1, cst2)) := bin_op_casts_opt idc desired_type (t1, t2) in
let typ := bin_op_natural_output_opt idc (Option.or_else cst1 t1, Option.or_else cst2 t2) in
let '((e1, t1), (e2, t2)) := (Zcast (always:=false) cst1 (e1, t1), Zcast (always:=false) cst2 (e2, t2)) in
Zcast (always:=false) cstout ((idc @@@ (e1, e2))%Cexpr, typ).
Definition arith_un_arith_expr_of_PHOAS_ident
(s:=tZ)
(d:=tZ)
(idc : Z_unop)
: option int.type -> arith_expr_for (type.base s) -> arith_expr_for (type.base d)
:= fun desired_type '(e, t) =>
let '(cstout, cst) := un_op_casts_opt idc desired_type t in
let typ := (*un_op_natural_output_opt idc*) Option.or_else cst t in
let '(e, t) := Zcast (always:=false) cst (e, t) in
Zcast (always:=false) cstout ((idc @@@ e)%Cexpr, typ).
Local Definition fakeprod (A B : Compilers.type.type base.type) : Compilers.type.type base.type
:= match A, B with
| type.base A, type.base B => type.base (base.type.prod A B)
| type.arrow _ _, _
| _, type.arrow _ _
=> type.base base.type.unit
end.
Definition arith_expr_for_uncurried_domain (t : Compilers.type.type base.type)
:= match t return Type with
| type.base t => unit
| type.arrow s d => arith_expr_for (type.uncurried_domain fakeprod s d)
end.
Section with_bind.
(* N.B. If we make the [bind*_err] notations, then Coq can't
infer types correctly; if we make them [Local
Definition]s or [Let]s, then [ocamlopt] fails with
"Error: The type of this module, [...], contains type
variables that cannot be generalized". We need to run
[cbv] below to actually unfold them. >.< *)
Local Notation ErrT T := (T + list string)%type.
Local Notation ret v := (@inl _ (list string) v) (only parsing).
Local Notation "x <- v ; f" := (match v with
| inl x => f
| inr err => inr err
end).
Declare Scope err_scope.
(*Local*) Delimit Scope err_scope with err.
Local Notation "x <- v ; f" := (match v with
| inl x => f
| inr err => inr err
end) : err_scope.
Reserved Notation "A1 ,, A2 <- X ; B" (at level 70, A2 at next level, right associativity, format "'[v' A1 ,, A2 <- X ; '/' B ']'").
Reserved Notation "A1 ,, A2 <- X1 , X2 ; B" (at level 70, A2 at next level, right associativity, format "'[v' A1 ,, A2 <- X1 , X2 ; '/' B ']'").
Reserved Notation "A1 ,, A2 ,, A3 <- X ; B" (at level 70, A2 at next level, A3 at next level, right associativity, format "'[v' A1 ,, A2 ,, A3 <- X ; '/' B ']'").
Reserved Notation "A1 ,, A2 ,, A3 <- X1 , X2 , X3 ; B" (at level 70, A2 at next level, A3 at next level, right associativity, format "'[v' A1 ,, A2 ,, A3 <- X1 , X2 , X3 ; '/' B ']'").
Reserved Notation "A1 ,, A2 ,, A3 ,, A4 <- X ; B" (at level 70, A2 at next level, A3 at next level, A4 at next level, right associativity, format "'[v' A1 ,, A2 ,, A3 ,, A4 <- X ; '/' B ']'").
Reserved Notation "A1 ,, A2 ,, A3 ,, A4 <- X1 , X2 , X3 , X4 ; B" (at level 70, A2 at next level, A3 at next level, A4 at next level, right associativity, format "'[v' A1 ,, A2 ,, A3 ,, A4 <- X1 , X2 , X3 , X4 ; '/' B ']'").
Reserved Notation "A1 ,, A2 ,, A3 ,, A4 ,, A5 <- X ; B" (at level 70, A2 at next level, A3 at next level, A4 at next level, A5 at next level, right associativity, format "'[v' A1 ,, A2 ,, A3 ,, A4 ,, A5 <- X ; '/' B ']'").
Reserved Notation "A1 ,, A2 ,, A3 ,, A4 ,, A5 <- X1 , X2 , X3 , X4 , X5 ; B" (at level 70, A2 at next level, A3 at next level, A4 at next level, A5 at next level, right associativity, format "'[v' A1 ,, A2 ,, A3 ,, A4 ,, A5 <- X1 , X2 , X3 , X4 , X5 ; '/' B ']'").
Let bind2_err {A B C} (v1 : ErrT A) (v2 : ErrT B) (f : A -> B -> ErrT C) : ErrT C
:= match v1, v2 with
| inl x1, inl x2 => f x1 x2
| inr err, inl _ | inl _, inr err => inr err
| inr err1, inr err2 => inr (List.app err1 err2)
end.
Local Notation "x1 ,, x2 <- v1 , v2 ; f"
:= (bind2_err v1 v2 (fun x1 x2 => f)).
Local Notation "x1 ,, x2 <- v ; f"
:= (x1 ,, x2 <- fst v , snd v; f).
Let bind3_err {A B C R} (v1 : ErrT A) (v2 : ErrT B) (v3 : ErrT C) (f : A -> B -> C -> ErrT R) : ErrT R
:= (x12 ,, x3 <- (x1 ,, x2 <- v1, v2; inl (x1, x2)), v3;
let '(x1, x2) := x12 in
f x1 x2 x3).
Local Notation "x1 ,, x2 ,, x3 <- v1 , v2 , v3 ; f"
:= (bind3_err v1 v2 v3 (fun x1 x2 x3 => f)).
Local Notation "x1 ,, x2 ,, x3 <- v ; f"
:= (let '(v1, v2, v3) := v in x1 ,, x2 ,, x3 <- v1 , v2 , v3; f).
Let bind4_err {A B C D R} (v1 : ErrT A) (v2 : ErrT B) (v3 : ErrT C) (v4 : ErrT D) (f : A -> B -> C -> D -> ErrT R) : ErrT R
:= (x12 ,, x34 <- (x1 ,, x2 <- v1, v2; inl (x1, x2)), (x3 ,, x4 <- v3, v4; inl (x3, x4));
let '((x1, x2), (x3, x4)) := (x12, x34) in
f x1 x2 x3 x4).
Local Notation "x1 ,, x2 ,, x3 ,, x4 <- v1 , v2 , v3 , v4 ; f"
:= (bind4_err v1 v2 v3 v4 (fun x1 x2 x3 x4 => f)).
Local Notation "x1 ,, x2 ,, x3 ,, x4 <- v ; f"
:= (let '(v1, v2, v3, v4) := v in x1 ,, x2 ,, x3 ,, x4 <- v1 , v2 , v3 , v4; f).
Let bind5_err {A B C D E R} (v1 : ErrT A) (v2 : ErrT B) (v3 : ErrT C) (v4 : ErrT D) (v5 : ErrT E) (f : A -> B -> C -> D -> E -> ErrT R) : ErrT R
:= (x12 ,, x345 <- (x1 ,, x2 <- v1, v2; inl (x1, x2)), (x3 ,, x4 ,, x5 <- v3, v4, v5; inl (x3, x4, x5));
let '((x1, x2), (x3, x4, x5)) := (x12, x345) in
f x1 x2 x3 x4 x5).
Local Notation "x1 ,, x2 ,, x3 ,, x4 ,, x5 <- v1 , v2 , v3 , v4 , v5 ; f"
:= (bind5_err v1 v2 v3 v4 v5 (fun x1 x2 x3 x4 x5 => f)).
Local Notation "x1 ,, x2 ,, x3 ,, x4 ,, x5 <- v ; f"
:= (let '(v1, v2, v3, v4, v5) := v in x1 ,, x2 ,, x3 ,, x4 ,, x5 <- v1 , v2 , v3 , v4 , v5; f).
Definition maybe_log2 (s : Z) : option Z
:= if 2^Z.log2 s =? s then Some (Z.log2 s) else None.
Definition maybe_loglog2 (s : Z) : option nat
:= (v <- maybe_log2 s;
v <- maybe_log2 v;
if Z.leb 0 v
then Some (Z.to_nat v)
else None).
Definition arith_expr_of_PHOAS_literal_Z
(t:=tZ)
v
: int.option.interp (type.final_codomain t) -> arith_expr_for_base t
:= fun r
=> cast_down_if_needed
r
(literal v @@@ TT, Some (int.of_zrange_relaxed (relax_zrange r[v~>v])))%core%Cexpr%option%zrange.
Definition arith_expr_of_PHOAS_ident
{t}
(idc : ident.ident t)
: int.option.interp (type.final_codomain t) -> type.interpM_final (fun T => ErrT T) arith_expr_for_base t
:= match idc in ident.ident t return int.option.interp (type.final_codomain t) -> type.interpM_final (fun T => ErrT T) arith_expr_for_base t with
| ident.Literal base.type.Z v
=> fun r => ret (arith_expr_of_PHOAS_literal_Z v r)
| ident.tt => fun _ => ret tt
| ident.nil t
=> fun _ => ret nil
| ident.cons t
=> fun r x xs => ret (cast_down_if_needed r (cons x xs))
| ident.fst A B => fun r xy => ret (cast_down_if_needed r (@fst _ _ xy))
| ident.snd A B => fun r xy => ret (cast_down_if_needed r (@snd _ _ xy))
| ident.List_nth_default tZ
=> fun r d ls n
=> List.nth_default (inr ["Invalid list index " ++ show n]%string)
(List.map (fun x => ret (cast_down_if_needed r x)) ls) n
| ident.Z_shiftr
=> fun r e '(offset, roffset)
=> match invert_literal offset with
| Some offset
=> ret (arith_un_arith_expr_of_PHOAS_ident (Z_shiftr offset) r e)
| None => inr ["Invalid right-shift by a non-literal"]%string
end
| ident.Z_shiftl
=> fun r e '(offset, roffset)
=> match invert_literal offset with
| Some offset
=> ret (arith_un_arith_expr_of_PHOAS_ident (Z_shiftl offset) r e)
| None => inr ["Invalid left-shift by a non-literal"]%string
end
| ident.Z_truncating_shiftl
=> fun r '(bitwidth, rbitwidth) e '(offset, roffset)
=> match invert_literal bitwidth, invert_literal offset with
| Some bitwidth, Some offset
=> let rpre_out := Some (int.of_zrange_relaxed r[0 ~> Z.max (2^offset) (2^bitwidth-1)]%zrange) in
let shifted := arith_un_arith_expr_of_PHOAS_ident (Z_shiftl offset) rpre_out e in
let truncated := arith_bin_arith_expr_of_PHOAS_ident Z_land r (shifted, arith_expr_of_PHOAS_literal_Z (2^bitwidth-1) (Some (int.of_zrange_relaxed r[0~>2^bitwidth - 1]))) in
ret truncated
| _, None => inr ["Invalid (truncating) left-shift by a non-literal"]%string
| None, _ => inr ["Invalid left-shift truncated to a non-literal bitwidth"]%string
end
| ident.value_barrier
=> fun r '(x, rx)
=> match rx with
| Some ty
=> ret (cast_down_if_needed r (Z_value_barrier ty @@@ x, Some ty))
| None => inr ["Invalid unknown integer size for value_barrier"]%string
end
| ident.Z_bneg => fun r x => ret (arith_un_arith_expr_of_PHOAS_ident Z_bneg r x)
| ident.Z_land => fun r x y => ret (arith_bin_arith_expr_of_PHOAS_ident Z_land r (x, y))
| ident.Z_lor => fun r x y => ret (arith_bin_arith_expr_of_PHOAS_ident Z_lor r (x, y))
| ident.Z_lxor => fun r x y => ret (arith_bin_arith_expr_of_PHOAS_ident Z_lxor r (x, y))
| ident.Z_add => fun r x y => ret (arith_bin_arith_expr_of_PHOAS_ident Z_add r (x, y))
| ident.Z_mul => fun r x y => ret (arith_bin_arith_expr_of_PHOAS_ident Z_mul r (x, y))
| ident.Z_sub => fun r x y => ret (arith_bin_arith_expr_of_PHOAS_ident Z_sub r (x, y))
| ident.Z_lnot_modulo
=> fun rout '(e, r) '(modulus, _)
=> match invert_literal modulus with
| Some modulus
=> match maybe_loglog2 modulus with
| Some lgbitwidth
=> let ty := int.unsigned lgbitwidth in
let rin' := Some ty in
let '(e, _) := Zcast (always:=false) rin' (e, r) in
ret (cast_down_if_needed rout (cast_up_if_needed rout (Z_lnot ty @@@ e, rin')))
| None => inr ["Invalid modulus for Z.lnot (not 2^(2^_)): " ++ show modulus]%string
end
| None => inr ["Invalid non-literal modulus for Z.lnot"]%string
end
| ident.pair A B
=> fun _ _ _ => inr ["Invalid identifier in arithmetic expression " ++ show idc]%string
| ident.Z_opp (* we pretend this is [0 - _] *)
=> fun r x =>
let zero := (literal 0 @@@ TT, Some (int.of_zrange_relaxed (relax_zrange r[0~>0]))) in
ret (arith_bin_arith_expr_of_PHOAS_ident Z_sub r (zero, x))
| ident.Literal _ v
=> fun _ => ret v
| ident.comment _
| ident.comment_no_keep _
| ident.Nat_succ
| ident.Nat_pred
| ident.Nat_max
| ident.Nat_mul
| ident.Nat_add
| ident.Nat_sub
| ident.Nat_eqb
| ident.Pos_add
| ident.Pos_mul
| ident.prod_rect _ _ _
| ident.bool_rect _
| ident.bool_rect_nodep _
| ident.nat_rect _
| ident.eager_nat_rect _
| ident.nat_rect_arrow _ _
| ident.eager_nat_rect_arrow _ _
| ident.Some _
| ident.None _
| ident.option_rect _ _
| ident.list_rect _ _
| ident.eager_list_rect _ _
| ident.list_rect_arrow _ _ _
| ident.eager_list_rect_arrow _ _ _
| ident.nat_rect_fbb_b _ _ _
| ident.nat_rect_fbb_b_b _ _ _ _
| ident.list_rect_fbb_b _ _ _ _
| ident.list_rect_fbb_b_b _ _ _ _ _
| ident.list_rect_fbb_b_b_b _ _ _ _ _ _
| ident.list_rect_fbb_b_b_b_b _ _ _ _ _ _ _
| ident.list_rect_fbb_b_b_b_b_b _ _ _ _ _ _ _ _
| ident.list_case _ _
| ident.List_length _
| ident.List_seq
| ident.List_repeat _
| ident.List_firstn _
| ident.List_skipn _
| ident.List_combine _ _
| ident.List_map _ _
| ident.List_app _
| ident.List_rev _
| ident.List_flat_map _ _
| ident.List_partition _
| ident.List_filter _
| ident.List_fold_right _ _
| ident.List_update_nth _
| ident.List_nth_default _
| ident.eager_List_nth_default _
| ident.Z_pow
| ident.Z_div
| ident.Z_modulo
| ident.Z_eqb
| ident.Z_leb
| ident.Z_ltb
| ident.Z_geb
| ident.Z_gtb
| ident.Z_min
| ident.Z_max
| ident.Z_abs
| ident.Z_log2
| ident.Z_log2_up
| ident.Z_of_nat
| ident.Z_to_nat
| ident.Z_pos
| ident.Z_to_pos
| ident.Z_ltz
| ident.Z_zselect
| ident.Z_mul_split
| ident.Z_mul_high
| ident.Z_add_get_carry
| ident.Z_add_with_carry
| ident.Z_add_with_get_carry
| ident.Z_sub_get_borrow
| ident.Z_sub_with_get_borrow
| ident.Z_add_modulo
| ident.Z_rshi
| ident.Z_cc_m
| ident.Z_combine_at_bitwidth
| ident.Z_cast
| ident.Z_cast2
| ident.Build_zrange
| ident.zrange_rect _
| ident.fancy_add
| ident.fancy_addc
| ident.fancy_sub
| ident.fancy_subb
| ident.fancy_mulll
| ident.fancy_mullh
| ident.fancy_mulhl
| ident.fancy_mulhh
| ident.fancy_rshi
| ident.fancy_selc
| ident.fancy_selm
| ident.fancy_sell
| ident.fancy_addm
=> fun _ => type.interpM_return _ _ _ (inr ["Invalid identifier in arithmetic expression " ++ show idc]%string)
end%core%Cexpr%option%zrange.
Fixpoint collect_args_and_apply_unknown_casts {t}
: (int.option.interp (type.final_codomain t) -> type.interpM_final (fun T => ErrT T) arith_expr_for_base t)
-> type.interpM_final
(fun T => ErrT T)
(fun t => int.option.interp t -> ErrT (arith_expr_for_base t))
t
:= match t
return ((int.option.interp (type.final_codomain t) -> type.interpM_final (fun T => ErrT T) arith_expr_for_base t)
-> type.interpM_final
(fun T => ErrT T)
(fun t => int.option.interp t -> ErrT (arith_expr_for_base t))
t)
with
| type.base t => fun v => ret v
| type.arrow (type.base s) d
=> fun f
(x : (int.option.interp s -> ErrT (arith_expr_for_base s)))
=> match x int.option.None return _ with
| inl x'
=> @collect_args_and_apply_unknown_casts
d
(fun rout => f rout x')
| inr errs => type.interpM_return _ _ _ (inr errs)
end
| type.arrow (type.arrow _ _) _
=> fun _ => type.interpM_return _ _ _ (inr ["Invalid higher-order function"])
end.
Definition collect_args_and_apply_known_casts {t}
(idc : ident.ident t)
: ErrT (type.interpM_final
(fun T => ErrT T)
(fun t => int.option.interp t -> ErrT (arith_expr_for_base t))
t)
:= match idc in ident.ident t
return ErrT
(type.interpM_final
(fun T => ErrT T)
(fun t => int.option.interp t -> ErrT (arith_expr_for_base t))
t)
with
| ident.Z_cast
=> inl
(fun r arg
=> r <- r tt;
let r := Some (int.of_zrange_relaxed r) in
inl (fun r'
=> if language_specific_cast_adjustment
then
arg <- arg r; ret (Zcast_down_if_needed r' arg)
else
arg <- arg None; ret (Zcast (always:=true) r arg)))
| ident.Z_cast2
=> inl
(fun r arg
=> r <- r (tt, tt);
let r := (Some (int.of_zrange_relaxed (fst r)), Some (int.of_zrange_relaxed (snd r))) in
inl (fun r'
=> if language_specific_cast_adjustment
then
(arg <- arg r; ret (cast_down_if_needed (t:=tZ*tZ) r' arg))
else
(arg <- arg (None, None); ret (cast (always:=true) (t:=tZ*tZ) r arg))))
| ident.pair A B
=> inl (fun ea eb
=> inl
(fun '(ra, rb)
=> ea' ,, eb' <- ea ra, eb rb;
inl (ea', eb')))
| ident.nil _
=> inl (inl (fun _ => inl nil))
| ident.cons t
=> inl
(fun x xs
=> inl
(fun rls
=> let mkcons (r : int.option.interp t)
(rs : int.option.interp (base.type.list t))
:= x ,, xs <- x r, xs rs;
inl (cons x xs) in
match rls with
| Some (cons r rs) => mkcons r (Some rs)
| Some nil
| None
=> mkcons int.option.None int.option.None
end))
| _ => inr ["Invalid identifier where cast or constructor was expected: " ++ show idc]%string
end.
Definition collect_args_and_apply_casts {t} (idc : ident.ident t)
(convert_no_cast : int.option.interp (type.final_codomain t) -> type.interpM_final (fun T => ErrT T) arith_expr_for_base t)
: type.interpM_final
(fun T => ErrT T)
(fun t => int.option.interp t -> ErrT (arith_expr_for_base t))
t
:= match collect_args_and_apply_known_casts idc with
| inl res => res
| inr errs => collect_args_and_apply_unknown_casts convert_no_cast
end.
Fixpoint arith_expr_of_base_PHOAS_Var
{t}
: base_var_data t -> int.option.interp t -> ErrT (arith_expr_for_base t)
:= match t with
| tZ
=> fun '(n, is_ptr, r, td) r'
=> let v := if is_ptr
then (Dereference @@@ Var type.Zptr n)%Cexpr
else Var type.Z n in
ret (cast_down_if_needed r' (v, r))
| base.type.prod A B
=> fun '(da, db) '(ra, rb)
=> (ea,, eb <- @arith_expr_of_base_PHOAS_Var A da ra, @arith_expr_of_base_PHOAS_Var B db rb;
inl (ea, eb))
| base.type.list tZ