-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
4 changed files
with
105 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
from dataclasses import dataclass, field | ||
import re | ||
|
||
from .database_item import DatabaseItem | ||
|
||
|
||
@dataclass(frozen=True) | ||
class Domain(DatabaseItem): | ||
""" | ||
A data class for PostgreSQL domains. | ||
""" | ||
_item_type: str = 'domain' | ||
_name_pattern: str = field(default=r'CREATE\s+DOMAIN\s+(?:(\w+)\.)?(".*?"|\w+)') | ||
_schema_pattern: str = field(default=r'CREATE\s+DOMAIN\s+(\w+)\.') | ||
|
||
def __post_init__(self): | ||
create_statement = self.create.strip() | ||
|
||
# Schema parsing | ||
schema_match = re.search(self._schema_pattern, create_statement, re.IGNORECASE) | ||
schema = schema_match.group(1) if schema_match else 'public' | ||
object.__setattr__(self, '_schema', schema) | ||
|
||
# Name parsing | ||
name_match = re.search(self._name_pattern, create_statement, re.IGNORECASE) | ||
if name_match: | ||
name = name_match.group(2).strip('"') | ||
object.__setattr__(self, '_name', name) | ||
else: | ||
raise ValueError("Could not parse the name from the create statement") | ||
|
||
def full_sql(self, exists=False) -> str: | ||
sql_parts = [self.create.strip()] | ||
|
||
if exists: | ||
sql_parts[0] = sql_parts[0].replace("CREATE DOMAIN", "CREATE DOMAIN IF NOT EXISTS", 1) | ||
|
||
if self.comment: | ||
sql_parts.append(self.comment.strip()) | ||
|
||
return "\n\n".join(sql_parts) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
import unittest | ||
from postnormalism.schema import Domain | ||
|
||
|
||
class TestDomain(unittest.TestCase): | ||
def test_domain_full_sql(self): | ||
"""Test that the full_sql method of the Domain class returns the expected SQL statement.""" | ||
create_domain = """ | ||
CREATE DOMAIN "text/html" AS TEXT; | ||
""" | ||
|
||
domain = Domain(create=create_domain) | ||
expected_sql = create_domain.strip() | ||
|
||
self.assertEqual(domain.full_sql(), expected_sql) | ||
|
||
def test_domain_full_sql_with_exists(self): | ||
"""Test that the full_sql method of the Domain class includes the IF NOT EXISTS clause if exists flag is set.""" | ||
create_domain = """ | ||
CREATE DOMAIN "text/html" AS TEXT; | ||
""" | ||
domain = Domain(create=create_domain) | ||
|
||
output_sql = domain.full_sql(exists=True) | ||
|
||
self.assertIn("IF NOT EXISTS", output_sql) | ||
|
||
def test_invalid_create_domain_statement(self): | ||
"""Test that an invalid CREATE DOMAIN statement raises a ValueError.""" | ||
invalid_create_domain = """ | ||
CREATE "text/html" AS TEXT; | ||
""" | ||
with self.assertRaises(ValueError): | ||
Domain(create=invalid_create_domain) | ||
|
||
def test_domain_in_non_public_schema(self): | ||
"""Test that a domain is correctly created inside a non-public schema.""" | ||
create_domain = """ | ||
CREATE DOMAIN custom_schema."text/html" AS TEXT; | ||
""" | ||
|
||
domain = Domain(create=create_domain) | ||
expected_sql = create_domain.strip() | ||
self.assertEqual(domain.full_sql(), expected_sql) | ||
self.assertEqual(domain.schema, "custom_schema") | ||
self.assertEqual(domain.name, "text/html") | ||
|
||
def test_domain_in_non_public_schema_with_exists(self): | ||
"""Test that a domain in a non-public schema includes IF NOT EXISTS when exists=True.""" | ||
create_domain = """ | ||
CREATE DOMAIN custom_schema."text/html" AS TEXT; | ||
""" | ||
|
||
domain = Domain(create=create_domain) | ||
output_sql = domain.full_sql(exists=True) | ||
self.assertIn("IF NOT EXISTS", output_sql) | ||
self.assertEqual(domain.schema, "custom_schema") | ||
self.assertEqual(domain.name, "text/html") | ||
|
||
|
||
if __name__ == '__main__': | ||
unittest.main() |