@@ -5,8 +5,8 @@ module Numeric.Statistics.Inference.Bayes.Exact.VariableElimination where
5
5
-- import GHC.TypeNats
6
6
import Data.List (groupBy , sort , sortBy )
7
7
import Data.Ord (comparing )
8
- import Data.Foldable (foldlM , maximumBy )
9
- import Data.Monoid (Sum (.. ), Product (.. ))
8
+ import Data.Foldable (foldlM , maximumBy , minimumBy )
9
+ -- import Data.Monoid (Sum(..), Product(..))
10
10
11
11
-- algebraic-graphs
12
12
import Algebra.Graph (Graph (.. ), vertex , edge , overlay , connect )
@@ -17,106 +17,118 @@ import qualified Algebra.Graph.ToGraph as TG (ToGraph(..))
17
17
import qualified Data.Bimap as BM
18
18
-- containers
19
19
import qualified Data.IntMap as IM
20
- import qualified Data.Set as S
20
+ import qualified Data.Set as S (Set , empty , singleton , union , intersection , filter , toList )
21
+ import Data.Set ((\\) )
21
22
-- exceptions
22
23
import Control.Monad.Catch (MonadThrow (.. ))
23
24
-- massiv
24
25
-- import qualified Data.Massiv.Array as A (Array, all, Comp(..), makeArray, Construct(..), Sz(..))
25
26
-- import Data.Massiv.Array (Index, Ix1(..), D, (..:), ifoldlWithin', foldlWithin', Lower, Dim(..), Source)
26
27
-- mtl
27
- import Control.Monad.State (MonadState (.. ))
28
+ import Control.Monad.State (MonadState (.. ), gets )
28
29
-- permutation
29
30
-- import qualified Data.Permute as P (permute, next, elems)
30
31
-- transformers
31
- import Control.Monad.State (State (.. ), evalState )
32
- -- import Control.Monad.Trans.State (StateT(..), runStateT, evalStateT)
32
+ import Control.Monad.State (State (.. ), runState , evalState , execState )
33
+ import Control.Monad.Trans.State (StateT (.. ), runStateT , evalStateT )
33
34
-- vector
34
35
import qualified Data.Vector as V
35
36
36
- import Prelude hiding (lookup )
37
-
37
+ import Data.Graph.Examples (student )
38
38
import Data.Permutation (Permutation , permutation , getPermutation , permutations )
39
39
40
40
41
- student :: Graph Char
42
- student =
43
- connect c d `overlay`
44
- connect d g `overlay`
45
- connect g h `overlay`
46
- connect g l `overlay`
47
- connect l j `overlay`
48
- connect j h `overlay`
49
- connect s j `overlay`
50
- connect i g `overlay`
51
- connect i s
52
- where
53
- c = vertex ' c'
54
- d = vertex ' d'
55
- g = vertex ' g'
56
- h = vertex ' h'
57
- i = vertex ' i'
58
- j = vertex ' j'
59
- l = vertex ' l'
60
- s = vertex ' s'
41
+ import Prelude hiding (lookup )
61
42
62
- {- |
43
+ minimumMaxDegOrdering
44
+ :: (TG. ToGraph g , Ord (TG. ToVertex g )) =>
45
+ g -> TG. ToVertex g -> Permutation (TG. ToVertex g )
46
+ minimumMaxDegOrdering g v =
47
+ minimumBy (compareOrderings g) $ permutations vs where
48
+ vs = verticesWithout g $ S. singleton v
63
49
64
- λ> permutations <$> TG.topSort student
50
+ compareOrderings :: (TG. ToGraph g , Ord (TG. ToVertex g )) =>
51
+ g
52
+ -> Permutation (TG. ToVertex g )
53
+ -> Permutation (TG. ToVertex g )
54
+ -> Ordering
55
+ compareOrderings g vp1 vp2 =
56
+ compare (maxFS g $ getPermutation vp1) (maxFS g $ getPermutation vp2)
65
57
66
- Just ["iscdgljh","sicdgljh","csidgljh","dscigljh","gscdiljh","lscdgijh","jscdglih","hscdglji","icsdgljh","idcsgljh","igcdsljh","ilcdgsjh","ijcdglsh","ihcdgljs","isdcgljh","isgdcljh","isldgcjh","isjdglch","ishdgljc","iscgdljh","isclgdjh","iscjgldh","ischgljd","iscdlgjh","iscdjlgh","iscdhljg","iscdgjlh","iscdghjl","iscdglhj"]
67
- -}
58
+ maxFS :: (TG. ToGraph g , Foldable t , Ord (TG. ToVertex g )) =>
59
+ g -> t (TG. ToVertex g ) -> Int
60
+ maxFS g vs = maxFactorSize $ snd $ runLex (elims g vs)
61
+
62
+
63
+ -- | All vertices in the graph but a given subset
64
+ verticesWithout :: (TG. ToGraph g , Ord (TG. ToVertex g )) =>
65
+ g -> S. Set (TG. ToVertex g ) -> [TG. ToVertex g ]
66
+ verticesWithout g vs = S. toList $ TG. vertexSet g \\ vs
68
67
69
68
70
69
-- data SumProduct a =
71
70
-- SumOver a (Factor a)
72
71
-- | Product (S.Set (Factor a))
73
72
74
73
type VarId = Int
74
+ data Temp = Temp { varId :: VarId , maxFactorSize :: Int } deriving (Eq , Show )
75
75
-- newtype Lex m a = Lex { unLex :: StateT VarId m a } deriving (Functor, Applicative, Monad, MonadState VarId)
76
76
-- runLex :: Monad m => Lex m a -> m a
77
77
-- runLex lx = evalStateT (unLex lx) 0
78
- newtype Lex a = Lex { unLex :: State VarId a } deriving (Functor , Applicative , Monad , MonadState VarId )
79
- runLex :: Lex a -> a
80
- runLex lx = evalState (unLex lx) 0
78
+ newtype Lex a = Lex { unLex :: State Temp a } deriving (Functor , Applicative , Monad , MonadState Temp )
79
+
80
+ initTemp :: Temp
81
+ initTemp = Temp 0 0
81
82
82
- insert :: a -> IM. IntMap a -> Lex (IM. IntMap a )
83
- insert x mx = do
84
- k <- get
83
+ -- | Run a 'Lex' computation and return its result along with the final 'Temp' value
84
+ runLex :: Lex a -> (a , Temp )
85
+ runLex lx = runState (unLex lx) initTemp
86
+
87
+ -- | Insert a new factor into scope and update the maximum size of scopes seen so far
88
+ insertFactor :: Factor a -> IM. IntMap (Factor a ) -> Lex (IM. IntMap (Factor a ))
89
+ insertFactor x mx = do
90
+ Temp k mfs <- get
85
91
let mx' = IM. insert k x mx
86
- put $ succ k
92
+ sz = factorSize x
93
+ put $ Temp (succ k) (max sz mfs)
87
94
pure mx'
88
95
89
- fromList :: Foldable t => t a -> Lex (IM. IntMap a )
90
- fromList xs = foldlM (flip insert) IM. empty xs
96
+ factorSize :: Factor a -> Int
97
+ factorSize = length
98
+
99
+ fromList :: Foldable t => t (Factor a ) -> Lex (IM. IntMap (Factor a ))
100
+ fromList xs = foldlM (flip insertFactor) IM. empty xs
101
+
91
102
92
- -- | sequential sum-product elimination of graph factors
103
+
104
+ -- | sequential sum-product elimination of graph factors, given a vertex elimination order
93
105
--
94
- -- >>> elims student "cdihg"
95
- -- fromList [(5 ,{ 'j' 'l' 's' }),(12 ,{ 'j' 'l' 's' })]
106
+ -- >>> runLex $ elims student "cdihg"
107
+ -- ( fromList [(4 ,{ 'j' 'l' 's' }),(5 ,{ 'j' 'l' 's' })],Temp {varId = 5, maxFactorSize = 3})
96
108
elims :: (TG. ToGraph g , Foldable t , Ord (TG. ToVertex g )) =>
97
109
g
98
110
-> t (TG. ToVertex g )
99
- -> IM. IntMap (Factor (TG. ToVertex g ))
100
- elims g vs = runLex $ do
101
- im0 <- factorIM g
102
- foldlM (flip spe ) im0 vs
111
+ -> Lex ( IM. IntMap (Factor (TG. ToVertex g ) ))
112
+ elims g vs = do
113
+ let im0 = factorIM g
114
+ foldlM (flip sumProductElim ) im0 vs
103
115
104
116
factorIM :: (TG. ToGraph g , Ord (TG. ToVertex g )) =>
105
117
g
106
- -> Lex (IM. IntMap (Factor (TG. ToVertex g )))
107
- factorIM g = do
108
- im <- fromList $ TG. vertexList g
109
- pure $ IM. map (`moralFactor` g) im
118
+ -> IM. IntMap (Factor (TG. ToVertex g ))
119
+ factorIM g = IM. map (`moralFactor` g) im where
120
+ im = IM. fromList $ zip [0 .. ] (TG. vertexList g)
110
121
111
122
-- | Sum-product elimination
112
- spe :: (Ord a ) => a -> IM. IntMap (Factor a ) -> Lex (IM. IntMap (Factor a ))
113
- spe z pphi = insert tau pphiC where
114
- pphi' = factorsContaining z pphi
115
- pphiC = pphi `IM.difference` pphi'
116
- tau = eliminate z pphi'
123
+ sumProductElim :: (Ord a ) => a -> IM. IntMap (Factor a ) -> Lex (IM. IntMap (Factor a ))
124
+ sumProductElim z pphi = insertFactor tau pphiC
125
+ where
126
+ pphi' = factorsContaining z pphi
127
+ pphiC = pphi `IM.difference` pphi'
128
+ tau = eliminate z pphi'
117
129
118
- forgetIndices :: Ord a => IM. IntMap a -> S. Set a
119
- forgetIndices = S. fromList . map snd . IM. toList
130
+ -- forgetIndices :: Ord a => IM.IntMap a -> S.Set a
131
+ -- forgetIndices = S.fromList . map snd . IM.toList
120
132
121
133
-- | Factors containing a given variable
122
134
factorsContaining :: Ord a => a -> IM. IntMap (Factor a ) -> IM. IntMap (Factor a )
@@ -140,6 +152,10 @@ sumOver v f = Factor $ S.filter (/= v) $ scope f
140
152
hasInScope :: Ord a => a -> Factor a -> Bool
141
153
hasInScope v f = not $ null $ scope f `S.intersection` S. singleton v
142
154
155
+ factors :: (TG. ToGraph t , Ord (TG. ToVertex t )) =>
156
+ t -> [Factor (TG. ToVertex t )]
157
+ factors g = (`moralFactor` g) `map` TG. vertexList g
158
+
143
159
moralFactor :: (TG. ToGraph g , Ord (TG. ToVertex g )) =>
144
160
TG. ToVertex g -> g -> Factor (TG. ToVertex g )
145
161
moralFactor v g = Factor $ TG. preSet v g `S.union` S. singleton v
0 commit comments