-
Notifications
You must be signed in to change notification settings - Fork 21
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Improved handling of different ways of specifing references in the DDL. * Flake fixes * Initial SQLModel support, basic tests pass, some extra types added. * Add preliminary support for more types. * Some fixes based on trying to run the generated code. * Ensure snake_case of indexes and constraints. * Updated tests.
- Loading branch information
Showing
8 changed files
with
687 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,169 @@ | ||
from typing import Dict, List, Optional | ||
|
||
from table_meta.model import Column | ||
|
||
import omymodels.models.sqlmodel.templates as st | ||
from omymodels import logic, types | ||
from omymodels.helpers import create_class_name, datetime_now_check | ||
from omymodels.models.sqlmodel.types import types_mapping | ||
from omymodels.types import datetime_types | ||
|
||
|
||
class GeneratorBase: | ||
def __init__(self): | ||
self.custom_types = {} | ||
|
||
|
||
class ModelGenerator(GeneratorBase): | ||
def __init__(self): | ||
self.state = set() | ||
self.postgresql_dialect_cols = set() | ||
self.constraint = False | ||
self.im_index = False | ||
self.types_mapping = types_mapping | ||
self.templates = st | ||
self.prefix = "sa." | ||
super().__init__() | ||
|
||
def prepare_column_default(self, column_data: Dict, column: str) -> str: | ||
if isinstance(column_data.default, str): | ||
if column_data.type.upper() in datetime_types: | ||
if datetime_now_check(column_data.default.lower()): | ||
# todo: need to add other popular PostgreSQL & MySQL functions | ||
column_data.default = "func.now()" | ||
self.state.add("func") | ||
elif "'" not in column_data.default: | ||
column_data.default = f"'{column_data.default}'" | ||
else: | ||
if "'" not in column_data.default: | ||
column_data.default = f"'{column_data.default}'" | ||
else: | ||
column_data.default = f"'{str(column_data.default)}'" | ||
column += st.default.format(default=column_data.default) | ||
return column | ||
|
||
def add_custom_type_orm( | ||
self, custom_types: Dict, column_data_type: str, column_type: str | ||
) -> dict: | ||
column_type_data = None | ||
if "." in column_data_type: | ||
column_data_type = column_data_type.split(".")[1] | ||
column_type = custom_types.get(column_data_type, column_type) | ||
|
||
if isinstance(column_type, tuple): | ||
column_data_type = column_type[1] | ||
column_type = column_type[0] | ||
if column_type is not None: | ||
column_type_data = { | ||
"pydantic": column_data_type, | ||
"sa": f"{column_type}({column_data_type})", | ||
} | ||
return column_type_data | ||
|
||
def prepare_column_type(self, column_data: Column) -> str: | ||
column_type = None | ||
column_data = types.prepare_column_data(column_data) | ||
|
||
if self.custom_types: | ||
column_type = self.add_custom_type_orm( | ||
self.custom_types, column_data.type, column_type | ||
) | ||
|
||
if not column_type: | ||
column_type = types.prepare_type(column_data, self.types_mapping) | ||
if column_type["sa"] in types.postgresql_dialect: | ||
self.postgresql_dialect_cols.add(column_type["sa"]) | ||
|
||
if "[" in column_data.type and column_data.type not in types.json_types: | ||
# @TODO: How do we handle arrays for SQLModel? | ||
self.postgresql_dialect_cols.add("ARRAY") | ||
column_type = f"ARRAY({column_type})" | ||
return column_type | ||
|
||
def add_table_args( | ||
self, model: str, table: Dict, schema_global: bool = True | ||
) -> str: | ||
statements = [] | ||
t = self.templates | ||
if table.indexes: | ||
for index in table.indexes: | ||
if not index["unique"]: | ||
self.im_index = True | ||
statements.append( | ||
t.index_template.format( | ||
columns="', '".join(index["columns"]), | ||
name=f"'{index['index_name']}'", | ||
) | ||
) | ||
else: | ||
self.constraint = True | ||
statements.append( | ||
t.unique_index_template.format( | ||
columns=",".join(index["columns"]), | ||
name=f"'{index['index_name']}'", | ||
) | ||
) | ||
if not schema_global and table.table_schema: | ||
statements.append(t.schema.format(schema_name=table.table_schema)) | ||
if statements: | ||
model += t.table_args.format(statements=",".join(statements)) | ||
return model | ||
|
||
def generate_model( | ||
self, | ||
table: Dict, | ||
singular: bool = True, | ||
exceptions: Optional[List] = None, | ||
schema_global: Optional[bool] = True, | ||
*args, | ||
**kwargs, | ||
) -> str: | ||
"""method to prepare one Model defention - name & tablename & columns""" | ||
model = "" | ||
|
||
model = st.model_template.format( | ||
model_name=create_class_name(table.name, singular, exceptions), | ||
table_name=table.name, | ||
) | ||
for column in table.columns: | ||
column_type = self.prepare_column_type(column) | ||
pydantic_type_str = column_type["pydantic"] | ||
if column.nullable or column.name in table.primary_key: | ||
pydantic_type_str = f"Optional[{pydantic_type_str}]" | ||
col_str = st.column_template.format( | ||
column_name=column.name.replace(" ", "_"), column_type=pydantic_type_str | ||
) | ||
|
||
col_str = logic.setup_column_attributes( | ||
column, table.primary_key, col_str, table, schema_global, st, self | ||
) | ||
if column_type["sa"]: | ||
sa_type = types.add_size_to_orm_column(column_type["sa"], column) | ||
col_str += st.sa_type.format(satype=sa_type) | ||
col_str += ")\n" | ||
|
||
col_str = col_str.replace("(, ", "(") | ||
|
||
model += col_str | ||
if table.indexes or table.alter or table.checks or not schema_global: | ||
model = self.add_table_args(model, table, schema_global) | ||
return model | ||
|
||
def create_header(self, tables: List[Dict], schema: bool = False) -> str: | ||
"""header of the file - imports & sqlalchemy init""" | ||
header = "" | ||
header += st.sqlalchemy_import # Do we always need this import? | ||
if "func" in self.state: | ||
header += st.sql_alchemy_func_import + "\n" | ||
if self.postgresql_dialect_cols: | ||
header += ( | ||
st.postgresql_dialect_import.format( | ||
types=",".join(self.postgresql_dialect_cols) | ||
) | ||
+ "\n" | ||
) | ||
if self.constraint: | ||
header += st.unique_cons_import + "\n" | ||
if self.im_index: | ||
header += st.index_import + "\n" | ||
return header |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
import datetime | ||
import decimal | ||
from typing import Optional | ||
from sqlmodel import Field, SQLModel | ||
|
||
{{ headers }} | ||
{{ models }} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
# imports | ||
postgresql_dialect_import = """from sqlalchemy.dialects.postgresql import {types} | ||
from pydantic import Json, UUID4""" | ||
sql_alchemy_func_import = "from sqlalchemy.sql import func" | ||
index_import = "from sqlalchemy import Index" | ||
|
||
sqlalchemy_import = """import sqlalchemy as sa | ||
""" | ||
|
||
unique_cons_import = "from sqlalchemy.schema import UniqueConstraint" | ||
enum_import = "from enum import {enums}" | ||
|
||
# model defenition | ||
model_template = """\n | ||
class {model_name}(SQLModel, table=True):\n | ||
__tablename__ = \'{table_name}\'\n | ||
""" | ||
|
||
# columns defenition | ||
column_template = """ {column_name}: {column_type} = Field(""" | ||
required = "" | ||
default = ", sa_column_kwargs={{'server_default': {default}}}" | ||
pk_template = ", default=None, primary_key=True" | ||
unique = ", unique=True" | ||
autoincrement = "" | ||
index = ", index=True" | ||
sa_type = ", sa_type={satype}" | ||
|
||
# tables properties | ||
|
||
table_args = """ | ||
__table_args__ = ( | ||
{statements}, | ||
) | ||
""" | ||
fk_constraint_template = """ | ||
{fk_name} = sa.ForeignKeyConstraint( | ||
[{fk_columns}], [{fk_references_columns}]) | ||
""" | ||
fk_in_column = ", foreign_key='{ref_schema}.{ref_table}.{ref_column}'" | ||
fk_in_column_without_schema = ", foreign_key='{ref_table}.{ref_column}'" | ||
|
||
unique_index_template = """ | ||
UniqueConstraint({columns}, name={name})""" | ||
|
||
index_template = """ | ||
Index({name}, '{columns}')""" | ||
|
||
schema = """ | ||
dict(schema="{schema_name}")""" | ||
|
||
on_delete = ', ondelete="{mode}"' | ||
on_update = ', onupdate="{mode}"' |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
from omymodels.types import ( | ||
big_integer_types, | ||
boolean_types, | ||
datetime_types, | ||
float_types, | ||
integer_types, | ||
json_types, | ||
numeric_types, | ||
populate_types_mapping, | ||
string_types, | ||
) | ||
|
||
postgresql_dialect = ["ARRAY", "JSON", "JSONB", "UUID"] | ||
|
||
mapper = { | ||
string_types: {"pydantic": "str", "sa": None}, | ||
integer_types: {"pydantic": "int", "sa": None}, | ||
big_integer_types: {"pydantic": "int", "sa": "sa.BigInteger"}, | ||
float_types: {"pydantic": "float", "sa": None}, | ||
numeric_types: {"pydantic": "decimal.Decimal", "sa": "sa.Numeric"}, | ||
boolean_types: {"pydantic": "bool", "sa": None}, | ||
datetime_types: {"pydantic": "datetime.datetime", "sa": None}, | ||
json_types: {"pydantic": "Json", "sa": "JSON"}, | ||
} | ||
|
||
types_mapping = populate_types_mapping(mapper) | ||
|
||
direct_types = { | ||
"date": {"pydantic": "datetime.date", "sa": None}, | ||
"timestamp": {"pydantic": "datetime.datetime", "sa": None}, | ||
"time": {"pydantic": "datetime.time", "sa": "sa.Time"}, | ||
"text": {"pydantic": "str", "sa": "sa.Text"}, | ||
"longtext": {"pydantic": "str", "sa": "sa.Text"}, # confirm this is proper SA type. | ||
"mediumtext": { | ||
"pydantic": "str", | ||
"sa": "sa.Text", | ||
}, # confirm this is proper SA type. | ||
"tinytext": {"pydantic": "str", "sa": "sa.Text"}, # confirm this is proper SA type. | ||
"smallint": {"pydantic": "int", "sa": "sa.SmallInteger"}, | ||
"jsonb": {"pydantic": "Json", "sa": "JSONB"}, | ||
"uuid": {"pydantic": "UUID4", "sa": "UUID"}, | ||
"real": {"pydantic": "float", "sa": "sa.REAL"}, | ||
"int unsigned": { | ||
"pydantic": "int", | ||
"sa": None, | ||
}, # use mysql INTEGER(unsigned=True) for sa? | ||
"tinyint": {"pydantic": "int", "sa": None}, # what's the proper type for this? | ||
"tinyint unsigned": { | ||
"pydantic": "int", | ||
"sa": None, | ||
}, # what's the proper type for this? | ||
"smallint unsigned": { | ||
"pydantic": "int", | ||
"sa": None, | ||
}, # what's the proper type for this? | ||
"bigint unsigned": {"pydantic": "int", "sa": "sa.BigInteger"}, | ||
# see https://sqlmodel.tiangolo.com/advanced/decimal/#decimals-in-sqlmodel | ||
"decimal unsigned": {"pydantic": "decimal.Decimal", "sa": None}, | ||
"decimalunsigned": {"pydantic": "decimal.Decimal", "sa": None}, | ||
# Points need extensions: | ||
# geojson_pydantic and this import: from geojson_pydantic.geometries import Point | ||
# geoalchemy2 and this import: from geoalchemy2 import Geometry | ||
# NOTE: SRID is not parsed currently. Default values likely not correct. | ||
"point": {"pydantic": "Point", "sa": "Geography(geometry_type='POINT')"}, | ||
"blob": {"pydantic": "bytes", "sa": "sa.LargeBinary"}, | ||
} | ||
|
||
types_mapping.update(direct_types) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.