From 4893e38f07472c3553736c1947144004cbaf38c8 Mon Sep 17 00:00:00 2001 From: cfhowes Date: Sat, 10 Feb 2024 23:56:01 -0800 Subject: [PATCH] Sqlmodel support (#58) * Improved handling of different ways of specifing references in the DDL. * Flake fixes * Initial SQLModel support, basic tests pass, some extra types added. * Add preliminary support for more types. * Some fixes based on trying to run the generated code. * Ensure snake_case of indexes and constraints. * Updated tests. --- omymodels/from_ddl.py | 14 +- omymodels/generators.py | 2 + omymodels/models/sqlmodel/core.py | 169 ++++++++ omymodels/models/sqlmodel/sqlmodel.jinja2 | 7 + omymodels/models/sqlmodel/templates.py | 54 +++ omymodels/models/sqlmodel/types.py | 68 ++++ .../generator/test_sqlalchemy_core.py | 2 +- tests/functional/generator/test_sqlmodel.py | 373 ++++++++++++++++++ 8 files changed, 687 insertions(+), 2 deletions(-) create mode 100644 omymodels/models/sqlmodel/core.py create mode 100644 omymodels/models/sqlmodel/sqlmodel.jinja2 create mode 100644 omymodels/models/sqlmodel/templates.py create mode 100644 omymodels/models/sqlmodel/types.py create mode 100644 tests/functional/generator/test_sqlmodel.py diff --git a/omymodels/from_ddl.py b/omymodels/from_ddl.py index 6273a7f..3207b7c 100644 --- a/omymodels/from_ddl.py +++ b/omymodels/from_ddl.py @@ -70,7 +70,9 @@ def snake_case(string: str) -> str: return re.sub(r"(? Dict[str, list]: +def convert_ddl_to_models( # noqa: C901 + data: Dict, no_auto_snake_case: bool +) -> Dict[str, list]: final_data = {"tables": [], "types": []} refs = {} tables = [] @@ -97,6 +99,16 @@ def convert_ddl_to_models(data: Dict, no_auto_snake_case: bool) -> Dict[str, lis column["name"] = snake_case(column["name"]) if column["name"] in refs: column["references"] = refs[column["name"]] + if not no_auto_snake_case: + table["primary_key"] = [snake_case(pk) for pk in table["primary_key"]] + for uniq in table.get("constraints", {}).get("uniques", []): + uniq["columns"] = [snake_case(c) for c in uniq["columns"]] + # NOTE: We are not going to try and parse check constraint statements + # and update the snake case. + for idx in table.get("index", []): + idx["columns"] = [snake_case(c) for c in idx["columns"]] + for col_detail in idx["detailed_columns"]: + col_detail["name"] = snake_case(col_detail["name"]) tables.append(TableMeta(**table)) final_data["tables"] = tables _types = [] diff --git a/omymodels/generators.py b/omymodels/generators.py index f4f515f..d7c4ebb 100644 --- a/omymodels/generators.py +++ b/omymodels/generators.py @@ -8,6 +8,7 @@ from omymodels.models.pydantic import core as p from omymodels.models.sqlalchemy import core as s from omymodels.models.sqlalchemy_core import core as sc +from omymodels.models.sqlmodel import core as sm models = { "gino": g, @@ -15,6 +16,7 @@ "dataclass": d, "sqlalchemy": s, "sqlalchemy_core": sc, + "sqlmodel": sm, } diff --git a/omymodels/models/sqlmodel/core.py b/omymodels/models/sqlmodel/core.py new file mode 100644 index 0000000..70d63cd --- /dev/null +++ b/omymodels/models/sqlmodel/core.py @@ -0,0 +1,169 @@ +from typing import Dict, List, Optional + +from table_meta.model import Column + +import omymodels.models.sqlmodel.templates as st +from omymodels import logic, types +from omymodels.helpers import create_class_name, datetime_now_check +from omymodels.models.sqlmodel.types import types_mapping +from omymodels.types import datetime_types + + +class GeneratorBase: + def __init__(self): + self.custom_types = {} + + +class ModelGenerator(GeneratorBase): + def __init__(self): + self.state = set() + self.postgresql_dialect_cols = set() + self.constraint = False + self.im_index = False + self.types_mapping = types_mapping + self.templates = st + self.prefix = "sa." + super().__init__() + + def prepare_column_default(self, column_data: Dict, column: str) -> str: + if isinstance(column_data.default, str): + if column_data.type.upper() in datetime_types: + if datetime_now_check(column_data.default.lower()): + # todo: need to add other popular PostgreSQL & MySQL functions + column_data.default = "func.now()" + self.state.add("func") + elif "'" not in column_data.default: + column_data.default = f"'{column_data.default}'" + else: + if "'" not in column_data.default: + column_data.default = f"'{column_data.default}'" + else: + column_data.default = f"'{str(column_data.default)}'" + column += st.default.format(default=column_data.default) + return column + + def add_custom_type_orm( + self, custom_types: Dict, column_data_type: str, column_type: str + ) -> dict: + column_type_data = None + if "." in column_data_type: + column_data_type = column_data_type.split(".")[1] + column_type = custom_types.get(column_data_type, column_type) + + if isinstance(column_type, tuple): + column_data_type = column_type[1] + column_type = column_type[0] + if column_type is not None: + column_type_data = { + "pydantic": column_data_type, + "sa": f"{column_type}({column_data_type})", + } + return column_type_data + + def prepare_column_type(self, column_data: Column) -> str: + column_type = None + column_data = types.prepare_column_data(column_data) + + if self.custom_types: + column_type = self.add_custom_type_orm( + self.custom_types, column_data.type, column_type + ) + + if not column_type: + column_type = types.prepare_type(column_data, self.types_mapping) + if column_type["sa"] in types.postgresql_dialect: + self.postgresql_dialect_cols.add(column_type["sa"]) + + if "[" in column_data.type and column_data.type not in types.json_types: + # @TODO: How do we handle arrays for SQLModel? + self.postgresql_dialect_cols.add("ARRAY") + column_type = f"ARRAY({column_type})" + return column_type + + def add_table_args( + self, model: str, table: Dict, schema_global: bool = True + ) -> str: + statements = [] + t = self.templates + if table.indexes: + for index in table.indexes: + if not index["unique"]: + self.im_index = True + statements.append( + t.index_template.format( + columns="', '".join(index["columns"]), + name=f"'{index['index_name']}'", + ) + ) + else: + self.constraint = True + statements.append( + t.unique_index_template.format( + columns=",".join(index["columns"]), + name=f"'{index['index_name']}'", + ) + ) + if not schema_global and table.table_schema: + statements.append(t.schema.format(schema_name=table.table_schema)) + if statements: + model += t.table_args.format(statements=",".join(statements)) + return model + + def generate_model( + self, + table: Dict, + singular: bool = True, + exceptions: Optional[List] = None, + schema_global: Optional[bool] = True, + *args, + **kwargs, + ) -> str: + """method to prepare one Model defention - name & tablename & columns""" + model = "" + + model = st.model_template.format( + model_name=create_class_name(table.name, singular, exceptions), + table_name=table.name, + ) + for column in table.columns: + column_type = self.prepare_column_type(column) + pydantic_type_str = column_type["pydantic"] + if column.nullable or column.name in table.primary_key: + pydantic_type_str = f"Optional[{pydantic_type_str}]" + col_str = st.column_template.format( + column_name=column.name.replace(" ", "_"), column_type=pydantic_type_str + ) + + col_str = logic.setup_column_attributes( + column, table.primary_key, col_str, table, schema_global, st, self + ) + if column_type["sa"]: + sa_type = types.add_size_to_orm_column(column_type["sa"], column) + col_str += st.sa_type.format(satype=sa_type) + col_str += ")\n" + + col_str = col_str.replace("(, ", "(") + + model += col_str + if table.indexes or table.alter or table.checks or not schema_global: + model = self.add_table_args(model, table, schema_global) + return model + + def create_header(self, tables: List[Dict], schema: bool = False) -> str: + """header of the file - imports & sqlalchemy init""" + header = "" + header += st.sqlalchemy_import # Do we always need this import? + if "func" in self.state: + header += st.sql_alchemy_func_import + "\n" + if self.postgresql_dialect_cols: + header += ( + st.postgresql_dialect_import.format( + types=",".join(self.postgresql_dialect_cols) + ) + + "\n" + ) + if self.constraint: + header += st.unique_cons_import + "\n" + if self.im_index: + header += st.index_import + "\n" + return header diff --git a/omymodels/models/sqlmodel/sqlmodel.jinja2 b/omymodels/models/sqlmodel/sqlmodel.jinja2 new file mode 100644 index 0000000..a0867d8 --- /dev/null +++ b/omymodels/models/sqlmodel/sqlmodel.jinja2 @@ -0,0 +1,7 @@ +import datetime +import decimal +from typing import Optional +from sqlmodel import Field, SQLModel + +{{ headers }} +{{ models }} \ No newline at end of file diff --git a/omymodels/models/sqlmodel/templates.py b/omymodels/models/sqlmodel/templates.py new file mode 100644 index 0000000..3c68247 --- /dev/null +++ b/omymodels/models/sqlmodel/templates.py @@ -0,0 +1,54 @@ +# imports +postgresql_dialect_import = """from sqlalchemy.dialects.postgresql import {types} +from pydantic import Json, UUID4""" +sql_alchemy_func_import = "from sqlalchemy.sql import func" +index_import = "from sqlalchemy import Index" + +sqlalchemy_import = """import sqlalchemy as sa +""" + +unique_cons_import = "from sqlalchemy.schema import UniqueConstraint" +enum_import = "from enum import {enums}" + +# model defenition +model_template = """\n +class {model_name}(SQLModel, table=True):\n + __tablename__ = \'{table_name}\'\n +""" + +# columns defenition +column_template = """ {column_name}: {column_type} = Field(""" +required = "" +default = ", sa_column_kwargs={{'server_default': {default}}}" +pk_template = ", default=None, primary_key=True" +unique = ", unique=True" +autoincrement = "" +index = ", index=True" +sa_type = ", sa_type={satype}" + +# tables properties + +table_args = """ + __table_args__ = ( + {statements}, + ) + +""" +fk_constraint_template = """ + {fk_name} = sa.ForeignKeyConstraint( + [{fk_columns}], [{fk_references_columns}]) +""" +fk_in_column = ", foreign_key='{ref_schema}.{ref_table}.{ref_column}'" +fk_in_column_without_schema = ", foreign_key='{ref_table}.{ref_column}'" + +unique_index_template = """ + UniqueConstraint({columns}, name={name})""" + +index_template = """ + Index({name}, '{columns}')""" + +schema = """ + dict(schema="{schema_name}")""" + +on_delete = ', ondelete="{mode}"' +on_update = ', onupdate="{mode}"' diff --git a/omymodels/models/sqlmodel/types.py b/omymodels/models/sqlmodel/types.py new file mode 100644 index 0000000..852ded6 --- /dev/null +++ b/omymodels/models/sqlmodel/types.py @@ -0,0 +1,68 @@ +from omymodels.types import ( + big_integer_types, + boolean_types, + datetime_types, + float_types, + integer_types, + json_types, + numeric_types, + populate_types_mapping, + string_types, +) + +postgresql_dialect = ["ARRAY", "JSON", "JSONB", "UUID"] + +mapper = { + string_types: {"pydantic": "str", "sa": None}, + integer_types: {"pydantic": "int", "sa": None}, + big_integer_types: {"pydantic": "int", "sa": "sa.BigInteger"}, + float_types: {"pydantic": "float", "sa": None}, + numeric_types: {"pydantic": "decimal.Decimal", "sa": "sa.Numeric"}, + boolean_types: {"pydantic": "bool", "sa": None}, + datetime_types: {"pydantic": "datetime.datetime", "sa": None}, + json_types: {"pydantic": "Json", "sa": "JSON"}, +} + +types_mapping = populate_types_mapping(mapper) + +direct_types = { + "date": {"pydantic": "datetime.date", "sa": None}, + "timestamp": {"pydantic": "datetime.datetime", "sa": None}, + "time": {"pydantic": "datetime.time", "sa": "sa.Time"}, + "text": {"pydantic": "str", "sa": "sa.Text"}, + "longtext": {"pydantic": "str", "sa": "sa.Text"}, # confirm this is proper SA type. + "mediumtext": { + "pydantic": "str", + "sa": "sa.Text", + }, # confirm this is proper SA type. + "tinytext": {"pydantic": "str", "sa": "sa.Text"}, # confirm this is proper SA type. + "smallint": {"pydantic": "int", "sa": "sa.SmallInteger"}, + "jsonb": {"pydantic": "Json", "sa": "JSONB"}, + "uuid": {"pydantic": "UUID4", "sa": "UUID"}, + "real": {"pydantic": "float", "sa": "sa.REAL"}, + "int unsigned": { + "pydantic": "int", + "sa": None, + }, # use mysql INTEGER(unsigned=True) for sa? + "tinyint": {"pydantic": "int", "sa": None}, # what's the proper type for this? + "tinyint unsigned": { + "pydantic": "int", + "sa": None, + }, # what's the proper type for this? + "smallint unsigned": { + "pydantic": "int", + "sa": None, + }, # what's the proper type for this? + "bigint unsigned": {"pydantic": "int", "sa": "sa.BigInteger"}, + # see https://sqlmodel.tiangolo.com/advanced/decimal/#decimals-in-sqlmodel + "decimal unsigned": {"pydantic": "decimal.Decimal", "sa": None}, + "decimalunsigned": {"pydantic": "decimal.Decimal", "sa": None}, + # Points need extensions: + # geojson_pydantic and this import: from geojson_pydantic.geometries import Point + # geoalchemy2 and this import: from geoalchemy2 import Geometry + # NOTE: SRID is not parsed currently. Default values likely not correct. + "point": {"pydantic": "Point", "sa": "Geography(geometry_type='POINT')"}, + "blob": {"pydantic": "bytes", "sa": "sa.LargeBinary"}, +} + +types_mapping.update(direct_types) diff --git a/tests/functional/generator/test_sqlalchemy_core.py b/tests/functional/generator/test_sqlalchemy_core.py index 68a5381..f638124 100644 --- a/tests/functional/generator/test_sqlalchemy_core.py +++ b/tests/functional/generator/test_sqlalchemy_core.py @@ -118,7 +118,7 @@ class Products(Base): __tablename__ = 'products' - id = sa.Column(sa.Integer(), nullable=False) + id = sa.Column(sa.Integer(), primary_key=True) merchant_id = sa.Column(sa.Integer(), sa.ForeignKey('merchants.id'), nullable=False) """ diff --git a/tests/functional/generator/test_sqlmodel.py b/tests/functional/generator/test_sqlmodel.py new file mode 100644 index 0000000..42d0bd3 --- /dev/null +++ b/tests/functional/generator/test_sqlmodel.py @@ -0,0 +1,373 @@ +from omymodels import create_models + + +def test_with_enums(): + expected = """import datetime +import decimal +from typing import Optional +from sqlmodel import Field, SQLModel + +from enum import Enum +import sqlalchemy as sa +from sqlalchemy.sql import func +from sqlalchemy.dialects.postgresql import JSON +from pydantic import Json, UUID4 + + + +class MaterialType(str, Enum): + + article = 'article' + video = 'video' + + +class Material(SQLModel, table=True): + + __tablename__ = 'material' + + id: Optional[int] = Field(default=None, primary_key=True) + title: str = Field() + description: Optional[str] = Field(sa_type=sa.Text()) + link: str = Field() + type: Optional[MaterialType] = Field(sa_type=sa.Enum(MaterialType)) + additional_properties: Optional[Json] = Field(sa_column_kwargs={'server_default': '{"key": "value"}'}, sa_type=JSON()) + created_at: Optional[datetime.datetime] = Field(sa_column_kwargs={'server_default': func.now()}) + updated_at: Optional[datetime.datetime] = Field() +""" # noqa: E501 + ddl = """ +CREATE TYPE "material_type" AS ENUM ( + 'video', + 'article' +); + +CREATE TABLE "material" ( + "id" SERIAL PRIMARY KEY, + "title" varchar NOT NULL, + "description" text, + "link" varchar NOT NULL, + "type" material_type, + "additional_properties" json DEFAULT '{"key": "value"}', + "created_at" timestamp DEFAULT (now()), + "updated_at" timestamp +); +""" + result = create_models(ddl, models_type="sqlmodel") + assert expected == result["code"] + + +def test_foreign_keys(): + expected = """import datetime +import decimal +from typing import Optional +from sqlmodel import Field, SQLModel + +import sqlalchemy as sa + + + +class Materials(SQLModel, table=True): + + __tablename__ = 'materials' + + id: Optional[int] = Field(default=None, primary_key=True) + title: str = Field() + description: Optional[str] = Field() + link: Optional[str] = Field() + created_at: Optional[datetime.datetime] = Field() + updated_at: Optional[datetime.datetime] = Field() + + +class MaterialAttachments(SQLModel, table=True): + + __tablename__ = 'material_attachments' + + material_id: Optional[int] = Field(foreign_key='materials.id') + attachment_id: Optional[int] = Field(foreign_key='attachments.id') + + +class Attachments(SQLModel, table=True): + + __tablename__ = 'attachments' + + id: Optional[int] = Field(default=None, primary_key=True) + title: Optional[str] = Field() + description: Optional[str] = Field() + created_at: Optional[datetime.datetime] = Field() + updated_at: Optional[datetime.datetime] = Field() +""" + ddl = """ + + CREATE TABLE "materials" ( + "id" int PRIMARY KEY, + "title" varchar NOT NULL, + "description" varchar, + "link" varchar, + "created_at" timestamp, + "updated_at" timestamp + ); + + CREATE TABLE "material_attachments" ( + "material_id" int, + "attachment_id" int + ); + + CREATE TABLE "attachments" ( + "id" int PRIMARY KEY, + "title" varchar, + "description" varchar, + "created_at" timestamp, + "updated_at" timestamp + ); + + + ALTER TABLE "material_attachments" ADD FOREIGN KEY ("material_id") REFERENCES "materials" ("id"); + + ALTER TABLE "material_attachments" ADD FOREIGN KEY ("attachment_id") REFERENCES "attachments" ("id"); + + """ + result = create_models(ddl, models_type="sqlmodel")["code"] + assert result == expected + + +def test_foreign_keys_defined_inline(): + """ + This should be the same output as test_foreign_keys, but with a slightly + different, yet valid input DDL. + """ + expected = """import datetime +import decimal +from typing import Optional +from sqlmodel import Field, SQLModel + +import sqlalchemy as sa + + + +class Materials(SQLModel, table=True): + + __tablename__ = 'materials' + + id: Optional[int] = Field(default=None, primary_key=True) + title: str = Field() + description: Optional[str] = Field() + link: Optional[str] = Field() + created_at: Optional[datetime.datetime] = Field() + updated_at: Optional[datetime.datetime] = Field() + + +class MaterialAttachments(SQLModel, table=True): + + __tablename__ = 'material_attachments' + + material_id: Optional[int] = Field(foreign_key='materials.id') + attachment_id: Optional[int] = Field(foreign_key='attachments.id') + + +class Attachments(SQLModel, table=True): + + __tablename__ = 'attachments' + + id: Optional[int] = Field(default=None, primary_key=True) + title: Optional[str] = Field() + description: Optional[str] = Field() + created_at: Optional[datetime.datetime] = Field() + updated_at: Optional[datetime.datetime] = Field() +""" + ddl = """ + + CREATE TABLE "materials" ( + "id" int PRIMARY KEY, + "title" varchar NOT NULL, + "description" varchar, + "link" varchar, + "created_at" timestamp, + "updated_at" timestamp + ); + + CREATE TABLE "material_attachments" ( + "material_id" int, + "attachment_id" int, + CONSTRAINT "material_id_ibfk" FOREIGN KEY ("material_id") REFERENCES "materials" ("id"), + CONSTRAINT "attachment_id_ibfk" FOREIGN KEY ("attachment_id") REFERENCES "attachments" ("id") + ); + + CREATE TABLE "attachments" ( + "id" int PRIMARY KEY, + "title" varchar, + "description" varchar, + "created_at" timestamp, + "updated_at" timestamp + ); + """ + result = create_models(ddl, models_type="sqlmodel")["code"] + assert result == expected + + +def test_multi_col_pk_and_fk(): + """ + This should test that we can properly setup tables with compound PRIMARY and + FOREIGN keys. + """ + expected = """import datetime +import decimal +from typing import Optional +from sqlmodel import Field, SQLModel + +import sqlalchemy as sa +from sqlalchemy.sql import func + + + +class Complexpk(SQLModel, table=True): + + __tablename__ = 'complexpk' + + complex_id: Optional[int] = Field(default=None, primary_key=True) + date_part: Optional[datetime.datetime] = Field(sa_column_kwargs={'server_default': func.now()}, default=None, primary_key=True) + title: str = Field() + description: Optional[str] = Field() + + +class LinkedTo(SQLModel, table=True): + + __tablename__ = 'linked_to' + + id: Optional[int] = Field(default=None, primary_key=True) + complexpk_complex_id: Optional[int] = Field(foreign_key='complexpk.complex_id') + complexpk_date_part: Optional[int] = Field(foreign_key='complexpk.date_part') + comment: Optional[str] = Field() +""" # noqa: E501 + + ddl = """ + + CREATE TABLE "complexpk" ( + "complex_id" int unsigned NOT NULL, + "date_part" datetime NOT NULL DEFAULT CURRENT_TIMESTAMP, + "title" varchar NOT NULL, + "description" varchar, + PRIMARY KEY ("complex_id","date_part") + ); + + CREATE TABLE "linked_to" ( + "id" int PRIMARY KEY, + "complexpk_complex_id" int, + "complexpk_date_part" int, + "comment" varchar, + CONSTRAINT "id_date_part_ibfk" FOREIGN KEY ("complexpk_complex_id", "complexpk_date_part") + REFERENCES "complexpk" ("complex_id", "date_part") + ); + + """ + result = create_models(ddl, models_type="sqlmodel")["code"] + assert result == expected + + +def test_upper_name_produces_the_same_result(): + expected = """import datetime +import decimal +from typing import Optional +from sqlmodel import Field, SQLModel + +from enum import Enum +import sqlalchemy as sa +from sqlalchemy.sql import func +from sqlalchemy.dialects.postgresql import JSON +from pydantic import Json, UUID4 + + + +class MaterialType(str, Enum): + + article = 'article' + video = 'video' + + +class Material(SQLModel, table=True): + + __tablename__ = 'material' + + id: Optional[int] = Field(default=None, primary_key=True) + title: str = Field() + description: Optional[str] = Field(sa_type=sa.Text()) + link: str = Field() + type: Optional[MaterialType] = Field(sa_type=sa.Enum(MaterialType)) + additional_properties: Optional[Json] = Field(sa_column_kwargs={'server_default': '{"key": "value"}'}, sa_type=JSON()) + created_at: Optional[datetime.datetime] = Field(sa_column_kwargs={'server_default': func.now()}) + updated_at: Optional[datetime.datetime] = Field() +""" # noqa: E501 + ddl = """ +CREATE TYPE "material_type" AS ENUM ( + 'video', + 'article' +); + +CREATE TABLE "material" ( + "id" SERIAL PRIMARY KEY, + "title" varchar NOT NULL, + "description" text, + "link" varchar NOT NULL, + "type" material_type, + "additional_properties" json DEFAULT '{"key": "value"}', + "created_at" timestamp DEFAULT (NOW()), + "updated_at" timestamp +); +""" + result = create_models(ddl, models_type="sqlmodel") + assert expected == result["code"] + + +def test_foreign_keys_in_different_schema(): + expected = """import datetime +import decimal +from typing import Optional +from sqlmodel import Field, SQLModel + +import sqlalchemy as sa + + + +class Table1(SQLModel, table=True): + + __tablename__ = 'table1' + + id: Optional[int] = Field(default=None, primary_key=True) + reference_to_table_in_another_schema: int = Field(foreign_key='schema2.table2.id') + + __table_args__ = ( + + dict(schema="schema1"), + ) + + + +class Table2(SQLModel, table=True): + + __tablename__ = 'table2' + + id: Optional[int] = Field(default=None, primary_key=True) + + __table_args__ = ( + + dict(schema="schema2"), + ) + +""" + ddl = """ +CREATE SCHEMA "schema1"; + +CREATE SCHEMA "schema2"; + +CREATE TABLE "schema1"."table1" ( + "id" int PRIMARY KEY, + "reference_to_table_in_another_schema" int NOT NULL +); + +CREATE TABLE "schema2"."table2" ( + "id" int PRIMARY KEY +); + +ALTER TABLE "schema1"."table1" ADD FOREIGN KEY +("reference_to_table_in_another_schema") REFERENCES "schema2"."table2" ("id"); +""" + result = create_models(ddl, schema_global=False, models_type="sqlmodel")["code"] + assert result == expected