Skip to content

Commit

Permalink
feat: upgrade sqlize_value with tortoise.converters, and replace it…
Browse files Browse the repository at this point in the history
… with `escape`
  • Loading branch information
BryanLee authored and NightMarcher committed Jan 13, 2025
1 parent 81db3bf commit b2d390c
Show file tree
Hide file tree
Showing 5 changed files with 105 additions and 46 deletions.
8 changes: 4 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ Generate sql and execute
```sql
UPDATE `account` SET extend =
JSON_MERGE_PATCH(JSON_SET(JSON_REMOVE(COALESCE(extend, '{}'), '$.deprecated'), '$.last_login',CAST('{"ipv4": "209.182.101.161"}' AS JSON), '$.uuid','fd04f7f2-24fc-4a73-a1d7-b6e99a464c5f'), '{"updated_at": "2022-10-30 21:34:15", "info": {"online_sec": 636}}')
, active=True, name='new_name'
, active=1, name='new_name'
WHERE `id`=8
```

Expand Down Expand Up @@ -162,7 +162,7 @@ Generate sql and execute
```sql
INSERT INTO `account_bak`
(gender, locale, active, name, extend)
SELECT gender, CASE id WHEN 3 THEN 'zh_CN' WHEN 4 THEN 'en_US' WHEN 5 THEN 'fr_FR' ELSE '' END locale, False active, CONCAT(LEFT(name, 26), ' [NEW]') name, '{}' extend
SELECT gender, CASE id WHEN 3 THEN 'zh_CN' WHEN 4 THEN 'en_US' WHEN 5 THEN 'fr_FR' ELSE '' END locale, 0 active, CONCAT(LEFT(name, 26), ' [NEW]') name, '{}' extend
FROM `account`
WHERE `id` IN (4,5,6)
```
Expand All @@ -185,8 +185,8 @@ Generate sql and execute
JOIN (
SELECT * FROM (
VALUES
ROW(7, False, False, 1, '{"test": 1, "debug": 0}'),
ROW(15, False, True, 0, '{"test": 1, "debug": 0}')
ROW(7, 0, 0, 1, '{"test": 1, "debug": 0}'),
ROW(15, 0, 1, 0, '{"test": 1, "debug": 0}')
) AS fly_table (id, deleted, active, gender, extend)
) tmp ON `account`.id=tmp.id AND `account`.deleted=tmp.deleted
SET `account`.active=tmp.active, `account`.gender=tmp.gender, `account`.extend=JSON_MERGE_PATCH(COALESCE(`account`.extend, '{}'), tmp.extend)
Expand Down
5 changes: 3 additions & 2 deletions fastapi_esql/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from logging.config import dictConfig

from tortoise.converters import escape_string
from tortoise.converters import escape_item, escape_string
from tortoise.queryset import Q

from .const import (
Expand All @@ -23,7 +23,7 @@
wrap_backticks,
)

__version__ = "0.0.14"
__version__ = "0.0.15"

