Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add conversion to list before padding, extend length of result #38

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions munkres.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def pad_matrix(self, matrix: Matrix, pad_value: int=0) -> Matrix:
new_matrix = []
for row in matrix:
row_len = len(row)
new_row = row[:]
new_row = list(row[:])
if total_rows > row_len:
# Row too short. Pad it.
new_row += [pad_value] * (total_rows - row_len)
Expand Down Expand Up @@ -163,8 +163,8 @@ def compute(self, cost_matrix: Matrix) -> Sequence[Tuple[int, int]]:

# Look for the starred columns
results = []
for i in range(self.original_length):
for j in range(self.original_width):
for i in range(self.n):
for j in range(self.n):
if self.marked[i][j] == 1:
results += [(i, j)]

Expand Down
89 changes: 64 additions & 25 deletions test/test_munkres.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@

m = Munkres()


def _get_cost(matrix):
indices = m.compute(matrix)
return sum([matrix[row][column] for row, column in indices])
return sum([matrix[row][column] for row, column in indices if row < len(matrix) and column < len(matrix[0])])


def test_documented_example():
'''
Expand All @@ -18,6 +20,7 @@ def test_documented_example():
cost = _get_cost(matrix)
assert cost == 12


def float_example():
'''
Test a matrix with float values
Expand All @@ -28,6 +31,7 @@ def float_example():
cost = _get_cost(matrix)
assert cost == pytest.approx(13.5)


