diff --git a/pgtricks/pg_dump_splitsort.py b/pgtricks/pg_dump_splitsort.py index 0b9eee7..38ec5fc 100755 --- a/pgtricks/pg_dump_splitsort.py +++ b/pgtricks/pg_dump_splitsort.py @@ -16,24 +16,23 @@ MEMORY_UNITS = {"": 1, "k": KIBIBYTE, "m": MEBIBYTE, "g": GIBIBYTE} -def try_float(s1: str, s2: str) -> tuple[str, str] | tuple[float, float]: +def try_float(s1: str, s2: str) -> tuple[float, float]: """Convert two strings to floats. Return original ones on conversion error.""" if not s1 or not s2 or s1[0] not in '0123456789.-' or s2[0] not in '0123456789.-': # optimization - return s1, s2 - try: - return float(s1), float(s2) - except ValueError: - return s1, s2 + raise ValueError + return float(s1), float(s2) def linecomp(l1: str, l2: str) -> int: p1 = l1.split('\t', 1) p2 = l2.split('\t', 1) # TODO: unquote cast after support for Python 3.8 is dropped - v1, v2 = cast("tuple[float, float]", try_float(p1[0], p2[0])) - result = (v1 > v2) - (v1 < v2) - # modifying a line to see whether Darker works: + try: + v1, v2 = try_float(p1[0], p2[0]) + result = (v1 > v2) - (v1 < v2) + except ValueError: + result = (p1[0] > p2[0]) - (p1[0] < p2[0]) if not result and len(p1) == len(p2) == 2: return linecomp(p1[1], p2[1]) return result diff --git a/pgtricks/tests/test_pg_dump_splitsort.py b/pgtricks/tests/test_pg_dump_splitsort.py index e38e8d3..e885984 100644 --- a/pgtricks/tests/test_pg_dump_splitsort.py +++ b/pgtricks/tests/test_pg_dump_splitsort.py @@ -1,3 +1,4 @@ +from contextlib import nullcontext from functools import cmp_to_key from textwrap import dedent @@ -36,29 +37,32 @@ def test_sql_copy_regular_expression(test_input, expected): @pytest.mark.parametrize( 's1, s2, expect', [ - ('', '', ('', '')), - ('foo', '', ('foo', '')), - ('foo', 'bar', ('foo', 'bar')), + ('', '', ValueError), + ('foo', '', ValueError), + ('foo', 'bar', ValueError), ('0', '1', (0.0, 1.0)), - ('0', 'one', ('0', 'one')), + ('0', 'one', ValueError), ('0.0', '0.0', (0.0, 0.0)), - ('0.0', 'one point zero', ('0.0', 'one point zero')), + ('0.0', 'one point zero', ValueError), ('0.', '1.', (0.0, 1.0)), - ('0.', 'one', ('0.', 'one')), + ('0.', 'one', ValueError), ('4.2', '0.42', (4.2, 0.42)), - ('4.2', 'four point two', ('4.2', 'four point two')), + ('4.2', 'four point two', ValueError), ('-.42', '-0.042', (-0.42, -0.042)), - ('-.42', 'minus something', ('-.42', 'minus something')), - (r'\N', r'\N', (r'\N', r'\N')), - ('foo', r'\N', ('foo', r'\N')), - ('-4.2', r'\N', ('-4.2', r'\N')), + ('-.42', 'minus something', ValueError), + (r'\N', r'\N', ValueError), + ('foo', r'\N', ValueError), + ('-4.2', r'\N', ValueError), ], ) def test_try_float(s1, s2, expect): - result1, result2 = try_float(s1, s2) - assert type(result1) is type(expect[0]) - assert type(result2) is type(expect[1]) - assert (result1, result2) == expect + with pytest.raises(expect) if expect is ValueError else nullcontext(): + + result1, result2 = try_float(s1, s2) + + assert type(result1) is type(expect[0]) + assert type(result2) is type(expect[1]) + assert (result1, result2) == expect @pytest.mark.parametrize(