Skip to content

Commit

Permalink
Fix compatibility with mypyc
Browse files Browse the repository at this point in the history
  • Loading branch information
akaihola committed Apr 25, 2024
1 parent 4a555ae commit 1f981d9
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 24 deletions.
17 changes: 8 additions & 9 deletions pgtricks/pg_dump_splitsort.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,24 +16,23 @@
MEMORY_UNITS = {"": 1, "k": KIBIBYTE, "m": MEBIBYTE, "g": GIBIBYTE}


def try_float(s1: str, s2: str) -> tuple[str, str] | tuple[float, float]:
def try_float(s1: str, s2: str) -> tuple[float, float]:
"""Convert two strings to floats. Return original ones on conversion error."""
if not s1 or not s2 or s1[0] not in '0123456789.-' or s2[0] not in '0123456789.-':
# optimization
return s1, s2
try:
return float(s1), float(s2)
except ValueError:
return s1, s2
raise ValueError
return float(s1), float(s2)


def linecomp(l1: str, l2: str) -> int:
p1 = l1.split('\t', 1)
p2 = l2.split('\t', 1)
# TODO: unquote cast after support for Python 3.8 is dropped
v1, v2 = cast("tuple[float, float]", try_float(p1[0], p2[0]))
result = (v1 > v2) - (v1 < v2)
# modifying a line to see whether Darker works:
try:
v1, v2 = try_float(p1[0], p2[0])
result = (v1 > v2) - (v1 < v2)
except ValueError:
result = (p1[0] > p2[0]) - (p1[0] < p2[0])
if not result and len(p1) == len(p2) == 2:
return linecomp(p1[1], p2[1])
return result
Expand Down
34 changes: 19 additions & 15 deletions pgtricks/tests/test_pg_dump_splitsort.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from contextlib import nullcontext
from functools import cmp_to_key
from textwrap import dedent

Expand Down Expand Up @@ -36,29 +37,32 @@ def test_sql_copy_regular_expression(test_input, expected):
@pytest.mark.parametrize(
's1, s2, expect',
[
('', '', ('', '')),
('foo', '', ('foo', '')),
('foo', 'bar', ('foo', 'bar')),
('', '', ValueError),
('foo', '', ValueError),
('foo', 'bar', ValueError),
('0', '1', (0.0, 1.0)),
('0', 'one', ('0', 'one')),
('0', 'one', ValueError),
('0.0', '0.0', (0.0, 0.0)),
('0.0', 'one point zero', ('0.0', 'one point zero')),
('0.0', 'one point zero', ValueError),
('0.', '1.', (0.0, 1.0)),
('0.', 'one', ('0.', 'one')),
('0.', 'one', ValueError),
('4.2', '0.42', (4.2, 0.42)),
('4.2', 'four point two', ('4.2', 'four point two')),
('4.2', 'four point two', ValueError),
('-.42', '-0.042', (-0.42, -0.042)),
('-.42', 'minus something', ('-.42', 'minus something')),
(r'\N', r'\N', (r'\N', r'\N')),
('foo', r'\N', ('foo', r'\N')),
('-4.2', r'\N', ('-4.2', r'\N')),
('-.42', 'minus something', ValueError),
(r'\N', r'\N', ValueError),
('foo', r'\N', ValueError),
('-4.2', r'\N', ValueError),
],
)
def test_try_float(s1, s2, expect):
result1, result2 = try_float(s1, s2)
assert type(result1) is type(expect[0])
assert type(result2) is type(expect[1])
assert (result1, result2) == expect
with pytest.raises(expect) if expect is ValueError else nullcontext():

result1, result2 = try_float(s1, s2)

assert type(result1) is type(expect[0])
assert type(result2) is type(expect[1])
assert (result1, result2) == expect


@pytest.mark.parametrize(
Expand Down

0 comments on commit 1f981d9

Please sign in to comment.