Skip to content

Commit f2ffe70

Browse files
authored
Merge pull request #173 from simple-crypto/fix_sasca
Fix misc sasca bugs with PUB and/or generic factors
2 parents 57895bb + 3b961ad commit f2ffe70

File tree

4 files changed

+144
-51
lines changed

4 files changed

+144
-51
lines changed

src/scalib_ext/scalib/src/sasca/belief_propagation.rs

+41-32
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@ use std::sync::Arc;
33
use itertools::Itertools;
44
use thiserror::Error;
55

6+
use crate::sasca::factor_graph::GenFactorOperand;
7+
68
use super::factor_graph as fg;
79
use super::factor_graph::{
810
EdgeId, EdgeSlice, EdgeVec, ExprFactor, Factor, FactorGraph, FactorId, FactorKind, FactorVec,
@@ -298,7 +300,11 @@ impl BPState {
298300
prop_factor!(factor_gen_and, &self.pub_reduced[factor_id])
299301
}
300302
ExprFactor::XOR => prop_factor!(factor_xor, &self.pub_reduced[factor_id]),
301-
ExprFactor::NOT => prop_factor!(factor_not, (self.graph.nc - 1) as u32),
303+
ExprFactor::NOT => prop_factor!(
304+
factor_not,
305+
&self.pub_reduced[factor_id],
306+
(self.graph.nc - 1) as u32
307+
),
302308
ExprFactor::ADD { .. } => {
303309
prop_factor!(factor_add, &self.pub_reduced[factor_id], &self.plans)
304310
}
@@ -638,14 +644,15 @@ fn factor_not<'a>(
638644
belief_from_var: &'a mut EdgeSlice<Distribution>,
639645
dest: &'a [VarId],
640646
clear_incoming: bool,
647+
pub_reduced: &PublicValue,
641648
inv_cst: u32,
642649
) -> impl Iterator<Item = Distribution> + 'a {
643650
factor_xor(
644651
factor,
645652
belief_from_var,
646653
dest,
647654
clear_incoming,
648-
&PublicValue::Single(inv_cst),
655+
&pub_reduced.map(|x| x ^ inv_cst),
649656
)
650657
}
651658

