Skip to content
This repository has been archived by the owner on May 17, 2024. It is now read-only.

Commit

Permalink
Merge pull request #662 from kindly/master
Browse files Browse the repository at this point in the history
Fix for more than 50 fields in Postgres
  • Loading branch information
dlawin authored Aug 9, 2023
2 parents 5e9f1fb + 09683af commit 94814e6
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 0 deletions.
5 changes: 5 additions & 0 deletions data_diff/sqeleton/databases/postgresql.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from typing import List
from ..abcs.database_types import (
DbPath,
JSON,
Expand Down Expand Up @@ -92,6 +93,10 @@ def quote(self, s: str):
def to_string(self, s: str):
return f"{s}::varchar"

def concat(self, items: List[str]) -> str:
joined_exprs = " || ".join(items)
return f"({joined_exprs})"

def _convert_db_precision_to_digits(self, p: int) -> int:
# Subtracting 2 due to wierd precision issues in PostgreSQL
return super()._convert_db_precision_to_digits(p) - 2
Expand Down
44 changes: 44 additions & 0 deletions tests/test_postgresql.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,3 +70,47 @@ def test_uuid(self):
self.connection.query(self.table_src.drop(True))
self.connection.query(self.table_dst.drop(True))
mysql_conn.query(self.table_dst.drop(True))


class Test100Fields(unittest.TestCase):
def setUp(self) -> None:
self.connection = get_conn(db.PostgreSQL)

table_suffix = random_table_suffix()

self.table_src_name = f"src{table_suffix}"
self.table_dst_name = f"dst{table_suffix}"

self.table_src = table(self.table_src_name)
self.table_dst = table(self.table_dst_name)

def test_100_fields(self):
self.connection.query('CREATE EXTENSION IF NOT EXISTS "uuid-ossp";', None)

columns = [f"col{i}" for i in range(100)]
fields = " ,".join(f'"{field}" TEXT' for field in columns)

queries = [
self.table_src.drop(True),
self.table_dst.drop(True),
f"CREATE TABLE {self.table_src_name} (id uuid DEFAULT uuid_generate_v4 (), {fields})",
commit,
self.table_src.insert_rows([[f"{x * y}" for x in range(100)] for y in range(10)], columns=columns),
commit,
self.table_dst.create(self.table_src),
commit,
self.table_src.insert_rows([[1 for x in range(100)]], columns=columns),
commit,
]

for query in queries:
self.connection.query(query)

a = TableSegment(self.connection, self.table_src.path, ("id",), extra_columns=tuple(columns))
b = TableSegment(self.connection, self.table_dst.path, ("id",), extra_columns=tuple(columns))

differ = HashDiffer()
diff = list(differ.diff_tables(a, b))
id_ = diff[0][1][0]
result = (id_,) + tuple("1" for x in range(100))
self.assertEqual(diff, [("-", result)])

0 comments on commit 94814e6

Please sign in to comment.