Skip to content

Commit

Permalink
Replace kwargs with SqlContext in get_sql (#23)
Browse files Browse the repository at this point in the history
* Use SqlContext instead of kwargs

* Move SqlContext to Query

* Remove unused type:ignore

* Bump version to 0.5.0

* Add a comment
  • Loading branch information
henadzit authored Jan 10, 2025
1 parent 4a263cb commit 1ab383c
Show file tree
Hide file tree
Showing 27 changed files with 751 additions and 835 deletions.
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
# ChangeLog

## 0.5

### 0.5.0

- Replace `get_sql` kwargs with `SqlContext` to improve performance

## 0.4

### 0.4.0
Expand Down
5 changes: 3 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,12 @@ The original repository includes many databases that Tortoise ORM doesn’t requ

## What changed?

Deleted unnecessary code that Tortoise ORM doesn’t require, and added features tailored specifically for Tortoise ORM.
Deleted unnecessary code that Tortoise ORM doesn’t require, added features tailored specifically for Tortoise ORM,
and modified to improve query generation performance.

## ThanksTo

- [pypika](https://github.com/kayak/pypika), a Python SQL query builder that exposes the full expressiveness of SQL,
- [pypika](https://github.com/kayak/pypika), a Python SQL query builder that exposes the full expressiveness of SQL,
using a syntax that mirrors the resulting query structure.

## License
Expand Down
1 change: 1 addition & 0 deletions pypika_tortoise/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .context import SqlContext
from .dialects import MSSQLQuery, MySQLQuery, OracleQuery, PostgreSQLQuery, SQLLiteQuery
from .enums import DatePart, Dialects, JoinType, Order
from .exceptions import (
Expand Down
50 changes: 50 additions & 0 deletions pypika_tortoise/context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
from __future__ import annotations

from dataclasses import dataclass


@dataclass(frozen=True)
class SqlContext:
"""Represents the context for get_sql() methods to determine how to render SQL."""

quote_char: str
secondary_quote_char: str
alias_quote_char: str
dialect: "Dialects"
as_keyword: bool = False
subquery: bool = False
with_alias: bool = False
with_namespace: bool = False
subcriterion: bool = False
parameterizer: "Parameterizer" | None = None
groupby_alias: bool = True
orderby_alias: bool = True

def copy(self, **kwargs) -> SqlContext:
return SqlContext(
quote_char=kwargs.get("quote_char", self.quote_char),
secondary_quote_char=kwargs.get("secondary_quote_char", self.secondary_quote_char),
alias_quote_char=kwargs.get("alias_quote_char", self.alias_quote_char),
dialect=kwargs.get("dialect", self.dialect),
as_keyword=kwargs.get("as_keyword", self.as_keyword),
subquery=kwargs.get("subquery", self.subquery),
with_alias=kwargs.get("with_alias", self.with_alias),
with_namespace=kwargs.get("with_namespace", self.with_namespace),
subcriterion=kwargs.get("subcriterion", self.subcriterion),
parameterizer=kwargs.get("parameterizer", self.parameterizer),
groupby_alias=kwargs.get("groupby_alias", self.groupby_alias),
orderby_alias=kwargs.get("orderby_alias", self.orderby_alias),
)


from .enums import Dialects # noqa: E402

DEFAULT_SQL_CONTEXT = SqlContext(
quote_char='"',
secondary_quote_char="'",
alias_quote_char="",
as_keyword=False,
dialect=Dialects.SQLITE,
)

from .terms import Parameterizer # noqa: E402
34 changes: 19 additions & 15 deletions pypika_tortoise/dialects/mssql.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from typing import Any, cast

from ..context import DEFAULT_SQL_CONTEXT, SqlContext
from ..enums import Dialects
from ..exceptions import QueryException
from ..queries import Query, QueryBuilder
Expand All @@ -14,6 +15,8 @@ class MSSQLQuery(Query):
Defines a query class for use with Microsoft SQL Server.
"""

SQL_CONTEXT = DEFAULT_SQL_CONTEXT.copy(dialect=Dialects.MSSQL)

@classmethod
def _builder(cls, **kwargs: Any) -> "MSSQLQueryBuilder":
return MSSQLQueryBuilder(**kwargs)
Expand All @@ -23,7 +26,7 @@ class MSSQLQueryBuilder(QueryBuilder):
QUERY_CLS = MSSQLQuery

def __init__(self, **kwargs: Any) -> None:
super().__init__(dialect=Dialects.MSSQL, **kwargs)
super().__init__(**kwargs)
self._top: int | None = None

@builder
Expand All @@ -45,44 +48,45 @@ def fetch_next(self, limit: int) -> MSSQLQueryBuilder: # type:ignore[return]
# Overridden to provide a more domain-specific API for T-SQL users
self._limit = cast(ValueWrapper, self.wrap_constant(limit))

def _offset_sql(self, **kwargs) -> str:
def _offset_sql(self, ctx: SqlContext) -> str:
order_by = ""
if not self._orderbys:
order_by = " ORDER BY (SELECT 0)"
return order_by + " OFFSET {offset} ROWS".format(
offset=self._offset.get_sql(**kwargs) if self._offset is not None else 0
offset=self._offset.get_sql(ctx) if self._offset is not None else 0
)

def _limit_sql(self, **kwargs) -> str:
def _limit_sql(self, ctx: SqlContext) -> str:
if self._limit is None:
return ""
return " FETCH NEXT {limit} ROWS ONLY".format(limit=self._limit.get_sql(**kwargs))
return " FETCH NEXT {limit} ROWS ONLY".format(limit=self._limit.get_sql(ctx))

def _apply_pagination(self, querystring: str, **kwargs) -> str:
def _apply_pagination(self, querystring: str, ctx: SqlContext) -> str:
# Note: Overridden as MSSQL specifies offset before the fetch next limit
if self._limit is not None or self._offset:
# Offset has to be present if fetch next is specified in a MSSQL query
querystring += self._offset_sql(**kwargs)
querystring += self._offset_sql(ctx)

if self._limit is not None:
querystring += self._limit_sql(**kwargs)
querystring += self._limit_sql(ctx)

return querystring

def get_sql(self, *args: Any, **kwargs: Any) -> str:
def get_sql(self, ctx: SqlContext | None = None) -> str:
if not ctx:
ctx = MSSQLQuery.SQL_CONTEXT
# MSSQL does not support group by a field alias.
# Note: set directly in kwargs as they are re-used down the tree in the case of subqueries!
kwargs["groupby_alias"] = False
return super().get_sql(*args, **kwargs)
ctx = ctx.copy(groupby_alias=False)
return super().get_sql(ctx)

def _top_sql(self) -> str:
return "TOP ({}) ".format(self._top) if self._top else ""

def _select_sql(self, **kwargs: Any) -> str:
def _select_sql(self, ctx: SqlContext) -> str:
ctx = ctx.copy(with_alias=True, subquery=True)
return "SELECT {distinct}{top}{select}".format(
top=self._top_sql(),
distinct="DISTINCT " if self._distinct else "",
select=",".join(
term.get_sql(with_alias=True, subquery=True, **kwargs) for term in self._selects
),
select=",".join(term.get_sql(ctx) for term in self._selects),
)
77 changes: 38 additions & 39 deletions pypika_tortoise/dialects/mysql.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from datetime import time
from typing import Any, cast

from ..context import DEFAULT_SQL_CONTEXT, SqlContext
from ..enums import Dialects
from ..queries import Query, QueryBuilder, Table
from ..terms import ValueWrapper
Expand All @@ -15,6 +16,8 @@ class MySQLQuery(Query):
Defines a query class for use with MySQL.
"""

SQL_CONTEXT = DEFAULT_SQL_CONTEXT.copy(dialect=Dialects.MYSQL, quote_char="`")

@classmethod
def _builder(cls, **kwargs: Any) -> "MySQLQueryBuilder":
return MySQLQueryBuilder(**kwargs)
Expand All @@ -25,8 +28,8 @@ def load(cls, fp: str) -> "MySQLLoadQueryBuilder":


class MySQLValueWrapper(ValueWrapper):
def get_value_sql(self, **kwargs: Any) -> str:
quote_char = kwargs.get("secondary_quote_char") or ""
def get_value_sql(self, ctx: SqlContext) -> str:
quote_char = ctx.secondary_quote_char or ""
if isinstance(value := self.value, str):
value = value.replace(quote_char, quote_char * 2)
value = value.replace("\\", "\\\\")
Expand All @@ -37,60 +40,54 @@ def get_value_sql(self, **kwargs: Any) -> str:
elif isinstance(value, (dict, list)):
value = format_quotes(json.dumps(value), quote_char)
return value.replace("\\", "\\\\")
return super().get_value_sql(**kwargs)
return super().get_value_sql(ctx)


class MySQLQueryBuilder(QueryBuilder):
QUOTE_CHAR = "`"
QUERY_CLS = MySQLQuery

def __init__(self, **kwargs: Any) -> None:
super().__init__(
dialect=Dialects.MYSQL,
wrapper_cls=MySQLValueWrapper,
wrap_set_operation_queries=False,
**kwargs,
)
self._modifiers: list[str] = []

def _on_conflict_sql(self, **kwargs: Any) -> str:
kwargs["alias_quote_char"] = (
self.ALIAS_QUOTE_CHAR
if self.QUERY_ALIAS_QUOTE_CHAR is None
else self.QUERY_ALIAS_QUOTE_CHAR
def _on_conflict_sql(self, ctx: SqlContext) -> str:
ctx = ctx.copy(
as_keyword=True,
)
kwargs["as_keyword"] = True
querystring = format_alias_sql("", self.alias, **kwargs)
return querystring
return format_alias_sql("", self.alias, ctx)

def get_sql(self, **kwargs: Any) -> str: # type:ignore[override]
self._set_kwargs_defaults(kwargs)
querystring = super().get_sql(**kwargs)
def get_sql(self, ctx: SqlContext | None = None) -> str:
ctx = ctx or MySQLQuery.SQL_CONTEXT
querystring = super().get_sql(ctx)
if querystring and self._update_table:
if self._orderbys:
querystring += self._orderby_sql(**kwargs)
querystring += self._orderby_sql(ctx)
if self._limit:
querystring += self._limit_sql()
querystring += self._limit_sql(ctx)
return querystring

def _on_conflict_action_sql(self, **kwargs: Any) -> str:
kwargs.pop("with_namespace", None)
def _on_conflict_action_sql(self, ctx: SqlContext) -> str:
on_conflict_ctx = ctx.copy(with_namespace=False)
if len(self._on_conflict_do_updates) > 0:
updates = []
for field, value in self._on_conflict_do_updates:
if value:
updates.append(
"{field}={value}".format(
field=field.get_sql(**kwargs),
value=value.get_sql(**kwargs),
field=field.get_sql(on_conflict_ctx),
value=value.get_sql(on_conflict_ctx),
)
)
else:
updates.append(
"{field}={alias}.{value}".format(
field=field.get_sql(**kwargs),
alias=format_quotes(self.alias, self.QUOTE_CHAR),
value=field.get_sql(**kwargs),
field=field.get_sql(on_conflict_ctx),
alias=format_quotes(self.alias, ctx.quote_char),
value=field.get_sql(on_conflict_ctx),
)
)
action_sql = " ON DUPLICATE KEY UPDATE {updates}".format(updates=",".join(updates))
Expand All @@ -107,23 +104,22 @@ def modifier(self, value: str) -> MySQLQueryBuilder: # type:ignore[return]
"""
self._modifiers.append(value)

def _select_sql(self, **kwargs: Any) -> str:
def _select_sql(self, ctx: SqlContext) -> str:
"""
Overridden function to generate the SELECT part of the SQL statement,
with the addition of the a modifier if present.
"""
ctx = ctx.copy(with_alias=True, subquery=True)
return "SELECT {distinct}{modifier}{select}".format(
distinct="DISTINCT " if self._distinct else "",
modifier="{} ".format(" ".join(self._modifiers)) if self._modifiers else "",
select=",".join(
term.get_sql(with_alias=True, subquery=True, **kwargs) for term in self._selects
),
select=",".join(term.get_sql(ctx) for term in self._selects),
)

def _insert_sql(self, **kwargs: Any) -> str:
def _insert_sql(self, ctx: SqlContext) -> str:
insert_table = cast(Table, self._insert_table)
return "INSERT {ignore}INTO {table}".format(
table=insert_table.get_sql(**kwargs),
table=insert_table.get_sql(ctx),
ignore="IGNORE " if self._on_conflict_do_nothing else "",
)

Expand All @@ -143,23 +139,26 @@ def load(self, fp: str) -> MySQLLoadQueryBuilder: # type:ignore[return]
def into(self, table: str | Table) -> MySQLLoadQueryBuilder: # type:ignore[return]
self._into_table = table if isinstance(table, Table) else Table(table)

def get_sql(self, *args: Any, **kwargs: Any) -> str:
def get_sql(self, ctx: SqlContext | None = None) -> str:
if not ctx:
ctx = MySQLQuery.SQL_CONTEXT

querystring = ""
if self._load_file and self._into_table:
querystring += self._load_file_sql(**kwargs)
querystring += self._into_table_sql(**kwargs)
querystring += self._options_sql(**kwargs)
querystring += self._load_file_sql(ctx)
querystring += self._into_table_sql(ctx)
querystring += self._options_sql(ctx)

return querystring

def _load_file_sql(self, **kwargs: Any) -> str:
def _load_file_sql(self, ctx: SqlContext) -> str:
return "LOAD DATA LOCAL INFILE '{}'".format(self._load_file)

def _into_table_sql(self, **kwargs: Any) -> str:
def _into_table_sql(self, ctx: SqlContext) -> str:
table = cast(Table, self._into_table)
return " INTO TABLE `{}`".format(table.get_sql(**kwargs))
return " INTO TABLE {}".format(table.get_sql(ctx))

def _options_sql(self, **kwargs: Any) -> str:
def _options_sql(self, ctx: SqlContext) -> str:
return " FIELDS TERMINATED BY ','"

def __str__(self) -> str:
Expand Down
28 changes: 16 additions & 12 deletions pypika_tortoise/dialects/oracle.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
from __future__ import annotations

from typing import Any

from ..context import DEFAULT_SQL_CONTEXT, SqlContext
from ..enums import Dialects
from ..queries import Query, QueryBuilder

Expand All @@ -9,31 +12,32 @@ class OracleQuery(Query):
Defines a query class for use with Oracle.
"""

SQL_CONTEXT = DEFAULT_SQL_CONTEXT.copy(dialect=Dialects.ORACLE, alias_quote_char='"')

@classmethod
def _builder(cls, **kwargs: Any) -> "OracleQueryBuilder":
return OracleQueryBuilder(**kwargs)


class OracleQueryBuilder(QueryBuilder):
QUOTE_CHAR = '"'
QUERY_CLS = OracleQuery
ALIAS_QUOTE_CHAR = '"'

def __init__(self, **kwargs: Any) -> None:
super().__init__(dialect=Dialects.ORACLE, **kwargs)
super().__init__(**kwargs)

def get_sql(self, *args: Any, **kwargs: Any) -> str:
# Oracle does not support group by a field alias
# Note: set directly in kwargs as they are re-used down the tree in the case of subqueries!
kwargs["groupby_alias"] = False
return super().get_sql(*args, **kwargs)
def get_sql(self, ctx: SqlContext | None = None) -> str:
if not ctx:
ctx = OracleQuery.SQL_CONTEXT
# Oracle does not support group by a field alias.
ctx = ctx.copy(groupby_alias=False)
return super().get_sql(ctx)

def _offset_sql(self, **kwargs) -> str:
def _offset_sql(self, ctx: SqlContext) -> str:
if self._offset is None:
return ""
return " OFFSET {offset} ROWS".format(offset=self._offset.get_sql(**kwargs))
return " OFFSET {offset} ROWS".format(offset=self._offset.get_sql(ctx))

def _limit_sql(self, **kwargs) -> str:
def _limit_sql(self, ctx: SqlContext) -> str:
if self._limit is None:
return ""
return " FETCH NEXT {limit} ROWS ONLY".format(limit=self._limit.get_sql(**kwargs))
return " FETCH NEXT {limit} ROWS ONLY".format(limit=self._limit.get_sql(ctx))
Loading

0 comments on commit 1ab383c

Please sign in to comment.