Skip to content

Commit

Permalink
saving
Browse files Browse the repository at this point in the history
  • Loading branch information
ocramz committed Sep 2, 2019
2 parents 741f066 + 7741c9e commit b8e69ed
Show file tree
Hide file tree
Showing 4 changed files with 106 additions and 62 deletions.
2 changes: 1 addition & 1 deletion app/Main.hs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
{-# language OverloadedStrings #-}
module Main where

import Numeric.Statistics.Inference.Bayes.Exact.VariableElimination (student)
import Data.Graph.Examples (student)

import Algebra.Graph.Export (Doc, literal, render)
import Algebra.Graph.Export.Dot (Style(..), defaultStyle, export, Attribute(..))
Expand Down
5 changes: 3 additions & 2 deletions bayesian-inference.cabal
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ library
Numeric.Statistics.Utils
Numeric.Statistics.Sampling.MetropolisHastings
Numeric.Math
Data.Graph.Examples
Data.Permutation
other-modules: System.Random.MWC.Probability.Conditional
build-depends: base >= 4.7 && < 5
Expand All @@ -38,8 +39,8 @@ library
, mtl
, mwc-probability
, mwc-probability-transition
-- , permutation
, primitive
, text
, transformers
, vector
-- DEBUG
Expand Down Expand Up @@ -74,7 +75,7 @@ test-suite doctest
type: exitcode-stdio-1.0
hs-source-dirs: test
main-is: DocTest.hs
other-modules: Bayes.Exact.VariableElimination
-- other-modules: Data.Permutation
build-depends: base
, bayesian-inference
, doctest
Expand Down
27 changes: 27 additions & 0 deletions src/Data/Graph/Examples.hs
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
module Data.Graph.Examples where

-- algebraic-graphs
import Algebra.Graph (Graph(..), vertex, overlay, connect)
-- import qualified Algebra.Graph.Class as GC (Graph(..))
-- import qualified Algebra.Graph.ToGraph as TG (ToGraph(..))

student :: Graph Char
student =
connect c d `overlay`
connect d g `overlay`
connect g h `overlay`
connect g l `overlay`
connect l j `overlay`
connect j h `overlay`
connect s j `overlay`
connect i g `overlay`
connect i s
where
c = vertex 'c'
d = vertex 'd'
g = vertex 'g'
h = vertex 'h'
i = vertex 'i'
j = vertex 'j'
l = vertex 'l'
s = vertex 's'
134 changes: 75 additions & 59 deletions src/Numeric/Statistics/Inference/Bayes/Exact/VariableElimination.hs
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@ module Numeric.Statistics.Inference.Bayes.Exact.VariableElimination where
-- import GHC.TypeNats
import Data.List (groupBy, sort, sortBy)
import Data.Ord (comparing)
import Data.Foldable (foldlM, maximumBy)
import Data.Monoid (Sum(..), Product(..))
import Data.Foldable (foldlM, maximumBy, minimumBy)
-- import Data.Monoid (Sum(..), Product(..))

-- algebraic-graphs
import Algebra.Graph (Graph(..), vertex, edge, overlay, connect)
Expand All @@ -17,106 +17,118 @@ import qualified Algebra.Graph.ToGraph as TG (ToGraph(..))
import qualified Data.Bimap as BM
-- containers
import qualified Data.IntMap as IM
import qualified Data.Set as S
import qualified Data.Set as S (Set, empty, singleton, union, intersection, filter, toList)
import Data.Set ((\\))
-- exceptions
import Control.Monad.Catch (MonadThrow(..))
-- massiv
-- import qualified Data.Massiv.Array as A (Array, all, Comp(..), makeArray, Construct(..), Sz(..))
-- import Data.Massiv.Array (Index, Ix1(..), D, (..:), ifoldlWithin', foldlWithin', Lower, Dim(..), Source)
-- mtl
import Control.Monad.State (MonadState(..))
import Control.Monad.State (MonadState(..), gets)
-- permutation
-- import qualified Data.Permute as P (permute, next, elems)
-- transformers
import Control.Monad.State (State(..), evalState)
-- import Control.Monad.Trans.State (StateT(..), runStateT, evalStateT)
import Control.Monad.State (State(..), runState, evalState, execState)
import Control.Monad.Trans.State (StateT(..), runStateT, evalStateT)
-- vector
import qualified Data.Vector as V

import Prelude hiding (lookup)

import Data.Graph.Examples (student)
import Data.Permutation (Permutation, permutation, getPermutation, permutations)


student :: Graph Char
student =
connect c d `overlay`
connect d g `overlay`
connect g h `overlay`
connect g l `overlay`
connect l j `overlay`
connect j h `overlay`
connect s j `overlay`
connect i g `overlay`
connect i s
where
c = vertex 'c'
d = vertex 'd'
g = vertex 'g'
h = vertex 'h'
i = vertex 'i'
j = vertex 'j'
l = vertex 'l'
s = vertex 's'
import Prelude hiding (lookup)

{- |
minimumMaxDegOrdering
:: (TG.ToGraph g, Ord (TG.ToVertex g)) =>
g -> TG.ToVertex g -> Permutation (TG.ToVertex g)
minimumMaxDegOrdering g v =
minimumBy (compareOrderings g) $ permutations vs where
vs = verticesWithout g $ S.singleton v

λ> permutations <$> TG.topSort student
compareOrderings :: (TG.ToGraph g, Ord (TG.ToVertex g)) =>
g
-> Permutation (TG.ToVertex g)
-> Permutation (TG.ToVertex g)
-> Ordering
compareOrderings g vp1 vp2 =
compare (maxFS g $ getPermutation vp1) (maxFS g $ getPermutation vp2)

Just ["iscdgljh","sicdgljh","csidgljh","dscigljh","gscdiljh","lscdgijh","jscdglih","hscdglji","icsdgljh","idcsgljh","igcdsljh","ilcdgsjh","ijcdglsh","ihcdgljs","isdcgljh","isgdcljh","isldgcjh","isjdglch","ishdgljc","iscgdljh","isclgdjh","iscjgldh","ischgljd","iscdlgjh","iscdjlgh","iscdhljg","iscdgjlh","iscdghjl","iscdglhj"]
-}
maxFS :: (TG.ToGraph g, Foldable t, Ord (TG.ToVertex g)) =>
g -> t (TG.ToVertex g) -> Int
maxFS g vs = maxFactorSize $ snd $ runLex (elims g vs)


-- | All vertices in the graph but a given subset
verticesWithout :: (TG.ToGraph g, Ord (TG.ToVertex g)) =>
g -> S.Set (TG.ToVertex g) -> [TG.ToVertex g]
verticesWithout g vs = S.toList $ TG.vertexSet g \\ vs


-- data SumProduct a =
-- SumOver a (Factor a)
-- | Product (S.Set (Factor a))

type VarId = Int
data Temp = Temp { varId :: VarId, maxFactorSize :: Int } deriving (Eq, Show)
-- newtype Lex m a = Lex { unLex :: StateT VarId m a } deriving (Functor, Applicative, Monad, MonadState VarId)
-- runLex :: Monad m => Lex m a -> m a
-- runLex lx = evalStateT (unLex lx) 0
newtype Lex a = Lex { unLex :: State VarId a } deriving (Functor, Applicative, Monad, MonadState VarId)
runLex :: Lex a -> a
runLex lx = evalState (unLex lx) 0
newtype Lex a = Lex { unLex :: State Temp a } deriving (Functor, Applicative, Monad, MonadState Temp)

initTemp :: Temp
initTemp = Temp 0 0

insert :: a -> IM.IntMap a -> Lex (IM.IntMap a)
insert x mx = do
k <- get
-- | Run a 'Lex' computation and return its result along with the final 'Temp' value
runLex :: Lex a -> (a, Temp)
runLex lx = runState (unLex lx) initTemp

-- | Insert a new factor into scope and update the maximum size of scopes seen so far
insertFactor :: Factor a -> IM.IntMap (Factor a) -> Lex (IM.IntMap (Factor a))
insertFactor x mx = do
Temp k mfs <- get
let mx' = IM.insert k x mx
put $ succ k
sz = factorSize x
put $ Temp (succ k) (max sz mfs)
pure mx'

fromList :: Foldable t => t a -> Lex (IM.IntMap a)
fromList xs = foldlM (flip insert) IM.empty xs
factorSize :: Factor a -> Int
factorSize = length

fromList :: Foldable t => t (Factor a) -> Lex (IM.IntMap (Factor a))
fromList xs = foldlM (flip insertFactor) IM.empty xs


-- | sequential sum-product elimination of graph factors

-- | sequential sum-product elimination of graph factors, given a vertex elimination order
--
-- >>> elims student "cdihg"
-- fromList [(5,{ 'j' 'l' 's' }),(12,{ 'j' 'l' 's' })]
-- >>> runLex $ elims student "cdihg"
-- (fromList [(4,{ 'j' 'l' 's' }),(5,{ 'j' 'l' 's' })],Temp {varId = 5, maxFactorSize = 3})
elims :: (TG.ToGraph g, Foldable t, Ord (TG.ToVertex g)) =>
g
-> t (TG.ToVertex g)
-> IM.IntMap (Factor (TG.ToVertex g))
elims g vs = runLex $ do
im0 <- factorIM g
foldlM (flip spe) im0 vs
-> Lex (IM.IntMap (Factor (TG.ToVertex g)))
elims g vs = do
let im0 = factorIM g
foldlM (flip sumProductElim) im0 vs

factorIM :: (TG.ToGraph g, Ord (TG.ToVertex g)) =>
g
-> Lex (IM.IntMap (Factor (TG.ToVertex g)))
factorIM g = do
im <- fromList $ TG.vertexList g
pure $ IM.map (`moralFactor` g) im
-> IM.IntMap (Factor (TG.ToVertex g))
factorIM g = IM.map (`moralFactor` g) im where
im = IM.fromList $ zip [0..] (TG.vertexList g)

-- | Sum-product elimination
spe :: (Ord a) => a -> IM.IntMap (Factor a) -> Lex (IM.IntMap (Factor a))
spe z pphi = insert tau pphiC where
pphi' = factorsContaining z pphi
pphiC = pphi `IM.difference` pphi'
tau = eliminate z pphi'
sumProductElim :: (Ord a) => a -> IM.IntMap (Factor a) -> Lex (IM.IntMap (Factor a))
sumProductElim z pphi = insertFactor tau pphiC
where
pphi' = factorsContaining z pphi
pphiC = pphi `IM.difference` pphi'
tau = eliminate z pphi'

forgetIndices :: Ord a => IM.IntMap a -> S.Set a
forgetIndices = S.fromList . map snd . IM.toList
-- forgetIndices :: Ord a => IM.IntMap a -> S.Set a
-- forgetIndices = S.fromList . map snd . IM.toList

-- | Factors containing a given variable
factorsContaining :: Ord a => a -> IM.IntMap (Factor a) -> IM.IntMap (Factor a)
Expand All @@ -140,6 +152,10 @@ sumOver v f = Factor $ S.filter (/= v) $ scope f
hasInScope :: Ord a => a -> Factor a -> Bool
hasInScope v f = not $ null $ scope f `S.intersection` S.singleton v

factors :: (TG.ToGraph t, Ord (TG.ToVertex t)) =>
t -> [Factor (TG.ToVertex t)]
factors g = (`moralFactor` g) `map` TG.vertexList g

moralFactor :: (TG.ToGraph g, Ord (TG.ToVertex g)) =>
TG.ToVertex g -> g -> Factor (TG.ToVertex g)
moralFactor v g = Factor $ TG.preSet v g `S.union` S.singleton v
Expand Down

0 comments on commit b8e69ed

Please sign in to comment.