From c4c89537cde01e550eba8fa0221d5a379447a582 Mon Sep 17 00:00:00 2001 From: Giulia Baldini Date: Mon, 21 Sep 2020 17:01:42 +0200 Subject: [PATCH 1/2] Add conversion to list before padding, extend length of result to padding --- munkres.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/munkres.py b/munkres.py index 2f2edbc..1e8b2ca 100644 --- a/munkres.py +++ b/munkres.py @@ -100,7 +100,7 @@ def pad_matrix(self, matrix: Matrix, pad_value: int=0) -> Matrix: new_matrix = [] for row in matrix: row_len = len(row) - new_row = row[:] + new_row = list(row[:]) if total_rows > row_len: # Row too short. Pad it. new_row += [pad_value] * (total_rows - row_len) @@ -163,8 +163,8 @@ def compute(self, cost_matrix: Matrix) -> Sequence[Tuple[int, int]]: # Look for the starred columns results = [] - for i in range(self.original_length): - for j in range(self.original_width): + for i in range(self.n): + for j in range(self.n): if self.marked[i][j] == 1: results += [(i, j)] From 5ed8f02a587c4ffc73fce5baffef0ff70503bd8c Mon Sep 17 00:00:00 2001 From: Giulia Baldini Date: Tue, 22 Sep 2020 14:27:40 +0200 Subject: [PATCH 2/2] Fix tests by adding constraints in get cost function --- test/test_munkres.py | 89 +++++++++++++++++++++++++++++++------------- 1 file changed, 64 insertions(+), 25 deletions(-) diff --git a/test/test_munkres.py b/test/test_munkres.py index 23796dd..8aed588 100644 --- a/test/test_munkres.py +++ b/test/test_munkres.py @@ -4,9 +4,11 @@ m = Munkres() + def _get_cost(matrix): indices = m.compute(matrix) - return sum([matrix[row][column] for row, column in indices]) + return sum([matrix[row][column] for row, column in indices if row < len(matrix) and column < len(matrix[0])]) + def test_documented_example(): ''' @@ -18,6 +20,7 @@ def test_documented_example(): cost = _get_cost(matrix) assert cost == 12 + def float_example(): ''' Test a matrix with float values @@ -28,6 +31,7 @@ def float_example(): cost = _get_cost(matrix) assert cost == pytest.approx(13.5) + def test_5_x_5(): matrix = [[12, 9, 27, 10, 23], [7, 13, 13, 30, 19], @@ -37,6 +41,7 @@ def test_5_x_5(): cost = _get_cost(matrix) assert cost == 51 + def test_5_x_5_float(): matrix = [[12.01, 9.02, 27.03, 10.04, 23.05], [7.06, 13.07, 13.08, 30.09, 19.1], @@ -61,6 +66,7 @@ def test_10_x_10(): cost = _get_cost(matrix) assert cost == 66 + def test_10_x_10_float(): matrix = [[37.001, 34.002, 29.003, 26.004, 19.005, 8.006, 9.007, 23.008, 19.009, 29.01], [9.011, 28.012, 20.013, 8.014, 18.015, 20.016, 14.017, 33.018, 23.019, 14.02], @@ -75,6 +81,7 @@ def test_10_x_10_float(): cost = _get_cost(matrix) assert cost == pytest.approx(66.505) + def test_20_x_20(): matrix = [[5, 4, 3, 9, 8, 9, 3, 5, 6, 9, 4, 10, 3, 5, 6, 6, 1, 8, 10, 2], [10, 9, 9, 2, 8, 3, 9, 9, 10, 1, 7, 10, 8, 4, 2, 1, 4, 8, 4, 8], @@ -99,27 +106,49 @@ def test_20_x_20(): cost = _get_cost(matrix) assert cost == 22 + def test_20_x_20_float(): - matrix = [[5.0001, 4.0002, 3.0003, 9.0004, 8.0005, 9.0006, 3.0007, 5.0008, 6.0009, 9.001, 4.0011, 10.0012, 3.0013, 5.0014, 6.0015, 6.0016, 1.0017, 8.0018, 10.0019, 2.002], - [10.0021, 9.0022, 9.0023, 2.0024, 8.0025, 3.0026, 9.0027, 9.0028, 10.0029, 1.003, 7.0031, 10.0032, 8.0033, 4.0034, 2.0035, 1.0036, 4.0037, 8.0038, 4.0039, 8.004], - [10.0041, 4.0042, 4.0043, 3.0044, 1.0045, 3.0046, 5.0047, 10.0048, 6.0049, 8.005, 6.0051, 8.0052, 4.0053, 10.0054, 7.0055, 2.0056, 4.0057, 5.0058, 1.0059, 8.006], - [2.0061, 1.0062, 4.0063, 2.0064, 3.0065, 9.0066, 3.0067, 4.0068, 7.0069, 3.007, 4.0071, 1.0072, 3.0073, 2.0074, 9.0075, 8.0076, 6.0077, 5.0078, 7.0079, 8.008], - [3.0081, 4.0082, 4.0083, 1.0084, 4.0085, 10.0086, 1.0087, 2.0088, 6.0089, 4.009, 5.0091, 10.0092, 2.0093, 2.0094, 3.0095, 9.0096, 10.0097, 9.0098, 9.0099, 10.01], - [1.0101, 10.0102, 1.0103, 8.0104, 1.0105, 3.0106, 1.0107, 7.0108, 1.0109, 1.011, 2.0111, 1.0112, 2.0113, 6.0114, 3.0115, 3.0116, 4.0117, 4.0118, 8.0119, 6.012], - [1.0121, 8.0122, 7.0123, 10.0124, 10.0125, 3.0126, 4.0127, 6.0128, 1.0129, 6.013, 6.0131, 4.0132, 9.0133, 6.0134, 9.0135, 6.0136, 4.0137, 5.0138, 4.0139, 7.014], - [8.0141, 10.0142, 3.0143, 9.0144, 4.0145, 9.0146, 3.0147, 3.0148, 4.0149, 6.015, 4.0151, 2.0152, 6.0153, 7.0154, 7.0155, 4.0156, 4.0157, 3.0158, 4.0159, 7.016], - [1.0161, 3.0162, 8.0163, 2.0164, 6.0165, 9.0166, 2.0167, 7.0168, 4.0169, 8.017, 10.0171, 8.0172, 10.0173, 5.0174, 1.0175, 3.0176, 10.0177, 10.0178, 2.0179, 9.018], - [2.0181, 4.0182, 1.0183, 9.0184, 2.0185, 9.0186, 7.0187, 8.0188, 2.0189, 1.019, 4.0191, 10.0192, 5.0193, 2.0194, 7.0195, 6.0196, 5.0197, 7.0198, 2.0199, 6.02], - [4.0201, 5.0202, 1.0203, 4.0204, 2.0205, 3.0206, 3.0207, 4.0208, 1.0209, 8.021, 8.0211, 2.0212, 6.0213, 9.0214, 5.0215, 9.0216, 6.0217, 3.0218, 9.0219, 3.022], - [3.0221, 1.0222, 1.0223, 8.0224, 6.0225, 8.0226, 8.0227, 7.0228, 9.0229, 3.023, 2.0231, 1.0232, 8.0233, 2.0234, 4.0235, 7.0236, 3.0237, 1.0238, 2.0239, 4.024], - [5.0241, 9.0242, 8.0243, 6.0244, 10.0245, 4.0246, 10.0247, 3.0248, 4.0249, 10.025, 10.0251, 10.0252, 1.0253, 7.0254, 8.0255, 8.0256, 7.0257, 7.0258, 8.0259, 8.026], - [1.0261, 4.0262, 6.0263, 1.0264, 6.0265, 1.0266, 2.0267, 10.0268, 5.0269, 10.027, 2.0271, 6.0272, 2.0273, 4.0274, 5.0275, 5.0276, 3.0277, 5.0278, 1.0279, 5.028], - [5.0281, 6.0282, 9.0283, 10.0284, 6.0285, 6.0286, 10.0287, 6.0288, 4.0289, 1.029, 5.0291, 3.0292, 9.0293, 5.0294, 2.0295, 10.0296, 9.0297, 9.0298, 5.0299, 1.03], - [10.0301, 9.0302, 4.0303, 6.0304, 9.0305, 5.0306, 3.0307, 7.0308, 10.0309, 1.031, 6.0311, 8, 1.0312, 1.0313, 10.0314, 9.0315, 5.0316, 7.0317, 7.0318, 5.0319, 1.032], - [2.0321, 6.0322, 6.0323, 6.0324, 6.0325, 2.0326, 9.0327, 4.0328, 7.0329, 5.033, 3.0331, 2.0332, 10.0333, 3.0334, 4.0335, 5.0336, 10.0337, 9.0338, 1.0339, 7.034], - [5.0341, 2.0342, 4.0343, 9.0344, 8.0345, 4.0346, 8.0347, 2.0348, 4.0349, 1.035, 3.0351, 7.0352, 6.0353, 8.0354, 1.0355, 6.0356, 8.0357, 8.0358, 10.0359, 10.036], - [9.0361, 6.0362, 3.0363, 1.0364, 8.0365, 5.0366, 7.0367, 8.0368, 7.0369, 2.037, 1.0371, 8.0372, 2.0373, 8.0374, 3.0375, 7.0376, 4.0377, 8.0378, 7.0379, 7.038], - [8.0381, 4.0382, 4.0383, 9.0384, 7.0385, 10.0386, 6.0387, 2.0388, 1.0389, 5.039, 8.0391, 5.0392, 1.0393, 1.0394, 1.0395, 9.0396, 1.0397, 3.0398, 5.0399, 3.04]] + matrix = [ + [5.0001, 4.0002, 3.0003, 9.0004, 8.0005, 9.0006, 3.0007, 5.0008, 6.0009, 9.001, 4.0011, 10.0012, 3.0013, 5.0014, + 6.0015, 6.0016, 1.0017, 8.0018, 10.0019, 2.002], + [10.0021, 9.0022, 9.0023, 2.0024, 8.0025, 3.0026, 9.0027, 9.0028, 10.0029, 1.003, 7.0031, 10.0032, 8.0033, + 4.0034, 2.0035, 1.0036, 4.0037, 8.0038, 4.0039, 8.004], + [10.0041, 4.0042, 4.0043, 3.0044, 1.0045, 3.0046, 5.0047, 10.0048, 6.0049, 8.005, 6.0051, 8.0052, 4.0053, + 10.0054, 7.0055, 2.0056, 4.0057, 5.0058, 1.0059, 8.006], + [2.0061, 1.0062, 4.0063, 2.0064, 3.0065, 9.0066, 3.0067, 4.0068, 7.0069, 3.007, 4.0071, 1.0072, 3.0073, 2.0074, + 9.0075, 8.0076, 6.0077, 5.0078, 7.0079, 8.008], + [3.0081, 4.0082, 4.0083, 1.0084, 4.0085, 10.0086, 1.0087, 2.0088, 6.0089, 4.009, 5.0091, 10.0092, 2.0093, + 2.0094, 3.0095, 9.0096, 10.0097, 9.0098, 9.0099, 10.01], + [1.0101, 10.0102, 1.0103, 8.0104, 1.0105, 3.0106, 1.0107, 7.0108, 1.0109, 1.011, 2.0111, 1.0112, 2.0113, 6.0114, + 3.0115, 3.0116, 4.0117, 4.0118, 8.0119, 6.012], + [1.0121, 8.0122, 7.0123, 10.0124, 10.0125, 3.0126, 4.0127, 6.0128, 1.0129, 6.013, 6.0131, 4.0132, 9.0133, + 6.0134, 9.0135, 6.0136, 4.0137, 5.0138, 4.0139, 7.014], + [8.0141, 10.0142, 3.0143, 9.0144, 4.0145, 9.0146, 3.0147, 3.0148, 4.0149, 6.015, 4.0151, 2.0152, 6.0153, 7.0154, + 7.0155, 4.0156, 4.0157, 3.0158, 4.0159, 7.016], + [1.0161, 3.0162, 8.0163, 2.0164, 6.0165, 9.0166, 2.0167, 7.0168, 4.0169, 8.017, 10.0171, 8.0172, 10.0173, + 5.0174, 1.0175, 3.0176, 10.0177, 10.0178, 2.0179, 9.018], + [2.0181, 4.0182, 1.0183, 9.0184, 2.0185, 9.0186, 7.0187, 8.0188, 2.0189, 1.019, 4.0191, 10.0192, 5.0193, 2.0194, + 7.0195, 6.0196, 5.0197, 7.0198, 2.0199, 6.02], + [4.0201, 5.0202, 1.0203, 4.0204, 2.0205, 3.0206, 3.0207, 4.0208, 1.0209, 8.021, 8.0211, 2.0212, 6.0213, 9.0214, + 5.0215, 9.0216, 6.0217, 3.0218, 9.0219, 3.022], + [3.0221, 1.0222, 1.0223, 8.0224, 6.0225, 8.0226, 8.0227, 7.0228, 9.0229, 3.023, 2.0231, 1.0232, 8.0233, 2.0234, + 4.0235, 7.0236, 3.0237, 1.0238, 2.0239, 4.024], + [5.0241, 9.0242, 8.0243, 6.0244, 10.0245, 4.0246, 10.0247, 3.0248, 4.0249, 10.025, 10.0251, 10.0252, 1.0253, + 7.0254, 8.0255, 8.0256, 7.0257, 7.0258, 8.0259, 8.026], + [1.0261, 4.0262, 6.0263, 1.0264, 6.0265, 1.0266, 2.0267, 10.0268, 5.0269, 10.027, 2.0271, 6.0272, 2.0273, + 4.0274, 5.0275, 5.0276, 3.0277, 5.0278, 1.0279, 5.028], + [5.0281, 6.0282, 9.0283, 10.0284, 6.0285, 6.0286, 10.0287, 6.0288, 4.0289, 1.029, 5.0291, 3.0292, 9.0293, + 5.0294, 2.0295, 10.0296, 9.0297, 9.0298, 5.0299, 1.03], + [10.0301, 9.0302, 4.0303, 6.0304, 9.0305, 5.0306, 3.0307, 7.0308, 10.0309, 1.031, 6.0311, 8, 1.0312, 1.0313, + 10.0314, 9.0315, 5.0316, 7.0317, 7.0318, 5.0319, 1.032], + [2.0321, 6.0322, 6.0323, 6.0324, 6.0325, 2.0326, 9.0327, 4.0328, 7.0329, 5.033, 3.0331, 2.0332, 10.0333, 3.0334, + 4.0335, 5.0336, 10.0337, 9.0338, 1.0339, 7.034], + [5.0341, 2.0342, 4.0343, 9.0344, 8.0345, 4.0346, 8.0347, 2.0348, 4.0349, 1.035, 3.0351, 7.0352, 6.0353, 8.0354, + 1.0355, 6.0356, 8.0357, 8.0358, 10.0359, 10.036], + [9.0361, 6.0362, 3.0363, 1.0364, 8.0365, 5.0366, 7.0367, 8.0368, 7.0369, 2.037, 1.0371, 8.0372, 2.0373, 8.0374, + 3.0375, 7.0376, 4.0377, 8.0378, 7.0379, 7.038], + [8.0381, 4.0382, 4.0383, 9.0384, 7.0385, 10.0386, 6.0387, 2.0388, 1.0389, 5.039, 8.0391, 5.0392, 1.0393, 1.0394, + 1.0395, 9.0396, 1.0397, 3.0398, 5.0399, 3.04]] cost = _get_cost(matrix) ''' Here, it becomes mandatory to set "places" argument, otherwise test might @@ -128,6 +157,7 @@ def test_20_x_20_float(): ''' assert cost == pytest.approx(20.362, rel=1e-3) + def test_disallowed(): matrix = [[5, 9, DISALLOWED], [10, DISALLOWED, 2], @@ -135,6 +165,7 @@ def test_disallowed(): cost = _get_cost(matrix) assert cost == 19 + def test_disallowed_float(): matrix = [[5.1, 9.2, DISALLOWED], [10.3, DISALLOWED, 2.4], @@ -142,6 +173,7 @@ def test_disallowed_float(): cost = _get_cost(matrix) assert cost == pytest.approx(20.1) + def test_profit(): profit_matrix = [[94, 66, 100, 18, 48], [51, 63, 97, 79, 11], @@ -156,6 +188,7 @@ def test_profit(): profit = sum([profit_matrix[row][column] for row, column in indices]) assert profit == 392 + def test_profit_float(): profit_matrix = [[94.01, 66.02, 100.03, 18.04, 48.05], [51.06, 63.07, 97.08, 79.09, 11.1], @@ -170,6 +203,7 @@ def test_profit_float(): profit = sum([profit_matrix[row][column] for row, column in indices]) assert profit == pytest.approx(362.65) + def test_irregular(): matrix = [[12, 26, 17], [49, 43, 36, 10, 5], @@ -180,6 +214,7 @@ def test_irregular(): cost = _get_cost(matrix) assert cost == 43 + def test_irregular_float(): matrix = [[12.01, 26.02, 17.03], [49.04, 43.05, 36.06, 10.07, 5.08], @@ -190,6 +225,7 @@ def test_irregular_float(): cost = _get_cost(matrix) assert cost == pytest.approx(43.42) + def test_rectangular(): matrix = [[34, 26, 17, 12], [43, 43, 36, 10], @@ -203,6 +239,7 @@ def test_rectangular(): assert padded_cost == cost assert cost == 70 + def test_rectangular_float(): matrix = [[34.01, 26.02, 17.03, 12.04], [43.05, 43.06, 36.07, 10.08], @@ -216,16 +253,18 @@ def test_rectangular_float(): assert padded_cost == pytest.approx(cost) assert cost == pytest.approx(70.42) + def test_unsolvable(): with pytest.raises(UnsolvableMatrix): matrix = [[5, 9, DISALLOWED], - [10, DISALLOWED, 2], - [DISALLOWED, DISALLOWED, DISALLOWED]] + [10, DISALLOWED, 2], + [DISALLOWED, DISALLOWED, DISALLOWED]] m.compute(matrix) + def test_unsolvable_float(): with pytest.raises(UnsolvableMatrix): matrix = [[5.1, 9.2, DISALLOWED], - [10.3, DISALLOWED, 2.4], - [DISALLOWED, DISALLOWED, DISALLOWED]] + [10.3, DISALLOWED, 2.4], + [DISALLOWED, DISALLOWED, DISALLOWED]] m.compute(matrix)