Skip to content

Commit

Permalink
Initial WIP comit of Lean lrat proof checker.
Browse files Browse the repository at this point in the history
  • Loading branch information
Joe Hendrix committed Mar 14, 2021
0 parents commit 26eb2d4
Show file tree
Hide file tree
Showing 10 changed files with 579 additions and 0 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
/build
12 changes: 12 additions & 0 deletions LRat.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
import LRat.Dimacs
import LRat.LRat

def main (args:List String) : IO Unit := do
match args with
| [dimacsFile, lratFile] => do
let h ← HStream.fromPath dimacsFile
let cnf ← Dimacs.read h
let h ← HStream.fromPath lratFile
readLRat h cnf.varCount (ClauseDB.fromDimacs cnf)
| _ => do
IO.println "Expected dimacsfile and lratfile."
29 changes: 29 additions & 0 deletions LRat/Common.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
def Char.toUInt8 (c:Char) := UInt8.ofNat c.toNat

def UInt8.toChar (w:UInt8) : Char := Char.ofNat w.toNat
def UInt8.toUInt64 (w:UInt8) : UInt64 := UInt64.ofNat w.toNat

def String.toByteArray (s:String) : ByteArray := s.foldl (λa c => a.push (UInt8.ofNat c.toNat)) ByteArray.empty

partial def ByteArray.beq (x y : ByteArray) : Bool := do
if x.size == y.size then
let rec loop (i : Nat) :=
if i ≥ x.size then
true
else if x.get! i == y.get! i then
loop (i+1)
else
false
loop 0
else
false

instance : BEq ByteArray where
beq := ByteArray.beq

def max {α} [h:HasLessEq α] [DecidableRel (@HasLessEq.LessEq α h)] (x y : α) : α := if x ≥ y then x else y

class Member (α : Type u) (β : Type v) where
member : α → β → Prop

infix:50 "" => Member.member
33 changes: 33 additions & 0 deletions LRat/Cont.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
universes u v w

def ContT (ρ : Type u) (m : Type u → Type v) (α : Type u) : Type (max u v) :=
(α → m ρ) → m ρ

namespace ContT
section
variable {ρ : Type u} {m : Type u → Type v}
variable [Monad m] {α β : Type u}

@[inline] protected def pure (a : α) : ContT ρ m α := fun c => c a

@[inline] protected def bind (x : ContT ρ m α) (f : α → ContT ρ m β) : ContT ρ m β :=
fun c => x (fun a => f a c)

def terminate (r : ρ) : ContT ρ m α := λc => pure r

