@@ -3,6 +3,8 @@ use std::sync::Arc;
3
3
use itertools:: Itertools ;
4
4
use thiserror:: Error ;
5
5
6
+ use crate :: sasca:: factor_graph:: GenFactorOperand ;
7
+
6
8
use super :: factor_graph as fg;
7
9
use super :: factor_graph:: {
8
10
EdgeId , EdgeSlice , EdgeVec , ExprFactor , Factor , FactorGraph , FactorId , FactorKind , FactorVec ,
@@ -298,7 +300,11 @@ impl BPState {
298
300
prop_factor ! ( factor_gen_and, & self . pub_reduced[ factor_id] )
299
301
}
300
302
ExprFactor :: XOR => prop_factor ! ( factor_xor, & self . pub_reduced[ factor_id] ) ,
301
- ExprFactor :: NOT => prop_factor ! ( factor_not, ( self . graph. nc - 1 ) as u32 ) ,
303
+ ExprFactor :: NOT => prop_factor ! (
304
+ factor_not,
305
+ & self . pub_reduced[ factor_id] ,
306
+ ( self . graph. nc - 1 ) as u32
307
+ ) ,
302
308
ExprFactor :: ADD { .. } => {
303
309
prop_factor ! ( factor_add, & self . pub_reduced[ factor_id] , & self . plans)
304
310
}
@@ -638,14 +644,15 @@ fn factor_not<'a>(
638
644
belief_from_var : & ' a mut EdgeSlice < Distribution > ,
639
645
dest : & ' a [ VarId ] ,
640
646
clear_incoming : bool ,
647
+ pub_reduced : & PublicValue ,
641
648
inv_cst : u32 ,
642
649
) -> impl Iterator < Item = Distribution > + ' a {
643
650
factor_xor (
644
651
factor,
645
652
belief_from_var,
646
653
dest,
647
654
clear_incoming,
648
- & PublicValue :: Single ( inv_cst) ,
655
+ & pub_reduced . map ( |x| x ^ inv_cst) ,
649
656
)
650
657
}
651
658
@@ -910,9 +917,11 @@ fn factor_gen_factor<'a>(
910
917
} ;
911
918
let res: Vec < Distribution > = dest. iter ( ) . map ( |dest| {
912
919
let dest_idx = factor. edges . get_index_of ( dest) . unwrap ( ) ;
920
+ let op_dest_idx = operands. iter ( ) . position ( |op| if let GenFactorOperand :: Var { factor_edge_id, .. } = op { * factor_edge_id == dest_idx } else { false } ) . expect ( "must have dest operand" ) ;
913
921
let mut distr = belief_from_var[ factor. edges [ dest_idx] ] . clone ( ) ;
914
922
distr. ensure_full ( ) ;
915
- for i in 0 ..nmulti {
923
+ let nmulti_actual = if factor. multi { nmulti } else { 1 } ;
924
+ for i in 0 ..nmulti_actual {
916
925
let gen_factor = match gen_factor {
917
926
GenFactor :: Single ( x) => x,
918
927
GenFactor :: Multi ( x) => & x[ i] ,
@@ -922,10 +931,10 @@ fn factor_gen_factor<'a>(
922
931
assert_eq ! ( gen_factor. shape( ) . len( ) , operands. len( ) ) ;
923
932
// First slice the array with the constants.
924
933
let gen_factor = gen_factor. slice_each_axis ( |ax| match operands[ ax. axis . index ( ) ] {
925
- fg:: GenFactorOperand :: Var ( _ , _ ) => ndarray:: Slice :: new ( 0 , None , 1 ) ,
926
- fg:: GenFactorOperand :: Pub ( pub_idx ) => {
927
- let mut pub_val = public_values[ factor. publics [ pub_idx ] . 0 ] . get ( i) as isize ;
928
- if factor. publics [ pub_idx ] . 1 {
934
+ fg:: GenFactorOperand :: Var { .. } => ndarray:: Slice :: new ( 0 , None , 1 ) ,
935
+ fg:: GenFactorOperand :: Pub { pub_id } => {
936
+ let mut pub_val = public_values[ factor. publics [ pub_id ] . 0 ] . get ( i) as isize ;
937
+ if factor. publics [ pub_id ] . 1 {
929
938
if nc. is_power_of_two ( ) {
930
939
pub_val = !pub_val;
931
940
} else {
@@ -938,12 +947,12 @@ fn factor_gen_factor<'a>(
938
947
} ) ;
939
948
let mut gen_factor = gen_factor. to_owned ( ) ;
940
949
for ( op_idx, op) in operands. iter ( ) . enumerate ( ) {
941
- if op_idx != dest_idx {
942
- if let fg :: GenFactorOperand :: Var ( var_idx , neg ) = op {
943
- if * neg {
950
+ if let fg :: GenFactorOperand :: Var { factor_edge_id , negated } = op {
951
+ if * factor_edge_id != dest_idx {
952
+ if * negated {
944
953
todo ! ( "Negated operands on generalized factors not yet implemented." ) ;
945
954
}
946
- let distr = & belief_from_var[ factor. edges [ * var_idx ] ] ;
955
+ let distr = & belief_from_var[ factor. edges [ * factor_edge_id ] ] ;
947
956
let mut new_gen_factor: ndarray:: ArrayD < f64 > = ndarray:: ArrayD :: zeros ( gen_factor. slice_axis ( ndarray:: Axis ( op_idx) , ndarray:: Slice :: new ( 0 , Some ( 1 ) , 1 ) ) . shape ( ) ) ;
948
957
if let Some ( distr) = distr. value ( ) {
949
958
for ( d, gf) in distr. slice ( s ! [ i, ..] ) . iter ( ) . zip ( gen_factor. axis_chunks_iter ( ndarray:: Axis ( op_idx) , 1 ) ) {
@@ -959,10 +968,10 @@ fn factor_gen_factor<'a>(
959
968
}
960
969
}
961
970
// Drop useless axes.
962
- for _ in 0 ..dest_idx {
971
+ for _ in 0 ..op_dest_idx {
963
972
gen_factor. index_axis_inplace ( ndarray:: Axis ( 0 ) , 0 ) ;
964
973
}
965
- for _ in ( dest_idx +1 ) ..operands. len ( ) {
974
+ for _ in ( op_dest_idx +1 ) ..operands. len ( ) {
966
975
gen_factor. index_axis_inplace ( ndarray:: Axis ( 1 ) , 0 ) ;
967
976
}
968
977
distr. value_mut ( ) . unwrap ( ) . slice_mut ( s ! [ i, ..] ) . assign ( & gen_factor) ;
@@ -974,43 +983,43 @@ fn factor_gen_factor<'a>(
974
983
dest. fill ( 0.0 ) ;
975
984
for op_values in gen_factor. outer_iter ( ) {
976
985
let mut res = 1.0 ;
977
- for ( op_idx , ( op, val) ) in operands. iter ( ) . zip ( op_values. iter ( ) ) . enumerate ( ) {
978
- if op_idx != dest_idx {
979
- match op {
980
- fg :: GenFactorOperand :: Var ( var_idx , neg ) => {
986
+ for ( op, val) in operands. iter ( ) . zip ( op_values. iter ( ) ) {
987
+ match op {
988
+ fg :: GenFactorOperand :: Var { factor_edge_id , negated } => {
989
+ if * factor_edge_id != dest_idx {
981
990
let mut val = * val;
982
- if * neg {
991
+ if * negated {
983
992
if nc. is_power_of_two ( ) {
984
993
val = !val & ( ( nc - 1 ) as ClassVal ) ;
985
994
} else {
986
995
// TODO Check that we enforce this at graph creation time and return a proper error.
987
996
panic ! ( "Cannot negate operands with non-power-of-two number of classes." ) ;
988
997
}
989
998
}
990
- let distr = & belief_from_var[ factor. edges [ * var_idx ] ] ;
999
+ let distr = & belief_from_var[ factor. edges [ * factor_edge_id ] ] ;
991
1000
// For uniform, we implicitly multiply by 1.0
992
1001
if let Some ( distr) = distr. value ( ) {
993
1002
res *= distr[ ( i, val as usize ) ] ;
994
1003
}
995
1004
}
996
- fg:: GenFactorOperand :: Pub ( pub_idx) => {
997
- let mut pub_val = public_values[ factor. publics [ * pub_idx] . 0 ] . get ( i) ;
998
- if factor. publics [ * pub_idx] . 1 {
999
- if nc. is_power_of_two ( ) {
1000
- pub_val = !pub_val & ( ( nc - 1 ) as ClassVal ) ;
1001
- } else {
1002
- // TODO Check that we enforce this at graph creation time and return a proper error.
1003
- panic ! ( "Cannot negate operands with non-power-of-two number of classes." ) ;
1004
- }
1005
- }
1006
- if pub_val != * val {
1007
- res = 0.0 ;
1005
+ }
1006
+ fg:: GenFactorOperand :: Pub { pub_id} => {
1007
+ let mut pub_val = public_values[ factor. publics [ * pub_id] . 0 ] . get ( i) ;
1008
+ if factor. publics [ * pub_id] . 1 {
1009
+ if nc. is_power_of_two ( ) {
1010
+ pub_val = !pub_val & ( ( nc - 1 ) as ClassVal ) ;
1011
+ } else {
1012
+ // TODO Check that we enforce this at graph creation time and return a proper error.
1013
+ panic ! ( "Cannot negate operands with non-power-of-two number of classes." ) ;
1008
1014
}
1009
1015
}
1016
+ if pub_val != * val {
1017
+ res = 0.0 ;
1018
+ }
1010
1019
}
1011
1020
}
1012
1021
}
1013
- dest[ op_values[ dest_idx ] as usize ] += res;
1022
+ dest[ op_values[ op_dest_idx ] as usize ] += res;
1014
1023
}
1015
1024
}
1016
1025
}
0 commit comments