def test_5_x_5():
matrix = [[12, 9, 27, 10, 23],
[7, 13, 13, 30, 19],
Expand All @@ -37,6 +41,7 @@ def test_5_x_5():
cost = _get_cost(matrix)
assert cost == 51


def test_5_x_5_float():
matrix = [[12.01, 9.02, 27.03, 10.04, 23.05],
[7.06, 13.07, 13.08, 30.09, 19.1],
Expand All @@ -61,6 +66,7 @@ def test_10_x_10():
cost = _get_cost(matrix)
assert cost == 66


def test_10_x_10_float():
matrix = [[37.001, 34.002, 29.003, 26.004, 19.005, 8.006, 9.007, 23.008, 19.009, 29.01],
[9.011, 28.012, 20.013, 8.014, 18.015, 20.016, 14.017, 33.018, 23.019, 14.02],
Expand All @@ -75,6 +81,7 @@ def test_10_x_10_float():
cost = _get_cost(matrix)
assert cost == pytest.approx(66.505)


def test_20_x_20():
matrix = [[5, 4, 3, 9, 8, 9, 3, 5, 6, 9, 4, 10, 3, 5, 6, 6, 1, 8, 10, 2],
[10, 9, 9, 2, 8, 3, 9, 9, 10, 1, 7, 10, 8, 4, 2, 1, 4, 8, 4, 8],
Expand All @@ -99,27 +106,49 @@ def test_20_x_20():
cost = _get_cost(matrix)
assert cost == 22


def test_20_x_20_float():
matrix = [[5.0001, 4.0002, 3.0003, 9.0004, 8.0005, 9.0006, 3.0007, 5.0008, 6.0009, 9.001, 4.0011, 10.0012, 3.0013, 5.0014, 6.0015, 6.0016, 1.0017, 8.0018, 10.0019, 2.002],
[10.0021, 9.0022, 9.0023, 2.0024, 8.0025, 3.0026, 9.0027, 9.0028, 10.0029, 1.003, 7.0031, 10.0032, 8.0033, 4.0034, 2.0035, 1.0036, 4.0037, 8.0038, 4.0039, 8.004],
[10.0041, 4.0042, 4.0043, 3.0044, 1.0045, 3.0046, 5.0047, 10.0048, 6.0049, 8.005, 6.0051, 8.0052, 4.0053, 10.0054, 7.0055, 2.0056, 4.0057, 5.0058, 1.0059, 8.006],
[2.0061, 1.0062, 4.0063, 2.0064, 3.0065, 9.0066, 3.0067, 4.0068, 7.0069, 3.007, 4.0071, 1.0072, 3.0073, 2.0074, 9.0075, 8.0076, 6.0077, 5.0078, 7.0079, 8.008],
[3.0081, 4.0082, 4.0083, 1.0084, 4.0085, 10.0086, 1.0087, 2.0088, 6.0089, 4.009, 5.0091, 10.0092, 2.0093, 2.0094, 3.0095, 9.0096, 10.0097, 9.0098, 9.0099, 10.01],
[1.0101, 10.0102, 1.0103, 8.0104, 1.0105, 3.0106, 1.0107, 7.0108, 1.0109, 1.011, 2.0111, 1.0112, 2.0113, 6.0114, 3.0115, 3.0116, 4.0117, 4.0118, 8.0119, 6.012],
[1.0121, 8.0122, 7.0123, 10.0124, 10.0125, 3.0126, 4.0127, 6.0128, 1.0129, 6.013, 6.0131, 4.0132, 9.0133, 6.0134, 9.0135, 6.0136, 4.0137, 5.0138, 4.0139, 7.014],
[8.0141, 10.0142, 3.0143, 9.0144, 4.0145, 9.0146, 3.0147, 3.0148, 4.0149, 6.015, 4.0151, 2.0152, 6.0153, 7.0154, 7.0155, 4.0156, 4.0157, 3.0158, 4.0159, 7.016],
[1.0161, 3.0162, 8.0163, 2.0164, 6.0165, 9.0166, 2.0167, 7.0168, 4.0169, 8.017, 10.0171, 8.0172, 10.0173, 5.0174, 1.0175, 3.0176, 10.0177, 10.0178, 2.0179, 9.018],
[2.0181, 4.0182, 1.0183, 9.0184, 2.0185, 9.0186, 7.0187, 8.0188, 2.0189, 1.019, 4.0191, 10.0192, 5.0193, 2.0194, 7.0195, 6.0196, 5.0197, 7.0198, 2.0199, 6.02],
[4.0201, 5.0202, 1.0203, 4.0204, 2.0205, 3.0206, 3.0207, 4.0208, 1.0209, 8.021, 8.0211, 2.0212, 6.0213, 9.0214, 5.0215, 9.0216, 6.0217, 3.0218, 9.0219, 3.022],
[3.0221, 1.0222, 1.0223, 8.0224, 6.0225, 8.0226, 8.0227, 7.0228, 9.0229, 3.023, 2.0231, 1.0232, 8.0233, 2.0234, 4.0235, 7.0236, 3.0237, 1.0238, 2.0239, 4.024],
[5.0241, 9.0242, 8.0243, 6.0244, 10.0245, 4.0246, 10.0247, 3.0248, 4.0249, 10.025, 10.0251, 10.0252, 1.0253, 7.0254, 8.0255, 8.0256, 7.0257, 7.0258, 8.0259, 8.026],
[1.0261, 4.0262, 6.0263, 1.0264, 6.0265, 1.0266, 2.0267, 10.0268, 5.0269, 10.027, 2.0271, 6.0272, 2.0273, 4.0274, 5.0275, 5.0276, 3.0277, 5.0278, 1.0279, 5.028],
[5.0281, 6.0282, 9.0283, 10.0284, 6.0285, 6.0286, 10.0287, 6.0288, 4.0289, 1.029, 5.0291, 3.0292, 9.0293, 5.0294, 2.0295, 10.0296, 9.0297, 9.0298, 5.0299, 1.03],
[10.0301, 9.0302, 4.0303, 6.0304, 9.0305, 5.0306, 3.0307, 7.0308, 10.0309, 1.031, 6.0311, 8, 1.0312, 1.0313, 10.0314, 9.0315, 5.0316, 7.0317, 7.0318, 5.0319, 1.032],
[2.0321, 6.0322, 6.0323, 6.0324, 6.0325, 2.0326, 9.0327, 4.0328, 7.0329, 5.033, 3.0331, 2.0332, 10.0333, 3.0334, 4.0335, 5.0336, 10.0337, 9.0338, 1.0339, 7.034],
[5.0341, 2.0342, 4.0343, 9.0344, 8.0345, 4.0346, 8.0347, 2.0348, 4.0349, 1.035, 3.0351, 7.0352, 6.0353, 8.0354, 1.0355, 6.0356, 8.0357, 8.0358, 10.0359, 10.036],
[9.0361, 6.0362, 3.0363, 1.0364, 8.0365, 5.0366, 7.0367, 8.0368, 7.0369, 2.037, 1.0371, 8.0372, 2.0373, 8.0374, 3.0375, 7.0376, 4.0377, 8.0378, 7.0379, 7.038],
[8.0381, 4.0382, 4.0383, 9.0384, 7.0385, 10.0386, 6.0387, 2.0388, 1.0389, 5.039, 8.0391, 5.0392, 1.0393, 1.0394, 1.0395, 9.0396, 1.0397, 3.0398, 5.0399, 3.04]]
matrix = [
[5.0001, 4.0002, 3.0003, 9.0004, 8.0005, 9.0006, 3.0007, 5.0008, 6.0009, 9.001, 4.0011, 10.0012, 3.0013, 5.0014,
6.0015, 6.0016, 1.0017, 8.0018, 10.0019, 2.002],
[10.0021, 9.0022, 9.0023, 2.0024, 8.0025, 3.0026, 9.0027, 9.0028, 10.0029, 1.003, 7.0031, 10.0032, 8.0033,
4.0034, 2.0035, 1.0036, 4.0037, 8.0038, 4.0039, 8.004],
[10.0041, 4.0042, 4.0043, 3.0044, 1.0045, 3.0046, 5.0047, 10.0048, 6.0049, 8.005, 6.0051, 8.0052, 4.0053,
10.0054, 7.0055, 2.0056, 4.0057, 5.0058, 1.0059, 8.006],
[2.0061, 1.0062, 4.0063, 2.0064, 3.0065, 9.0066, 3.0067, 4.0068, 7.0069, 3.007, 4.0071, 1.0072, 3.0073, 2.0074,
9.0075, 8.0076, 6.0077, 5.0078, 7.0079, 8.008],
[3.0081, 4.0082, 4.0083, 1.0084, 4.0085, 10.0086, 1.0087, 2.0088, 6.0089, 4.009, 5.0091, 10.0092, 2.0093,
2.0094, 3.0095, 9.0096, 10.0097, 9.0098, 9.0099, 10.01],
[1.0101, 10.0102, 1.0103, 8.0104, 1.0105, 3.0106, 1.0107, 7.0108, 1.0109, 1.011, 2.0111, 1.0112, 2.0113, 6.0114,
3.0115, 3.0116, 4.0117, 4.0118, 8.0119, 6.012],
[1.0121, 8.0122, 7.0123, 10.0124, 10.0125, 3.0126, 4.0127, 6.0128, 1.0129, 6.013, 6.0131, 4.0132, 9.0133,
6.0134, 9.0135, 6.0136, 4.0137, 5.0138, 4.0139, 7.014],
[8.0141, 10.0142, 3.0143, 9.0144, 4.0145, 9.0146, 3.0147, 3.0148, 4.0149, 6.015, 4.0151, 2.0152, 6.0153, 7.0154,
7.0155, 4.0156, 4.0157, 3.0158, 4.0159, 7.016],
[1.0161, 3.0162, 8.0163, 2.0164, 6.0165, 9.0166, 2.0167, 7.0168, 4.0169, 8.017, 10.0171, 8.0172, 10.0173,
5.0174, 1.0175, 3.0176, 10.0177, 10.0178, 2.0179, 9.018],
[2.0181, 4.0182, 1.0183, 9.0184, 2.0185, 9.0186, 7.0187, 8.0188, 2.0189, 1.019, 4.0191, 10.0192, 5.0193, 2.0194,
7.0195, 6.0196, 5.0197, 7.0198, 2.0199, 6.02],
[4.0201, 5.0202, 1.0203, 4.0204, 2.0205, 3.0206, 3.0207, 4.0208, 1.0209, 8.021, 8.0211, 2.0212, 6.0213, 9.0214,
5.0215, 9.0216, 6.0217, 3.0218, 9.0219, 3.022],
[3.0221, 1.0222, 1.0223, 8.0224, 6.0225, 8.0226, 8.0227, 7.0228, 9.0229, 3.023, 2.0231, 1.0232, 8.0233, 2.0234,
4.0235, 7.0236, 3.0237, 1.0238, 2.0239, 4.024],
[5.0241, 9.0242, 8.0243, 6.0244, 10.0245, 4.0246, 10.0247, 3.0248, 4.0249, 10.025, 10.0251, 10.0252, 1.0253,
7.0254, 8.0255, 8.0256, 7.0257, 7.0258, 8.0259, 8.026],
[1.0261, 4.0262, 6.0263, 1.0264, 6.0265, 1.0266, 2.0267, 10.0268, 5.0269, 10.027, 2.0271, 6.0272, 2.0273,
4.0274, 5.0275, 5.0276, 3.0277, 5.0278, 1.0279, 5.028],
[5.0281, 6.0282, 9.0283, 10.0284, 6.0285, 6.0286, 10.0287, 6.0288, 4.0289, 1.029, 5.0291, 3.0292, 9.0293,
5.0294, 2.0295, 10.0296, 9.0297, 9.0298, 5.0299, 1.03],
[10.0301, 9.0302, 4.0303, 6.0304, 9.0305, 5.0306, 3.0307, 7.0308, 10.0309, 1.031, 6.0311, 8, 1.0312, 1.0313,
10.0314, 9.0315, 5.0316, 7.0317, 7.0318, 5.0319, 1.032],
[2.0321, 6.0322, 6.0323, 6.0324, 6.0325, 2.0326, 9.0327, 4.0328, 7.0329, 5.033, 3.0331, 2.0332, 10.0333, 3.0334,
4.0335, 5.0336, 10.0337, 9.0338, 1.0339, 7.034],
[5.0341, 2.0342, 4.0343, 9.0344, 8.0345, 4.0346, 8.0347, 2.0348, 4.0349, 1.035, 3.0351, 7.0352, 6.0353, 8.0354,
1.0355, 6.0356, 8.0357, 8.0358, 10.0359, 10.036],
[9.0361, 6.0362, 3.0363, 1.0364, 8.0365, 5.0366, 7.0367, 8.0368, 7.0369, 2.037, 1.0371, 8.0372, 2.0373, 8.0374,
3.0375, 7.0376, 4.0377, 8.0378, 7.0379, 7.038],
[8.0381, 4.0382, 4.0383, 9.0384, 7.0385, 10.0386, 6.0387, 2.0388, 1.0389, 5.039, 8.0391, 5.0392, 1.0393, 1.0394,
1.0395, 9.0396, 1.0397, 3.0398, 5.0399, 3.04]]
cost = _get_cost(matrix)
'''
Here, it becomes mandatory to set "places" argument, otherwise test might
Expand All @@ -128,20 +157,23 @@ def test_20_x_20_float():
'''
assert cost == pytest.approx(20.362, rel=1e-3)


def test_disallowed():
matrix = [[5, 9, DISALLOWED],
[10, DISALLOWED, 2],
[8, DISALLOWED, 4]]
cost = _get_cost(matrix)
assert cost == 19


def test_disallowed_float():
matrix = [[5.1, 9.2, DISALLOWED],
[10.3, DISALLOWED, 2.4],
[8.5, DISALLOWED, 4.6]]
cost = _get_cost(matrix)
assert cost == pytest.approx(20.1)


def test_profit():
profit_matrix = [[94, 66, 100, 18, 48],
[51, 63, 97, 79, 11],
Expand All @@ -156,6 +188,7 @@ def test_profit():
profit = sum([profit_matrix[row][column] for row, column in indices])
assert profit == 392


def test_profit_float():
profit_matrix = [[94.01, 66.02, 100.03, 18.04, 48.05],
[51.06, 63.07, 97.08, 79.09, 11.1],
Expand All @@ -170,6 +203,7 @@ def test_profit_float():
profit = sum([profit_matrix[row][column] for row, column in indices])
assert profit == pytest.approx(362.65)


def test_irregular():
matrix = [[12, 26, 17],
[49, 43, 36, 10, 5],
Expand All @@ -180,6 +214,7 @@ def test_irregular():
cost = _get_cost(matrix)
assert cost == 43


def test_irregular_float():
matrix = [[12.01, 26.02, 17.03],
[49.04, 43.05, 36.06, 10.07, 5.08],
Expand All @@ -190,6 +225,7 @@ def test_irregular_float():
cost = _get_cost(matrix)
assert cost == pytest.approx(43.42)


def test_rectangular():
matrix = [[34, 26, 17, 12],
[43, 43, 36, 10],
Expand All @@ -203,6 +239,7 @@ def test_rectangular():
assert padded_cost == cost
assert cost == 70


def test_rectangular_float():
matrix = [[34.01, 26.02, 17.03, 12.04],
[43.05, 43.06, 36.07, 10.08],
Expand All @@ -216,16 +253,18 @@ def test_rectangular_float():
assert padded_cost == pytest.approx(cost)
assert cost == pytest.approx(70.42)


def test_unsolvable():
with pytest.raises(UnsolvableMatrix):
matrix = [[5, 9, DISALLOWED],
[10, DISALLOWED, 2],
[DISALLOWED, DISALLOWED, DISALLOWED]]
[10, DISALLOWED, 2],
[DISALLOWED, DISALLOWED, DISALLOWED]]
m.compute(matrix)


def test_unsolvable_float():
with pytest.raises(UnsolvableMatrix):
matrix = [[5.1, 9.2, DISALLOWED],
[10.3, DISALLOWED, 2.4],
[DISALLOWED, DISALLOWED, DISALLOWED]]
[10.3, DISALLOWED, 2.4],
[DISALLOWED, DISALLOWED, DISALLOWED]]
m.compute(matrix)