diff --git a/cpm/models.py b/cpm/models.py index 93fc6b2..4136abf 100644 --- a/cpm/models.py +++ b/cpm/models.py @@ -25,6 +25,8 @@ def __init__(self, matrix: list[list[Optional[float]]], columns: list[str], inst if instigator not in ['row', 'column']: raise ValueError('instigator argument needs to be either "row" or "column".') + self.validate_matrix() + @staticmethod def clean_matrix(matrix) -> list[list[float]]: cleaned_matrix = [] @@ -42,6 +44,14 @@ def clean_matrix(matrix) -> list[list[float]]: return cleaned_matrix + def validate_matrix(self): + try: + assert len(self.matrix) == len(self.columns) + for row in self.matrix: + assert len(row) == len(self.columns) + except AssertionError: + raise ValueError('Matrix dimensions are inconsistent with provided columns.') + def __str__(self): return f'{self.columns}\n{self.matrix}' @@ -184,6 +194,9 @@ def __init__(self, start_index: int, target_index: int, dsm_impact: DSM, dsm_lik start_index = target_index target_index = temp + if len(dsm_impact.matrix) != len(dsm_likelihood.matrix): + raise ValueError('Impact and Likelihood matrices need to have the same dimensions.') + self.dsm_impact: DSM = dsm_impact self.dsm_likelihood: DSM = dsm_likelihood self.start_index: int = start_index diff --git a/tests/test_input_validation.py b/tests/test_input_validation.py new file mode 100644 index 0000000..ec0f388 --- /dev/null +++ b/tests/test_input_validation.py @@ -0,0 +1,54 @@ +import pytest +from cpm.models import ChangePropagationTree, DSM +from cpm.parse import parse_csv +from cpm.utils import calculate_risk_matrix + + +def test_throws_if_dsm_instigator_mismatch_1(): + dsm_p = parse_csv('./tests/test-assets/dsm-cpx-probs.csv', instigator="row") + dsm_i = parse_csv('./tests/test-assets/dsm-cpx-imps.csv') + + with pytest.raises(ValueError): + calculate_risk_matrix(dsm_i, dsm_p, search_depth=4) + + +def test_throws_if_dsm_instigator_mismatch_2(): + dsm_p = parse_csv('./tests/test-assets/dsm-cpx-probs.csv', instigator="row") + dsm_i = parse_csv('./tests/test-assets/dsm-cpx-imps.csv') + + with pytest.raises(ValueError): + ChangePropagationTree(0, 4, dsm_i, dsm_p) + + +def test_throws_if_dsm_incomplete(): + mtx = [ + [0.1, 0.2, 0.3], + [0.4, 0.5, 0.6], + [0.7, 0.8, 0.9], + ] + cols = ["a", "b", "c", "d"] + with pytest.raises(ValueError): + dsm = DSM(mtx, cols) + + +def test_throws_if_dsm_size_mismatch(): + cols_p = ["a", "b", "c"] + mtx_p = [ + [0.1, 0.2, 0.3], + [0.4, 0.5, 0.6], + [0.7, 0.8, 0.9], + ] + cols_i = ["a", "b", "c", "d"] + mtx_i = [ + [0.1, 0.2, 0.3, 0.4], + [0.4, 0.5, 0.6, 0.7], + [0.7, 0.8, 0.9, 0.10], + [0.11, 0.12, 0.13, 0.14] + ] + + dsm_p = DSM(mtx_p, cols_p) + dsm_i = DSM(mtx_i, cols_i) + + with pytest.raises(ValueError): + # If DSMs are of different size, then input validation should prevent execution. + ChangePropagationTree(0, 2, dsm_i, dsm_p) diff --git a/tests/test_propagation.py b/tests/test_propagation.py index 75926bd..1564da6 100644 --- a/tests/test_propagation.py +++ b/tests/test_propagation.py @@ -122,26 +122,34 @@ def test_probability_calculation(): def test_dsm_input_robustness(): instigator = 'column' cols = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h'] + # Purposefully poorly formatted input matrix mtx_i = [['-', '0', 0.3, None, None, None, 0.7, 0], [None, '-', 0, 0.4, 0.5, None, None, None], - ["0.1", 0, '-', '', '', '', 0.8], - [0, 0, 0, 'D', 0, 0.7, 0], + ["0.1", 0, '-', '', '', '0.6', 0, 0], + [0, 0, 0, 'D', 0, 0, None, 0.8], [0, 0, 0, 0, None, None, 0.7, 0], ['0.1', '0.2', 0, 0, None, 'F', 0.7, 0], [0, 0, 0, 0, 0, 0.6, 99, 0], [0, 0, 0.3, 0.4, 0, 0, 0, 'H']] dsm_i = DSM(mtx_i, cols, instigator) + # Purposefully poorly formatted input matrix + mtx_l = [[None, None, 0.1, None, None, None, 0.1, None], + [None, 'B', 0, 0.2, "0.2", 0, 0, None], + ["0.3", None, 'C', None, None, 0.3, None, 0], + [0, 0, 0, "D", 0, 0, 0, 0.4], + [0, 0, 0, 0, "E", 0, 0.5, 0], + [0.6, 0.6, None, None, None, "F", "0.6", None], + [None, None, None, None, None, 0.7, "G", 0], + [0, 0, 0.8, 0.8, 0, 0, 0, "H"]] + + dsm_p = DSM(mtx_l, cols, instigator) - dsm_l = [[], - [], - [], - [], - [], - [], - [], - []] - + dsm_r = parse_csv('./tests/test-assets/dsm-cpx-answers-risks.csv') + res_mtx = calculate_risk_matrix(dsm_i, dsm_p, search_depth=4) -def test_input_throws_if_dsm_instigator_mismatch(): - pass + for i, col_i in enumerate(dsm_r.columns): + for j, col_j in enumerate(dsm_r.columns): + if i == j: + continue + assert abs(res_mtx[i][j] - dsm_r.matrix[i][j]) < 0.001, f"Failed for index i={i} (row {col_i}), j={j} (col {col_j})"