Skip to content

Commit

Permalink
Fixed issues when requesting only probability metrics
Browse files Browse the repository at this point in the history
Also added initial cleaning of matrix to make input more uniform and predictable.

Also, DSMs will no longer contain None values.
  • Loading branch information
johnmartins committed Jan 8, 2025
1 parent 963f801 commit 2384e09
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 14 deletions.
33 changes: 28 additions & 5 deletions cpm/models.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Optional
import numbers


class DSM:
Expand All @@ -16,14 +17,31 @@ def __init__(self, matrix: list[list[Optional[float]]], columns: list[str], inst
:param instigator: Can either be **column** or **row**. Determines directionality of interactions in DSM.
By default, propagation travels from column to row
"""
self.matrix = matrix
self.matrix = DSM.clean_matrix(matrix)
self.columns = columns
self.node_network: dict[int, 'GraphNode'] = self.build_node_network(instigator)
self.instigator = instigator

if instigator not in ['row', 'column']:
raise ValueError('instigator argument needs to be either "row" or "column".')

@staticmethod
def clean_matrix(matrix) -> list[list[float]]:
cleaned_matrix = []
for i, row in enumerate(matrix):
cleaned_matrix.append([])
for j, val in enumerate(row):
if val is None:
val = 0
try:
cleaned_value = float(val)
except ValueError:
cleaned_value = 0

cleaned_matrix[i].append(cleaned_value)

return cleaned_matrix

def __str__(self):
return f'{self.columns}\n{self.matrix}'

Expand All @@ -47,8 +65,8 @@ def build_node_network(self, instigator: str) -> dict[int, 'GraphNode']:
# Ignore diagonal
if i == j:
continue
# Ignore empty cells
if col == "" or col is None:
# Ignore empty connections
if col == "" or col is None or col == 0:
continue

numerical_value = 0.0
Expand Down Expand Up @@ -96,8 +114,13 @@ def set_level(self):

return level

def get_probability(self):
def get_probability(self, stack=0):

# If this node is the single node in the chain, then the probability is none.
if len(self.next) == 0 and stack == 0:
return 0

# If final node in chain, set probability to 1 to complete the calculation
if len(self.next) == 0:
return 1

Expand All @@ -107,7 +130,7 @@ def get_probability(self):
# Likelihood of propagating to this node
from_this = self.node.neighbours[next_index]
# Likelihood of that node being propagated to:
to_next = self.next[next_index].get_probability()
to_next = self.next[next_index].get_probability(stack=stack+1)
prob = prob * (1 - from_this * to_next)

return 1 - prob
Expand Down
8 changes: 4 additions & 4 deletions tests/test_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@ def test_parse_dsm_header():
def test_parse_dsm_matrix():
dsm = parse_csv('./tests/test-assets/dsm-simple-symmetrical.csv')

should_be = [[None, 0.1, 0.2, 0.3],
[0.1, None, 0.4, 0.5],
[0.2, 0.4, None, 0.6],
[0.3, 0.5, 0.6, None]]
should_be = [[0, 0.1, 0.2, 0.3],
[0.1, 0, 0.4, 0.5],
[0.2, 0.4, 0, 0.6],
[0.3, 0.5, 0.6, 0]]

for i, row in enumerate(should_be):
for j, col in enumerate(row):
Expand Down
38 changes: 33 additions & 5 deletions tests/test_propagation.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from cpm.models import ChangePropagationTree
from cpm.models import ChangePropagationTree, DSM
from cpm.parse import parse_csv
from cpm.utils import calculate_risk_matrix

Expand Down Expand Up @@ -112,8 +112,36 @@ def test_probability_calculation():

dsm_r = parse_csv('./tests/test-assets/dsm-cpx-answers-probs.csv')

for i, col in enumerate(dsm_r.columns):
for j, col in enumerate(dsm_r.columns):
if dsm_r.matrix[i][j] is None:
for i, col_i in enumerate(dsm_r.columns):
for j, col_j in enumerate(dsm_r.columns):
if i == j:
continue
assert abs(res_mtx[i][j] - dsm_r.matrix[i][j]) < 0.001
assert abs(res_mtx[i][j] - dsm_r.matrix[i][j]) < 0.001, f"Failed for index i={i} (row {col_i}), j={j} (col {col_j})"


def test_dsm_input_robustness():
instigator = 'column'
cols = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h']
mtx_i = [['-', '0', 0.3, None, None, None, 0.7, 0],
[None, '-', 0, 0.4, 0.5, None, None, None],
["0.1", 0, '-', '', '', '', 0.8],
[0, 0, 0, 'D', 0, 0.7, 0],
[0, 0, 0, 0, None, None, 0.7, 0],
['0.1', '0.2', 0, 0, None, 'F', 0.7, 0],
[0, 0, 0, 0, 0, 0.6, 99, 0],
[0, 0, 0.3, 0.4, 0, 0, 0, 'H']]

dsm_i = DSM(mtx_i, cols, instigator)

dsm_l = [[],
[],
[],
[],
[],
[],
[],
[]]


def test_input_throws_if_dsm_instigator_mismatch():
pass

0 comments on commit 2384e09

Please sign in to comment.