diff --git a/LRat.lean b/LRat.lean index d137b52..0ccb72a 100644 --- a/LRat.lean +++ b/LRat.lean @@ -1,8 +1,12 @@ + import LRat.Dimacs import LRat.LRat def main (args:List String) : IO Unit := do match args with + | [dimacsFile] => do + let h ← HStream.fromPath dimacsFile + let cnf ← Dimacs.read h | [dimacsFile, lratFile] => do let h ← HStream.fromPath dimacsFile let cnf ← Dimacs.read h diff --git a/LRat/Dimacs.lean b/LRat/Dimacs.lean index c387a7c..0e6b1f1 100644 --- a/LRat/Dimacs.lean +++ b/LRat/Dimacs.lean @@ -9,7 +9,7 @@ def Lit := SignedInt namespace Lit -def isNull (l:Lit) : Bool := !(l.value == 0) +def isNull (l:Lit) : Bool := l.value == 0 def var (l:Lit) : UInt64 := l.magnitude @@ -57,7 +57,8 @@ protected def member (c:Clause) (l:Lit) : Bool := do protected def size (self:Clause) : Nat := self.lits.size /-- Return lit at given index in clause. -/ -def getOp (self:Clause) (idx:Nat) : Lit := self.lits[idx] +def getOp (self:Clause) (idx:Nat) : Lit := + if idx < self.lits.size then self.lits[idx] else ⟨0⟩ --- @Clause.read' h vc a@ Read a list of ints with magnitude between 1 and vc --- and stops when it reads a zero. @@ -89,7 +90,7 @@ partial def Dimacs.read (h:HStream) : IO Dimacs := do read h else if c == 'p'.toUInt8 then let cnf ← h.getWord - if cnf != "cnf".toByteArray then do + if cnf != "cnf".toByteArray then throw (IO.userError ("Expected \"cnf\" -- found " ++ toString cnf)) else let varCnt ← h.getUInt64 @@ -102,7 +103,8 @@ partial def Dimacs.read (h:HStream) : IO Dimacs := do pure a else do let c ← Clause.read h varCnt - let _ ← h.getLine + if !(← h.isEof) then + h.getLine loop (remaining - 1) (a.push c) let a ← loop clauseCnt Array.empty pure { varCount := varCnt, clauses := a } diff --git a/LRat/HStream.lean b/LRat/HStream.lean index d2678a0..70a1946 100644 --- a/LRat/HStream.lean +++ b/LRat/HStream.lean @@ -1,31 +1,60 @@ import LRat.Common structure HStream where - next : IO.Ref UInt8 - rest : IO.FS.Handle + eofRef : IO.Ref Bool + next : IO.Ref UInt8 + rest : IO.FS.Handle namespace HStream def fromPath (path:String) : IO HStream := do let h ← IO.FS.Handle.mk path IO.FS.Mode.read - let n ← h.read 1 - let r ← IO.mkRef (n.get! 0) - pure { next := r, rest := h } -def peekByte (h:HStream) : IO UInt8 := h.next.get + let eof ← h.isEof + let b ← if eof then pure ByteArray.empty else h.read 1 + let n := if b.size > 0 then b.get! 0 else 0 + + let eofRef ← IO.mkRef (b.size == 0) + let nextRef ← IO.mkRef n + pure { eofRef := eofRef, next := nextRef, rest := h } + +def isEof (h:HStream) : IO Bool := h.eofRef.get + +def peekByte (h:HStream) : IO UInt8 := do + if ← h.isEof then + throw $ IO.userError "Attempt to read past end of file." + h.next.get def skipByte (h:HStream) : IO Unit := do - let n ← h.rest.read 1 - h.next.set (n.get! 0) + if ← h.isEof then + throw $ IO.userError "Attempt to read past end of file." + if ← h.rest.isEof then + h.eofRef.set true + else + let b ← h.rest.read 1 + if b.size == 0 then + h.eofRef.set true + else + h.next.set (b.get! 0) def getByte (h:HStream) : IO UInt8 := do - let b ← peekByte h - h.skipByte + if ← h.eofRef.get then + throw $ IO.userError "Attempt to read past end of file." + let b ← h.next.get + if ← h.rest.isEof then + h.eofRef.set true + else + let b ← h.rest.read 1 + if b.size == 0 then + h.eofRef.set true + else + h.next.set (b.get! 0) pure b def getLine (h:HStream) : IO Unit := do - let b ← h.next.get - if b == 10 then + if ← h.eofRef.get then + throw $ IO.userError "Attempt to read past end of file." + if (← h.next.get) == 10 then h.skipByte else let _ ← h.rest.getLine @@ -33,8 +62,7 @@ def getLine (h:HStream) : IO Unit := do -- Skip whitespace partial def skipWS (h:HStream) : IO Unit := do - let b ← h.peekByte - if b == ' '.toUInt8 then + if (← h.peekByte) == ' '.toUInt8 then h.skipByte h.skipWS else diff --git a/LRat/LRat.lean b/LRat/LRat.lean index 5d40975..9bde79d 100644 --- a/LRat/LRat.lean +++ b/LRat/LRat.lean @@ -19,25 +19,26 @@ end Std structure Assignment where values : Std.HashMap UInt64 Bool -/- -application type mismatch - (0, c) -argument - c -has type - Clause : Type -but is expected to have type - Array Lit : Type --/ - namespace Assignment +protected def toString (a:Assignment) : String := do + let av := a.values.toArray + if av.size == 0 then + "[]" + else + let ppElt (p:UInt64 × Bool) := if p.snd then s! "{p.fst}" else s! "!{p.fst}" + let mut r : String := s! "[{ppElt av[0]}" + for i in [1:av.size] do + r := s! "{r},{ppElt av[i]}" + s! "{r}]" + +instance : ToString Assignment where + toString := Assignment.toString + def empty : Assignment := { values := Std.HashMap.empty } --- Set the value of the literal in the assignment --- --- Fails if literal already assigned. -def set! (a:Assignment) (l:Lit) : Assignment := +def set (a:Assignment) (l:Lit) : Assignment := { values := a.values.insert l.var l.polarity } --- Set the value of the literal in the assignment @@ -70,6 +71,8 @@ end Assignment @[reducible] def ClauseId := UInt64 +def ClauseId.ofNat := UInt64.ofNat + -- A set of clauses for checking. structure ClauseDB where -- First clause index (0 if empty) @@ -83,9 +86,18 @@ structure ClauseDB where namespace ClauseDB -def fromDimacs (d:Dimacs) : ClauseDB := sorry - ---def empty : ClauseDB := { maxIdx := 0, clauses := ∅ } +def fromDimacs (d:Dimacs) : ClauseDB := do + let cl := d.clauses + if cl.size == 0 then + pure { headClauseId := 0, lastClauseId := 0, maxClauseId := 0, clauses := ∅ } + else + let cnt := ClauseId.ofNat cl.size + let mut cm : Std.HashMap ClauseId (ClauseId × ClauseId × Array Lit) := ∅ + for i in [1:cl.size+1], c in cl do + let i := ClauseId.ofNat i + let n := if i == cnt then 0 else i+1 + cm := cm.insert i (i-1, n, c.lits) + pure { headClauseId := 1, lastClauseId := cnt, maxClauseId := cnt, clauses := cm } protected partial def forIn {β} [Monad m] (db : ClauseDB) (b : β) @@ -146,7 +158,6 @@ inductive RupResult | trueLit : RupResult -- Returned if literal in clause is true | multipleUnassigned : RupResult -- Returned if here are multiple unassigned literals. - /-- Apply unit propagation to an assignment and clause-/ def rup (a:Assignment) (cl:Clause) : RupResult := do -- Return conflict if we do not find an unassigned or true literal @@ -188,11 +199,12 @@ partial def getRup (h:HStream) (db:ClauseDB) (a:Assignment) : IO ClauseId := do | RupResult.conflict => do let r ← SignedInt.read h db.maxClauseId if !r.isZero then - throw $ IO.userError "Expected zero after conflict." - let _ <- h.getLine + throw $ IO.userError $ s! "Expected zero instead of {r} after conflict." + if !(←h.isEof) then + h.getLine pure 0 | RupResult.unit l => do - getRup h db (a.set! l) + getRup h db (a.set l) | RupResult.trueLit => throw $ IO.userError "Found true literal in clause." | RupResult.multipleUnassigned => @@ -223,14 +235,14 @@ partial def readRup (h:HStream) (db:ClauseDB) (pivot:Lit) (a:Assignment) : IO Un continue match a[l] with -- Assign proof - | none => a := a.set! l.negate + | none => a := a.set l.negate -- If literal is already false then do nothing | some false => pure () -- If literal is true, then we should be able to resolve. | some true => resolved := true break - -- We already resolved this so there should just be end of + -- We already resolved this so there should just be end of clauses. if resolved then let n ← SignedInt.read h db.maxClauseId if !n.isZero && n.isPos then @@ -239,10 +251,11 @@ partial def readRup (h:HStream) (db:ClauseDB) (pivot:Lit) (a:Assignment) : IO Un else clId ← getRup h db a +def lratError {α} (msg:String) : IO α := do + throw $ IO.userError msg + partial def readLRat (h:HStream) (varCount:UInt64) (db:ClauseDB) : IO Unit := do let newClauseId ← h.getUInt64 - if h : newClauseId ≤ db.maxClauseId then - throw $ IO.userError "Expected new clause to exceed maximum clause." h.skipWS; let c ← h.peekByte -- If deletion information @@ -258,7 +271,17 @@ partial def readLRat (h:HStream) (varCount:UInt64) (db:ClauseDB) : IO Unit := do loop db readLRat h varCount db else + if h : newClauseId ≤ db.maxClauseId then + lratError $ s! "Expected new clause id {newClauseId} to exceed maximum clause id {db.maxClauseId}." let cl ← Clause.read h varCount - readRup h db cl.pivot.negate (Assignment.negatedClause cl) - let db := db.insertClause newClauseId cl - readLRat h varCount db + if cl.size == 0 then + let clId0 ← getRup h db (Assignment.negatedClause cl) + if clId0 != 0 then + lratError "Final conflict clause only resolvable through unit propagation." + else + readRup h db cl.pivot.negate (Assignment.negatedClause cl) + if cl.size == 0 then + IO.println "UNSAT" + else + let db := db.insertClause newClauseId cl + readLRat h varCount db diff --git a/LRat/SignedInt.lean b/LRat/SignedInt.lean index f0834f0..62160e6 100644 --- a/LRat/SignedInt.lean +++ b/LRat/SignedInt.lean @@ -18,6 +18,16 @@ def SignedInt.isPos (l:SignedInt) : Bool := l.value &&& 1 == 0 /-- Return true if literal is negative. -/ def SignedInt.isNeg (l:SignedInt) : Bool := l.value &&& 1 == 1 +namespace SignedInt + +protected def toString (s:SignedInt) : String := + if s.isNeg then s! "-{s.magnitude}" else s! "{s.magnitude}" + +instance : ToString SignedInt where + toString := SignedInt.toString + +end SignedInt + -- @Lit.read h vc@ read the next signed numeral from @h@ with magnitude -- between 0 and vc and returns a literal for it. def SignedInt.read (h:HStream) (varCount: UInt64) : IO SignedInt := do @@ -31,9 +41,10 @@ def SignedInt.read (h:HStream) (varCount: UInt64) : IO SignedInt := do throw (IO.userError $ s! "Negated variable too large (idx = {w}, limit = {varCount})") pure ⟨w <<< 1 ||| 1⟩ else if b == '0'.toUInt8 then - let b ← h.peekByte - if !(b == ' '.toUInt8 || b == 10 || b == 13) then - throw (IO.userError $ s! "Expected whitespace or end-of-line.") + if !(← h.isEof) then + let b ← h.peekByte + if !(b == ' '.toUInt8 || b == 10 || b == 13) then + throw (IO.userError $ s! "Expected whitespace or end-of-line (found = {b}).") pure ⟨0⟩ else if '1'.toUInt8 ≤ b && b ≤ '9'.toUInt8 then let w ← h.getUInt64' (b - '0'.toUInt8).toUInt64 diff --git a/examples/handcrafted/lrat-fig1.dimacs b/examples/handcrafted/lrat-fig1.dimacs new file mode 100644 index 0000000..c98114e --- /dev/null +++ b/examples/handcrafted/lrat-fig1.dimacs @@ -0,0 +1,9 @@ +p cnf 4 8 + 1 2 -3 0 +-1 -2 3 0 + 2 3 -4 0 +-2 -3 4 0 +-1 -3 -4 0 + 1 3 4 0 +-1 2 4 0 + 1 -2 -4 0 \ No newline at end of file diff --git a/examples/handcrafted/lrat-fig1.lrat b/examples/handcrafted/lrat-fig1.lrat new file mode 100644 index 0000000..d137b76 --- /dev/null +++ b/examples/handcrafted/lrat-fig1.lrat @@ -0,0 +1,9 @@ + 9 1 2 0 1 6 3 0 + 9 d 1 0 +10 1 3 0 9 8 6 0 +10 d 6 0 +11 1 0 10 9 4 8 0 +11 d 10 9 8 0 +12 2 0 11 7 5 3 0 +12 d 7 3 0 +13 0 11 12 2 4 5 0 \ No newline at end of file diff --git a/leanpkg.toml b/leanpkg.toml index 91a6c1f..853ad7d 100644 --- a/leanpkg.toml +++ b/leanpkg.toml @@ -1,3 +1,4 @@ [package] name = "LRat" version = "0.1" +lean_version = "leanprover/lean4:nightly-2021-03-14"