Skip to content

Commit

Permalink
Sqlmodel support (#58)
Browse files Browse the repository at this point in the history
* Improved handling of different ways of specifing references in the DDL.

* Flake fixes

* Initial SQLModel support, basic tests pass, some extra types added.

* Add preliminary support for more types.

* Some fixes based on trying to run the generated code.

* Ensure snake_case of indexes and constraints.

* Updated tests.
  • Loading branch information
cfhowes authored Feb 11, 2024
1 parent aeee7b3 commit 4893e38
Show file tree
Hide file tree
Showing 8 changed files with 687 additions and 2 deletions.
14 changes: 13 additions & 1 deletion omymodels/from_ddl.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,9 @@ def snake_case(string: str) -> str:
return re.sub(r"(?<!^)(?=[A-Z])", "_", string).lower()


def convert_ddl_to_models(data: Dict, no_auto_snake_case: bool) -> Dict[str, list]:
def convert_ddl_to_models( # noqa: C901
data: Dict, no_auto_snake_case: bool
) -> Dict[str, list]:
final_data = {"tables": [], "types": []}
refs = {}
tables = []
Expand All @@ -97,6 +99,16 @@ def convert_ddl_to_models(data: Dict, no_auto_snake_case: bool) -> Dict[str, lis
column["name"] = snake_case(column["name"])
if column["name"] in refs:
column["references"] = refs[column["name"]]
if not no_auto_snake_case:
table["primary_key"] = [snake_case(pk) for pk in table["primary_key"]]
for uniq in table.get("constraints", {}).get("uniques", []):
uniq["columns"] = [snake_case(c) for c in uniq["columns"]]
# NOTE: We are not going to try and parse check constraint statements
# and update the snake case.
for idx in table.get("index", []):
idx["columns"] = [snake_case(c) for c in idx["columns"]]
for col_detail in idx["detailed_columns"]:
col_detail["name"] = snake_case(col_detail["name"])
tables.append(TableMeta(**table))
final_data["tables"] = tables
_types = []
Expand Down
2 changes: 2 additions & 0 deletions omymodels/generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,15 @@
from omymodels.models.pydantic import core as p
from omymodels.models.sqlalchemy import core as s
from omymodels.models.sqlalchemy_core import core as sc
from omymodels.models.sqlmodel import core as sm

models = {
"gino": g,
"pydantic": p,
"dataclass": d,
"sqlalchemy": s,
"sqlalchemy_core": sc,
"sqlmodel": sm,
}


Expand Down
169 changes: 169 additions & 0 deletions omymodels/models/sqlmodel/core.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
from typing import Dict, List, Optional

from table_meta.model import Column

import omymodels.models.sqlmodel.templates as st
from omymodels import logic, types
from omymodels.helpers import create_class_name, datetime_now_check
from omymodels.models.sqlmodel.types import types_mapping
from omymodels.types import datetime_types


class GeneratorBase:
def __init__(self):
self.custom_types = {}


class ModelGenerator(GeneratorBase):
def __init__(self):
self.state = set()
self.postgresql_dialect_cols = set()
self.constraint = False
self.im_index = False
self.types_mapping = types_mapping
self.templates = st
self.prefix = "sa."
super().__init__()

def prepare_column_default(self, column_data: Dict, column: str) -> str:
if isinstance(column_data.default, str):
if column_data.type.upper() in datetime_types:
if datetime_now_check(column_data.default.lower()):
# todo: need to add other popular PostgreSQL & MySQL functions
column_data.default = "func.now()"
self.state.add("func")
elif "'" not in column_data.default:
column_data.default = f"'{column_data.default}'"
else:
if "'" not in column_data.default:
column_data.default = f"'{column_data.default}'"
else:
column_data.default = f"'{str(column_data.default)}'"
column += st.default.format(default=column_data.default)
return column

def add_custom_type_orm(
self, custom_types: Dict, column_data_type: str, column_type: str
) -> dict:
column_type_data = None
if "." in column_data_type:
column_data_type = column_data_type.split(".")[1]
column_type = custom_types.get(column_data_type, column_type)

if isinstance(column_type, tuple):
column_data_type = column_type[1]
column_type = column_type[0]
if column_type is not None:
column_type_data = {
"pydantic": column_data_type,
"sa": f"{column_type}({column_data_type})",
}
return column_type_data

def prepare_column_type(self, column_data: Column) -> str:
column_type = None
column_data = types.prepare_column_data(column_data)

if self.custom_types:
column_type = self.add_custom_type_orm(
self.custom_types, column_data.type, column_type
)

if not column_type:
column_type = types.prepare_type(column_data, self.types_mapping)
if column_type["sa"] in types.postgresql_dialect:
self.postgresql_dialect_cols.add(column_type["sa"])

if "[" in column_data.type and column_data.type not in types.json_types:
# @TODO: How do we handle arrays for SQLModel?
self.postgresql_dialect_cols.add("ARRAY")
column_type = f"ARRAY({column_type})"
return column_type

def add_table_args(
self, model: str, table: Dict, schema_global: bool = True
) -> str:
statements = []
t = self.templates
if table.indexes:
for index in table.indexes:
if not index["unique"]:
self.im_index = True
statements.append(
t.index_template.format(
columns="', '".join(index["columns"]),
name=f"'{index['index_name']}'",
)
)
else:
self.constraint = True
statements.append(
t.unique_index_template.format(
columns=",".join(index["columns"]),
name=f"'{index['index_name']}'",
)
)
if not schema_global and table.table_schema:
statements.append(t.schema.format(schema_name=table.table_schema))
if statements:
model += t.table_args.format(statements=",".join(statements))
return model

def generate_model(
self,
table: Dict,
singular: bool = True,
exceptions: Optional[List] = None,
schema_global: Optional[bool] = True,
*args,
**kwargs,
) -> str:
"""method to prepare one Model defention - name & tablename & columns"""
model = ""

model = st.model_template.format(
model_name=create_class_name(table.name, singular, exceptions),
table_name=table.name,
)
for column in table.columns:
column_type = self.prepare_column_type(column)
pydantic_type_str = column_type["pydantic"]
if column.nullable or column.name in table.primary_key:
pydantic_type_str = f"Optional[{pydantic_type_str}]"
col_str = st.column_template.format(
column_name=column.name.replace(" ", "_"), column_type=pydantic_type_str
)

col_str = logic.setup_column_attributes(
column, table.primary_key, col_str, table, schema_global, st, self
)
if column_type["sa"]:
sa_type = types.add_size_to_orm_column(column_type["sa"], column)
col_str += st.sa_type.format(satype=sa_type)
col_str += ")\n"

col_str = col_str.replace("(, ", "(")

model += col_str
if table.indexes or table.alter or table.checks or not schema_global:
model = self.add_table_args(model, table, schema_global)
return model

def create_header(self, tables: List[Dict], schema: bool = False) -> str:
"""header of the file - imports & sqlalchemy init"""
header = ""
header += st.sqlalchemy_import # Do we always need this import?
if "func" in self.state:
header += st.sql_alchemy_func_import + "\n"
if self.postgresql_dialect_cols:
header += (
st.postgresql_dialect_import.format(
types=",".join(self.postgresql_dialect_cols)
)
+ "\n"
)
if self.constraint:
header += st.unique_cons_import + "\n"
if self.im_index:
header += st.index_import + "\n"
return header
7 changes: 7 additions & 0 deletions omymodels/models/sqlmodel/sqlmodel.jinja2
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
import datetime
import decimal
from typing import Optional
from sqlmodel import Field, SQLModel

{{ headers }}
{{ models }}
54 changes: 54 additions & 0 deletions omymodels/models/sqlmodel/templates.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# imports
postgresql_dialect_import = """from sqlalchemy.dialects.postgresql import {types}
from pydantic import Json, UUID4"""
sql_alchemy_func_import = "from sqlalchemy.sql import func"
index_import = "from sqlalchemy import Index"

sqlalchemy_import = """import sqlalchemy as sa
"""

unique_cons_import = "from sqlalchemy.schema import UniqueConstraint"
enum_import = "from enum import {enums}"

# model defenition
model_template = """\n
class {model_name}(SQLModel, table=True):\n
__tablename__ = \'{table_name}\'\n
"""

# columns defenition
column_template = """ {column_name}: {column_type} = Field("""
required = ""
default = ", sa_column_kwargs={{'server_default': {default}}}"
pk_template = ", default=None, primary_key=True"
unique = ", unique=True"
autoincrement = ""
index = ", index=True"
sa_type = ", sa_type={satype}"

# tables properties

table_args = """
__table_args__ = (
{statements},
)
"""
fk_constraint_template = """
{fk_name} = sa.ForeignKeyConstraint(
[{fk_columns}], [{fk_references_columns}])
"""
fk_in_column = ", foreign_key='{ref_schema}.{ref_table}.{ref_column}'"
fk_in_column_without_schema = ", foreign_key='{ref_table}.{ref_column}'"

unique_index_template = """
UniqueConstraint({columns}, name={name})"""

index_template = """
Index({name}, '{columns}')"""

schema = """
dict(schema="{schema_name}")"""

on_delete = ', ondelete="{mode}"'
on_update = ', onupdate="{mode}"'
68 changes: 68 additions & 0 deletions omymodels/models/sqlmodel/types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
from omymodels.types import (
big_integer_types,
boolean_types,
datetime_types,
float_types,
integer_types,
json_types,
numeric_types,
populate_types_mapping,
string_types,
)

postgresql_dialect = ["ARRAY", "JSON", "JSONB", "UUID"]

mapper = {
string_types: {"pydantic": "str", "sa": None},
integer_types: {"pydantic": "int", "sa": None},
big_integer_types: {"pydantic": "int", "sa": "sa.BigInteger"},
float_types: {"pydantic": "float", "sa": None},
numeric_types: {"pydantic": "decimal.Decimal", "sa": "sa.Numeric"},
boolean_types: {"pydantic": "bool", "sa": None},
datetime_types: {"pydantic": "datetime.datetime", "sa": None},
json_types: {"pydantic": "Json", "sa": "JSON"},
}

types_mapping = populate_types_mapping(mapper)

direct_types = {
"date": {"pydantic": "datetime.date", "sa": None},
"timestamp": {"pydantic": "datetime.datetime", "sa": None},
"time": {"pydantic": "datetime.time", "sa": "sa.Time"},
"text": {"pydantic": "str", "sa": "sa.Text"},
"longtext": {"pydantic": "str", "sa": "sa.Text"}, # confirm this is proper SA type.
"mediumtext": {
"pydantic": "str",
"sa": "sa.Text",
}, # confirm this is proper SA type.
"tinytext": {"pydantic": "str", "sa": "sa.Text"}, # confirm this is proper SA type.
"smallint": {"pydantic": "int", "sa": "sa.SmallInteger"},
"jsonb": {"pydantic": "Json", "sa": "JSONB"},
"uuid": {"pydantic": "UUID4", "sa": "UUID"},
"real": {"pydantic": "float", "sa": "sa.REAL"},
"int unsigned": {
"pydantic": "int",
"sa": None,
}, # use mysql INTEGER(unsigned=True) for sa?
"tinyint": {"pydantic": "int", "sa": None}, # what's the proper type for this?
"tinyint unsigned": {
"pydantic": "int",
"sa": None,
}, # what's the proper type for this?
"smallint unsigned": {
"pydantic": "int",
"sa": None,
}, # what's the proper type for this?
"bigint unsigned": {"pydantic": "int", "sa": "sa.BigInteger"},
# see https://sqlmodel.tiangolo.com/advanced/decimal/#decimals-in-sqlmodel
"decimal unsigned": {"pydantic": "decimal.Decimal", "sa": None},
"decimalunsigned": {"pydantic": "decimal.Decimal", "sa": None},
# Points need extensions:
# geojson_pydantic and this import: from geojson_pydantic.geometries import Point
# geoalchemy2 and this import: from geoalchemy2 import Geometry
# NOTE: SRID is not parsed currently. Default values likely not correct.
"point": {"pydantic": "Point", "sa": "Geography(geometry_type='POINT')"},
"blob": {"pydantic": "bytes", "sa": "sa.LargeBinary"},
}

types_mapping.update(direct_types)
2 changes: 1 addition & 1 deletion tests/functional/generator/test_sqlalchemy_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ class Products(Base):
__tablename__ = 'products'
id = sa.Column(sa.Integer(), nullable=False)
id = sa.Column(sa.Integer(), primary_key=True)
merchant_id = sa.Column(sa.Integer(), sa.ForeignKey('merchants.id'), nullable=False)
"""

Expand Down
Loading

0 comments on commit 4893e38

Please sign in to comment.