__all__ = [
"QsParsingError",
Expand All @@ -38,6 +38,7 @@
"SQLizer",
"Singleton",
"convert_dicts",
"escape_item",
"escape_string",
"timing",
"wrap_backticks",
Expand Down
69 changes: 50 additions & 19 deletions fastapi_esql/utils/sqlizer.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from enum import Enum
from logging import getLogger
from json import dumps
from typing import Any, Dict, List, Optional, Union

from tortoise import Model, __version__ as tortoise_version
from tortoise.converters import escape_item
from tortoise.queryset import Q
from tortoise.query_utils import QueryModifier

Expand Down Expand Up @@ -35,10 +37,10 @@ def __init__(self, field: str, whens: dict, default=None):
@property
def sql(self):
whens = " ".join(
f"WHEN {k} THEN {SQLizer.sqlize_value(v)}"
f"WHEN {k} THEN {SQLizer.escape(v)}"
for k, v in self.whens.items()
)
else_ = " ELSE " + SQLizer.sqlize_value(self.default) if self.default is not None else ""
else_ = " ELSE " + SQLizer.escape(self.default) if self.default is not None else ""
return f"CASE {self.field} {whens}{else_} END"


Expand Down Expand Up @@ -89,25 +91,51 @@ def resolve_orders(cls, orders: List[str]) -> str:
return ", ".join(orders_)

@classmethod
def sqlize_value(cls, value, to_json=False) -> str:
def escape(cls, obj, to_json=False, ver=2) -> str:
if ver == 1:
return cls._escape_v1(obj, to_json)
elif ver == 2:
return cls._escape_v2(obj, to_json)
return cls._escape_v1(obj, to_json)

@classmethod
def _escape_v1(cls, obj, to_json=False):
"""
Works like aiomysql.connection.Connection.escape
Original DIY `escape` method
"""
if value is None:
if obj is None:
return "NULL"
elif isinstance(value, (Cases, RawSQL)):
return value.sql
elif isinstance(value, (int, float, bool)):
return f"{value}"
elif isinstance(value, (dict, list, tuple)):
dumped = dumps(value, ensure_ascii=False)
elif isinstance(obj, (Cases, RawSQL)):
return obj.sql
elif isinstance(obj, (int, float, bool)):
return f"{obj}"
elif isinstance(obj, (dict, list, tuple)):
dumped = dumps(obj, ensure_ascii=False)
if to_json:
return f"CAST('{dumped}' AS JSON)"
# Same with above line
# return f"JSON_EXTRACT('{dumped}', '$')"
return f"'{dumped}'"
else:
return f"'{value}'"
return f"'{obj}'"

@classmethod
def _escape_v2(cls, obj, to_json=False):
"""
Escape whatever value you pass to it.
Partially copied from aiomysql.connection.Connection.escape
"""
if isinstance(obj, (Cases, RawSQL)):
return obj.sql
elif isinstance(obj, Enum):
return cls._escape_v2(obj.value)
elif isinstance(obj, (dict, list, tuple)):
dumped = dumps(obj, ensure_ascii=False)
if to_json:
return f"CAST('{dumped}' AS JSON)"
return f"'{dumped}'"
else:
return escape_item(obj, "utf8mb4")

@classmethod
def select_custom_fields(
Expand Down Expand Up @@ -178,17 +206,17 @@ def update_json_field(
json_obj = f"JSON_REMOVE({json_obj}, {rps})"
if path_value_dict:
pvs = [
f"'{path}',{cls.sqlize_value(value, to_json=True)}"
f"'{path}',{cls.escape(value, to_json=True)}"
for (path, value) in path_value_dict.items()
]
json_obj = f"JSON_SET({json_obj}, {', '.join(pvs)})"
if merge_dict:
json_obj = f"JSON_MERGE_PATCH({json_obj}, {cls.sqlize_value(merge_dict)})"
json_obj = f"JSON_MERGE_PATCH({json_obj}, {cls.escape(merge_dict)})"

assign_field_dict = assign_field_dict or {}
assign_fields = []
for k, v in assign_field_dict.items():
assign_fields.append(f"{k}={cls.sqlize_value(v)}")
assign_fields.append(f"{k}={cls.escape(v)}")
assign_field = ", ".join(assign_fields) if assign_fields else None

sql = """
Expand Down Expand Up @@ -220,7 +248,7 @@ def upsert_on_duplicate(
raise WrongParamsError("Parameters `table`, `dicts`, `insert_fields` are required")

values = [
f" ({', '.join(cls.sqlize_value(d.get(f)) for f in insert_fields)})"
f" ({', '.join(cls.escape(d.get(f)) for f in insert_fields)})"
for d in dicts
]
# NOTE Beginning with MySQL 8.0.19, it is possible to use an alias for the row
Expand Down Expand Up @@ -279,7 +307,7 @@ def insert_into_select(
assign_fields = []
for k, v in assign_field_dict.items():
fields.append(k)
assign_fields.append(f"{cls.sqlize_value(v)} {k}")
assign_fields.append(f"{cls.escape(v)} {k}")

sql = f"""
INSERT INTO {wrap_backticks(to_table or table)}
Expand All @@ -304,14 +332,14 @@ def build_fly_table(

if using_values:
rows = [
f" ROW({', '.join(cls.sqlize_value(d.get(f)) for f in fields)})"
f" ROW({', '.join(cls.escape(d.get(f)) for f in fields)})"
for d in dicts
]
values = "VALUES\n" + ",\n".join(rows)
table = f"fly_table ({', '.join(fields)})"
else:
rows = [
f"SELECT {', '.join(f'{cls.sqlize_value(d.get(f))} {f}' for f in fields)}"
f"SELECT {', '.join(f'{cls.escape(d.get(f))} {f}' for f in fields)}"
for d in dicts
]
values = "\n UNION\n ".join(rows)
Expand Down Expand Up @@ -354,3 +382,6 @@ def bulk_update_from_dicts(
"""
logger.debug(sql)
return sql


SQLizer.sqlize_value = SQLizer.escape
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "fastapi-efficient-sql"
version = "0.0.14"
version = "0.0.15"
description = "Generate bulk DML SQL and execute them based on Tortoise ORM and mysql8.0+, and integrated with FastAPI."
authors = ["BryanLee <[email protected]>"]
keywords = ["sql", "fastapi", "tortoise-orm", "mysql8", "bulk-operation"]
Expand Down
67 changes: 47 additions & 20 deletions tests/test_sqlizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,26 +68,53 @@ def test_resolve_orders(self):
orders = SQLizer.resolve_orders(["-created_at", "name"])
assert orders == "created_at DESC, name ASC"

def test_sqlize_value(self):
assert SQLizer.sqlize_value(None) == "NULL"
def test_sqlize_value_v1(self):
assert SQLizer.sqlize_value(None, ver=1) == "NULL"

raw_sql = RawSQL("statement")
assert SQLizer.sqlize_value(raw_sql) == raw_sql.sql
assert SQLizer.sqlize_value(raw_sql, ver=1) == raw_sql.sql
cases = Cases("is_ok", {0: "No", 1: "Yes"})
assert SQLizer.sqlize_value(cases) == cases.sql
assert SQLizer.sqlize_value(cases, ver=1) == cases.sql

assert SQLizer.sqlize_value(1024) == "1024"
assert SQLizer.sqlize_value(0.125) == "0.125"
assert SQLizer.sqlize_value(True) == "True"
assert SQLizer.sqlize_value(GenderEnum.unknown, ver=1) == "0"
assert SQLizer.sqlize_value(LocaleEnum.zh_CN, ver=1) == "'zh_CN'"

assert SQLizer.sqlize_value(1024, ver=1) == "1024"
assert SQLizer.sqlize_value(0.125, ver=1) == "0.125"
assert SQLizer.sqlize_value(True, ver=1) == "True"

assert (
SQLizer.sqlize_value({"gender": 0, "name": "羊淑兰"}, to_json=True, ver=1)
== """CAST('{"gender": 0, "name": "羊淑兰"}' AS JSON)"""
)
assert SQLizer.sqlize_value([1, 2, 4], ver=1) == "'[1, 2, 4]'"
assert SQLizer.sqlize_value(("a", "b", "c"), ver=1) == """'["a", "b", "c"]'"""

assert SQLizer.sqlize_value(datetime(2023, 1, 1, 12, 30), ver=1) == "'2023-01-01 12:30:00'"

def test_escape_v2(self):
assert SQLizer.escape(None, ver=2) == "NULL"

raw_sql = RawSQL("statement")
assert SQLizer.escape(raw_sql, ver=2) == raw_sql.sql
cases = Cases("is_ok", {0: "No", 1: "Yes"})
assert SQLizer.escape(cases, ver=2) == cases.sql

assert SQLizer.escape(GenderEnum.unknown, ver=2) == "0"
assert SQLizer.escape(LocaleEnum.zh_CN, ver=2) == "'zh_CN'"

assert SQLizer.escape(1024, ver=2) == "1024"
assert SQLizer.escape(0.125, ver=2) == "0.125"
assert SQLizer.escape(True, ver=2) == "1"

assert (
SQLizer.sqlize_value({"gender": 0, "name": "羊淑兰"}, to_json=True)
SQLizer.escape({"gender": 0, "name": "羊淑兰"}, to_json=True, ver=2)
== """CAST('{"gender": 0, "name": "羊淑兰"}' AS JSON)"""
)
assert SQLizer.sqlize_value([1, 2, 4]) == "'[1, 2, 4]'"
assert SQLizer.sqlize_value(("a", "b", "c")) == """'["a", "b", "c"]'"""
assert SQLizer.escape([1, 2, 4], ver=2) == "'[1, 2, 4]'"
assert SQLizer.escape(("a", "b", "c"), ver=2) == """'["a", "b", "c"]'"""

assert SQLizer.sqlize_value(datetime(2023, 1, 1, 12, 30)) == "'2023-01-01 12:30:00'"
assert SQLizer.escape(datetime(2023, 1, 1, 12, 30), ver=2) == "'2023-01-01T12:30:00'"

def test_select_custom_fields(self):
with self.assertRaises(WrongParamsError):
Expand Down Expand Up @@ -227,7 +254,7 @@ def test_update_json_field(self):
assert sql == """
UPDATE `account` SET extend =
JSON_MERGE_PATCH(JSON_SET(JSON_REMOVE(COALESCE(extend, '{}'), '$.deprecated'), '$.last_login',CAST('{"ipv4": "209.182.101.161"}' AS JSON), '$.uuid','fd04f7f2-24fc-4a73-a1d7-b6e99a464c5f'), '{"updated_at": "2022-10-30 21:34:15", "info": {"online_sec": 636}}')
, active=True, name='new_name'
, active=1, name='new_name'
WHERE `id`=8
"""

Expand Down Expand Up @@ -326,7 +353,7 @@ def test_insert_into_select(self):
assert archive_sql == """
INSERT INTO `account_bak`
(gender, locale, active, name, extend)
SELECT gender, CASE id WHEN 3 THEN 'zh_CN' WHEN 4 THEN 'en_US' WHEN 5 THEN 'fr_FR' ELSE '' END locale, False active, CONCAT(LEFT(name, 26), ' [NEW]') name, '{}' extend
SELECT gender, CASE id WHEN 3 THEN 'zh_CN' WHEN 4 THEN 'en_US' WHEN 5 THEN 'fr_FR' ELSE '' END locale, 0 active, CONCAT(LEFT(name, 26), ' [NEW]') name, '{}' extend
FROM `account`
WHERE `id` IN (4,5,6)
"""
Expand All @@ -346,7 +373,7 @@ def test_insert_into_select(self):
assert copy_sql == """
INSERT INTO `account`
(gender, locale, active, name, extend)
SELECT gender, CASE id WHEN 3 THEN 'zh_CN' WHEN 4 THEN 'en_US' WHEN 5 THEN 'fr_FR' ELSE '' END locale, False active, CONCAT(LEFT(name, 26), ' [NEW]') name, '{}' extend
SELECT gender, CASE id WHEN 3 THEN 'zh_CN' WHEN 4 THEN 'en_US' WHEN 5 THEN 'fr_FR' ELSE '' END locale, 0 active, CONCAT(LEFT(name, 26), ' [NEW]') name, '{}' extend
FROM `account`
WHERE `id` IN (4,5,6)
"""
Expand All @@ -368,9 +395,9 @@ def test_build_fly_table(self):
)
assert old_sql == """
SELECT * FROM (
SELECT 7 id, False active, 1 gender
SELECT 7 id, 0 active, 1 gender
UNION
SELECT 15 id, True active, 0 gender
SELECT 15 id, 1 active, 0 gender
) AS fly_table"""

new_sql = SQLizer.build_fly_table(
Expand All @@ -384,8 +411,8 @@ def test_build_fly_table(self):
assert new_sql == """
SELECT * FROM (
VALUES
ROW(7, False, 1),
ROW(15, True, 0)
ROW(7, 0, 1),
ROW(15, 1, 0)
) AS fly_table (id, active, gender)"""

def test_bulk_update_from_dicts(self):
Expand All @@ -412,8 +439,8 @@ def test_bulk_update_from_dicts(self):
JOIN (
SELECT * FROM (
VALUES
ROW(7, False, False, 1, '{"test": 1, "debug": 0}'),
ROW(15, False, True, 0, '{"test": 1, "debug": 0}')
ROW(7, 0, 0, 1, '{"test": 1, "debug": 0}'),
ROW(15, 0, 1, 0, '{"test": 1, "debug": 0}')
) AS fly_table (id, deleted, active, gender, extend)
) tmp ON `account`.id=tmp.id AND `account`.deleted=tmp.deleted
SET `account`.active=tmp.active, `account`.gender=tmp.gender, `account`.extend=JSON_MERGE_PATCH(COALESCE(`account`.extend, '{}'), tmp.extend)
Expand Down

0 comments on commit b2d390c

Please sign in to comment.