From 8e9d9fd3dafb711e2d6ec1a381a52e5c3cb1e0c6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Gr=C3=A9gory=20Mazerand?= Date: Thu, 14 Mar 2024 18:01:46 +0100 Subject: [PATCH 1/3] Add expressions for Interval --- sqlvalidator/grammar/sql.py | 63 +++++++++++++++++++++++++++++++++++-- 1 file changed, 61 insertions(+), 2 deletions(-) diff --git a/sqlvalidator/grammar/sql.py b/sqlvalidator/grammar/sql.py index d2c3701..9b71bde 100644 --- a/sqlvalidator/grammar/sql.py +++ b/sqlvalidator/grammar/sql.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import Any, List, Optional, Set +from typing import Any, List, Optional, Set, Union from sqlvalidator.grammar.tokeniser import lower @@ -1163,7 +1163,7 @@ def known_fields(self) -> Set[_FieldInfo]: class Unnest(Expression): def __init__(self, unnest_expression, with_offset, with_offset_as, offset_alias): - # unnest_expression: can be functiion call or alias of function call + # unnest_expression: can be function call or alias of function call super().__init__(unnest_expression) self.with_offset = with_offset self.with_offset_as = with_offset_as @@ -1598,3 +1598,62 @@ def __str__(self): case_str += "\n ELSE {}".format(transform(self.else_expression)) case_str += "\nEND" return case_str + + +class DateTimePart(Expression): + PARTS = ( + "year", + "quarter", + "month", + "week", + "day", + "hour", + "minute", + "second", + "millisecond", + "microsecond", + ) + + def __init__(self, expression: str, ending_datetime_part: Union[str, None] = None): + super().__init__(expression.upper()) + self.ending_datetime_part = ( + ending_datetime_part.upper() if ending_datetime_part is not None else None + ) + + def __repr__(self): + return "".format( + self.value, self.ending_datetime_part + ) + + def __eq__(self, other): + return ( + type(self) == type(other) + and self.value == other.value + and self.ending_datetime_part == other.ending_datetime_part + ) + + def __str__(self): + return ( + "{}".format(self.value) + if self.ending_datetime_part is None + else "{} TO {}".format(self.value, self.ending_datetime_part) + ) + + +class Interval(Expression): + def __init__(self, interval, datetime_part): + self.interval = interval + self.datetime_part = datetime_part + + def __repr__(self): + return "".format(self.interval, repr(self.datetime_part)) + + def __eq__(self, other): + return ( + type(self) == type(other) + and self.interval == other.interval + and self.datetime_part == other.datetime_part + ) + + def __str__(self): + return "INTERVAL {} {}".format(self.interval, self.datetime_part) From 04db338dd6078d9ba6a065504bb30e349da8379f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Gr=C3=A9gory=20Mazerand?= Date: Thu, 14 Mar 2024 18:02:19 +0100 Subject: [PATCH 2/3] Handle interval expression --- sqlvalidator/grammar/lexer.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/sqlvalidator/grammar/lexer.py b/sqlvalidator/grammar/lexer.py index 039d70c..e6061fd 100644 --- a/sqlvalidator/grammar/lexer.py +++ b/sqlvalidator/grammar/lexer.py @@ -17,6 +17,7 @@ Condition, CountFunctionCall, DatePartExtraction, + DateTimePart, ExceptClause, Expression, FilteredFunctionCall, @@ -26,6 +27,7 @@ HavingClause, Index, Integer, + Interval, Join, LimitClause, Negation, @@ -595,6 +597,12 @@ def parse( argument_tokens, next_token = get_tokens_until_one_of(tokens, []) next_token = next(tokens, None) expression = SelectStatementParser.parse(iter(argument_tokens)) + elif lower(main_token) == "interval": + interval, next_token = ExpressionParser.parse(tokens, can_alias=False) + datetime_part, next_token = ExpressionParser.parse( + tokens, first_token=next_token + ) + expression = Interval(interval, datetime_part) else: expression = None @@ -750,6 +758,17 @@ def parse( expression = StringParser.parse( tokens, start_quote=next_token, prefix=main_token ) + elif lower(main_token) in DateTimePart.PARTS: + cast_to = None + if next_token is not None and lower(next_token) == "to": + next_token = next(tokens, None) + if ( + next_token is not None + and lower(next_token) in DateTimePart.PARTS + ): + cast_to = next_token + next_token = next(tokens, None) + expression = DateTimePart(main_token, cast_to) else: expression = Column(main_token) From a330b3b1b09b98483b620a73ce26caa26581bc71 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Gr=C3=A9gory=20Mazerand?= Date: Thu, 14 Mar 2024 18:02:32 +0100 Subject: [PATCH 3/3] Add some tests on the intervals --- tests/integration/test_formatting.py | 24 ++++++++++++++++++++++++ tests/unit/test_lexer.py | 26 ++++++++++++++++++++++++++ 2 files changed, 50 insertions(+) diff --git a/tests/integration/test_formatting.py b/tests/integration/test_formatting.py index 80e6199..2e6998d 100644 --- a/tests/integration/test_formatting.py +++ b/tests/integration/test_formatting.py @@ -2205,3 +2205,27 @@ def test_nesting_case_expr(): FROM table """ assert format_sql(sql) == expected.strip() + + +def test_interval_with_single_datetime_part(): + sql = """select value + from table + where date(date) >= date_sub(current_date(), interval -5 hour)""" + expected = """ +SELECT value +FROM table +WHERE DATE(date) >= DATE_SUB(CURRENT_DATE(), INTERVAL -5 HOUR) +""" + assert format_sql(sql) == expected.strip() + + +def test_interval_with_datetime_part_range(): + sql = """select value + from table + where date(date) >= date_sub(current_date(), interval '8 20 17' month to hour)""" + expected = """ +SELECT value +FROM table +WHERE DATE(date) >= DATE_SUB(CURRENT_DATE(), INTERVAL '8 20 17' MONTH TO HOUR) +""" + assert format_sql(sql) == expected.strip() diff --git a/tests/unit/test_lexer.py b/tests/unit/test_lexer.py index b3732fc..0d4c63d 100644 --- a/tests/unit/test_lexer.py +++ b/tests/unit/test_lexer.py @@ -16,11 +16,13 @@ Column, Condition, CountFunctionCall, + DateTimePart, ExceptClause, FunctionCall, GroupByClause, Index, Integer, + Interval, Join, LimitClause, Null, @@ -966,3 +968,27 @@ def test_function_with_single_comma_string_param(): actual, _ = ExpressionParser.parse(to_tokens("test(',')")) expected = FunctionCall("test", String(",", quotes="'")) assert actual == expected + + +def test_interval_with_single_datetime_part(): + actual, _ = ExpressionParser.parse(to_tokens("INTERVAL -1 MONTH")) + expected = Interval(Integer(-1), DateTimePart("MONTH")) + assert actual == expected + + +def test_interval_with_datetime_part_range(): + actual, _ = ExpressionParser.parse(to_tokens("INTERVAL '8 -20 17' MONTH TO HOUR")) + expected = Interval(String("8 -20 17", quotes="'"), DateTimePart("MONTH", "HOUR")) + assert actual == expected + + +def test_interval_in_function(): + actual, _ = ExpressionParser.parse( + to_tokens("TIMESTAMP_ADD(CURRENT_TIMESTAMP(), INTERVAL 30 SECOND)") + ) + expected = FunctionCall( + "TIMESTAMP_ADD", + FunctionCall("CURRENT_TIMESTAMP"), + Interval(Integer(30), DateTimePart("SECOND")), + ) + assert actual == expected