diff --git a/.gitignore b/.gitignore index 5ba2bf6..4c0378c 100644 --- a/.gitignore +++ b/.gitignore @@ -7,3 +7,4 @@ dist/ tests/generator/_models.py ./models.py poetry.lock +*~ \ No newline at end of file diff --git a/omymodels/from_ddl.py b/omymodels/from_ddl.py index f5e673c..6273a7f 100644 --- a/omymodels/from_ddl.py +++ b/omymodels/from_ddl.py @@ -1,3 +1,4 @@ +import copy import os import re import sys @@ -21,9 +22,11 @@ def get_tables_information( "contains ddl or ddl_file that contains path to ddl file to parse" ) if ddl: - tables = DDLParser(ddl).run(group_by_type=True) + tables = DDLParser(ddl, normalize_names=True).run(group_by_type=True) elif ddl_file: - tables = parse_from_file(ddl_file, group_by_type=True) + tables = parse_from_file( + ddl_file, parser_settings={"normalize_names": True}, group_by_type=True + ) return tables @@ -69,11 +72,31 @@ def snake_case(string: str) -> str: def convert_ddl_to_models(data: Dict, no_auto_snake_case: bool) -> Dict[str, list]: final_data = {"tables": [], "types": []} + refs = {} tables = [] for table in data["tables"]: + for ref in table.get("constraints", {}).get("references", []): + # References can be compopund references. Here we split into one + # reference per column and then attach it to the column in the next + # loop. + for i in range(len(ref["columns"])): + ref_name = ( + ref["name"].split(",")[i] + if isinstance(ref["name"], str) + else ref["name"][i] + ) + if not no_auto_snake_case: + ref_name = snake_case(ref_name) + single_ref = copy.deepcopy(ref) + single_ref["column"] = ref["columns"][i] + del single_ref["columns"] + ref_name = ref_name.replace('"', "") + refs[ref_name] = single_ref for column in table["columns"]: if not no_auto_snake_case: column["name"] = snake_case(column["name"]) + if column["name"] in refs: + column["references"] = refs[column["name"]] tables.append(TableMeta(**table)) final_data["tables"] = tables _types = [] diff --git a/pyproject.toml b/pyproject.toml index 3d66a5b..5d8400d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,7 +24,7 @@ classifiers = [ [tool.poetry.dependencies] python = ">=3.7,<4.0" -simple-ddl-parser = "^0.28" +simple-ddl-parser = "^1.0.0" Jinja2 = "^3.0.1" py-models-parser = "^0.7.0" pydantic = "^1.8.2" diff --git a/tests/functional/generator/test_sqlalchemy.py b/tests/functional/generator/test_sqlalchemy.py index 7765a05..95dc0c4 100644 --- a/tests/functional/generator/test_sqlalchemy.py +++ b/tests/functional/generator/test_sqlalchemy.py @@ -124,6 +124,135 @@ class Attachments(Base): assert result == expected +def test_foreign_keys_defined_inline(): + """ + This should be the same output as test_foreign_keys, but with a slightly + different, yet valid input DDL. + """ + expected = """import sqlalchemy as sa +from sqlalchemy.ext.declarative import declarative_base + + +Base = declarative_base() + + +class Materials(Base): + + __tablename__ = 'materials' + + id = sa.Column(sa.Integer(), primary_key=True) + title = sa.Column(sa.String(), nullable=False) + description = sa.Column(sa.String()) + link = sa.Column(sa.String()) + created_at = sa.Column(sa.TIMESTAMP()) + updated_at = sa.Column(sa.TIMESTAMP()) + + +class MaterialAttachments(Base): + + __tablename__ = 'material_attachments' + + material_id = sa.Column(sa.Integer(), sa.ForeignKey('materials.id')) + attachment_id = sa.Column(sa.Integer(), sa.ForeignKey('attachments.id')) + + +class Attachments(Base): + + __tablename__ = 'attachments' + + id = sa.Column(sa.Integer(), primary_key=True) + title = sa.Column(sa.String()) + description = sa.Column(sa.String()) + created_at = sa.Column(sa.TIMESTAMP()) + updated_at = sa.Column(sa.TIMESTAMP()) +""" + ddl = """ + + CREATE TABLE "materials" ( + "id" int PRIMARY KEY, + "title" varchar NOT NULL, + "description" varchar, + "link" varchar, + "created_at" timestamp, + "updated_at" timestamp + ); + + CREATE TABLE "material_attachments" ( + "material_id" int, + "attachment_id" int, + CONSTRAINT "material_id_ibfk" FOREIGN KEY ("material_id") REFERENCES "materials" ("id"), + CONSTRAINT "attachment_id_ibfk" FOREIGN KEY ("attachment_id") REFERENCES "attachments" ("id") + ); + + CREATE TABLE "attachments" ( + "id" int PRIMARY KEY, + "title" varchar, + "description" varchar, + "created_at" timestamp, + "updated_at" timestamp + ); + """ + result = create_models(ddl, models_type="sqlalchemy")["code"] + assert result == expected + + +def test_multi_col_pk_and_fk(): + """ + This should test that we can properly setup tables with compound PRIMARY and + FOREIGN keys. + """ + expected = """import sqlalchemy as sa +from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy.sql import func + + +Base = declarative_base() + + +class Complexpk(Base): + + __tablename__ = 'complexpk' + + complex_id = sa.Column(int unsigned(), primary_key=True) + date_part = sa.Column(sa.DateTime(), server_default=func.now(), primary_key=True) + title = sa.Column(sa.String(), nullable=False) + description = sa.Column(sa.String()) + + +class LinkedTo(Base): + + __tablename__ = 'linked_to' + + id = sa.Column(sa.Integer(), primary_key=True) + complexpk_complex_id = sa.Column(sa.Integer(), sa.ForeignKey('complexpk.complex_id')) + complexpk_date_part = sa.Column(sa.Integer(), sa.ForeignKey('complexpk.date_part')) + comment = sa.Column(sa.String()) +""" + + ddl = """ + + CREATE TABLE "complexpk" ( + "complex_id" int unsigned NOT NULL, + "date_part" datetime NOT NULL DEFAULT CURRENT_TIMESTAMP, + "title" varchar NOT NULL, + "description" varchar, + PRIMARY KEY ("complex_id","date_part") + ); + + CREATE TABLE "linked_to" ( + "id" int PRIMARY KEY, + "complexpk_complex_id" int, + "complexpk_date_part" int, + "comment" varchar, + CONSTRAINT "id_date_part_ibfk" FOREIGN KEY ("complexpk_complex_id", "complexpk_date_part") + REFERENCES "complexpk" ("complex_id", "date_part") + ); + + """ + result = create_models(ddl, models_type="sqlalchemy")["code"] + assert result == expected + + def test_upper_name_produces_the_same_result(): expected = """import sqlalchemy as sa from sqlalchemy.ext.declarative import declarative_base diff --git a/tests/unit/test_common.py b/tests/unit/test_common.py index e20018a..be3b924 100644 --- a/tests/unit/test_common.py +++ b/tests/unit/test_common.py @@ -41,7 +41,6 @@ def test_mssql_brackets_removed(): """ result = create_models(ddl) expected = """from gino import Gino -from sqlalchemy.dialects.postgresql import ARRAY db = Gino(schema="dbo") @@ -50,7 +49,7 @@ class UsersWorkSchedule(db.Model): __tablename__ = 'users_WorkSchedule' - id = db.Column(ARRAY((1,1)), primary_key=True) + id = db.Column(db.Integer(), primary_key=True) request_drop_date = db.Column(smalldatetime()) shift_class = db.Column(db.String(5)) start_history = db.Column(datetime2(7), nullable=False)