Skip to content

Commit

Permalink
handle composite primary keys and foreign keys in Tables
Browse files Browse the repository at this point in the history
  • Loading branch information
jzmiller1 committed Aug 18, 2024
1 parent 10ed039 commit 056bffb
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 2 deletions.
6 changes: 4 additions & 2 deletions postnormalism/schema/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,13 +53,14 @@ def _extract_columns(self):
for part in parts:
part = part.strip()

if part.upper().startswith("UNIQUE") or part.upper().startswith("CHECK"):
if part.upper().startswith(("UNIQUE", "CHECK", "PRIMARY KEY", "FOREIGN")):
continue
match = re.match(self._pattern_create, part)
if match:
columns.append(match.group(1))
else:
column_and_constraint = re.match(r"^\s*(\w+)\s+.*(?:UNIQUE|CHECK)\s*\(.*\)", part,
# Handle cases where constraints are included with column definitions
column_and_constraint = re.match(r"^\s*(\w+)\s+.*(?:UNIQUE|CHECK|PRIMARY KEY)\s*\(.*\)", part,
re.IGNORECASE)
if column_and_constraint:
columns.append(column_and_constraint.group(1))
Expand All @@ -69,6 +70,7 @@ def _extract_columns(self):
match = re.search(self._pattern_alter, line.strip(), re.IGNORECASE)
if match:
columns.append(match.group(1))
columns = list(dict.fromkeys(columns))
object.__setattr__(self, '_columns', columns)

def _extract_inherited_columns(self):
Expand Down
37 changes: 37 additions & 0 deletions tests/items/test_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,3 +237,40 @@ def test_exclude_check_and_unique_constraints(self):
]
process_action_table = Table(create=create_statement)
self.assertEqual(process_action_table.columns, expected_columns)

def test_table_with_composite_primary_key(self):
create_table = """
CREATE TABLE material_attribute_map (
material UUID NOT NULL,
attribute UUID NOT NULL,
spread NUMERIC[],
PRIMARY KEY (material, attribute)
);
"""

table = Table(create=create_table)
expected_columns = ["material", "attribute", "spread"]

self.assertEqual(table.columns, expected_columns)

def test_table_with_foreign_keys(self):
create_table = """
CREATE TABLE order_items (
order_id UUID NOT NULL,
product_id UUID NOT NULL,
quantity INT NOT NULL,
price NUMERIC(10, 2) NOT NULL,
discount NUMERIC(5, 2),
PRIMARY KEY (order_id, product_id),
FOREIGN KEY (order_id) REFERENCES orders(id) ON DELETE CASCADE,
FOREIGN KEY (product_id) REFERENCES products(id) ON DELETE CASCADE
);
"""

table = Table(create=create_table)
expected_columns = [
"order_id", "product_id", "quantity", "price", "discount"
]

self.assertEqual(table.columns, expected_columns)

0 comments on commit 056bffb

Please sign in to comment.