diff --git a/.bumpversion.cfg b/.bumpversion.cfg new file mode 100644 index 0000000..e456d68 --- /dev/null +++ b/.bumpversion.cfg @@ -0,0 +1,6 @@ +[bumpversion] +current_version = 1.2.0 +commit = True +tag = True + +[bumpversion:file:pyproject.toml] diff --git a/munkres.py b/munkres.py index 2f2edbc..4053ea9 100644 --- a/munkres.py +++ b/munkres.py @@ -133,6 +133,7 @@ def compute(self, cost_matrix: Matrix) -> Sequence[Tuple[int, int]]: A list of `(row, column)` tuples that describe the lowest cost path through the matrix """ + self.__check_unsolvability(cost_matrix) self.C = self.pad_matrix(cost_matrix) self.n = len(self.C) self.original_length = len(cost_matrix) @@ -181,6 +182,44 @@ def __make_matrix(self, n: int, val: AnyNum) -> Matrix: matrix += [[val for j in range(n)]] return matrix + def __transpose_matrix(self, matrix: Matrix): + return [list(row) for row in zip(*matrix)] + + def __check_unsolvability(self, matrix: Matrix): + """Checks additional conditions to see if an input Munkres cost matrix is unsolvable. + + This check identifies potential infinite loop edge cases and raises a ``munkres.UnsolvableMatrix`` error. + + + Args: + matrix (Matrix): a Matrix object, representing a cost matrix (or a profit matrix) suitable + for Munkres optimization. + + Raises: + munkres.UnsolvableMatrix: if the matrix is unsolvable. + """ + + def _check_one_dimension_solvability(matrix): + + non_disallowed_indices = [] # (in possibly-offending rows) + + for row in matrix: + + # check to see if all but 1 cell in the row are DISALLOWED + + indices = [i for i, val in enumerate(row) if not isinstance(val, type(DISALLOWED))] + + if len(indices) == 1: + + if indices[0] in non_disallowed_indices: + raise UnsolvableMatrix( + "This matrix cannot be solved and will loop infinitely" + ) + + non_disallowed_indices.append(indices[0]) + + _check_one_dimension_solvability(matrix) + def __step1(self) -> int: """ For each row of the matrix, find the smallest element and @@ -342,11 +381,17 @@ def __step6(self) -> int: def __find_smallest(self) -> AnyNum: """Find the smallest uncovered value in the matrix.""" minval = sys.maxsize + uncovered_vals = [] for i in range(self.n): for j in range(self.n): if (not self.row_covered[i]) and (not self.col_covered[j]): + uncovered_vals.append(self.C[i][j]) if self.C[i][j] is not DISALLOWED and minval > self.C[i][j]: minval = self.C[i][j] + if all(map(lambda val: type(val) is DISALLOWED_OBJ, uncovered_vals)): + raise UnsolvableMatrix( + "The only uncovered values are disallowed. This matrix will loop infinitely!" + ) return minval diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..72480f2 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,23 @@ +# pyproject.toml + +[build-system] +requires = ["setuptools>=61.2.0", "wheel"] +build-backend = "setuptools.build_meta" + +[project] +name = "jk-munkres" +version = "1.2.0" +description = "A fork of munkres" +readme = "README.md" +authors = [{ name = "Jonathan Kerr" }] +license = { file = "LICENSE" } +classifiers = [ + "License :: OSI Approved :: Apache Software License", + "Programming Language :: Python", + "Programming Language :: Python :: 3", +] +keywords = ["munkres", "assignment problem"] +requires-python = ">=3.9" + +[project.optional-dependencies] +dev = ["black", "bumpver", "isort", "pip-tools", "pytest"] diff --git a/setup.cfg b/setup.cfg deleted file mode 100644 index 094303d..0000000 --- a/setup.cfg +++ /dev/null @@ -1,5 +0,0 @@ -[sdist] -formats: gztar - -[bdist_wheel] -universal=1 diff --git a/setup.py b/setup.py deleted file mode 100644 index f3ead93..0000000 --- a/setup.py +++ /dev/null @@ -1,112 +0,0 @@ -#!/usr/bin/env python -# -# Distutils setup script for Munkres -# --------------------------------------------------------------------------- - -from setuptools import setup -import re -import os -import sys -from distutils.cmd import Command -from abc import abstractmethod - -if sys.version_info[0:2] < (3, 5): - columns = int(os.environ.get('COLUMNS', '80')) - 1 - msg = ('As of version 1.1.0, this munkres package no longer supports ' + - 'Python 2. Either upgrade to Python 3.5 or better, or use an ' + - 'older version of munkres (e.g., 1.0.12).') - sys.stderr.write(msg + '\n') - raise Exception(msg) - -# Load the module. - -here = os.path.dirname(os.path.abspath(sys.argv[0])) - -def import_from_file(file, name): - # See https://stackoverflow.com/a/19011259/53495 - import importlib.machinery - import importlib.util - loader = importlib.machinery.SourceFileLoader(name, file) - spec = importlib.util.spec_from_loader(loader.name, loader) - mod = importlib.util.module_from_spec(spec) - loader.exec_module(mod) - return mod - -mf = os.path.join(here, 'munkres.py') -munkres = import_from_file(mf, 'munkres') -long_description = munkres.__doc__ -version = str(munkres.__version__) -(author, email) = re.match('^(.*),\s*(.*)$', munkres.__author__).groups() -url = munkres.__url__ -license = munkres.__license__ - -API_DOCS_BUILD = 'apidocs' - -class CommandHelper(Command): - user_options = [] - - def __init__(self, dist): - Command.__init__(self, dist) - - def initialize_options(self): - pass - - def finalize_options(self): - pass - - @abstractmethod - def run(self): - pass - -class Doc(CommandHelper): - description = 'create the API docs' - - def run(self): - os.environ['PYTHONPATH'] = '.' - cmd = 'pdoc --html --html-dir {} --overwrite --html-no-source munkres'.format( - API_DOCS_BUILD - ) - print('+ {}'.format(cmd)) - rc = os.system(cmd) - if rc != 0: - raise Exception("Failed to run pdoc. rc={}".format(rc)) - -class Test(CommandHelper): - - def run(self): - import pytest - os.environ['PYTHONPATH'] = '.' - rc = pytest.main(['-W', 'ignore', '-ra', '--cache-clear', 'test', '.']) - if rc != 0: - raise Exception('*** Tests failed.') - -# Run setup - -setup( - name="munkres", - version=version, - description="Munkres (Hungarian) algorithm for the Assignment Problem", - long_description=long_description, - long_description_content_type='text/markdown', - url=url, - license=license, - author=author, - author_email=email, - py_modules=["munkres"], - cmdclass = { - 'doc': Doc, - 'docs': Doc, - 'apidoc': Doc, - 'apidocs': Doc, - 'test': Test - }, - classifiers = [ - 'Intended Audience :: Developers', - 'Intended Audience :: Science/Research', - 'License :: OSI Approved :: Apache Software License', - 'Operating System :: OS Independent', - 'Programming Language :: Python', - 'Topic :: Scientific/Engineering :: Mathematics', - 'Topic :: Software Development :: Libraries :: Python Modules' - ] -) diff --git a/test/test_munkres.py b/test/test_munkres.py index 23796dd..570277a 100644 --- a/test/test_munkres.py +++ b/test/test_munkres.py @@ -216,12 +216,27 @@ def test_rectangular_float(): assert padded_cost == pytest.approx(cost) assert cost == pytest.approx(70.42) -def test_unsolvable(): - with pytest.raises(UnsolvableMatrix): - matrix = [[5, 9, DISALLOWED], + +@pytest.mark.parametrize( + "unsolvable_matrix", + ( + [ + [5, 9, DISALLOWED], [10, DISALLOWED, 2], - [DISALLOWED, DISALLOWED, DISALLOWED]] - m.compute(matrix) + [DISALLOWED, DISALLOWED, DISALLOWED] + ], + + [ + [DISALLOWED, 161, DISALLOWED], + [DISALLOWED, 1, DISALLOWED], + [DISALLOWED, 157, DISALLOWED], + [37, DISALLOWED, 5] + ] + ) +) +def test_unsolvable(unsolvable_matrix): + with pytest.raises(UnsolvableMatrix): + m.compute(unsolvable_matrix) def test_unsolvable_float(): with pytest.raises(UnsolvableMatrix):