Skip to content

Commit

Permalink
Merge pull request #17 from henadzit/feat/array-parametrization
Browse files Browse the repository at this point in the history
Implement Array parametrization
  • Loading branch information
henadzit authored Nov 28, 2024
2 parents 469fc1c + 59e9205 commit a303518
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 12 deletions.
5 changes: 4 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@

## 0.3

## 0.3.0
### 0.3.1
- `Array` can be parametrized

### 0.3.0
- Add `Parameterizer`
- Uppdate `Parameter` to be dialect-aware
- Remove `ListParameter`, `DictParameter`, `QmarkParameter`, etc.
Expand Down
28 changes: 19 additions & 9 deletions pypika/terms.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def fields_(self) -> set["Field"]:
@staticmethod
def wrap_constant(
val, wrapper_cls: Type["Term"] | None = None
) -> ValueError | NodeT | "LiteralValue" | "Array" | "Tuple" | "ValueWrapper":
) -> NodeT | "LiteralValue" | "Array" | "Tuple" | "ValueWrapper":
"""
Used for wrapping raw inputs such as numbers in Criterions and Operator.
Expand Down Expand Up @@ -210,7 +210,7 @@ def all_(self) -> "All":

def isin(self, arg: list | tuple | set | "Term") -> "ContainsCriterion":
if isinstance(arg, (list, tuple, set)):
return ContainsCriterion(self, Tuple(*[self.wrap_constant(value) for value in arg]))
return ContainsCriterion(self, Tuple(*arg))
return ContainsCriterion(self, arg)

def notin(self, arg: list | tuple | set | "Term") -> "ContainsCriterion":
Expand Down Expand Up @@ -776,15 +776,25 @@ def replace_table( # type:ignore[return]


class Array(Tuple):
def get_sql(self, **kwargs: Any) -> str:
dialect = kwargs.get("dialect", None)
values = ",".join(term.get_sql(**kwargs) for term in self.values) # type:ignore[union-attr]
def __init__(self, *values: Any) -> None:
super().__init__(*values)
self.original_value = list(values)

sql = "[{}]".format(values)
if dialect in (Dialects.POSTGRESQL, Dialects.REDSHIFT):
sql = "ARRAY[{}]".format(values) if len(values) > 0 else "'{}'"
def get_sql(self, parameterizer: Parameterizer | None = None, **kwargs: Any) -> str:
if parameterizer is None or not parameterizer.should_parameterize(self.original_value):
dialect = kwargs.get("dialect", None)
values = ",".join(
term.get_sql(**kwargs) for term in self.values
) # type:ignore[union-attr]

return format_alias_sql(sql, self.alias, **kwargs)
sql = "[{}]".format(values)
if dialect in (Dialects.POSTGRESQL, Dialects.REDSHIFT):
sql = "ARRAY[{}]".format(values) if len(values) > 0 else "'{}'"

return format_alias_sql(sql, self.alias, **kwargs)

param = parameterizer.create_param(self.original_value)
return param.get_sql(**kwargs)


class Bracket(Tuple):
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "pypika-tortoise"
version = "0.3.0"
version = "0.3.1"
description = "Forked from pypika and streamline just for tortoise-orm"
authors = ["long2ice <[email protected]>"]
license = "Apache-2.0"
Expand Down
9 changes: 8 additions & 1 deletion tests/test_tuples.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from pypika import Array, Bracket, PostgreSQLQuery, Query, Table, Tables, Tuple
from pypika.functions import Coalesce, NullIf, Sum
from pypika.terms import Field
from pypika.terms import Field, Parameterizer, Star


class TupleTests(unittest.TestCase):
Expand Down Expand Up @@ -149,6 +149,13 @@ def test_render_alias_in_array_sql(self):
q = Query.from_(tb).select(Array(tb.col).as_("different_name"))
self.assertEqual(str(q), 'SELECT ["col"] "different_name" FROM "tb"')

def test_parametrization(self):
q = Query.from_(self.table_abc).select(Star()).where(self.table_abc.f == Array(1, 2, 3))

parameterizer = Parameterizer()
sql = q.get_sql(parameterizer=parameterizer)
self.assertEqual('SELECT * FROM "abc" WHERE "f"=?', sql)


class BracketTests(unittest.TestCase):
table_abc, table_efg = Tables("abc", "efg")
Expand Down

0 comments on commit a303518

Please sign in to comment.