@@ -910,9 +917,11 @@ fn factor_gen_factor<'a>(
910917
};
911918
let res: Vec<Distribution> = dest.iter().map(|dest| {
912919
let dest_idx = factor.edges.get_index_of(dest).unwrap();
920+
let op_dest_idx = operands.iter().position(|op| if let GenFactorOperand::Var { factor_edge_id, .. } = op { *factor_edge_id == dest_idx } else { false }).expect("must have dest operand");
913921
let mut distr = belief_from_var[factor.edges[dest_idx]].clone();
914922
distr.ensure_full();
915-
for i in 0..nmulti {
923+
let nmulti_actual = if factor.multi { nmulti } else { 1 };
924+
for i in 0..nmulti_actual {
916925
let gen_factor = match gen_factor {
917926
GenFactor::Single(x) => x,
918927
GenFactor::Multi(x) => &x[i],
@@ -922,10 +931,10 @@ fn factor_gen_factor<'a>(
922931
assert_eq!(gen_factor.shape().len(), operands.len());
923932
// First slice the array with the constants.
924933
let gen_factor = gen_factor.slice_each_axis(|ax| match operands[ax.axis.index()] {
925-
fg::GenFactorOperand::Var(_, _) => ndarray::Slice::new(0, None, 1),
926-
fg::GenFactorOperand::Pub(pub_idx) => {
927-
let mut pub_val = public_values[factor.publics[pub_idx].0].get(i) as isize;
928-
if factor.publics[pub_idx].1 {
934+
fg::GenFactorOperand::Var { ..} => ndarray::Slice::new(0, None, 1),
935+
fg::GenFactorOperand::Pub { pub_id } => {
936+
let mut pub_val = public_values[factor.publics[pub_id].0].get(i) as isize;
937+
if factor.publics[pub_id].1 {
929938
if nc.is_power_of_two() {
930939
pub_val = !pub_val;
931940
} else {
@@ -938,12 +947,12 @@ fn factor_gen_factor<'a>(
938947
});
939948
let mut gen_factor = gen_factor.to_owned();
940949
for (op_idx, op) in operands.iter().enumerate() {
941-
if op_idx != dest_idx {
942-
if let fg::GenFactorOperand::Var(var_idx, neg) = op {
943-
if *neg {
950+
if let fg::GenFactorOperand::Var { factor_edge_id, negated } = op {
951+
if *factor_edge_id != dest_idx {
952+
if *negated {
944953
todo!("Negated operands on generalized factors not yet implemented.");
945954
}
946-
let distr = &belief_from_var[factor.edges[*var_idx]];
955+
let distr = &belief_from_var[factor.edges[*factor_edge_id]];
947956
let mut new_gen_factor: ndarray::ArrayD<f64> = ndarray::ArrayD::zeros(gen_factor.slice_axis(ndarray::Axis(op_idx), ndarray::Slice::new(0, Some(1), 1)).shape());
948957
if let Some(distr) = distr.value() {
949958
for (d, gf) in distr.slice(s![i,..]).iter().zip(gen_factor.axis_chunks_iter(ndarray::Axis(op_idx), 1)) {
@@ -959,10 +968,10 @@ fn factor_gen_factor<'a>(
959968
}
960969
}
961970
// Drop useless axes.
962-
for _ in 0..dest_idx {
971+
for _ in 0..op_dest_idx {
963972
gen_factor.index_axis_inplace(ndarray::Axis(0), 0);
964973
}
965-
for _ in (dest_idx+1)..operands.len() {
974+
for _ in (op_dest_idx+1)..operands.len() {
966975
gen_factor.index_axis_inplace(ndarray::Axis(1), 0);
967976
}
968977
distr.value_mut().unwrap().slice_mut(s![i,..]).assign(&gen_factor);
@@ -974,43 +983,43 @@ fn factor_gen_factor<'a>(
974983
dest.fill(0.0);
975984
for op_values in gen_factor.outer_iter() {
976985
let mut res = 1.0;
977-
for (op_idx, (op, val)) in operands.iter().zip(op_values.iter()).enumerate() {
978-
if op_idx != dest_idx {
979-
match op {
980-
fg::GenFactorOperand::Var(var_idx, neg) => {
986+
for (op, val) in operands.iter().zip(op_values.iter()) {
987+
match op {
988+
fg::GenFactorOperand::Var { factor_edge_id, negated} => {
989+
if *factor_edge_id != dest_idx {
981990
let mut val = *val;
982-
if *neg {
991+
if *negated {
983992
if nc.is_power_of_two() {
984993
val = !val & ((nc - 1) as ClassVal);
985994
} else {
986995
// TODO Check that we enforce this at graph creation time and return a proper error.
987996
panic!("Cannot negate operands with non-power-of-two number of classes.");
988997
}
989998
}
990-
let distr = &belief_from_var[factor.edges[*var_idx]];
999+
let distr = &belief_from_var[factor.edges[*factor_edge_id]];
9911000
// For uniform, we implicitly multiply by 1.0
9921001
if let Some(distr) = distr.value() {
9931002
res *= distr[(i, val as usize)];
9941003
}
9951004
}
996-
fg::GenFactorOperand::Pub(pub_idx) => {
997-
let mut pub_val = public_values[factor.publics[*pub_idx].0].get(i);
998-
if factor.publics[*pub_idx].1 {
999-
if nc.is_power_of_two() {
1000-
pub_val = !pub_val & ((nc - 1) as ClassVal);
1001-
} else {
1002-
// TODO Check that we enforce this at graph creation time and return a proper error.
1003-
panic!("Cannot negate operands with non-power-of-two number of classes.");
1004-
}
1005-
}
1006-
if pub_val != *val {
1007-
res = 0.0;
1005+
}
1006+
fg::GenFactorOperand::Pub{pub_id} => {
1007+
let mut pub_val = public_values[factor.publics[*pub_id].0].get(i);
1008+
if factor.publics[*pub_id].1 {
1009+
if nc.is_power_of_two() {
1010+
pub_val = !pub_val & ((nc - 1) as ClassVal);
1011+
} else {
1012+
// TODO Check that we enforce this at graph creation time and return a proper error.
1013+
panic!("Cannot negate operands with non-power-of-two number of classes.");
10081014
}
10091015
}
1016+
if pub_val != *val {
1017+
res = 0.0;
1018+
}
10101019
}
10111020
}
10121021
}
1013-
dest[op_values[dest_idx] as usize] += res;
1022+
dest[op_values[op_dest_idx] as usize] += res;
10141023
}
10151024
}
10161025
}

src/scalib_ext/scalib/src/sasca/factor_graph.rs

+23-16
Original file line numberDiff line numberDiff line change
@@ -70,8 +70,13 @@ pub(super) enum ExprFactor {
7070

7171
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
7272
pub(super) enum GenFactorOperand {
73-
Var(usize, bool),
74-
Pub(usize),
73+
Var {
74+
factor_edge_id: usize,
75+
negated: bool,
76+
},
77+
Pub {
78+
pub_id: usize,
79+
},
7580
}
7681

7782
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
@@ -374,8 +379,10 @@ impl FactorGraph {
374379
let ops: Vec<&PublicValue> = operands
375380
.iter()
376381
.map(|op| match op {
377-
GenFactorOperand::Var(idx, ..) => &var_assignments[*idx],
378-
GenFactorOperand::Pub(idx) => &public_values[*idx],
382+
GenFactorOperand::Var { factor_edge_id, .. } => {
383+
&var_assignments[*factor_edge_id]
384+
}
385+
GenFactorOperand::Pub { pub_id } => &public_values[*pub_id],
379386
})
380387
.collect();
381388
let nmulti_ops = ops.iter().find_map(|op| {
@@ -451,25 +458,25 @@ impl FactorGraph {
451458
self.factors
452459
.values()
453460
.map(|factor| {
461+
let mut pubs = factor
462+
.publics
463+
.iter()
464+
.map(|(pub_id, nv)| (&public_values[*pub_id], *nv));
454465
match &factor.kind {
455466
// Not used
456467
FactorKind::Assign {
457-
expr: ExprFactor::NOT,
458-
..
459-
}
460-
| FactorKind::Assign {
461468
expr: ExprFactor::LOOKUP { .. },
462469
..
463470
}
464471
| FactorKind::GenFactor { .. } => PublicValue::Single(0),
465-
FactorKind::Assign { expr, has_res } => self.merge_pubs(
466-
expr,
467-
!has_res,
468-
factor
469-
.publics
470-
.iter()
471-
.map(|(pub_id, nv)| (&public_values[*pub_id], *nv)),
472-
),
472+
FactorKind::Assign {
473+
expr: ExprFactor::NOT,
474+
..
475+
} => pubs
476+
.next()
477+
.map(|(val, _)| val.clone())
478+
.unwrap_or(PublicValue::Single(0)),
479+
FactorKind::Assign { expr, has_res } => self.merge_pubs(expr, !has_res, pubs),
473480
}
474481
})
475482
.collect()

src/scalib_ext/scalib/src/sasca/fg_build.rs

+5-2
Original file line numberDiff line numberDiff line change
@@ -200,10 +200,13 @@ impl fg::FactorGraph {
200200
let mut operands = Vec::new();
201201
for (i, p) in is_pub.iter().enumerate() {
202202
if *p {
203-
operands.push(fg::GenFactorOperand::Pub(n_pubs));
203+
operands.push(fg::GenFactorOperand::Pub { pub_id: n_pubs });
204204
n_pubs += 1;
205205
} else {
206-
operands.push(fg::GenFactorOperand::Var(n_vars, vars[i].neg));
206+
operands.push(fg::GenFactorOperand::Var {
207+
factor_edge_id: n_vars,
208+
negated: vars[i].neg,
209+
});
207210
n_vars += 1;
208211
}
209212
}

tests/test_factorgraph.py

+75-1
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
import pytest
2-
from scalib.attacks import FactorGraph, BPState
2+
from scalib.attacks import FactorGraph, BPState, GenFactor
33
import numpy as np
44
import os
55
import copy
6+
import itertools as it
67

78

89
def normalize_distr(x):
@@ -1778,3 +1779,76 @@ def test_cycle_detection_single_factor_with_multi():
17781779
fg = FactorGraph(graph_desc)
17791780
bp = BPState(fg, 2)
17801781
assert bp.is_cyclic()
1782+
1783+
1784+
def test_generic_single_multi():
1785+
nc = 2
1786+
n_exec = 2
1787+
graph_desc = f"""
1788+
NC {nc}
1789+
1790+
VAR SINGLE A
1791+
VAR SINGLE B
1792+
VAR SINGLE C
1793+
1794+
GENERIC SINGLE XOR
1795+
1796+
PROPERTY XOR(A,B,C)
1797+
"""
1798+
1799+
def xor(a, b):
1800+
return a ^ b
1801+
1802+
fg = FactorGraph(graph_desc)
1803+
1804+
XOR = np.array(
1805+
[[a, b, a ^ b] for a, b in it.product(range(nc), repeat=2)],
1806+
dtype=np.uint32,
1807+
)
1808+
gen_factors = {
1809+
"XOR": GenFactor.sparse_functional(XOR),
1810+
}
1811+
1812+
bp = BPState(fg, n_exec, gen_factors=gen_factors)
1813+
bp.bp_loopy(1, True)
1814+
1815+
1816+
def test_factor_not_pub():
1817+
nc = 4
1818+
graph_desc = f"""
1819+
NC {nc}
1820+
1821+
PUB SINGLE A
1822+
VAR SINGLE B
1823+
PROPERTY B = !A
1824+
"""
1825+
fg = FactorGraph(graph_desc)
1826+
for a in range(nc):
1827+
bp = BPState(fg, 1, public_values={"A": a})
1828+
bp.bp_acyclic("B")
1829+
result = bp.get_distribution("B")
1830+
assert np.argmax(result) == (nc - 1) ^ a
1831+
1832+
1833+
def test_factor_gen_pub():
1834+
nc = 2
1835+
graph_desc = f"""
1836+
NC {nc}
1837+
PUB SINGLE A
1838+
VAR SINGLE B
1839+
GENERIC SINGLE NOT
1840+
PROPERTY NOT(A,B)
1841+
"""
1842+
fg = FactorGraph(graph_desc)
1843+
not_factors = [
1844+
GenFactor.sparse_functional(
1845+
np.array([(a, (nc - 1) ^ a) for a in range(nc)], dtype=np.uint32)
1846+
),
1847+
GenFactor.dense(np.array([[0, 1], [1, 0]], dtype=np.float64)),
1848+
]
1849+
for nf in not_factors:
1850+
for a in range(nc):
1851+
bp = BPState(fg, 1, public_values={"A": a}, gen_factors={"NOT": nf})
1852+
bp.bp_acyclic("B")
1853+
result = bp.get_distribution("B")
1854+
assert np.argmax(result) == (nc - 1) ^ a

0 commit comments

Comments
 (0)