Skip to content

Commit 60fc72e

Browse files
committed
Fixed bug with inheritance and various other issues
1 parent 0dad52d commit 60fc72e

File tree

5 files changed

+79
-46
lines changed

5 files changed

+79
-46
lines changed

sqlorm/model.py

+44-36
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from .engine import Engine, ensure_transaction, _signals, _signal_rv
66
from .sqlfunc import is_sqlfunc, sqlfunc, fetchall, fetchone, execute, update
77
from .resultset import ResultSet, CompositeResultSet
8-
from .types import SQLType
8+
from .types import SQLType, Integer
99
from .mapper import (
1010
Mapper,
1111
MappedColumnMixin,
@@ -17,23 +17,27 @@
1717

1818
class ModelMetaclass(abc.ABCMeta):
1919
def __new__(cls, name, bases, dct):
20-
if not bases or abc.ABC in bases:
20+
if len(bases) == 1 and bases[0] is abc.ABC: # BaseModel
2121
return super().__new__(cls, name, bases, dct)
22-
dct = cls.pre_process_model_class_dict(name, bases, dct)
22+
23+
model_registry = cls.find_model_registry(bases)
24+
mapped_attrs = cls.process_mapped_attributes(dct)
25+
cls.process_sql_methods(dct, model_registry)
2326
model_class = super().__new__(cls, name, bases, dct)
2427
cls.process_meta_inheritance(model_class)
25-
return cls.post_process_model_class(model_class)
28+
if abc.ABC not in bases:
29+
cls.create_mapper(model_class, mapped_attrs)
30+
model_class.__model_registry__.register(model_class)
31+
return model_class
2632

27-
@classmethod
28-
def pre_process_model_class_dict(cls, name, bases, dct):
29-
model_registry = {}
33+
def find_model_registry(bases):
3034
for base in bases:
31-
if issubclass(base, BaseModel):
32-
model_registry = base.__model_registry__
33-
break
34-
35-
dct["table"] = SQL.Id(dct.get("__table__", dct.get("table", name.lower())))
35+
if hasattr(base, "__model_registry__"):
36+
return base.__model_registry__
37+
return ModelRegistry()
3638

39+
@staticmethod
40+
def process_mapped_attributes(dct):
3741
mapped_attrs = {}
3842
for name, annotation in dct.get("__annotations__", {}).items():
3943
primary_key = False
@@ -45,11 +49,11 @@ def pre_process_model_class_dict(cls, name, bases, dct):
4549
dct[name] = mapped_attrs[name] = Column(name, annotation, primary_key=primary_key)
4650
elif isinstance(dct[name], Column):
4751
mapped_attrs[name] = dct[name]
48-
dct[name].type = SQLType.from_pytype(annotation)
52+
if dct[name].type is None:
53+
dct[name].type = SQLType.from_pytype(annotation)
4954
elif isinstance(dct[name], Relationship):
5055
# add now to keep the declaration order
5156
mapped_attrs[name] = dct[name]
52-
5357
for attr_name, attr in dct.items():
5458
if isinstance(attr, Column) and not attr.name:
5559
# in the case of models, we allow column object to be initialized without names
@@ -58,27 +62,28 @@ def pre_process_model_class_dict(cls, name, bases, dct):
5862
if isinstance(attr, (Column, Relationship)) and attr_name not in mapped_attrs:
5963
# not annotated attributes
6064
mapped_attrs[attr_name] = attr
61-
continue
62-
65+
return mapped_attrs
66+
67+
@classmethod
68+
def process_sql_methods(cls, dct, model_registry=None):
69+
for attr_name, attr in dct.items():
6370
wrapper = type(attr) if isinstance(attr, (staticmethod, classmethod)) else False
6471
if wrapper:
6572
# the only way to replace the wrapped function for a class/static method is before the class initialization.
6673
attr = attr.__wrapped__
67-
if callable(attr):
68-
if is_sqlfunc(attr):
69-
dct[attr_name] = cls.make_sqlfunc_from_method(attr, wrapper, model_registry)
70-
71-
dct["__mapper__"] = mapped_attrs
72-
return dct
74+
if callable(attr) and is_sqlfunc(attr):
75+
# the model registry is passed as template locals to sql func methods
76+
# so model classes are available in the evaluation scope of SQLTemplate
77+
dct[attr_name] = cls.make_sqlfunc_from_method(attr, wrapper, model_registry)
7378

7479
@staticmethod
75-
def make_sqlfunc_from_method(func, decorator, model_registry):
80+
def make_sqlfunc_from_method(func, decorator, template_locals=None):
7681
doc = inspect.getdoc(func)
7782
accessor = "cls" if decorator is classmethod else "self"
7883
if doc.upper().startswith("SELECT WHERE"):
7984
doc = doc[7:]
8085
if doc.upper().startswith("WHERE"):
81-
func.__doc__ = "{%s.select_from()} %s" % (accessor, doc)
86+
doc = "{%s.select_from()} %s" % (accessor, doc)
8287
if doc.upper().startswith("INSERT INTO ("):
8388
doc = "INSERT INTO {%s.table} %s" % (accessor, doc[12:])
8489
if doc.upper().startswith("UPDATE SET"):
@@ -87,21 +92,26 @@ def make_sqlfunc_from_method(func, decorator, model_registry):
8792
doc = "DELETE FROM {%s.table} %s" % (accessor, doc[7:])
8893
if "WHERE SELF" in doc.upper():
8994
doc = doc.replace("WHERE SELF", "WHERE {self.__mapper__.primary_key_condition(self)}")
95+
func.__doc__ = doc
9096
if not getattr(func, "query_decorator", None) and ".select_from(" in doc:
9197
# because the statement does not start with SELECT, it would default to execute when using .select_from()
9298
func = fetchall(func)
93-
# the model registry is passed as template locals to sql func methods
94-
# so model classes are available in the evaluation scope of SQLTemplate
95-
method = sqlfunc(func, is_method=True, template_locals=model_registry)
99+
method = sqlfunc(func, is_method=True, template_locals=template_locals)
96100
return decorator(method) if decorator else method
97101

98102
@staticmethod
99-
def post_process_model_class(cls):
100-
mapped_attrs = cls.__mapper__
103+
def create_mapper(cls, mapped_attrs=None):
104+
cls.table = SQL.Id(getattr(cls, "__table__", getattr(cls, "table", cls.__name__.lower())))
101105
cls.__mapper__ = ModelMapper(
102106
cls, cls.table.name, allow_unknown_columns=cls.Meta.allow_unknown_columns
103107
)
104-
cls.__mapper__.map(mapped_attrs)
108+
109+
for attr_name in dir(cls):
110+
if isinstance(getattr(cls, attr_name), (Column, Relationship)) and attr_name not in mapped_attrs:
111+
cls.__mapper__.map(attr_name, getattr(cls, attr_name))
112+
if mapped_attrs:
113+
cls.__mapper__.map(mapped_attrs)
114+
105115
cls.c = cls.__mapper__.columns # handy shortcut
106116

107117
auto_primary_key = cls.Meta.auto_primary_key
@@ -110,14 +120,11 @@ def post_process_model_class(cls):
110120
# we force the usage of SELECT * as we auto add a primary key without any other mapped columns
111121
# without doing this, only the primary key would be selected
112122
cls.__mapper__.force_select_wildcard = True
113-
cls.__mapper__.map(auto_primary_key, Column(auto_primary_key, primary_key=True))
114-
115-
cls.__model_registry__.register(cls)
116-
return cls
123+
cls.__mapper__.map(auto_primary_key, Column(auto_primary_key, type=cls.Meta.auto_primary_key_type, primary_key=True))
117124

118125
@staticmethod
119126
def process_meta_inheritance(cls):
120-
if getattr(cls.Meta, "__inherit__", True):
127+
if hasattr(cls, "Meta") and getattr(cls.Meta, "__inherit__", True):
121128
bases_meta = ModelMetaclass.aggregate_bases_meta_attrs(cls)
122129
for key, value in bases_meta.items():
123130
if not hasattr(cls.Meta, key):
@@ -130,7 +137,7 @@ def process_meta_inheritance(cls):
130137
def aggregate_bases_meta_attrs(cls):
131138
meta = {}
132139
for base in cls.__bases__:
133-
if issubclass(base, BaseModel):
140+
if hasattr(base, "Meta"):
134141
if getattr(base.Meta, "__inherit__", True):
135142
meta.update(ModelMetaclass.aggregate_bases_meta_attrs(base))
136143
meta.update(
@@ -331,6 +338,7 @@ class Meta:
331338
auto_primary_key: t.Optional[str] = (
332339
"id" # auto generate a primary key with this name if no primary key are declared
333340
)
341+
auto_primary_key_type: SQLType = Integer
334342
allow_unknown_columns: bool = True # hydrate() will set attributes for unknown columns
335343

336344
@classmethod

sqlorm/schema.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ def create_all(model_registry=None, engine=None, check_missing=False, logger=Non
2626
missing = False
2727
with ensure_transaction(engine) as tx:
2828
try:
29-
tx.execute(model.find_one())
29+
model.find_one()
3030
except Exception:
3131
missing = True
3232
if missing:

sqlorm/sql_template.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,16 @@ def __init__(self, template, code):
1010
self.template = template
1111
self.code = code
1212

13-
def _render(self, params):
13+
def eval(self):
1414
return eval(self.code, self.template.eval_globals, self.template.locals)
1515

16+
def _render(self, params):
17+
return SQL(self.eval())._render(params)
18+
1619

1720
class ParametrizedEvalBlock(EvalBlock):
1821
def _render(self, params):
19-
return params.add(super()._render(params))
22+
return params.add(self.eval())
2023

2124

2225
class SQLTemplateError(Exception):

tests/test_engine.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ def test_pool():
6060
assert conn in e.pool
6161

6262
with e as tx:
63+
tx.session.connect()
6364
assert not e.pool
6465
assert len(e.active_conns) == 2
6566
assert tx.session.conn is conn
@@ -95,7 +96,7 @@ def test_session():
9596
e = Engine.from_uri("sqlite://:memory:")
9697
with e.session() as sess:
9798
assert isinstance(sess, Session)
98-
assert sess.conn
99+
assert not sess.conn
99100
assert sess.engine is e
100101
assert not sess.virtual_tx
101102

@@ -104,4 +105,5 @@ def test_session():
104105
assert tx.session is sess
105106
assert not tx.virtual
106107

108+
sess.connect()
107109
assert sess.conn

tests/test_model.py

+26-6
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1-
from sqlorm import Model, SQL, Relationship, is_dirty, PrimaryKey
1+
from sqlorm import Model, SQL, Relationship, is_dirty, PrimaryKey, Column
22
from sqlorm.mapper import Mapper
33
from models import *
4+
import abc
45

56

67
def test_model_registry():
@@ -68,6 +69,25 @@ def test_mapper():
6869
assert User.tasks in User.__mapper__.relationships
6970

7071

72+
def test_inheritance():
73+
class A(Model, abc.ABC):
74+
col1: str
75+
col2 = Column(type=int)
76+
col3 = Column(type=bool)
77+
78+
assert not hasattr(A, "__mapper__")
79+
assert isinstance(A.col1, Column)
80+
assert A.col3.type.sql_type == "boolean"
81+
82+
class B(A):
83+
col3: str
84+
col4 = Column(type=int)
85+
86+
assert B.__mapper__
87+
assert B.__mapper__.columns.names == ["col1", "col2", "col3", "col4", "id"]
88+
assert B.col3.type.sql_type == "text"
89+
90+
7191
def test_find_all(engine):
7292
listener_called = False
7393

@@ -194,10 +214,10 @@ def test_update(cls):
194214
def test_delete(cls):
195215
"DELETE WHERE col1 = 'foo'"
196216

197-
assert TestModel.find_all.sql(TestModel) == "SELECT test.id , test.col1 FROM test WHERE col1 = 'foo'"
198-
assert TestModel.test_insert.sql(TestModel) == "INSERT INTO test (col1) VALUES ('bar')"
199-
assert TestModel.test_insert.sql(TestModel) == "UPDATE test SET col1 = 'bar'"
200-
assert TestModel.test_insert.sql(TestModel) == "DELETE FROM test WHERE col1 = 'foo'"
217+
assert str(TestModel.find_all.sql(TestModel)) == "SELECT test.id , test.col1 FROM test WHERE col1 = 'foo'"
218+
assert str(TestModel.test_insert.sql(TestModel)) == "INSERT INTO test (col1) VALUES ('bar')"
219+
assert str(TestModel.test_update.sql(TestModel)) == "UPDATE test SET col1 = 'bar'"
220+
assert str(TestModel.test_delete.sql(TestModel)) == "DELETE FROM test WHERE col1 = 'foo'"
201221

202222

203223
def test_dirty_tracking(engine):
@@ -351,4 +371,4 @@ def on_after_delete(sender, obj):
351371
assert listener_called == 2
352372

353373
user = User.get(4)
354-
assert not user
374+
assert not user

0 commit comments

Comments
 (0)