Skip to content

Commit

Permalink
+ Bind variable counter
Browse files Browse the repository at this point in the history
  • Loading branch information
Andreas Kosubek committed Dec 7, 2023
1 parent da85652 commit e206d6b
Show file tree
Hide file tree
Showing 3 changed files with 85 additions and 57 deletions.
4 changes: 2 additions & 2 deletions pygeofilter/backends/oraclesql/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from .evaluate import to_sql_where, to_sql_where_with_binds
from .evaluate import to_sql_where, to_sql_where_with_bind_variables

__all__ = ["to_sql_where", "to_sql_where_with_binds"]
__all__ = ["to_sql_where", "to_sql_where_with_bind_variables"]
106 changes: 67 additions & 39 deletions pygeofilter/backends/oraclesql/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,6 @@
ast.SpatialComparisonOp.EQUALS: "EQUAL",
}

WITH_BINDS = False

BIND_VARIABLES = {}


class OracleSQLEvaluator(Evaluator):
def __init__(
Expand All @@ -72,6 +68,11 @@ def __init__(
self.attribute_map = attribute_map
self.function_map = function_map

self.with_bind_variables = False
self.bind_variables = {}
# Counter for bind variables
self.b_cnt = 0

@handle(ast.Not)
def not_(self, node, sub):
return f"NOT {sub}"
Expand All @@ -82,21 +83,29 @@ def combination(self, node, lhs, rhs):

@handle(ast.Comparison, subclasses=True)
def comparison(self, node, lhs, rhs):
if WITH_BINDS:
BIND_VARIABLES[f"{lhs}"] = rhs
sql = f"({lhs} {COMPARISON_OP_MAP[node.op]} :{lhs})"
if self.with_bind_variables:
self.bind_variables[f"{lhs}_{self.b_cnt}"] = rhs
sql = f"({lhs} {COMPARISON_OP_MAP[node.op]} :{lhs}_{self.b_cnt})"
self.b_cnt += 1
else:
sql = f"({lhs} {COMPARISON_OP_MAP[node.op]} {rhs})"
return sql

@handle(ast.Between)
def between(self, node, lhs, low, high):
if WITH_BINDS:
BIND_VARIABLES[f"{lhs}_high"] = high
BIND_VARIABLES[f"{lhs}_low"] = low
sql = f"({lhs} {'NOT ' if node.not_ else ''}BETWEEN :{lhs}_low AND :{lhs}_high)"
if self.with_bind_variables:
self.bind_variables[f"{lhs}_high_{self.b_cnt}"] = high
self.bind_variables[f"{lhs}_low_{self.b_cnt}"] = low
sql = (
f"({lhs} {'NOT ' if node.not_ else ''}BETWEEN "
f":{lhs}_low_{self.b_cnt} AND :{lhs}_high_{self.b_cnt})"
)
self.b_cnt += 1
else:
sql = f"({lhs} {'NOT ' if node.not_ else ''}BETWEEN {low} AND {high})"
sql = (
f"({lhs} {'NOT ' if node.not_ else ''}BETWEEN "
f"{low} AND {high})"
)
return sql

@handle(ast.Like)
Expand All @@ -107,10 +116,10 @@ def like(self, node, lhs):
if node.singlechar != "_":
pattern = pattern.replace(node.singlechar, "_")

if WITH_BINDS:
BIND_VARIABLES[f"{lhs}"] = pattern
if self.with_bind_variables:
self.bind_variables[f"{lhs}_{self.b_cnt}"] = pattern
sql = f"{lhs} {'NOT ' if node.not_ else ''}LIKE "
sql += f":{lhs} ESCAPE '{node.escapechar}'"
sql += f":{lhs}_{self.b_cnt} ESCAPE '{node.escapechar}'"

else:
sql = f"{lhs} {'NOT ' if node.not_ else ''}LIKE "
Expand Down Expand Up @@ -147,12 +156,19 @@ def bbox(self, node, lhs):
srid = 4326
param = "mask=ANYINTERACT"

if WITH_BINDS:
BIND_VARIABLES["geo_json"] = geo_json
BIND_VARIABLES["srid"] = srid
geom_sql = "SDO_UTIL.FROM_JSON(geometry => :geo_json, srid => :srid)"
if self.with_bind_variables:
self.bind_variables[f"geo_json_{self.b_cnt}"] = geo_json
self.bind_variables[f"srid_{self.b_cnt}"] = srid
geom_sql = (
f"SDO_UTIL.FROM_JSON(geometry => :geo_json_{self.b_cnt}, "
f"srid => :srid_{self.b_cnt})"
)
self.b_cnt += 1
else:
geom_sql = f"SDO_UTIL.FROM_JSON(geometry => '{geo_json}', srid => {srid})"
geom_sql = (
f"SDO_UTIL.FROM_JSON(geometry => '{geo_json}', "
f"srid => {srid})"
)

sql = f"SDO_RELATE({lhs}, {geom_sql}, '{param}') = 'TRUE'"
return sql
Expand Down Expand Up @@ -185,12 +201,19 @@ def geometry(self, node: values.Geometry):
srid = 4326
geo_json = json.dumps(node.geometry)
print(geo_json)
if WITH_BINDS:
BIND_VARIABLES["geo_json"] = geo_json
BIND_VARIABLES["srid"] = srid
sql = "SDO_UTIL.FROM_JSON(geometry => :geo_json, srid => :srid)"
if self.with_bind_variables:
self.bind_variables[f"geo_json_{self.b_cnt}"] = geo_json
self.bind_variables[f"srid_{self.b_cnt}"] = srid
sql = (
f"SDO_UTIL.FROM_JSON(geometry => :geo_json_{self.b_cnt}, "
f"srid => :srid_{self.b_cnt})"
)
self.b_cnt += 1
else:
sql = f"SDO_UTIL.FROM_JSON(geometry => '{geo_json}', srid => {srid})"
sql = (
f"SDO_UTIL.FROM_JSON(geometry => '{geo_json}', "
f"srid => {srid})"
)
return sql

@handle(values.Envelope)
Expand All @@ -199,12 +222,19 @@ def envelope(self, node: values.Envelope):
# node and translate to SRID
srid = 4326
geo_json = json.dumps(node.geometry)
if WITH_BINDS:
BIND_VARIABLES["geo_json"] = geo_json
BIND_VARIABLES["srid"] = srid
sql = "SDO_UTIL.FROM_JSON(geometry => :geo_json, srid => :srid)"
if self.with_bind_variables:
self.bind_variables[f"geo_json_{self.b_cnt}"] = geo_json
self.bind_variables[f"srid_{self.b_cnt}"] = srid
sql = (
f"SDO_UTIL.FROM_JSON(geometry => :geo_json_{self.b_cnt}, "
f"srid => :srid_{self.b_cnt})"
)
self.b_cnt += 1
else:
sql = f"SDO_UTIL.FROM_JSON(geometry => '{geo_json}', srid => {srid})"
sql = (
f"SDO_UTIL.FROM_JSON(geometry => '{geo_json}', "
f"srid => {srid})"
)
return sql


Expand All @@ -213,19 +243,17 @@ def to_sql_where(
field_mapping: Dict[str, str],
function_map: Optional[Dict[str, str]] = None,
) -> str:
global WITH_BINDS
WITH_BINDS = False
return OracleSQLEvaluator(field_mapping, function_map or {}).evaluate(root)
orcle = OracleSQLEvaluator(field_mapping, function_map or {})
orcle.with_bind_variables = False
return orcle.evaluate(root)


def to_sql_where_with_binds(
def to_sql_where_with_bind_variables(
root: ast.Node,
field_mapping: Dict[str, str],
function_map: Optional[Dict[str, str]] = None,
) -> str:
orcle = OracleSQLEvaluator(field_mapping, function_map or {})
global WITH_BINDS
WITH_BINDS = True
global BIND_VARIABLES
BIND_VARIABLES = {}
return orcle.evaluate(root), BIND_VARIABLES
orcle.with_bind_variables = True
orcle.bind_variables = {}
return orcle.evaluate(root), orcle.bind_variables
32 changes: 16 additions & 16 deletions tests/backends/oraclesql/test_evaluate.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from pygeofilter.backends.oraclesql import (
to_sql_where,
to_sql_where_with_binds,
to_sql_where_with_bind_variables,
)
from pygeofilter.parsers.ecql import parse

Expand All @@ -24,13 +24,13 @@ def test_between():


def test_between_with_binds():
where, binds = to_sql_where_with_binds(
where, binds = to_sql_where_with_bind_variables(
parse("int_attr NOT BETWEEN 4 AND 6"),
FIELD_MAPPING,
FUNCTION_MAP
)
assert where == "(int_attr NOT BETWEEN :int_attr_low AND :int_attr_high)"
assert binds == {"int_attr_low": 4, "int_attr_high": 6}
assert where == "(int_attr NOT BETWEEN :int_attr_low_0 AND :int_attr_high_0)"
assert binds == {"int_attr_low_0": 4, "int_attr_high_0": 6}


def test_like():
Expand All @@ -43,13 +43,13 @@ def test_like():


def test_like_with_binds():
where, binds = to_sql_where_with_binds(
where, binds = to_sql_where_with_bind_variables(
parse("str_attr LIKE 'foo%'"),
FIELD_MAPPING,
FUNCTION_MAP
)
assert where == "str_attr LIKE :str_attr ESCAPE '\\'"
assert binds == {"str_attr": "foo%"}
assert where == "str_attr LIKE :str_attr_0 ESCAPE '\\'"
assert binds == {"str_attr_0": "foo%"}


def test_combination():
Expand All @@ -62,13 +62,13 @@ def test_combination():


def test_combination_with_binds():
where, binds = to_sql_where_with_binds(
where, binds = to_sql_where_with_bind_variables(
parse("int_attr = 5 AND float_attr < 6.0"),
FIELD_MAPPING,
FUNCTION_MAP
)
assert where == "((int_attr = :int_attr) AND (float_attr < :float_attr))"
assert binds == {"int_attr": 5, "float_attr": 6.0}
assert where == "((int_attr = :int_attr_0) AND (float_attr < :float_attr_1))"
assert binds == {"int_attr_0": 5, "float_attr_1": 6.0}


def test_spatial():
Expand All @@ -89,7 +89,7 @@ def test_spatial():


def test_spatial_with_binds():
where, binds = to_sql_where_with_binds(
where, binds = to_sql_where_with_bind_variables(
parse("INTERSECTS(point_attr, ENVELOPE (0 1 0 1))"),
FIELD_MAPPING,
FUNCTION_MAP,
Expand All @@ -100,10 +100,10 @@ def test_spatial_with_binds():
)
assert where == (
"SDO_RELATE(geometry_attr, "
"SDO_UTIL.FROM_JSON(geometry => :geo_json, srid => :srid), "
"SDO_UTIL.FROM_JSON(geometry => :geo_json_0, srid => :srid_0), "
"'mask=ANYINTERACT') = 'TRUE'"
)
assert binds == {"geo_json": geo_json, "srid": 4326}
assert binds == {"geo_json_0": geo_json, "srid_0": 4326}


def test_bbox():
Expand All @@ -128,7 +128,7 @@ def test_bbox():


def test_bbox_with_binds():
where, binds = to_sql_where_with_binds(
where, binds = to_sql_where_with_bind_variables(
parse("BBOX(point_attr,-140.99778,41.6751050889,-52.6480987209,83.23324)"),
FIELD_MAPPING,
FUNCTION_MAP,
Expand All @@ -143,7 +143,7 @@ def test_bbox_with_binds():
)
assert where == (
"SDO_RELATE(geometry_attr, "
"SDO_UTIL.FROM_JSON(geometry => :geo_json, srid => :srid), "
"SDO_UTIL.FROM_JSON(geometry => :geo_json_0, srid => :srid_0), "
"'mask=ANYINTERACT') = 'TRUE'"
)
assert binds == {"geo_json": geo_json, "srid": 4326}
assert binds == {"geo_json_0": geo_json, "srid_0": 4326}

0 comments on commit e206d6b

Please sign in to comment.