Skip to content

Commit

Permalink
Feature/lean external arith shifts (#968)
Browse files Browse the repository at this point in the history
Add lean4 support for the following functions in arith.sail:

    _shl8
    _shl32
    _shl1
    _shl_int
    _shr32
    _shr_int
  • Loading branch information
benjaminselfridge authored Feb 13, 2025
1 parent a434a0a commit 7dee571
Show file tree
Hide file tree
Showing 4 changed files with 80 additions and 9 deletions.
26 changes: 18 additions & 8 deletions lib/arith.sail
Original file line number Diff line number Diff line change
Expand Up @@ -126,20 +126,30 @@ Similarly, we define shifts of 32 and 1 (i.e., powers of two).

The most general shift operations also allow negative shifts which go in the opposite direction, for compatibility with ASL.
*/
val _shl8 = pure {c: "shl_mach_int", _: "shl_int"} :
val _shl8 = pure {c: "shl_mach_int", lean: "Int.shiftl", _: "shl_int"} :
forall 'n, 0 <= 'n <= 3. (int(8), int('n)) -> {'m, 'm in {8, 16, 32, 64}. int('m)}

val _shl32 = pure {c: "shl_mach_int", _: "shl_int"} :
val _shl32 = pure {c: "shl_mach_int", lean: "Int.shiftl", _: "shl_int"} :
forall 'n, 'n in {0, 1}. (int(32), int('n)) -> {'m, 'm in {32, 64}. int('m)}

val _shl1 = pure {c: "shl_mach_int", _: "shl_int"} :
val _shl1 = pure {c: "shl_mach_int", lean: "Int.shiftl", _: "shl_int"} :
forall 'n, 0 <= 'n <= 3. (int(1), int('n)) -> {'m, 'm in {1, 2, 4, 8}. int('m)}

val _shl_int = pure "shl_int" : forall 'n, 0 <= 'n. (int, int('n)) -> int

val _shr32 = pure {c: "shr_mach_int", _: "shr_int"} : forall 'n, 0 <= 'n <= 31. (int('n), int(1)) -> {'m, 0 <= 'm <= 15. int('m)}

val _shr_int = pure "shr_int" : forall 'n, 0 <= 'n. (int, int('n)) -> int
val _shl_int = pure {
lean: "Int.shiftl",
_: "shl_int"
} : forall 'n, 0 <= 'n. (int, int('n)) -> int

val _shr32 = pure {
c: "shr_mach_int",
lean: "Int.shiftl",
_: "shr_int"
} : forall 'n, 0 <= 'n <= 31. (int('n), int(1)) -> {'m, 0 <= 'm <= 15. int('m)}

val _shr_int = pure {
lean: "Int.shiftr",
_: "shr_int"
} : forall 'n, 0 <= 'n. (int, int('n)) -> int

function _shl_int_general(m: int, n: int) -> int = if n >= 0 then _shl_int(m, n) else _shr_int(m, negate(n))
function _shr_int_general(m: int, n: int) -> int = if n >= 0 then _shr_int(m, n) else _shl_int(m, negate(n))
Expand Down
21 changes: 20 additions & 1 deletion src/sail_lean_backend/Sail/Sail.lean
Original file line number Diff line number Diff line change
Expand Up @@ -178,10 +178,29 @@ def addInt {w : Nat} (x : BitVec w) (i : Int) : BitVec w :=

end BitVec

namespace Nat

-- NB: below is taken from Mathlib.Logic.Function.Iterate
/-- Iterate a function. -/
def iterate {α : Sort u} (op : α → α) : Nat → α → α
| 0, a => a
| Nat.succ k, a => iterate op k (op a)

end Nat

namespace Int

def intAbs (x : Int) : Int := Int.ofNat (Int.natAbs x)

end Int
def shiftl (a : Int) (n : Int) : Int :=
match n with
| Int.ofNat n => Sail.Nat.iterate (fun x => x * 2) n a
| Int.negSucc n => Sail.Nat.iterate (fun x => x / 2) (n+1) a

def shiftr (a : Int) (n : Int) : Int :=
match n with
| Int.ofNat n => Sail.Nat.iterate (fun x => x / 2) n a
| Int.negSucc n => Sail.Nat.iterate (fun x => x * 2) (n+1) a

end Int
end Sail
18 changes: 18 additions & 0 deletions test/lean/extern.expected.lean
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,24 @@ def extern_negate (_ : Unit) : Int :=
def extern_mult (_ : Unit) : Int :=
(HMul.hMul 5 4)

def extern__shl8 (_ : Unit) : Int :=
(Int.shiftl 8 2)

def extern__shl32 (_ : Unit) : Int :=
(Int.shiftl 32 1)

def extern__shl1 (_ : Unit) : Int :=
(Int.shiftl 1 2)

def extern__shl_int (_ : Unit) : Int :=
(Int.shiftl 4 2)

def extern__shr32 (_ : Unit) : Int :=
(Int.shiftl 30 1)

def extern__shr_int (_ : Unit) : Int :=
(Int.shiftr 8 2)

def extern_tdiv (_ : Unit) : Int :=
(Int.tdiv 5 4)

Expand Down
24 changes: 24 additions & 0 deletions test/lean/extern.sail
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,30 @@ function extern_mult() -> int = {
return mult_int(5, 4)
}

function extern__shl8() -> int = {
return _shl8(8, 2)
}

function extern__shl32() -> int = {
return _shl32(32, 1)
}

function extern__shl1() -> int = {
return _shl1(1, 2)
}

function extern__shl_int() -> int = {
return _shl_int(4, 2)
}

function extern__shr32() -> int = {
return _shr32(30, 1)
}

function extern__shr_int() -> int = {
return _shr_int(8, 2)
}

function extern_tdiv() -> int = {
return tdiv_int(5, 4)
}
Expand Down

0 comments on commit 7dee571

Please sign in to comment.