Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/lean external arith shifts #968

Merged
merged 17 commits into from
Feb 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
17 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 18 additions & 8 deletions lib/arith.sail
Original file line number Diff line number Diff line change
Expand Up @@ -126,20 +126,30 @@ Similarly, we define shifts of 32 and 1 (i.e., powers of two).

The most general shift operations also allow negative shifts which go in the opposite direction, for compatibility with ASL.
*/
val _shl8 = pure {c: "shl_mach_int", _: "shl_int"} :
val _shl8 = pure {c: "shl_mach_int", lean: "Int.shiftl", _: "shl_int"} :
forall 'n, 0 <= 'n <= 3. (int(8), int('n)) -> {'m, 'm in {8, 16, 32, 64}. int('m)}

val _shl32 = pure {c: "shl_mach_int", _: "shl_int"} :
val _shl32 = pure {c: "shl_mach_int", lean: "Int.shiftl", _: "shl_int"} :
forall 'n, 'n in {0, 1}. (int(32), int('n)) -> {'m, 'm in {32, 64}. int('m)}

val _shl1 = pure {c: "shl_mach_int", _: "shl_int"} :
val _shl1 = pure {c: "shl_mach_int", lean: "Int.shiftl", _: "shl_int"} :
forall 'n, 0 <= 'n <= 3. (int(1), int('n)) -> {'m, 'm in {1, 2, 4, 8}. int('m)}

