Skip to content

Commit b8e69ed

Browse files
committed
saving
2 parents 741f066 + 7741c9e commit b8e69ed

File tree

4 files changed

+106
-62
lines changed

4 files changed

+106
-62
lines changed

app/Main.hs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
{-# language OverloadedStrings #-}
22
module Main where
33

4-
import Numeric.Statistics.Inference.Bayes.Exact.VariableElimination (student)
4+
import Data.Graph.Examples (student)
55

66
import Algebra.Graph.Export (Doc, literal, render)
77
import Algebra.Graph.Export.Dot (Style(..), defaultStyle, export, Attribute(..))

bayesian-inference.cabal

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ library
2525
Numeric.Statistics.Utils
2626
Numeric.Statistics.Sampling.MetropolisHastings
2727
Numeric.Math
28+
Data.Graph.Examples
2829
Data.Permutation
2930
other-modules: System.Random.MWC.Probability.Conditional
3031
build-depends: base >= 4.7 && < 5
@@ -38,8 +39,8 @@ library
3839
, mtl
3940
, mwc-probability
4041
, mwc-probability-transition
41-
-- , permutation
4242
, primitive
43+
, text
4344
, transformers
4445
, vector
4546
-- DEBUG
@@ -74,7 +75,7 @@ test-suite doctest
7475
type: exitcode-stdio-1.0
7576
hs-source-dirs: test
7677
main-is: DocTest.hs
77-
other-modules: Bayes.Exact.VariableElimination
78+
-- other-modules: Data.Permutation
7879
build-depends: base
7980
, bayesian-inference
8081
, doctest

src/Data/Graph/Examples.hs

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
module Data.Graph.Examples where
2+
3+
-- algebraic-graphs
4+
import Algebra.Graph (Graph(..), vertex, overlay, connect)
5+
-- import qualified Algebra.Graph.Class as GC (Graph(..))
6+
-- import qualified Algebra.Graph.ToGraph as TG (ToGraph(..))
7+
8+
student :: Graph Char
9+
student =
10+
connect c d `overlay`
11+
connect d g `overlay`
12+
connect g h `overlay`
13+
connect g l `overlay`
14+
connect l j `overlay`
15+
connect j h `overlay`
16+
connect s j `overlay`
17+
connect i g `overlay`
18+
connect i s
19+
where
20+
c = vertex 'c'
21+
d = vertex 'd'
22+
g = vertex 'g'
23+
h = vertex 'h'
24+
i = vertex 'i'
25+
j = vertex 'j'
26+
l = vertex 'l'
27+
s = vertex 's'

src/Numeric/Statistics/Inference/Bayes/Exact/VariableElimination.hs

Lines changed: 75 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@ module Numeric.Statistics.Inference.Bayes.Exact.VariableElimination where
55
-- import GHC.TypeNats
66
import Data.List (groupBy, sort, sortBy)
77
import Data.Ord (comparing)
8-
import Data.Foldable (foldlM, maximumBy)
9-
import Data.Monoid (Sum(..), Product(..))
8+
import Data.Foldable (foldlM, maximumBy, minimumBy)
9+
-- import Data.Monoid (Sum(..), Product(..))
1010

1111
-- algebraic-graphs
1212
import Algebra.Graph (Graph(..), vertex, edge, overlay, connect)
@@ -17,106 +17,118 @@ import qualified Algebra.Graph.ToGraph as TG (ToGraph(..))
1717
import qualified Data.Bimap as BM
1818
-- containers
1919
import qualified Data.IntMap as IM
20-
import qualified Data.Set as S
20+
import qualified Data.Set as S (Set, empty, singleton, union, intersection, filter, toList)
21+
import Data.Set ((\\))
2122
-- exceptions
2223
import Control.Monad.Catch (MonadThrow(..))
2324
-- massiv
2425
-- import qualified Data.Massiv.Array as A (Array, all, Comp(..), makeArray, Construct(..), Sz(..))
2526
-- import Data.Massiv.Array (Index, Ix1(..), D, (..:), ifoldlWithin', foldlWithin', Lower, Dim(..), Source)
2627
-- mtl
27-
import Control.Monad.State (MonadState(..))
28+
import Control.Monad.State (MonadState(..), gets)
2829
-- permutation
2930
-- import qualified Data.Permute as P (permute, next, elems)
3031
-- transformers
31-
import Control.Monad.State (State(..), evalState)
32-
-- import Control.Monad.Trans.State (StateT(..), runStateT, evalStateT)
32+
import Control.Monad.State (State(..), runState, evalState, execState)
33+
import Control.Monad.Trans.State (StateT(..), runStateT, evalStateT)
3334
-- vector
3435
import qualified Data.Vector as V
3536

36-
import Prelude hiding (lookup)
37-
37+
import Data.Graph.Examples (student)
3838
import Data.Permutation (Permutation, permutation, getPermutation, permutations)
3939

4040

41-
student :: Graph Char
42-
student =
43-
connect c d `overlay`
44-
connect d g `overlay`
45-
connect g h `overlay`
46-
connect g l `overlay`
47-
connect l j `overlay`
48-
connect j h `overlay`
49-
connect s j `overlay`
50-
connect i g `overlay`
51-
connect i s
52-
where
53-
c = vertex 'c'
54-
d = vertex 'd'
55-
g = vertex 'g'
56-
h = vertex 'h'
57-
i = vertex 'i'
58-
j = vertex 'j'
59-
l = vertex 'l'
60-
s = vertex 's'
41+
import Prelude hiding (lookup)
6142

62-
{- |
43+
minimumMaxDegOrdering
44+
:: (TG.ToGraph g, Ord (TG.ToVertex g)) =>
45+
g -> TG.ToVertex g -> Permutation (TG.ToVertex g)
46+
minimumMaxDegOrdering g v =
47+
minimumBy (compareOrderings g) $ permutations vs where
48+
vs = verticesWithout g $ S.singleton v
6349

64-
λ> permutations <$> TG.topSort student
50+
compareOrderings :: (TG.ToGraph g, Ord (TG.ToVertex g)) =>
51+
g
52+
-> Permutation (TG.ToVertex g)
53+
-> Permutation (TG.ToVertex g)
54+
-> Ordering
55+
compareOrderings g vp1 vp2 =
56+
compare (maxFS g $ getPermutation vp1) (maxFS g $ getPermutation vp2)
6557

66-
Just ["iscdgljh","sicdgljh","csidgljh","dscigljh","gscdiljh","lscdgijh","jscdglih","hscdglji","icsdgljh","idcsgljh","igcdsljh","ilcdgsjh","ijcdglsh","ihcdgljs","isdcgljh","isgdcljh","isldgcjh","isjdglch","ishdgljc","iscgdljh","isclgdjh","iscjgldh","ischgljd","iscdlgjh","iscdjlgh","iscdhljg","iscdgjlh","iscdghjl","iscdglhj"]
67-
-}
58+
maxFS :: (TG.ToGraph g, Foldable t, Ord (TG.ToVertex g)) =>
59+
g -> t (TG.ToVertex g) -> Int
60+
maxFS g vs = maxFactorSize $ snd $ runLex (elims g vs)
61+
62+
63+
-- | All vertices in the graph but a given subset
64+
verticesWithout :: (TG.ToGraph g, Ord (TG.ToVertex g)) =>
65+
g -> S.Set (TG.ToVertex g) -> [TG.ToVertex g]
66+
verticesWithout g vs = S.toList $ TG.vertexSet g \\ vs
6867

6968

7069
-- data SumProduct a =
7170
-- SumOver a (Factor a)
7271
-- | Product (S.Set (Factor a))
7372

7473
type VarId = Int
74+
data Temp = Temp { varId :: VarId, maxFactorSize :: Int } deriving (Eq, Show)
7575
-- newtype Lex m a = Lex { unLex :: StateT VarId m a } deriving (Functor, Applicative, Monad, MonadState VarId)
7676
-- runLex :: Monad m => Lex m a -> m a
7777
-- runLex lx = evalStateT (unLex lx) 0
78-
newtype Lex a = Lex { unLex :: State VarId a } deriving (Functor, Applicative, Monad, MonadState VarId)
79-
runLex :: Lex a -> a
80-
runLex lx = evalState (unLex lx) 0
78+
newtype Lex a = Lex { unLex :: State Temp a } deriving (Functor, Applicative, Monad, MonadState Temp)
79+
80+
initTemp :: Temp
81+
initTemp = Temp 0 0
8182

82-
insert :: a -> IM.IntMap a -> Lex (IM.IntMap a)
83-
insert x mx = do
84-
k <- get
83+
-- | Run a 'Lex' computation and return its result along with the final 'Temp' value
84+
runLex :: Lex a -> (a, Temp)
85+
runLex lx = runState (unLex lx) initTemp
86+
87+
-- | Insert a new factor into scope and update the maximum size of scopes seen so far
88+
insertFactor :: Factor a -> IM.IntMap (Factor a) -> Lex (IM.IntMap (Factor a))
89+
insertFactor x mx = do
90+
Temp k mfs <- get
8591
let mx' = IM.insert k x mx
86-
put $ succ k
92+
sz = factorSize x
93+
put $ Temp (succ k) (max sz mfs)
8794
pure mx'
8895

89-
fromList :: Foldable t => t a -> Lex (IM.IntMap a)
90-
fromList xs = foldlM (flip insert) IM.empty xs
96+
factorSize :: Factor a -> Int
97+
factorSize = length
98+
99+
fromList :: Foldable t => t (Factor a) -> Lex (IM.IntMap (Factor a))
100+
fromList xs = foldlM (flip insertFactor) IM.empty xs
101+
91102

92-
-- | sequential sum-product elimination of graph factors
103+
104+
-- | sequential sum-product elimination of graph factors, given a vertex elimination order
93105
--
94-
-- >>> elims student "cdihg"
95-
-- fromList [(5,{ 'j' 'l' 's' }),(12,{ 'j' 'l' 's' })]
106+
-- >>> runLex $ elims student "cdihg"
107+
-- (fromList [(4,{ 'j' 'l' 's' }),(5,{ 'j' 'l' 's' })],Temp {varId = 5, maxFactorSize = 3})
96108
elims :: (TG.ToGraph g, Foldable t, Ord (TG.ToVertex g)) =>
97109
g
98110
-> t (TG.ToVertex g)
99-
-> IM.IntMap (Factor (TG.ToVertex g))
100-
elims g vs = runLex $ do
101-
im0 <- factorIM g
102-
foldlM (flip spe) im0 vs
111+
-> Lex (IM.IntMap (Factor (TG.ToVertex g)))
112+
elims g vs = do
113+
let im0 = factorIM g
114+
foldlM (flip sumProductElim) im0 vs
103115

104116
factorIM :: (TG.ToGraph g, Ord (TG.ToVertex g)) =>
105117
g
106-
-> Lex (IM.IntMap (Factor (TG.ToVertex g)))
107-
factorIM g = do
108-
im <- fromList $ TG.vertexList g
109-
pure $ IM.map (`moralFactor` g) im
118+
-> IM.IntMap (Factor (TG.ToVertex g))
119+
factorIM g = IM.map (`moralFactor` g) im where
120+
im = IM.fromList $ zip [0..] (TG.vertexList g)
110121

111122
-- | Sum-product elimination
112-
spe :: (Ord a) => a -> IM.IntMap (Factor a) -> Lex (IM.IntMap (Factor a))
113-
spe z pphi = insert tau pphiC where
114-
pphi' = factorsContaining z pphi
115-
pphiC = pphi `IM.difference` pphi'
116-
tau = eliminate z pphi'
123+
sumProductElim :: (Ord a) => a -> IM.IntMap (Factor a) -> Lex (IM.IntMap (Factor a))
124+
sumProductElim z pphi = insertFactor tau pphiC
125+
where
126+
pphi' = factorsContaining z pphi
127+
pphiC = pphi `IM.difference` pphi'
128+
tau = eliminate z pphi'
117129

118-
forgetIndices :: Ord a => IM.IntMap a -> S.Set a
119-
forgetIndices = S.fromList . map snd . IM.toList
130+
-- forgetIndices :: Ord a => IM.IntMap a -> S.Set a
131+
-- forgetIndices = S.fromList . map snd . IM.toList
120132

121133
-- | Factors containing a given variable
122134
factorsContaining :: Ord a => a -> IM.IntMap (Factor a) -> IM.IntMap (Factor a)
@@ -140,6 +152,10 @@ sumOver v f = Factor $ S.filter (/= v) $ scope f
140152
hasInScope :: Ord a => a -> Factor a -> Bool
141153
hasInScope v f = not $ null $ scope f `S.intersection` S.singleton v
142154

155+
factors :: (TG.ToGraph t, Ord (TG.ToVertex t)) =>
156+
t -> [Factor (TG.ToVertex t)]
157+
factors g = (`moralFactor` g) `map` TG.vertexList g
158+
143159
moralFactor :: (TG.ToGraph g, Ord (TG.ToVertex g)) =>
144160
TG.ToVertex g -> g -> Factor (TG.ToVertex g)
145161
moralFactor v g = Factor $ TG.preSet v g `S.union` S.singleton v

0 commit comments

Comments
 (0)