From b2d390c0644920250b0b336d92a71ff4b665867e Mon Sep 17 00:00:00 2001 From: BryanLee Date: Fri, 10 Jan 2025 20:32:08 +0800 Subject: [PATCH] feat: upgrade `sqlize_value` with tortoise.converters, and replace it with `escape` --- README.md | 8 ++-- fastapi_esql/__init__.py | 5 ++- fastapi_esql/utils/sqlizer.py | 69 +++++++++++++++++++++++++---------- pyproject.toml | 2 +- tests/test_sqlizer.py | 67 ++++++++++++++++++++++++---------- 5 files changed, 105 insertions(+), 46 deletions(-) diff --git a/README.md b/README.md index fc8a1ce..61140dd 100644 --- a/README.md +++ b/README.md @@ -116,7 +116,7 @@ Generate sql and execute ```sql UPDATE `account` SET extend = JSON_MERGE_PATCH(JSON_SET(JSON_REMOVE(COALESCE(extend, '{}'), '$.deprecated'), '$.last_login',CAST('{"ipv4": "209.182.101.161"}' AS JSON), '$.uuid','fd04f7f2-24fc-4a73-a1d7-b6e99a464c5f'), '{"updated_at": "2022-10-30 21:34:15", "info": {"online_sec": 636}}') - , active=True, name='new_name' + , active=1, name='new_name' WHERE `id`=8 ``` @@ -162,7 +162,7 @@ Generate sql and execute ```sql INSERT INTO `account_bak` (gender, locale, active, name, extend) - SELECT gender, CASE id WHEN 3 THEN 'zh_CN' WHEN 4 THEN 'en_US' WHEN 5 THEN 'fr_FR' ELSE '' END locale, False active, CONCAT(LEFT(name, 26), ' [NEW]') name, '{}' extend + SELECT gender, CASE id WHEN 3 THEN 'zh_CN' WHEN 4 THEN 'en_US' WHEN 5 THEN 'fr_FR' ELSE '' END locale, 0 active, CONCAT(LEFT(name, 26), ' [NEW]') name, '{}' extend FROM `account` WHERE `id` IN (4,5,6) ``` @@ -185,8 +185,8 @@ Generate sql and execute JOIN ( SELECT * FROM ( VALUES - ROW(7, False, False, 1, '{"test": 1, "debug": 0}'), - ROW(15, False, True, 0, '{"test": 1, "debug": 0}') + ROW(7, 0, 0, 1, '{"test": 1, "debug": 0}'), + ROW(15, 0, 1, 0, '{"test": 1, "debug": 0}') ) AS fly_table (id, deleted, active, gender, extend) ) tmp ON `account`.id=tmp.id AND `account`.deleted=tmp.deleted SET `account`.active=tmp.active, `account`.gender=tmp.gender, `account`.extend=JSON_MERGE_PATCH(COALESCE(`account`.extend, '{}'), tmp.extend) diff --git a/fastapi_esql/__init__.py b/fastapi_esql/__init__.py index 281367d..2948813 100644 --- a/fastapi_esql/__init__.py +++ b/fastapi_esql/__init__.py @@ -1,6 +1,6 @@ from logging.config import dictConfig -from tortoise.converters import escape_string +from tortoise.converters import escape_item, escape_string from tortoise.queryset import Q from .const import ( @@ -23,7 +23,7 @@ wrap_backticks, ) -__version__ = "0.0.14" +__version__ = "0.0.15" __all__ = [ "QsParsingError", @@ -38,6 +38,7 @@ "SQLizer", "Singleton", "convert_dicts", + "escape_item", "escape_string", "timing", "wrap_backticks", diff --git a/fastapi_esql/utils/sqlizer.py b/fastapi_esql/utils/sqlizer.py index 94a547c..08b0f30 100644 --- a/fastapi_esql/utils/sqlizer.py +++ b/fastapi_esql/utils/sqlizer.py @@ -1,8 +1,10 @@ +from enum import Enum from logging import getLogger from json import dumps from typing import Any, Dict, List, Optional, Union from tortoise import Model, __version__ as tortoise_version +from tortoise.converters import escape_item from tortoise.queryset import Q from tortoise.query_utils import QueryModifier @@ -35,10 +37,10 @@ def __init__(self, field: str, whens: dict, default=None): @property def sql(self): whens = " ".join( - f"WHEN {k} THEN {SQLizer.sqlize_value(v)}" + f"WHEN {k} THEN {SQLizer.escape(v)}" for k, v in self.whens.items() ) - else_ = " ELSE " + SQLizer.sqlize_value(self.default) if self.default is not None else "" + else_ = " ELSE " + SQLizer.escape(self.default) if self.default is not None else "" return f"CASE {self.field} {whens}{else_} END" @@ -89,25 +91,51 @@ def resolve_orders(cls, orders: List[str]) -> str: return ", ".join(orders_) @classmethod - def sqlize_value(cls, value, to_json=False) -> str: + def escape(cls, obj, to_json=False, ver=2) -> str: + if ver == 1: + return cls._escape_v1(obj, to_json) + elif ver == 2: + return cls._escape_v2(obj, to_json) + return cls._escape_v1(obj, to_json) + + @classmethod + def _escape_v1(cls, obj, to_json=False): """ - Works like aiomysql.connection.Connection.escape + Original DIY `escape` method """ - if value is None: + if obj is None: return "NULL" - elif isinstance(value, (Cases, RawSQL)): - return value.sql - elif isinstance(value, (int, float, bool)): - return f"{value}" - elif isinstance(value, (dict, list, tuple)): - dumped = dumps(value, ensure_ascii=False) + elif isinstance(obj, (Cases, RawSQL)): + return obj.sql + elif isinstance(obj, (int, float, bool)): + return f"{obj}" + elif isinstance(obj, (dict, list, tuple)): + dumped = dumps(obj, ensure_ascii=False) if to_json: return f"CAST('{dumped}' AS JSON)" # Same with above line # return f"JSON_EXTRACT('{dumped}', '$')" return f"'{dumped}'" else: - return f"'{value}'" + return f"'{obj}'" + + @classmethod + def _escape_v2(cls, obj, to_json=False): + """ + Escape whatever value you pass to it. + Partially copied from aiomysql.connection.Connection.escape + """ + if isinstance(obj, (Cases, RawSQL)): + return obj.sql + elif isinstance(obj, Enum): + return cls._escape_v2(obj.value) + elif isinstance(obj, (dict, list, tuple)): + dumped = dumps(obj, ensure_ascii=False) + if to_json: + return f"CAST('{dumped}' AS JSON)" + return f"'{dumped}'" + else: + return escape_item(obj, "utf8mb4") @classmethod def select_custom_fields( @@ -178,17 +206,17 @@ def update_json_field( json_obj = f"JSON_REMOVE({json_obj}, {rps})" if path_value_dict: pvs = [ - f"'{path}',{cls.sqlize_value(value, to_json=True)}" + f"'{path}',{cls.escape(value, to_json=True)}" for (path, value) in path_value_dict.items() ] json_obj = f"JSON_SET({json_obj}, {', '.join(pvs)})" if merge_dict: - json_obj = f"JSON_MERGE_PATCH({json_obj}, {cls.sqlize_value(merge_dict)})" + json_obj = f"JSON_MERGE_PATCH({json_obj}, {cls.escape(merge_dict)})" assign_field_dict = assign_field_dict or {} assign_fields = [] for k, v in assign_field_dict.items(): - assign_fields.append(f"{k}={cls.sqlize_value(v)}") + assign_fields.append(f"{k}={cls.escape(v)}") assign_field = ", ".join(assign_fields) if assign_fields else None sql = """ @@ -220,7 +248,7 @@ def upsert_on_duplicate( raise WrongParamsError("Parameters `table`, `dicts`, `insert_fields` are required") values = [ - f" ({', '.join(cls.sqlize_value(d.get(f)) for f in insert_fields)})" + f" ({', '.join(cls.escape(d.get(f)) for f in insert_fields)})" for d in dicts ] # NOTE Beginning with MySQL 8.0.19, it is possible to use an alias for the row @@ -279,7 +307,7 @@ def insert_into_select( assign_fields = [] for k, v in assign_field_dict.items(): fields.append(k) - assign_fields.append(f"{cls.sqlize_value(v)} {k}") + assign_fields.append(f"{cls.escape(v)} {k}") sql = f""" INSERT INTO {wrap_backticks(to_table or table)} @@ -304,14 +332,14 @@ def build_fly_table( if using_values: rows = [ - f" ROW({', '.join(cls.sqlize_value(d.get(f)) for f in fields)})" + f" ROW({', '.join(cls.escape(d.get(f)) for f in fields)})" for d in dicts ] values = "VALUES\n" + ",\n".join(rows) table = f"fly_table ({', '.join(fields)})" else: rows = [ - f"SELECT {', '.join(f'{cls.sqlize_value(d.get(f))} {f}' for f in fields)}" + f"SELECT {', '.join(f'{cls.escape(d.get(f))} {f}' for f in fields)}" for d in dicts ] values = "\n UNION\n ".join(rows) @@ -354,3 +382,6 @@ def bulk_update_from_dicts( """ logger.debug(sql) return sql + + +SQLizer.sqlize_value = SQLizer.escape diff --git a/pyproject.toml b/pyproject.toml index 8f9d7d6..599f522 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "fastapi-efficient-sql" -version = "0.0.14" +version = "0.0.15" description = "Generate bulk DML SQL and execute them based on Tortoise ORM and mysql8.0+, and integrated with FastAPI." authors = ["BryanLee "] keywords = ["sql", "fastapi", "tortoise-orm", "mysql8", "bulk-operation"] diff --git a/tests/test_sqlizer.py b/tests/test_sqlizer.py index 485a9cb..a97f7cb 100644 --- a/tests/test_sqlizer.py +++ b/tests/test_sqlizer.py @@ -68,26 +68,53 @@ def test_resolve_orders(self): orders = SQLizer.resolve_orders(["-created_at", "name"]) assert orders == "created_at DESC, name ASC" - def test_sqlize_value(self): - assert SQLizer.sqlize_value(None) == "NULL" + def test_sqlize_value_v1(self): + assert SQLizer.sqlize_value(None, ver=1) == "NULL" raw_sql = RawSQL("statement") - assert SQLizer.sqlize_value(raw_sql) == raw_sql.sql + assert SQLizer.sqlize_value(raw_sql, ver=1) == raw_sql.sql cases = Cases("is_ok", {0: "No", 1: "Yes"}) - assert SQLizer.sqlize_value(cases) == cases.sql + assert SQLizer.sqlize_value(cases, ver=1) == cases.sql - assert SQLizer.sqlize_value(1024) == "1024" - assert SQLizer.sqlize_value(0.125) == "0.125" - assert SQLizer.sqlize_value(True) == "True" + assert SQLizer.sqlize_value(GenderEnum.unknown, ver=1) == "0" + assert SQLizer.sqlize_value(LocaleEnum.zh_CN, ver=1) == "'zh_CN'" + + assert SQLizer.sqlize_value(1024, ver=1) == "1024" + assert SQLizer.sqlize_value(0.125, ver=1) == "0.125" + assert SQLizer.sqlize_value(True, ver=1) == "True" + + assert ( + SQLizer.sqlize_value({"gender": 0, "name": "羊淑兰"}, to_json=True, ver=1) + == """CAST('{"gender": 0, "name": "羊淑兰"}' AS JSON)""" + ) + assert SQLizer.sqlize_value([1, 2, 4], ver=1) == "'[1, 2, 4]'" + assert SQLizer.sqlize_value(("a", "b", "c"), ver=1) == """'["a", "b", "c"]'""" + + assert SQLizer.sqlize_value(datetime(2023, 1, 1, 12, 30), ver=1) == "'2023-01-01 12:30:00'" + + def test_escape_v2(self): + assert SQLizer.escape(None, ver=2) == "NULL" + + raw_sql = RawSQL("statement") + assert SQLizer.escape(raw_sql, ver=2) == raw_sql.sql + cases = Cases("is_ok", {0: "No", 1: "Yes"}) + assert SQLizer.escape(cases, ver=2) == cases.sql + + assert SQLizer.escape(GenderEnum.unknown, ver=2) == "0" + assert SQLizer.escape(LocaleEnum.zh_CN, ver=2) == "'zh_CN'" + + assert SQLizer.escape(1024, ver=2) == "1024" + assert SQLizer.escape(0.125, ver=2) == "0.125" + assert SQLizer.escape(True, ver=2) == "1" assert ( - SQLizer.sqlize_value({"gender": 0, "name": "羊淑兰"}, to_json=True) + SQLizer.escape({"gender": 0, "name": "羊淑兰"}, to_json=True, ver=2) == """CAST('{"gender": 0, "name": "羊淑兰"}' AS JSON)""" ) - assert SQLizer.sqlize_value([1, 2, 4]) == "'[1, 2, 4]'" - assert SQLizer.sqlize_value(("a", "b", "c")) == """'["a", "b", "c"]'""" + assert SQLizer.escape([1, 2, 4], ver=2) == "'[1, 2, 4]'" + assert SQLizer.escape(("a", "b", "c"), ver=2) == """'["a", "b", "c"]'""" - assert SQLizer.sqlize_value(datetime(2023, 1, 1, 12, 30)) == "'2023-01-01 12:30:00'" + assert SQLizer.escape(datetime(2023, 1, 1, 12, 30), ver=2) == "'2023-01-01T12:30:00'" def test_select_custom_fields(self): with self.assertRaises(WrongParamsError): @@ -227,7 +254,7 @@ def test_update_json_field(self): assert sql == """ UPDATE `account` SET extend = JSON_MERGE_PATCH(JSON_SET(JSON_REMOVE(COALESCE(extend, '{}'), '$.deprecated'), '$.last_login',CAST('{"ipv4": "209.182.101.161"}' AS JSON), '$.uuid','fd04f7f2-24fc-4a73-a1d7-b6e99a464c5f'), '{"updated_at": "2022-10-30 21:34:15", "info": {"online_sec": 636}}') - , active=True, name='new_name' + , active=1, name='new_name' WHERE `id`=8 """ @@ -326,7 +353,7 @@ def test_insert_into_select(self): assert archive_sql == """ INSERT INTO `account_bak` (gender, locale, active, name, extend) - SELECT gender, CASE id WHEN 3 THEN 'zh_CN' WHEN 4 THEN 'en_US' WHEN 5 THEN 'fr_FR' ELSE '' END locale, False active, CONCAT(LEFT(name, 26), ' [NEW]') name, '{}' extend + SELECT gender, CASE id WHEN 3 THEN 'zh_CN' WHEN 4 THEN 'en_US' WHEN 5 THEN 'fr_FR' ELSE '' END locale, 0 active, CONCAT(LEFT(name, 26), ' [NEW]') name, '{}' extend FROM `account` WHERE `id` IN (4,5,6) """ @@ -346,7 +373,7 @@ def test_insert_into_select(self): assert copy_sql == """ INSERT INTO `account` (gender, locale, active, name, extend) - SELECT gender, CASE id WHEN 3 THEN 'zh_CN' WHEN 4 THEN 'en_US' WHEN 5 THEN 'fr_FR' ELSE '' END locale, False active, CONCAT(LEFT(name, 26), ' [NEW]') name, '{}' extend + SELECT gender, CASE id WHEN 3 THEN 'zh_CN' WHEN 4 THEN 'en_US' WHEN 5 THEN 'fr_FR' ELSE '' END locale, 0 active, CONCAT(LEFT(name, 26), ' [NEW]') name, '{}' extend FROM `account` WHERE `id` IN (4,5,6) """ @@ -368,9 +395,9 @@ def test_build_fly_table(self): ) assert old_sql == """ SELECT * FROM ( - SELECT 7 id, False active, 1 gender + SELECT 7 id, 0 active, 1 gender UNION - SELECT 15 id, True active, 0 gender + SELECT 15 id, 1 active, 0 gender ) AS fly_table""" new_sql = SQLizer.build_fly_table( @@ -384,8 +411,8 @@ def test_build_fly_table(self): assert new_sql == """ SELECT * FROM ( VALUES - ROW(7, False, 1), - ROW(15, True, 0) + ROW(7, 0, 1), + ROW(15, 1, 0) ) AS fly_table (id, active, gender)""" def test_bulk_update_from_dicts(self): @@ -412,8 +439,8 @@ def test_bulk_update_from_dicts(self): JOIN ( SELECT * FROM ( VALUES - ROW(7, False, False, 1, '{"test": 1, "debug": 0}'), - ROW(15, False, True, 0, '{"test": 1, "debug": 0}') + ROW(7, 0, 0, 1, '{"test": 1, "debug": 0}'), + ROW(15, 0, 1, 0, '{"test": 1, "debug": 0}') ) AS fly_table (id, deleted, active, gender, extend) ) tmp ON `account`.id=tmp.id AND `account`.deleted=tmp.deleted SET `account`.active=tmp.active, `account`.gender=tmp.gender, `account`.extend=JSON_MERGE_PATCH(COALESCE(`account`.extend, '{}'), tmp.extend)