-- @[inline] protected def map (f : α → β) (x : ContT ρ m α) : ContT ρ m β :=
-- fun c => x (fun a => pure (c ) do let (a, s) ← x s; pure (f a, s)

instance : Monad (ContT σ m) where
pure := ContT.pure
bind := ContT.bind
-- map := StateT.Maps

end
end ContT

export ContT(terminate)

abbrev Cont ρ := ContT ρ Id

def Cont.runSame {ρ} (c:Cont ρ ρ) : ρ := c id
110 changes: 110 additions & 0 deletions LRat/Dimacs.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
import LRat.HStream
import LRat.SignedInt

/-- A literal
- A literal is encoded as the literal index shifted by one.
- The least significant bit is set if the literal is negated.
-/
def Lit := SignedInt

namespace Lit

def isNull (l:Lit) : Bool := !(l.value == 0)

def var (l:Lit) : UInt64 := l.magnitude

/-- Return true if literal is positive and false if negative. -/
def polarity (l:Lit) : Bool := l.isPos

-- @Lit.read h vc@ read the next signed numeral from @h@ with magnitude
-- between 0 and vc and returns a literal for it.
def read (h:HStream) (varCount: UInt64) : IO Lit := SignedInt.read h varCount

-- Negate literal
def negate (l:Lit) : Lit := ⟨l.value ^^^ 1

instance : Inhabited Lit := ⟨{value := 0}⟩

protected def beq (x y : Lit) : Bool := x.value == y.value

instance : BEq Lit where
beq := Lit.beq

end Lit

structure Clause :=
(lits : Array Lit)

namespace Clause

-- def empty : Clause := ⟨Array.empty⟩

def pivot (c:Clause) : Lit := c.lits[0]

protected def forIn {β} [Monad m] (x : Clause) (b : β) (f : Lit → β → m (ForInStep β)) : m β := x.lits.forIn b f

instance : ForIn m Clause Lit where
forIn := Clause.forIn

protected def member (c:Clause) (l:Lit) : Bool := do
let mut r : Bool := false
for k in c.lits do
if l == k then
r := true
break
r

protected def size (self:Clause) : Nat := self.lits.size

/-- Return lit at given index in clause. -/
def getOp (self:Clause) (idx:Nat) : Lit := self.lits[idx]

--- @Clause.read' h vc a@ Read a list of ints with magnitude between 1 and vc
--- and stops when it reads a zero.
partial def read' (h:HStream) (varCount: UInt64) (a:Array Lit) : IO Clause := do
h.skipWS
let l ← Lit.read h varCount
if l.isNull then
pure ⟨a⟩
else
read' h varCount (a.push l)

/--- Read a line expected to contain a clause. -/
def read (h:HStream) (varCount:UInt64): IO Clause := do
read' h varCount Array.empty

end Clause


structure Dimacs :=
(varCount : UInt64)
(clauses : Array Clause)

def Dimacs.clauseCount (d:Dimacs) : Nat := d.clauses.size

partial def Dimacs.read (h:HStream) : IO Dimacs := do
let c ← h.getByte
if c == 'c'.toUInt8 then
let _ ← h.getLine
read h
else if c == 'p'.toUInt8 then
let cnf ← h.getWord
if cnf != "cnf".toByteArray then do
throw (IO.userError ("Expected \"cnf\" -- found " ++ toString cnf))
else
let varCnt ← h.getUInt64
let clauseCnt ← h.getUInt64
if varCnt ≥ UInt64.ofNat (UInt64.size >>> 1) then
throw $ IO.userError "Variable count is too large."
let _ ← h.getLine
let rec loop (remaining:UInt64) (a:Array Clause) : IO (Array Clause) := do
if remaining == 0 then
pure a
else do
let c ← Clause.read h varCnt
let _ ← h.getLine
loop (remaining - 1) (a.push c)
let a ← loop clauseCnt Array.empty
pure { varCount := varCnt, clauses := a }
else
throw (IO.userError ("Unknown command: " ++ toString c))
81 changes: 81 additions & 0 deletions LRat/HStream.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
import LRat.Common

structure HStream where
next : IO.Ref UInt8
rest : IO.FS.Handle

namespace HStream

def fromPath (path:String) : IO HStream := do
let h ← IO.FS.Handle.mk path IO.FS.Mode.read
let n ← h.read 1
let r ← IO.mkRef (n.get! 0)
pure { next := r, rest := h }

def peekByte (h:HStream) : IO UInt8 := h.next.get

def skipByte (h:HStream) : IO Unit := do
let n ← h.rest.read 1
h.next.set (n.get! 0)

def getByte (h:HStream) : IO UInt8 := do
let b ← peekByte h
h.skipByte
pure b

def getLine (h:HStream) : IO Unit := do
let b ← h.next.get
if b == 10 then
h.skipByte
else
let _ ← h.rest.getLine
h.skipByte

-- Skip whitespace
partial def skipWS (h:HStream) : IO Unit := do
let b ← h.peekByte
if b == ' '.toUInt8 then
h.skipByte
h.skipWS
else
pure ()

partial def getWord' (h:HStream) (a:ByteArray) : IO ByteArray := do
let b ← h.peekByte
if b == ' '.toUInt8 then
pure a
else
h.skipByte
h.getWord' (a.push b)

partial def getWord (h:HStream) : IO ByteArray := do
let b ← h.getByte
if b == ' '.toUInt8 then
h.getWord
else
h.getWord' (ByteArray.empty.push b)

partial def getUInt64' (h:HStream) (c : UInt64) : IO UInt64 := do
let b ← h.peekByte
if '0'.toUInt8 ≤ b && b ≤ '9'.toUInt8 then
h.skipByte
let d := (b - '0'.toUInt8)
let c' := 10 * c + d.toUInt64
if c' < c then
throw (IO.userError <| s! "Numeric overflow: 10 * {c} + {d} .")
h.getUInt64' c'
else if b == ' '.toUInt8 || b == 10 || b == 13 then
pure c
else
throw (IO.userError <| s! "Expected digit {b}.")

partial def getUInt64 (h:HStream) : IO UInt64 := do
h.skipWS
let b ← h.peekByte
if '0'.toUInt8 ≤ b && b ≤ '9'.toUInt8 then
h.skipByte
h.getUInt64' (b - '0'.toUInt8).toUInt64
else
throw (IO.userError "Expected initial digit.")

end HStream
Loading

0 comments on commit 26eb2d4

Please sign in to comment.