Skip to content

Commit

Permalink
Revert "Switch to using extractLsb' to avoid casting in the goal"
Browse files Browse the repository at this point in the history
This reverts commit 0d3f545.
  • Loading branch information
pennyannn committed Sep 20, 2024
1 parent 40104f2 commit 0247d6a
Show file tree
Hide file tree
Showing 8 changed files with 46 additions and 49 deletions.
69 changes: 33 additions & 36 deletions Arm/Insts/Common.lean
Original file line number Diff line number Diff line change
Expand Up @@ -613,10 +613,13 @@ example : rev_vector 32 16 8 0xaabbccdd#32 (by decide)
/-- Divide bv `vector` into elements, each of size `size`. This function gets
the `e`'th element from the `vector`. -/
@[state_simp_rules]
def elem_get (vector : BitVec n) (e : Nat) (size : Nat): BitVec size :=
def elem_get (vector : BitVec n) (e : Nat) (size : Nat)
(h: size > 0): BitVec size :=
-- assert (e+1)*size <= n
let lo := e * size
extractLsb' lo size vector
let hi := lo + size - 1
have h : hi - lo + 1 = size := by simp only [hi, lo]; omega
BitVec.cast h $ extractLsb hi lo vector

/-- Divide bv `vector` into elements, each of size `size`. This function sets
the `e`'th element in the `vector`. -/
Expand Down Expand Up @@ -645,7 +648,7 @@ deriving DecidableEq, Repr
export ShiftInfo (esize elements shift unsigned round accumulate)

@[state_simp_rules]
def RShr (unsigned : Bool) (value : Int) (shift : Nat) (round : Bool)
def RShr (unsigned : Bool) (value : Int) (shift : Nat) (round : Bool) (h : n > 0)
: BitVec n :=
-- assert shift > 0
let fn := if unsigned then ushiftRight else sshiftRight
Expand All @@ -655,7 +658,8 @@ def RShr (unsigned : Bool) (value : Int) (shift : Nat) (round : Bool)
BitVec.ofInt (n + 1) rounded
else
BitVec.ofInt (n + 1) value
extractLsb' 0 n (fn rounded_bv shift)
have h₀ : n - 1 - 0 + 1 = n := by omega
BitVec.cast h₀ $ extractLsb (n-1) 0 (fn rounded_bv shift)

@[state_simp_rules]
def Int_with_unsigned (unsigned : Bool) (value : BitVec n) : Int :=
Expand All @@ -667,9 +671,9 @@ def shift_right_common_aux
if h : info.elements ≤ e then
result
else
let elem := Int_with_unsigned info.unsigned $ elem_get operand e info.esize
let shift_elem := RShr info.unsigned elem info.shift info.round
let acc_elem := elem_get operand2 e info.esize + shift_elem
let elem := Int_with_unsigned info.unsigned $ elem_get operand e info.esize info.h
let shift_elem := RShr info.unsigned elem info.shift info.round info.h
let acc_elem := elem_get operand2 e info.esize info.h + shift_elem
let result := elem_set result e info.esize acc_elem info.h
have _ : info.elements - (e + 1) < info.elements - e := by omega
shift_right_common_aux (e + 1) info operand operand2 result
Expand All @@ -681,13 +685,6 @@ theorem shift_le (x : Nat) (shift :Nat) :
simp only [Nat.shiftRight_eq_div_pow]
exact Nat.div_le_self x (2 ^ shift)

-- FIXME: should this be upstreamed?
theorem extractLsb'_ofNat (x n : Nat) (lo size : Nat) :
extractLsb' lo size (BitVec.ofNat n x) = .ofNat size ((x % 2^n) >>> lo) := by
apply eq_of_getLsbD_eq
intro ⟨i, _lt⟩
simp [BitVec.ofNat]

