Skip to content
This repository has been archived by the owner on May 17, 2024. It is now read-only.

Commit

Permalink
Merge pull request #369 from datafold/jan17
Browse files Browse the repository at this point in the history
Bugfix: Add brackets around WHERE clause
  • Loading branch information
erezsh authored Jan 17, 2023
2 parents 2a4ea5d + 1ed7ce0 commit 910c98d
Show file tree
Hide file tree
Showing 4 changed files with 8 additions and 5 deletions.
2 changes: 1 addition & 1 deletion data_diff/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ def write_usage(self, prog: str, args: str = "", prefix: Optional[str] = None) -
metavar="COUNT",
)
@click.option(
"-w", "--where", default=None, help="An additional 'where' expression to restrict the search space.", metavar="EXPR"
"-w", "--where", default=None, help="An additional 'where' expression to restrict the search space. Beware of SQL Injection!", metavar="EXPR"
)
@click.option("-a", "--algorithm", default=Algorithm.AUTO.value, type=click.Choice([i.value for i in Algorithm]))
@click.option(
Expand Down
2 changes: 1 addition & 1 deletion data_diff/joindiff_tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ def _test_null_keys(self, table1, table2):
q = t.select(*this[key_columns]).where(or_(this[k] == None for k in key_columns))
nulls = ts.database.query(q, list)
if nulls:
raise ValueError("NULL values in one or more primary keys")
raise ValueError(f"NULL values in one or more primary keys of {ts.table_path}")

def _collect_stats(self, i, table_seg: TableSegment, info_tree: InfoTree):
logger.debug(f"Collecting stats for table #{i}")
Expand Down
7 changes: 5 additions & 2 deletions data_diff/table_segment.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,11 @@ def __post_init__(self):
f"Error: min_update expected to be smaller than max_update! ({self.min_update} >= {self.max_update})"
)

def _where(self):
return f"({self.where})" if self.where else None

def _with_raw_schema(self, raw_schema: dict) -> "TableSegment":
schema = self.database._process_table_schema(self.table_path, raw_schema, self.relevant_columns, self.where)
schema = self.database._process_table_schema(self.table_path, raw_schema, self.relevant_columns, self._where())
return self.new(_schema=create_schema(self.database, self.table_path, schema, self.case_sensitive))

def with_schema(self) -> "TableSegment":
Expand Down Expand Up @@ -100,7 +103,7 @@ def source_table(self):

def make_select(self):
return self.source_table.where(
*self._make_key_range(), *self._make_update_range(), Code(self.where) if self.where else SKIP
*self._make_key_range(), *self._make_update_range(), Code(self._where()) if self.where else SKIP
)

def get_values(self) -> list:
Expand Down
2 changes: 1 addition & 1 deletion tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def test_api(self):

# test where
diff_id = diff[0][1][0]
where = f"id != {diff_id}"
where = f"id != {diff_id} OR id = 90000000"

t1 = connect_to_table(TEST_MYSQL_CONN_STRING, self.table_src_name, where=where)
t2 = connect_to_table(TEST_MYSQL_CONN_STRING, self.table_dst_name, where=where)
Expand Down

0 comments on commit 910c98d

Please sign in to comment.