diff --git a/backend/cmm_helpers.ml b/backend/cmm_helpers.ml index bcf0cf7855b..9c8199d604c 100644 --- a/backend/cmm_helpers.ml +++ b/backend/cmm_helpers.ml @@ -362,6 +362,7 @@ let neg_int c dbg = sub_int (Cconst_int (0, dbg)) c dbg let rec lsl_int c1 c2 dbg = match c1, c2 with + | c1, Cconst_int (0, _) -> c1 | Cop (Clsl, [c; Cconst_int (n1, _)], _), Cconst_int (n2, _) when n1 > 0 && n2 > 0 && n1 + n2 < size_int * 8 -> Cop (Clsl, [c; Cconst_int (n1 + n2, dbg)], dbg) @@ -1280,35 +1281,108 @@ let addr_array_initialize arr ofs newval dbg = [array_indexing log2_size_addr arr ofs dbg; newval], dbg ) -(* low_32 x is a value which agrees with x on at least the low 32 bits *) -let rec low_32 dbg = function - (* Ignore sign and zero extensions, which do not affect the low bits *) - | Cop (Casr, [Cop (Clsl, [x; Cconst_int (32, _)], _); Cconst_int (32, _)], _) - | Cop (Cand, [x; Cconst_natint (0xFFFFFFFFn, _)], _) -> - low_32 dbg x - | Clet (id, e, body) -> Clet (id, e, low_32 dbg body) - | x -> x - -(* sign_extend_32 sign-extends values from 32 bits to the word size. *) -let sign_extend_32 dbg e = - match low_32 dbg e with - | Cop - ( Cload - { memory_chunk = Thirtytwo_unsigned | Thirtytwo_signed; - mutability; - is_atomic - }, - args, - dbg ) -> - Cop - ( Cload { memory_chunk = Thirtytwo_signed; mutability; is_atomic }, - args, - dbg ) - | e -> - Cop - ( Casr, - [Cop (Clsl, [e; Cconst_int (32, dbg)], dbg); Cconst_int (32, dbg)], - dbg ) +(** [get_const_bitmask x] returns [Some (y, mask)] if [x] is [y & mask] *) +let get_const_bitmask = function + | Cop (Cand, ([x; Cconst_natint (mask, _)] | [Cconst_natint (mask, _); x]), _) + -> + Some (x, mask) + | Cop (Cand, ([x; Cconst_int (mask, _)] | [Cconst_int (mask, _); x]), _) -> + Some (x, Nativeint.of_int mask) + | _ -> None + +(** [low_bits ~bits x] is a (potentially simplified) value which agrees with x on at least + the low [bits] bits. E.g., [low_bits ~bits x & mask = x & mask], where [mask] is a + bitmask of the low [bits] bits . *) +let rec low_bits ~bits ~dbg x = + assert (bits > 0); + if bits >= arch_bits + then x + else + let unused_bits = arch_bits - bits in + let does_mask_keep_low_bits test_mask = + (* If the mask has all the low bits set, then the low bits are unchanged. + This could happen from zero-extension. *) + let mask = Nativeint.pred (Nativeint.shift_left 1n bits) in + Nativeint.equal mask (Nativeint.logand test_mask mask) + in + (* Ignore sign and zero extensions which do not affect the low bits *) + map_tail + (function + | Cop + ( (Casr | Clsr), + [Cop (Clsl, [x; Cconst_int (left, _)], _); Cconst_int (right, _)], + _ ) + when 0 <= right && right <= left && left <= unused_bits -> + (* these sign-extensions can be replaced with a left shift since we + don't care about the high bits that it changed *) + low_bits ~bits (lsl_const x (left - right) dbg) ~dbg + | x -> ( + match get_const_bitmask x with + | Some (x, bitmask) when does_mask_keep_low_bits bitmask -> + low_bits ~bits x ~dbg + | _ -> x)) + x + +(** [zero_extend ~bits dbg e] returns [e] with the most significant [arch_bits - bits] + bits set to 0 *) +let zero_extend ~bits ~dbg e = + assert (0 < bits && bits <= arch_bits); + let mask = Nativeint.pred (Nativeint.shift_left 1n bits) in + let zero_extend_via_mask e = + Cop (Cand, [e; natint_const_untagged dbg mask], dbg) + in + if bits = arch_bits + then e + else + map_tail + (function + | Cop (Cload { memory_chunk; mutability; is_atomic }, args, dbg) as e + -> ( + let load memory_chunk = + Cop (Cload { memory_chunk; mutability; is_atomic }, args, dbg) + in + match memory_chunk, bits with + | (Byte_signed | Byte_unsigned), 8 -> load Byte_unsigned + | (Sixteen_signed | Sixteen_unsigned), 16 -> load Sixteen_unsigned + | (Thirtytwo_signed | Thirtytwo_unsigned), 32 -> + load Thirtytwo_unsigned + | _ -> zero_extend_via_mask e) + | e -> zero_extend_via_mask e) + (low_bits ~bits e ~dbg) + +let sign_extend ~bits ~dbg e = + assert (0 < bits && bits <= arch_bits); + let unused_bits = arch_bits - bits in + let sign_extend_via_shift e = + asr_const (lsl_const e unused_bits dbg) unused_bits dbg + in + if bits = arch_bits + then e + else + map_tail + (function + | Cop ((Casr | Clsr), [inner; Cconst_int (n, _)], _) as e + when 0 <= n && n < arch_bits -> + (* see middle_end/flambda2/z3/sign_extension.py for proof *) + if n > unused_bits + then + (* sign-extension is a no-op since the top n bits already match *) + e + else + let e = lsl_const inner (unused_bits - n) dbg in + asr_const e unused_bits dbg + | Cop (Cload { memory_chunk; mutability; is_atomic }, args, dbg) as e + -> ( + let load memory_chunk = + Cop (Cload { memory_chunk; mutability; is_atomic }, args, dbg) + in + match memory_chunk, bits with + | (Byte_signed | Byte_unsigned), 8 -> load Byte_signed + | (Sixteen_signed | Sixteen_unsigned), 16 -> load Sixteen_signed + | (Thirtytwo_signed | Thirtytwo_unsigned), 32 -> load Thirtytwo_signed + | _ -> sign_extend_via_shift e) + | e -> sign_extend_via_shift e) + (low_bits ~bits e ~dbg) let unboxed_packed_array_ref arr index dbg ~memory_chunk ~elements_per_word = bind "arr" arr (fun arr -> @@ -1335,18 +1409,19 @@ let unboxed_int32_array_ref = let unboxed_mutable_int32_unboxed_product_array_ref arr ~array_index dbg = bind "arr" arr (fun arr -> bind "index" array_index (fun index -> - sign_extend_32 dbg + sign_extend ~bits:32 (Cop ( mk_load_mut Thirtytwo_signed, [array_indexing log2_size_addr arr index dbg], - dbg )))) + dbg )) + ~dbg)) let unboxed_mutable_int32_unboxed_product_array_set arr ~array_index ~new_value dbg = bind "arr" arr (fun arr -> bind "index" array_index (fun index -> bind "new_value" new_value (fun new_value -> - let new_value = sign_extend_32 dbg new_value in + let new_value = sign_extend ~bits:32 new_value ~dbg in Cop ( Cstore (Word_int, Assignment), [array_indexing log2_size_addr arr index dbg; new_value], @@ -1448,7 +1523,7 @@ let set_field_unboxed ~dbg memory_chunk block ~index_in_words newval = let field_address = array_indexing log2_size_addr block index_in_words dbg in - let newval = if size_in_bytes = 4 then low_32 dbg newval else newval in + let newval = low_bits newval ~dbg ~bits:(8 * size_in_bytes) in return_unit dbg (Cop (Cstore (memory_chunk, Assignment), [field_address; newval], dbg)) @@ -1647,16 +1722,12 @@ let call_cached_method obj tag cache pos args args_type result (apos, mode) dbg (* Allocation *) -(* CR layouts 5.1: When we pack int32s/float32s more efficiently, this code will - need to change. *) +(* CR layouts 5.1: When we pack int8/16/32s/float32s more efficiently, this code + will need to change. *) let memory_chunk_size_in_words_for_mixed_block = function - | (Byte_unsigned | Byte_signed | Sixteen_unsigned | Sixteen_signed) as - memory_chunk -> - Misc.fatal_errorf - "Fields with memory chunk %s are not allowed in mixed blocks" - (Printcmm.chunk memory_chunk) + | Byte_unsigned | Byte_signed | Sixteen_unsigned | Sixteen_signed | Thirtytwo_unsigned | Thirtytwo_signed -> - (* Int32s are currently stored using a whole word *) + (* small integers are currently stored using a whole word *) 1 | Single _ | Double -> (* Float32s are currently stored using a whole word *) @@ -1896,107 +1967,6 @@ let bigarray_word_kind : Lambda.bigarray_kind -> memory_chunk = function | Pbigarray_complex32 -> Single { reg = Float64 } | Pbigarray_complex64 -> Double -(* the three functions below assume 64-bit words *) -let () = assert (size_int = 8) - -let check_64_bit_target func = - if size_int <> 8 - then - Misc.fatal_errorf - "Cmm helpers function %s can only be used on 64-bit targets" func - -(* Like [low_32] but for 63-bit integers held in 64-bit registers. *) -(* CR gbury: Why not use Cmm.map_tail here ? It seems designed for that kind of - thing (and covers more cases than just Clet). *) -let rec low_63 dbg e = - check_64_bit_target "low_63"; - match e with - | Cop (Casr, [Cop (Clsl, [x; Cconst_int (1, _)], _); Cconst_int (1, _)], _) -> - low_63 dbg x - | Cop (Cand, [x; Cconst_natint (0x7FFF_FFFF_FFFF_FFFFn, _)], _) -> - low_63 dbg x - | Clet (id, x, body) -> Clet (id, x, low_63 dbg body) - | _ -> e - -(* CR-someday mshinwell/gbury: sign_extend_63 then tag_int should simplify to - just tag_int. *) -let sign_extend_63 dbg e = - check_64_bit_target "sign_extend_63"; - match e with - | Cop (Casr, [_; Cconst_int (n, _)], _) when n > 0 && n < 64 -> - (* [asr] by a positive constant is sign-preserving. However: - - - Some architectures treat the shift length modulo the word size. - - - OCaml does not define behavior of shifts by more than the word size. - - So we don't make the simplification for shifts of length 64 or more. *) - e - | _ -> - let e = low_63 dbg e in - Cop - ( Casr, - [Cop (Clsl, [e; Cconst_int (1, dbg)], dbg); Cconst_int (1, dbg)], - dbg ) - -(* zero_extend_32 zero-extends values from 32 bits to the word size. *) -let zero_extend_32 dbg e = - (* CR mshinwell for gbury: same question as above *) - match low_32 dbg e with - | Cop - ( Cload - { memory_chunk = Thirtytwo_signed | Thirtytwo_unsigned; - mutability; - is_atomic - }, - args, - dbg ) -> - Cop - ( Cload { memory_chunk = Thirtytwo_unsigned; mutability; is_atomic }, - args, - dbg ) - | e -> Cop (Cand, [e; natint_const_untagged dbg 0xFFFFFFFFn], dbg) - -let zero_extend_63 dbg e = - check_64_bit_target "zero_extend_63"; - let e = low_63 dbg e in - Cop (Cand, [e; natint_const_untagged dbg 0x7FFF_FFFF_FFFF_FFFFn], dbg) - -let zero_extend ~bits ~dbg e = - assert (0 < bits && bits <= arch_bits); - if bits = arch_bits - then e - else - match bits with - | 63 -> zero_extend_63 dbg e - | 32 -> zero_extend_32 dbg e - | bits -> Misc.fatal_errorf "zero_extend not implemented for %d bits" bits - -let sign_extend ~bits ~dbg e = - assert (0 < bits && bits <= arch_bits); - if bits = arch_bits - then e - else - match bits with - | 63 -> sign_extend_63 dbg e - | 32 -> sign_extend_32 dbg e - | bits -> Misc.fatal_errorf "sign_extend not implemented for %d bits" bits - -let low_bits ~bits ~(dbg : Debuginfo.t) e = - assert (0 < bits && bits <= arch_bits); - if bits = arch_bits - then e - else - match bits with - | 63 -> low_63 dbg e - | 32 -> low_32 dbg e - | bits -> Misc.fatal_errorf "low_bits not implemented for %d bits" bits - -let ignore_low_bits ~bits ~dbg:(_ : Debuginfo.t) e = - if bits = 1 - then ignore_low_bit_int e - else Misc.fatal_error "ignore_low_bits expected bits=1 for now" - let and_int e1 e2 dbg = let is_mask32 = function | Cconst_natint (0xFFFF_FFFFn, _) -> true @@ -2004,8 +1974,8 @@ let and_int e1 e2 dbg = | _ -> false in match e1, e2 with - | e, m when is_mask32 m -> zero_extend_32 dbg e - | m, e when is_mask32 m -> zero_extend_32 dbg e + | e, m when is_mask32 m -> zero_extend ~bits:32 e ~dbg + | m, e when is_mask32 m -> zero_extend ~bits:32 e ~dbg | e1, e2 -> Cop (Cand, [e1; e2], dbg) let or_int e1 e2 dbg = Cop (Cor, [e1; e2], dbg) @@ -2033,9 +2003,7 @@ let box_int_gen dbg (bi : Primitive.boxed_integer) mode arg = let arg' = if bi = Primitive.Boxed_int32 then - if big_endian - then Cop (Clsl, [arg; Cconst_int (32, dbg)], dbg) - else sign_extend_32 dbg arg + if big_endian then lsl_const arg 32 dbg else sign_extend ~bits:32 arg ~dbg else arg in Cop @@ -2079,12 +2047,12 @@ let unbox_int dbg bi = when bi = Primitive.Boxed_int32 && big_endian && alloc_matches_boxed_int bi ~hdr ~ops -> (* Force sign-extension of low 32 bits *) - sign_extend_32 dbg contents + sign_extend ~bits:32 contents ~dbg | Cop (Calloc _, [hdr; ops; contents], _dbg) when bi = Primitive.Boxed_int32 && (not big_endian) && alloc_matches_boxed_int bi ~hdr ~ops -> (* Force sign-extension of low 32 bits *) - sign_extend_32 dbg contents + sign_extend ~bits:32 contents ~dbg | Cop (Calloc _, [hdr; ops; contents], _dbg) when alloc_matches_boxed_int bi ~hdr ~ops -> contents @@ -2100,7 +2068,7 @@ let unbox_int dbg bi = | cmm -> default cmm) let make_unsigned_int bi arg dbg = - if bi = Primitive.Unboxed_int32 then zero_extend_32 dbg arg else arg + if bi = Primitive.Unboxed_int32 then zero_extend ~bits:32 arg ~dbg else arg let unaligned_load_16 ptr idx dbg = if Arch.allow_unaligned_access @@ -4315,7 +4283,7 @@ let make_unboxed_int32_array_payload dbg unboxed_int32_list = ( Cor, [ (* [a] is sign-extended by default. We need to change it to be zero-extended for the `or` operation to be correct. *) - zero_extend_32 dbg a; + zero_extend ~bits:32 a ~dbg; Cop (Clsl, [b; Cconst_int (32, dbg)], dbg) ], dbg ) in diff --git a/backend/cmm_helpers.mli b/backend/cmm_helpers.mli index 3381191f5b5..dea4644ab67 100644 --- a/backend/cmm_helpers.mli +++ b/backend/cmm_helpers.mli @@ -375,8 +375,9 @@ val bigarray_word_kind : Lambda.bigarray_kind -> memory_chunk (** Operations on n-bit integers *) -(** Simplify the given expression knowing low [bits] bits will be irrelevant *) -val ignore_low_bits : bits:int -> dbg:Debuginfo.t -> expression -> expression +(** Simplify the given expression knowing the low bit of the argument will be irrelevant +*) +val ignore_low_bit_int : expression -> expression (** Simplify the given expression knowing that bits other than the low [bits] bits will be irrelevant *) @@ -700,7 +701,7 @@ val create_ccatch : (** Shift operations. Inputs: a tagged caml integer and an untagged machine integer. Outputs: a tagged caml integer. - Take as first argument a tagged caml integer, and as + Takes as first argument a tagged caml integer, and as second argument an untagged machine intger which is the amount to shift the first argument by. *) diff --git a/middle_end/flambda2/to_cmm/to_cmm_expr.ml b/middle_end/flambda2/to_cmm/to_cmm_expr.ml index 0cdb343ac8a..bff2a7e27fd 100644 --- a/middle_end/flambda2/to_cmm/to_cmm_expr.ml +++ b/middle_end/flambda2/to_cmm/to_cmm_expr.ml @@ -110,9 +110,9 @@ let translate_external_call env res ~free_vars apply ~callee_simple ~args 2. All of the [machtype_component]s are singleton arrays. *) Array.map (fun machtype -> [| machtype |]) return_ty in - (* Returned int32 values need to be sign_extended because it's not clear - whether C code that returns an int32 returns one that is sign extended or - not. There is no need to wrap other return arities. *) + (* Returned small integer values need to be sign-extended because it's not + clear whether C code that returns a small integer returns one that is sign + extended or not. There is no need to wrap other return arities. *) let maybe_sign_extend kind dbg cmm = match Flambda_kind.With_subkind.kind kind with | Naked_number Naked_int32 -> C.sign_extend ~bits:32 ~dbg cmm diff --git a/middle_end/flambda2/to_cmm/to_cmm_primitive.ml b/middle_end/flambda2/to_cmm/to_cmm_primitive.ml index 4a650024cc6..bd4e8faf02b 100644 --- a/middle_end/flambda2/to_cmm/to_cmm_primitive.ml +++ b/middle_end/flambda2/to_cmm/to_cmm_primitive.ml @@ -742,7 +742,7 @@ let binary_int_shift_primitive _env dbg kind op x y = | (Naked_int64 | Naked_nativeint), Asr -> C.asr_int x y dbg let binary_int_comp_primitive _env dbg kind cmp x y = - let ignore_low_bit_int = C.ignore_low_bits ~bits:1 ~dbg in + let ignore_low_bit_int = C.ignore_low_bit_int in match ( (kind : Flambda_kind.Standard_int.t), (cmp : P.signed_or_unsigned P.comparison) ) diff --git a/middle_end/flambda2/to_cmm/to_cmm_shared.ml b/middle_end/flambda2/to_cmm/to_cmm_shared.ml index 7da2054eee5..a1e7e069935 100644 --- a/middle_end/flambda2/to_cmm/to_cmm_shared.ml +++ b/middle_end/flambda2/to_cmm/to_cmm_shared.ml @@ -205,7 +205,7 @@ let name_static res name = ~symbol:(fun s -> `Static_data [symbol_address (To_cmm_result.symbol res s)]) -let const_static cst = +let const_static cst : Cmm.data_item list = match Reg_width_const.descr cst with | Naked_immediate i -> [cint (nativeint_of_targetint (Targetint_31_63.to_targetint i))] @@ -394,7 +394,7 @@ let make_update env res dbg ({ kind; stride } : Update_kind.t) ~symbol var By using [Word_int] in the "fields" cases (see [Update_kind], above) we maintain the convention that 32-bit integers in 64-bit fields are sign extended. *) - if stride > 4 then Word_int else Thirtytwo_signed + if stride = Arch.size_addr then Word_int else Thirtytwo_signed | Naked_int64 -> Word_int | Naked_float -> Double | Naked_float32 -> diff --git a/middle_end/flambda2/z3/sign_extension.py b/middle_end/flambda2/z3/sign_extension.py new file mode 100755 index 00000000000..6613558072d --- /dev/null +++ b/middle_end/flambda2/z3/sign_extension.py @@ -0,0 +1,83 @@ +#!/usr/bin/env python3 +from z3 import * + +ARCH_BITS = 64 + + +class Op: + def __init__(self, name): + self.inner = BitVec(f"{name}.inner", ARCH_BITS) + self.shift_right = BitVec(f"{name}.shift_right", ARCH_BITS) + self.arith = Bool(f"{name}.arith") + + def as_ast(self) -> BitVecRef: + return If( + self.arith, + self.inner >> self.shift_right, + LShR(self.inner, self.shift_right), + ) + + def __repr__(self): + return repr(self.as_ast()) + + def size(self): + return self.as_ast().size() + + def reference_sign_extend(self, bits): + unused_bits = self.size() - bits + return (self.as_ast() << unused_bits) >> unused_bits + + def experimental_sign_extend(self, bits) -> BitVecRef: + unused_bits = self.size() - bits + return If( + self.shift_right > unused_bits, + self.as_ast(), + (self.inner << (unused_bits - self.shift_right)) >> unused_bits, + ) + + +if __name__ == "__main__": + s = Solver() + + x = Op("x") + bits = BitVec("bits", ARCH_BITS) # Number of low bits to preserve + + # assumptions + s.add(And(0 <= x.shift_right, x.shift_right < x.size())) + s.add(And(0 < bits, bits <= x.size())) + + # sanity check that we haven't introduced something crazy + assert s.check() == sat + + # falsify this + s.add(x.reference_sign_extend(bits) != x.experimental_sign_extend(bits)) + + print(s.to_smt2()) + + print("Verifying sign_extend...") + if s.check() == unsat: + print("sign_extend optimization is correct.") + else: + print("sign_extend is incorrect.") + model = s.model() + print("Counterexample:", model) + exit(1) + + +# ; benchmark generated from python API +# (set-info :status unknown) +# (declare-fun x.shift_right () (_ BitVec 64)) +# (declare-fun bits () (_ BitVec 64)) +# (declare-fun x.inner () (_ BitVec 64)) +# (declare-fun x.arith () Bool) +# (assert +# (and (bvsge x.shift_right (_ bv0 64)) (bvslt x.shift_right (_ bv64 64)))) +# (assert +# (let (($x36 (bvsle bits (_ bv64 64)))) +# (and (bvsgt bits (_ bv0 64)) $x36))) +# (assert +# (let ((?x51 (bvsub (_ bv64 64) bits))) +# (let ((?x47 (ite x.arith (bvashr x.inner x.shift_right) (bvlshr x.inner x.shift_right)))) +# (let ((?x66 (ite (bvsgt x.shift_right ?x51) ?x47 (bvashr (bvshl x.inner (bvsub ?x51 x.shift_right)) ?x51)))) +# (and (distinct (bvashr (bvshl ?x47 ?x51) ?x51) ?x66) true))))) +# (check-sat)