diff --git a/edgedb/datatypes/relative_duration.pyx b/edgedb/datatypes/relative_duration.pyx index cf2c0ab9..27c0c648 100644 --- a/edgedb/datatypes/relative_duration.pyx +++ b/edgedb/datatypes/relative_duration.pyx @@ -108,7 +108,12 @@ cdef class RelativeDuration: buf.append(f'{min}M') if sec or fsec: - sign = '-' if min < 0 or fsec < 0 else '' + # If the original microseconds are negative we expect '-' in front + # of all non-zero hour/min/second components. The hour/min sign + # can be taken as is, but seconds are constructed out of sec and + # fsec parts, both of which have their own sign and thus we cannot + # just use their string representations directly. + sign = '-' if self.microseconds < 0 else '' buf.append(f'{sign}{abs(sec)}') if fsec: diff --git a/tests/test_datetime.py b/tests/test_datetime.py index 08199077..ff7dfbc0 100644 --- a/tests/test_datetime.py +++ b/tests/test_datetime.py @@ -25,6 +25,11 @@ from edgedb.datatypes.datatypes import RelativeDuration, DateDuration +USECS_PER_HOUR = 3600000000 +USECS_PER_MINUTE = 60000000 +USECS_PER_SEC = 1000000 + + class TestDatetimeTypes(tb.SyncQueryTestCase): async def test_duration_01(self): @@ -60,6 +65,57 @@ async def test_duration_01(self): ''', durs) self.assertEqual(list(durs_from_db), durs) + async def test_duration_02(self): + # Make sure that when we break down the microseconds into the bigger + # components we still get consistent values. + tdn1h = timedelta(microseconds=-USECS_PER_HOUR) + tdn1m = timedelta(microseconds=-USECS_PER_MINUTE) + tdn1s = timedelta(microseconds=-USECS_PER_SEC) + tdn1us = timedelta(microseconds=-1) + durs = [ + ( + tdn1h, tdn1m, + timedelta(microseconds=-USECS_PER_HOUR - USECS_PER_MINUTE), + ), + ( + tdn1h, tdn1s, + timedelta(microseconds=-USECS_PER_HOUR - USECS_PER_SEC), + ), + ( + tdn1m, tdn1s, + timedelta(microseconds=-USECS_PER_MINUTE - USECS_PER_SEC), + ), + ( + tdn1h, tdn1us, + timedelta(microseconds=-USECS_PER_HOUR - 1), + ), + ( + tdn1m, tdn1us, + timedelta(microseconds=-USECS_PER_MINUTE - 1), + ), + ( + tdn1s, tdn1us, + timedelta(microseconds=-USECS_PER_SEC - 1), + ), + ] + + # Test encode + durs_enc = self.client.query(''' + WITH args := array_unpack( + >>$0) + SELECT args.0 + args.1 = args.2; + ''', durs) + + # Test decode + durs_dec = self.client.query(''' + WITH args := array_unpack( + >>$0) + SELECT (args.0 + args.1, args.2); + ''', durs) + + self.assertEqual(durs_enc, [True] * len(durs)) + self.assertEqual(list(durs_dec), [(d[2], d[2]) for d in durs]) + async def test_relative_duration_01(self): try: self.client.query("SELECT '1y'") @@ -124,6 +180,41 @@ async def test_relative_duration_02(self): self.assertEqual(repr(d1), '') + async def test_relative_duration_03(self): + # Make sure that when we break down the microseconds into the bigger + # components we still get the sign correctly in string + # representation. + durs = [ + RelativeDuration(microseconds=-USECS_PER_HOUR), + RelativeDuration(microseconds=-USECS_PER_MINUTE), + RelativeDuration(microseconds=-USECS_PER_SEC), + RelativeDuration(microseconds=-USECS_PER_HOUR - USECS_PER_MINUTE), + RelativeDuration(microseconds=-USECS_PER_HOUR - USECS_PER_SEC), + RelativeDuration(microseconds=-USECS_PER_MINUTE - USECS_PER_SEC), + RelativeDuration(microseconds=-USECS_PER_HOUR - USECS_PER_MINUTE - + USECS_PER_SEC), + RelativeDuration(microseconds=-USECS_PER_HOUR - 1), + RelativeDuration(microseconds=-USECS_PER_MINUTE - 1), + RelativeDuration(microseconds=-USECS_PER_SEC - 1), + RelativeDuration(microseconds=-1), + ] + + # Test that RelativeDuration.__str__ formats the + # same as + durs_as_text = self.client.query(''' + WITH args := array_unpack(>$0) + SELECT args; + ''', durs) + + # Test encode/decode roundtrip + durs_from_db = self.client.query(''' + WITH args := array_unpack(>$0) + SELECT args; + ''', durs) + + self.assertEqual(durs_as_text, [str(d) for d in durs]) + self.assertEqual(list(durs_from_db), durs) + async def test_date_duration_01(self): try: self.client.query("SELECT '1y'") @@ -168,3 +259,32 @@ async def test_date_duration_01(self): self.assertEqual(db_dur, str(client_dur)) self.assertEqual(list(durs_from_db), durs) + + async def test_date_duration_02(self): + # Make sure that when we break down the microseconds into the bigger + # components we still get the sign correctly in string + # representation. + durs = [ + DateDuration(months=11), + DateDuration(months=12), + DateDuration(months=13), + DateDuration(months=-11), + DateDuration(months=-12), + DateDuration(months=-13), + ] + + # Test that DateDuration.__str__ formats the + # same as + durs_as_text = self.client.query(''' + WITH args := array_unpack(>$0) + SELECT args; + ''', durs) + + # Test encode/decode roundtrip + durs_from_db = self.client.query(''' + WITH args := array_unpack(>$0) + SELECT args; + ''', durs) + + self.assertEqual(durs_as_text, [str(d) for d in durs]) + self.assertEqual(list(durs_from_db), durs)