Skip to content

Commit

Permalink
Initial support for domains
Browse files Browse the repository at this point in the history
  • Loading branch information
jzmiller1 committed Aug 22, 2024
1 parent 056bffb commit 51799b7
Show file tree
Hide file tree
Showing 4 changed files with 105 additions and 1 deletion.
1 change: 1 addition & 0 deletions postnormalism/schema/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .database_item import DatabaseItem
from .domain import Domain
from .schema import Schema
from .function import Function
from .table import Table
Expand Down
2 changes: 1 addition & 1 deletion postnormalism/schema/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def add_items(self, *items: DatabaseItem, schema_loaded: set) -> None:
self.items_by_type[item_type].append(item)

def get_items_by_type(self, item_type: str) -> list:
allowed_database_items = ["table", "function", "schema", "view", "trigger"]
allowed_database_items = ["table", "function", "schema", "view", "trigger", "domain"]
item_type = item_type.lower()
if item_type not in allowed_database_items:
raise ValueError(f"Invalid item_type: {item_type}")
Expand Down
41 changes: 41 additions & 0 deletions postnormalism/schema/domain.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
from dataclasses import dataclass, field
import re

from .database_item import DatabaseItem


@dataclass(frozen=True)
class Domain(DatabaseItem):
"""
A data class for PostgreSQL domains.
"""
_item_type: str = 'domain'
_name_pattern: str = field(default=r'CREATE\s+DOMAIN\s+(?:(\w+)\.)?(".*?"|\w+)')
_schema_pattern: str = field(default=r'CREATE\s+DOMAIN\s+(\w+)\.')

def __post_init__(self):
create_statement = self.create.strip()

# Schema parsing
schema_match = re.search(self._schema_pattern, create_statement, re.IGNORECASE)
schema = schema_match.group(1) if schema_match else 'public'
object.__setattr__(self, '_schema', schema)

# Name parsing
name_match = re.search(self._name_pattern, create_statement, re.IGNORECASE)
if name_match:
name = name_match.group(2).strip('"')
object.__setattr__(self, '_name', name)
else:
raise ValueError("Could not parse the name from the create statement")

def full_sql(self, exists=False) -> str:
sql_parts = [self.create.strip()]

if exists:
sql_parts[0] = sql_parts[0].replace("CREATE DOMAIN", "CREATE DOMAIN IF NOT EXISTS", 1)

if self.comment:
sql_parts.append(self.comment.strip())

return "\n\n".join(sql_parts)
62 changes: 62 additions & 0 deletions tests/items/test_domain.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import unittest
from postnormalism.schema import Domain


class TestDomain(unittest.TestCase):
def test_domain_full_sql(self):
"""Test that the full_sql method of the Domain class returns the expected SQL statement."""
create_domain = """
CREATE DOMAIN "text/html" AS TEXT;
"""

domain = Domain(create=create_domain)
expected_sql = create_domain.strip()

self.assertEqual(domain.full_sql(), expected_sql)

def test_domain_full_sql_with_exists(self):
"""Test that the full_sql method of the Domain class includes the IF NOT EXISTS clause if exists flag is set."""
create_domain = """
CREATE DOMAIN "text/html" AS TEXT;
"""
domain = Domain(create=create_domain)

output_sql = domain.full_sql(exists=True)

self.assertIn("IF NOT EXISTS", output_sql)

def test_invalid_create_domain_statement(self):
"""Test that an invalid CREATE DOMAIN statement raises a ValueError."""
invalid_create_domain = """
CREATE "text/html" AS TEXT;
"""
with self.assertRaises(ValueError):
Domain(create=invalid_create_domain)

def test_domain_in_non_public_schema(self):
"""Test that a domain is correctly created inside a non-public schema."""
create_domain = """
CREATE DOMAIN custom_schema."text/html" AS TEXT;
"""

domain = Domain(create=create_domain)
expected_sql = create_domain.strip()
self.assertEqual(domain.full_sql(), expected_sql)
self.assertEqual(domain.schema, "custom_schema")
self.assertEqual(domain.name, "text/html")

def test_domain_in_non_public_schema_with_exists(self):
"""Test that a domain in a non-public schema includes IF NOT EXISTS when exists=True."""
create_domain = """
CREATE DOMAIN custom_schema."text/html" AS TEXT;
"""

domain = Domain(create=create_domain)
output_sql = domain.full_sql(exists=True)
self.assertIn("IF NOT EXISTS", output_sql)
self.assertEqual(domain.schema, "custom_schema")
self.assertEqual(domain.name, "text/html")


if __name__ == '__main__':
unittest.main()

0 comments on commit 51799b7

Please sign in to comment.