@[state_simp_rules]
theorem shift_right_common_aux_64_2_tff (operand : BitVec 128)
(shift : Nat) (result : BitVec 128):
Expand All @@ -696,8 +693,8 @@ theorem shift_right_common_aux_64_2_tff (operand : BitVec 128)
unsigned := true, round := false, accumulate := false,
h := (by omega)}
operand 0#128 result =
(ushiftRight (extractLsb' 64 64 operand) shift)
++ (ushiftRight (extractLsb' 0 64 operand) shift) := by
(ushiftRight (extractLsb 127 64 operand) shift)
++ (ushiftRight (extractLsb 63 0 operand) shift) := by
unfold shift_right_common_aux
simp only [minimal_theory, bitvec_rules]
unfold shift_right_common_aux
Expand Down Expand Up @@ -733,17 +730,16 @@ theorem shift_right_common_aux_64_2_tff (operand : BitVec 128)
-- Eliminating casting functions
Int.ofNat_eq_coe, ofInt_natCast, ofNat_toNat
]
simp only [reduceExtracLsb', BitVec.zero_add]
generalize (extractLsb' 64 64 operand) = x
generalize (extractLsb' 0 64 operand) = y
have h0 : ∀ (z : BitVec 64), extractLsb' 0 64 ((zeroExtend 65 z).ushiftRight shift)
generalize (extractLsb 127 64 operand) = x; simp at x
generalize (extractLsb 63 0 operand) = y; simp at y
have h0 : ∀ (z : BitVec 64), extractLsb 63 0 ((zeroExtend 65 z).ushiftRight shift)
= z.ushiftRight shift := by
intro z
simp only [ushiftRight, toNat_truncate]
have h1: z.toNat % 2 ^ 65 = z.toNat := by omega
simp only [h1]
simp only [Std.Tactic.BVDecide.Normalize.BitVec.ofNatLt_reduce]
simp only [Nat.sub_zero, Nat.reduceAdd, extractLsb'_ofNat, Nat.shiftRight_zero]
simp only [Nat.sub_zero, Nat.reduceAdd, BitVec.extractLsb_ofNat, Nat.shiftRight_zero]
have h2 : z.toNat >>> shift % 2 ^ 65 = z.toNat >>> shift := by
refine Nat.mod_eq_of_lt ?h3
have h4 : z.toNat >>> shift ≤ z.toNat := by exact shift_le z.toNat shift
Expand Down Expand Up @@ -786,10 +782,10 @@ theorem shift_right_common_aux_32_4_fff (operand : BitVec 128)
unsigned := false, round := false, accumulate := false,
h := (by omega) }
operand 0#128 result =
(sshiftRight (extractLsb' 96 32 operand) shift)
++ (sshiftRight (extractLsb' 64 32 operand) shift)
++ (sshiftRight (extractLsb' 32 32 operand) shift)
++ (sshiftRight (extractLsb' 0 32 operand) shift) := by
(sshiftRight (extractLsb 127 96 operand) shift)
++ (sshiftRight (extractLsb 95 64 operand) shift)
++ (sshiftRight (extractLsb 63 32 operand) shift)
++ (sshiftRight (extractLsb 31 0 operand) shift) := by
unfold shift_right_common_aux
simp only [minimal_theory, bitvec_rules]
unfold shift_right_common_aux
Expand Down Expand Up @@ -829,19 +825,20 @@ theorem shift_right_common_aux_32_4_fff (operand : BitVec 128)
-- Eliminating casting functions
ofInt_eq_signExtend
]
generalize extractLsb' 0 32 operand = a
generalize extractLsb' 32 32 operand = b
generalize extractLsb' 64 32 operand = c
generalize extractLsb' 96 32 operand = d
generalize extractLsb 31 0 operand = a; simp at a
generalize extractLsb 63 32 operand = b; simp at b
generalize extractLsb 95 64 operand = c; simp at c
generalize extractLsb 127 96 operand = d; simp at d
have h : ∀ (x : BitVec 32),
extractLsb' 0 32 ((signExtend 33 x).sshiftRight shift)
extractLsb 31 0 ((signExtend 33 x).sshiftRight shift)
= x.sshiftRight shift := by
intros x
apply eq_of_getLsbD_eq; intros i; simp at i
simp only [getLsbD_sshiftRight]
simp only [getLsbD_extractLsb', Fin.is_lt, decide_True,
Nat.zero_add, getLsbD_sshiftRight,
getLsbD_signExtend, Bool.true_and]
simp only [Nat.sub_zero, Nat.reduceAdd, getLsbD_extract, Nat.zero_add,
getLsbD_sshiftRight, getLsbD_signExtend]
simp only [show (i : Nat) ≤ 31 by omega,
decide_True, Bool.true_and]
simp only [show ¬33 ≤ (i : Nat) by omega,
decide_False, Bool.not_false, Bool.true_and]
simp only [show ¬32 ≤ (i : Nat) by omega,
Expand Down Expand Up @@ -881,7 +878,7 @@ def shift_left_common_aux
if h : info.elements ≤ e then
result
else
let elem := elem_get operand e info.esize
let elem := elem_get operand e info.esize info.h
let shift_elem := elem <<< info.shift
let result := elem_set result e info.esize shift_elem info.h
have _ : info.elements - (e + 1) < info.elements - e := by omega
Expand All @@ -896,8 +893,8 @@ theorem shift_left_common_aux_64_2 (operand : BitVec 128)
unsigned := unsigned, round := round, accumulate := accumulate,
h := (by omega)}
operand result =
(extractLsb' 64 64 operand <<< shift)
++ (extractLsb' 0 64 operand <<< shift) := by
(extractLsb 127 64 operand <<< shift)
++ (extractLsb 63 0 operand <<< shift) := by
unfold shift_left_common_aux
simp only [minimal_theory, bitvec_rules]
unfold shift_left_common_aux
Expand Down
6 changes: 3 additions & 3 deletions Arm/Insts/DPSFP/Advanced_simd_copy.lean
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def exec_dup_element (inst : Advanced_simd_copy_cls) (s : ArmState) : ArmState :
let elements := datasize / esize
let operand := read_sfp idxdsize inst.Rn s
have h₀ : esize > 0 := by apply zero_lt_shift_left_pos (by decide)
let element := elem_get operand index esize
let element := elem_get operand index esize h₀
let result := dup_aux 0 elements esize element (BitVec.zero datasize) h₀
-- State Updates
let s := write_pc ((read_pc s) + 4#64) s
Expand Down Expand Up @@ -81,7 +81,7 @@ def exec_ins_element (inst : Advanced_simd_copy_cls) (s : ArmState) : ArmState :
let operand := read_sfp idxdsize inst.Rn s
let result := read_sfp 128 inst.Rd s
have h₀ : esize > 0 := by apply zero_lt_shift_left_pos (by decide)
let elem := elem_get operand src_index esize
let elem := elem_get operand src_index esize h₀
let result := elem_set result dst_index esize elem h₀
-- State Updates
let s := write_pc ((read_pc s) + 4#64) s
Expand Down Expand Up @@ -123,7 +123,7 @@ def exec_smov_umov (inst : Advanced_simd_copy_cls) (s : ArmState) (signed : Bool
-- if index == 0 then CheckFPEnabled64 else CheckFPAdvSIMDEnabled64
let operand := read_sfp idxdsize inst.Rn s
have h₀ : esize > 0 := by apply zero_lt_shift_left_pos (by decide)
let element := elem_get operand index esize
let element := elem_get operand index esize h₀
let result := if signed then signExtend datasize element else zeroExtend datasize element
-- State Updates
let s := write_pc ((read_pc s) + 4#64) s
Expand Down
4 changes: 2 additions & 2 deletions Arm/Insts/DPSFP/Advanced_simd_permute.lean
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ def trn_aux (p : Nat) (pairs : Nat) (esize : Nat) (part : Nat)
result
else
let idx_from := 2 * p + part
let op1_part := elem_get operand1 idx_from esize
let op2_part := elem_get operand2 idx_from esize
let op1_part := elem_get operand1 idx_from esize h
let op2_part := elem_get operand2 idx_from esize h
let result := elem_set result (2 * p) esize op1_part h
let result := elem_set result (2 * p + 1) esize op2_part h
have h₁ : pairs - (p + 1) < pairs - p := by omega
Expand Down
3 changes: 2 additions & 1 deletion Arm/Insts/DPSFP/Advanced_simd_scalar_copy.lean
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ def exec_advanced_simd_scalar_copy
let idxdsize := 64 <<< (lsb inst.imm5 4).toNat
let esize := 8 <<< size
let operand := read_sfp idxdsize inst.Rn s
let result := elem_get operand index.toNat esize
have h : esize > 0 := by apply zero_lt_shift_left_pos (by decide)
let result := elem_get operand index.toNat esize h
-- State Updates
let s := write_pc ((read_pc s) + 4#64) s
let s := write_sfp esize inst.Rd result s
Expand Down
4 changes: 2 additions & 2 deletions Arm/Insts/DPSFP/Advanced_simd_table_lookup.lean
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,10 @@ def tblx_aux (i : Nat) (elements : Nat) (indices : BitVec datasize)
result
else
have h₁ : 8 > 0 := by decide
let index := (elem_get indices i 8).toNat
let index := (elem_get indices i 8 h₁).toNat
let result :=
if index < 16 * regs then
let val := elem_get table index 8
let val := elem_get table index 8 h₁
elem_set result i 8 val h₁
else
result
Expand Down
4 changes: 2 additions & 2 deletions Arm/Insts/DPSFP/Advanced_simd_three_different.lean
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@ def pmull_op (e : Nat) (esize : Nat) (elements : Nat) (x : BitVec n)
if h₀ : elements <= e then
result
else
let element1 := elem_get x e esize
let element2 := elem_get y e esize
let element1 := elem_get x e esize H
let element2 := elem_get y e esize H
let elem_result := polynomial_mult element1 element2
have h₁ : esize + esize = 2 * esize := by omega
have h₂ : 2 * esize > 0 := by omega
Expand Down
4 changes: 2 additions & 2 deletions Arm/Insts/DPSFP/Advanced_simd_three_same.lean
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ def binary_vector_op_aux (e : Nat) (elems : Nat) (esize : Nat)
result
else
have h₁ : e < elems := by omega
let element1 := elem_get x e esize
let element2 := elem_get y e esize
let element1 := elem_get x e esize H
let element2 := elem_get y e esize H
let elem_result := op element1 element2
let result := elem_set result e esize elem_result H
have ht1 : elems - (e + 1) < elems - e := by omega
Expand Down
1 change: 0 additions & 1 deletion Proofs/AES-GCM/GCMInitV8Sym.lean
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,5 @@ theorem gcm_init_v8_program_correct (s0 sf : ArmState)
, shift_right_common_aux_32_4_fff
, DPSFP.AdvSIMDExpandImm
, DPSFP.dup_aux_0_4_32]
simp only [BitVec.extractLsb'_eq_extractLsb.symm]
generalize read_mem_bytes 16 (r (StateField.GPR 1#5) s0) s0 = Hinit
sorry

0 comments on commit 0247d6a

Please sign in to comment.