val _shl_int = pure "shl_int" : forall 'n, 0 <= 'n. (int, int('n)) -> int

val _shr32 = pure {c: "shr_mach_int", _: "shr_int"} : forall 'n, 0 <= 'n <= 31. (int('n), int(1)) -> {'m, 0 <= 'm <= 15. int('m)}

val _shr_int = pure "shr_int" : forall 'n, 0 <= 'n. (int, int('n)) -> int
val _shl_int = pure {
lean: "Int.shiftl",
_: "shl_int"
} : forall 'n, 0 <= 'n. (int, int('n)) -> int

val _shr32 = pure {
c: "shr_mach_int",
lean: "Int.shiftl",
_: "shr_int"
} : forall 'n, 0 <= 'n <= 31. (int('n), int(1)) -> {'m, 0 <= 'm <= 15. int('m)}

val _shr_int = pure {
lean: "Int.shiftr",
_: "shr_int"
} : forall 'n, 0 <= 'n. (int, int('n)) -> int

function _shl_int_general(m: int, n: int) -> int = if n >= 0 then _shl_int(m, n) else _shr_int(m, negate(n))
function _shr_int_general(m: int, n: int) -> int = if n >= 0 then _shr_int(m, n) else _shl_int(m, negate(n))
Expand Down
5 changes: 4 additions & 1 deletion src/bin/dune
Original file line number Diff line number Diff line change
Expand Up @@ -251,4 +251,7 @@
src/gen_lib/sail2_values.lem)
(%{workspace_root}/src/sail_lean_backend/Sail/Sail.lean
as
src/sail_lean_backend/Sail/Sail.lean)))
src/sail_lean_backend/Sail/Sail.lean)
(%{workspace_root}/src/sail_lean_backend/Sail/BitVec.lean
as
src/sail_lean_backend/Sail/BitVec.lean)))
212 changes: 212 additions & 0 deletions src/sail_lean_backend/Sail/BitVec.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,212 @@
/-
Copyright (c) 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Author(s): Shilpi Goel, Siddharth Bhat
-/

-- Taken from https://github.com/leanprover/LNSym/blob/main/Arm/BitVec.lean

import Lean.Elab.Term
import Lean.Meta.Reduce
import Std.Tactic.BVDecide

open BitVec

/- Bitvector pattern component syntax category, originally written by
Leonardo de Moura. -/
declare_syntax_cat bvpat_comp
syntax num : bvpat_comp
syntax ident (":" num)? : bvpat_comp
syntax "_" ":" num : bvpat_comp

/--
Bitvector pattern syntax category.
Example: [sf:1,0011010000,Rm:5,000000,Rn:5,Rd:5]
-/
declare_syntax_cat bvpat
syntax "[" bvpat_comp,* "]" : bvpat

open Lean

abbrev BVPatComp := TSyntax `bvpat_comp
abbrev BVPat := TSyntax `bvpat

/-- Return the number of bits in a bit-vector component pattern. -/
def BVPatComp.length (c : BVPatComp) : Nat := Id.run do
match c with
| `(bvpat_comp| $n:num) =>
let some str := n.raw.isLit? `num | pure 0
return str.length
| `(bvpat_comp| $_:ident : $n:num) =>
return n.raw.toNat
| `(bvpat_comp| $_:ident ) =>
return 1
| `(bvpat_comp| _ : $n:num) =>
return n.raw.toNat
| _ =>
return 0

/--
If the pattern component is a bitvector literal, convert it into a bit-vector term
denoting it.
-/
def BVPatComp.toBVLit? (c : BVPatComp) : MacroM (Option Term) := do
match c with
| `(bvpat_comp| $n:num) =>
let len := c.length
let some str := n.raw.isLit? `num | Macro.throwErrorAt c "invalid bit-vector literal"
let bs := str.toList
let mut val := 0
for b in bs do
if b = '1' then
val := 2*val + 1
else if b = '0' then
val := 2*val
else
Macro.throwErrorAt c "invalid bit-vector literal, '0'/'1's expected"
let r ← `(BitVec.ofNat $(quote len) $(quote val))
return some r
| _ => return none

/--
If the pattern component is a pattern variable of the form `<id>:<size>` return
`some id`.
-/
def BVPatComp.toBVVar? (c : BVPatComp) : MacroM (Option (TSyntax `ident)) := do
match c with
| `(bvpat_comp| $x:ident $[: $_:num]?) =>
return some x
| _ => return none

def BVPat.getComponents (p : BVPat) : Array BVPatComp :=
match p with
| `(bvpat| [$comp,*]) => comp.getElems.reverse
| _ => #[]

/--
Return the number of bits in a bit-vector pattern.
-/
def BVPat.length (p : BVPat) : Nat := Id.run do
let mut sz := 0
for c in p.getComponents do
sz := sz + c.length
return sz

/--
Return a term that evaluates to `true` if `var` is an instance of the pattern `pat`.
-/
def genBVPatMatchTest (vars : Array Term) (pats : Array BVPat) : MacroM Term := do
if vars.size != pats.size then
Macro.throwError "incorrect number of patterns"
let mut result ← `(true)

for (pat, var) in pats.zip vars do
let mut shift := 0
for c in pat.getComponents do
let len := c.length
if let some bv ← c.toBVLit? then
let test ← `(extractLsb $(quote (shift + (len - 1))) $(quote shift) $var == $bv)
result ← `($result && $test)
shift := shift + len
return result

/--
Given a variable `var` representing a term that matches the pattern `pat`, and a term `rhs`,
return a term of the form
```
let y₁ := var.extract ..
...
let yₙ := var.extract ..
rhs
```
where `yᵢ`s are the pattern variables in `pat`.
-/
def declBVPatVars (vars : Array Term) (pats : Array BVPat) (rhs : Term) : MacroM Term := do
let mut result := rhs
for (pat, var) in pats.zip vars do
let mut shift := 0
for c in pat.getComponents do
let len := c.length
if let some y ← c.toBVVar? then
let rhs ← `(extractLsb $(quote (shift + (len - 1))) $(quote shift) $var)
result ← `(let $y := $rhs; $result)
shift := shift + len
return result

/--
Define the `match_bv .. with | bvpat => rhs | _ => rhs`.
The last entry is the `else`-case since we currently do not check whether
the patterns are exhaustive or not.
-/
syntax (name := matchBv) "match_bv " term,+ "with" (atomic("| " bvpat,+) " => " term)* ("| " "_ " " => " term)? : term

open Lean
open Elab
open Term

def checkBVPatLengths (lens : Array (Option Nat)) (pss : Array (Array BVPat)) : TermElabM Unit := do
for (len, i) in lens.zipWithIndex do
let mut patLen := none
for ps in pss do
unless ps.size == lens.size do
throwError "Expected {lens.size} patterns, found {ps.size}"
let p := ps[i]!
let pLen := p.length

-- compare the length to that of the type of the discriminant
if let some pLen' := len then
unless pLen == pLen' do
throwErrorAt p "Exprected pattern of length {pLen}, found {pLen'} instead"

-- compare the lengths of the patterns
if let some pLen' := patLen then
unless pLen == pLen' do
throwErrorAt p "patterns have differrent lengths"
else
patLen := some pLen

-- We use this to gather all the conditions expressing that the
-- previous pattern matches failed. This allows in turn to prove
-- exaustivity of the pattern matching.
abbrev dite_gather {α : Sort u} {old : Prop} (c : Prop) [h : Decidable c]
(t : old ∧ c → α) (e : old ∧ ¬ c → α) (ho : old) : α :=
h.casesOn (λ hc => e (And.intro ho hc)) (λ hc => t (And.intro ho hc))

@[term_elab matchBv]
partial
def elabMatchBv : TermElab := fun stx typ? =>
match stx with
| `(match_bv $[$discrs:term],* with
$[ | $[$pss:bvpat],* => $rhss:term ]*
$[| _ => $rhsElse?:term]?) => do
let xs := discrs

-- try to get the length of the BV to error-out
-- if a pattern has the wrong length
-- TODO: is it the best way to do that?
let lens ← discrs.mapM (fun x => do
let x ← elabTerm x none
let typ ← Meta.inferType x
match_expr typ with
| BitVec n =>
let n ← Meta.reduce n
match n with
| .lit (.natVal n) => return some n
| _ => return none
| _ => return none)

checkBVPatLengths lens pss

let mut result :=
← if let some rhsElse := rhsElse? then
`(Function.const _ $rhsElse)
else
`(fun _ => by bv_decide)

for ps in pss.reverse, rhs in rhss.reverse do
let test ← liftMacroM <| genBVPatMatchTest xs ps
let rhs ← liftMacroM <| declBVPatVars xs ps rhs
result ← `(dite_gather $test (Function.const _ $rhs) $result)
let res ← liftMacroM <| `($result True.intro)
elabTerm res typ?
| _ => throwError "invalid syntax"
21 changes: 20 additions & 1 deletion src/sail_lean_backend/Sail/Sail.lean
Original file line number Diff line number Diff line change
Expand Up @@ -178,10 +178,29 @@ def addInt {w : Nat} (x : BitVec w) (i : Int) : BitVec w :=

end BitVec

namespace Nat

-- NB: below is taken from Mathlib.Logic.Function.Iterate
/-- Iterate a function. -/
def iterate {α : Sort u} (op : α → α) : Nat → α → α
| 0, a => a
| Nat.succ k, a => iterate op k (op a)

end Nat

namespace Int

def intAbs (x : Int) : Int := Int.ofNat (Int.natAbs x)

end Int
def shiftl (a : Int) (n : Int) : Int :=
match n with
| Int.ofNat n => Sail.Nat.iterate (fun x => x * 2) n a
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Were you still planning on using shift operators for these? We're going to merge, but maybe you can open another PR with those updates.

| Int.negSucc n => Sail.Nat.iterate (fun x => x / 2) (n+1) a

def shiftr (a : Int) (n : Int) : Int :=
match n with
| Int.ofNat n => Sail.Nat.iterate (fun x => x / 2) n a
| Int.negSucc n => Sail.Nat.iterate (fun x => x * 2) (n+1) a

end Int
end Sail
31 changes: 27 additions & 4 deletions src/sail_lean_backend/pretty_print_lean.ml
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,12 @@ let doc_lit (L_aux (lit, l)) =
| L_string s -> utf8string ("\"" ^ lean_escape_string s ^ "\"")
| L_real s -> utf8string s (* TODO test if this is really working *)

let doc_vec_lit (L_aux (lit, _) as l) =
match lit with
| L_zero -> string "0"
| L_one -> string "1"
| _ -> failwith "Unexpected litteral found in vector: " ^^ doc_lit l

let string_of_exp_con (E_aux (e, _)) =
match e with
| E_block _ -> "E_block"
Expand Down Expand Up @@ -362,17 +368,25 @@ let string_of_pat_con (P_aux (p, _)) =
let fixup_match_id (Id_aux (id, l) as id') =
match id with Id id -> Id_aux (Id (match id with "Some" -> "some" | "None" -> "none" | _ -> id), l) | _ -> id'

let rec doc_pat (P_aux (p, (l, annot)) as pat) =
let rec doc_pat ?(in_vector = false) (P_aux (p, (l, annot)) as pat) =
match p with
| P_wild -> underscore
| P_lit lit when in_vector -> doc_vec_lit lit
| P_lit lit -> doc_lit lit
| P_typ (Typ_aux (Typ_id (Id_aux (Id "bit", _)), _), p) when in_vector -> doc_pat p ^^ string ":1"
| P_typ (Typ_aux (Typ_app (Id_aux (Id id, _), [A_aux (A_nexp (Nexp_aux (Nexp_constant i, _)), _)]), _), p)
when in_vector && (id = "bits" || id = "bitvector") ->
doc_pat p ^^ string ":" ^^ doc_big_int i
| P_typ (ptyp, p) -> doc_pat p
| P_id id -> fixup_match_id id |> doc_id_ctor
| P_tuple pats -> separate (string ", ") (List.map doc_pat pats) |> parens
| P_list pats -> separate (string ", ") (List.map doc_pat pats) |> brackets
| P_vector pats -> concat (List.map (doc_pat ~in_vector:true) pats)
| P_vector_concat pats -> separate (string ",") (List.map (doc_pat ~in_vector:true) pats) |> brackets
| P_app (Id_aux (Id "None", _), p) -> string "none"
| P_app (cons, pats) -> doc_id_ctor (fixup_match_id cons) ^^ space ^^ separate_map (string ", ") doc_pat pats
| _ -> failwith ("Pattern " ^ string_of_pat_con pat ^ " " ^ string_of_pat pat ^ " not translatable yet.")
| P_as (pat, id) -> doc_pat pat
| _ -> failwith ("Doc Pattern " ^ string_of_pat_con pat ^ " " ^ string_of_pat pat ^ " not translatable yet.")

(* Copied from the Coq PP *)
let rebind_cast_pattern_vars pat typ exp =
Expand Down Expand Up @@ -412,6 +426,13 @@ let get_fn_implicits (Typ_aux (t, _)) : bool list =
in
match t with Typ_fn (args, cod) -> List.map arg_implicit args | _ -> []

let rec is_bitvector_pattern (P_aux (pat, _)) =
match pat with P_vector _ | P_vector_concat _ -> true | P_as (pat, _) -> is_bitvector_pattern pat | _ -> false

let match_or_match_bv brs =
if List.exists (function Pat_aux (Pat_exp (pat, _), _) -> is_bitvector_pattern pat | _ -> false) brs then "match_bv "
else "match "

let rec doc_match_clause (as_monadic : bool) ctx (Pat_aux (cl, l)) =
match cl with
| Pat_exp (pat, branch) ->
Expand Down Expand Up @@ -493,8 +514,10 @@ and doc_exp (as_monadic : bool) ctx (E_aux (e, (l, annot)) as full_exp) =
wrap_with_pure as_monadic
(braces (space ^^ doc_exp false ctx exp ^^ string " with " ^^ separate (comma ^^ space) args ^^ space))
| E_match (discr, brs) ->
let cases = separate_map hardline (fun br -> doc_match_clause as_monadic ctx br) brs in
string "match " ^^ doc_exp (effectful (effect_of discr)) ctx discr ^^ string " with" ^^ hardline ^^ cases
let cases = separate_map hardline (doc_match_clause as_monadic ctx) brs in
string (match_or_match_bv brs)
^^ doc_exp (effectful (effect_of discr)) ctx discr
^^ string " with" ^^ hardline ^^ cases
| E_assign ((LE_aux (le_act, tannot) as le), e) -> (
match le_act with
| LE_id id | LE_typ (_, id) -> string "writeReg " ^^ doc_id_ctor id ^^ space ^^ doc_exp false ctx e
Expand Down
11 changes: 6 additions & 5 deletions src/sail_lean_backend/sail_plugin_lean.ml
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ let lean_rewrites =
("move_termination_measures", []);
("instantiate_outcomes", [String_arg "coq"]);
("realize_mappings", []);
("remove_vector_subrange_pats", []);
(* ("remove_vector_subrange_pats", []); *)
("remove_duplicate_valspecs", []);
("toplevel_string_append", []);
("pat_string_append", []);
Expand All @@ -107,8 +107,8 @@ let lean_rewrites =
("tuple_assignments", []);
("vector_concat_assignments", []);
("simple_assignments", []);
("remove_vector_concat", []);
("remove_bitvector_pats", []);
(* ("remove_vector_concat", []); *)
(* ("remove_bitvector_pats", []); *)
(* ("remove_numeral_pats", []); *)
(* ("pattern_literals", [Literal_arg "lem"]); *)
("guarded_pats", []);
Expand All @@ -129,7 +129,7 @@ let lean_rewrites =
(* We need to do the exhaustiveness check before merging, because it may
introduce new wildcard clauses *)
("recheck_defs", []);
("make_cases_exhaustive", []);
(* ("make_cases_exhaustive", []); *)
(* merge funcls before adding the measure argument so that it doesn't
disappear into an internal pattern match *)
("merge_function_clauses", []);
Expand Down Expand Up @@ -185,7 +185,8 @@ let start_lean_output (out_name : string) default_sail_dir =
("cp -r " ^ Filename.quote (sail_dir ^ "/src/sail_lean_backend/Sail") ^ " " ^ Filename.quote lean_src_dir)
in
let main_file = open_out (Filename.concat project_dir (out_name_camel ^ ".lean")) in
output_string main_file ("import " ^ out_name_camel ^ ".Sail.Sail\n\n");
output_string main_file ("import " ^ out_name_camel ^ ".Sail.Sail\n");
output_string main_file ("import " ^ out_name_camel ^ ".Sail.BitVec\n\n");
output_string main_file "open Sail\n\n";
let lakefile = open_out (Filename.concat project_dir "lakefile.toml") in
{ out_name; out_name_camel; sail_dir; main_file; lakefile }
Expand Down
Loading
Loading