5
5
from .engine import Engine , ensure_transaction , _signals , _signal_rv
6
6
from .sqlfunc import is_sqlfunc , sqlfunc , fetchall , fetchone , execute , update
7
7
from .resultset import ResultSet , CompositeResultSet
8
- from .types import SQLType
8
+ from .types import SQLType , Integer
9
9
from .mapper import (
10
10
Mapper ,
11
11
MappedColumnMixin ,
17
17
18
18
class ModelMetaclass (abc .ABCMeta ):
19
19
def __new__ (cls , name , bases , dct ):
20
- if not bases or abc .ABC in bases :
20
+ if len ( bases ) == 1 and bases [ 0 ] is abc .ABC : # BaseModel
21
21
return super ().__new__ (cls , name , bases , dct )
22
- dct = cls .pre_process_model_class_dict (name , bases , dct )
22
+
23
+ model_registry = cls .find_model_registry (bases )
24
+ mapped_attrs = cls .process_mapped_attributes (dct )
25
+ cls .process_sql_methods (dct , model_registry )
23
26
model_class = super ().__new__ (cls , name , bases , dct )
24
27
cls .process_meta_inheritance (model_class )
25
- return cls .post_process_model_class (model_class )
28
+ if abc .ABC not in bases :
29
+ cls .create_mapper (model_class , mapped_attrs )
30
+ model_class .__model_registry__ .register (model_class )
31
+ return model_class
26
32
27
- @classmethod
28
- def pre_process_model_class_dict (cls , name , bases , dct ):
29
- model_registry = {}
33
+ def find_model_registry (bases ):
30
34
for base in bases :
31
- if issubclass (base , BaseModel ):
32
- model_registry = base .__model_registry__
33
- break
34
-
35
- dct ["table" ] = SQL .Id (dct .get ("__table__" , dct .get ("table" , name .lower ())))
35
+ if hasattr (base , "__model_registry__" ):
36
+ return base .__model_registry__
37
+ return ModelRegistry ()
36
38
39
+ @staticmethod
40
+ def process_mapped_attributes (dct ):
37
41
mapped_attrs = {}
38
42
for name , annotation in dct .get ("__annotations__" , {}).items ():
39
43
primary_key = False
@@ -45,11 +49,11 @@ def pre_process_model_class_dict(cls, name, bases, dct):
45
49
dct [name ] = mapped_attrs [name ] = Column (name , annotation , primary_key = primary_key )
46
50
elif isinstance (dct [name ], Column ):
47
51
mapped_attrs [name ] = dct [name ]
48
- dct [name ].type = SQLType .from_pytype (annotation )
52
+ if dct [name ].type is None :
53
+ dct [name ].type = SQLType .from_pytype (annotation )
49
54
elif isinstance (dct [name ], Relationship ):
50
55
# add now to keep the declaration order
51
56
mapped_attrs [name ] = dct [name ]
52
-
53
57
for attr_name , attr in dct .items ():
54
58
if isinstance (attr , Column ) and not attr .name :
55
59
# in the case of models, we allow column object to be initialized without names
@@ -58,27 +62,28 @@ def pre_process_model_class_dict(cls, name, bases, dct):
58
62
if isinstance (attr , (Column , Relationship )) and attr_name not in mapped_attrs :
59
63
# not annotated attributes
60
64
mapped_attrs [attr_name ] = attr
61
- continue
62
-
65
+ return mapped_attrs
66
+
67
+ @classmethod
68
+ def process_sql_methods (cls , dct , model_registry = None ):
69
+ for attr_name , attr in dct .items ():
63
70
wrapper = type (attr ) if isinstance (attr , (staticmethod , classmethod )) else False
64
71
if wrapper :
65
72
# the only way to replace the wrapped function for a class/static method is before the class initialization.
66
73
attr = attr .__wrapped__
67
- if callable (attr ):
68
- if is_sqlfunc (attr ):
69
- dct [attr_name ] = cls .make_sqlfunc_from_method (attr , wrapper , model_registry )
70
-
71
- dct ["__mapper__" ] = mapped_attrs
72
- return dct
74
+ if callable (attr ) and is_sqlfunc (attr ):
75
+ # the model registry is passed as template locals to sql func methods
76
+ # so model classes are available in the evaluation scope of SQLTemplate
77
+ dct [attr_name ] = cls .make_sqlfunc_from_method (attr , wrapper , model_registry )
73
78
74
79
@staticmethod
75
- def make_sqlfunc_from_method (func , decorator , model_registry ):
80
+ def make_sqlfunc_from_method (func , decorator , template_locals = None ):
76
81
doc = inspect .getdoc (func )
77
82
accessor = "cls" if decorator is classmethod else "self"
78
83
if doc .upper ().startswith ("SELECT WHERE" ):
79
84
doc = doc [7 :]
80
85
if doc .upper ().startswith ("WHERE" ):
81
- func . __doc__ = "{%s.select_from()} %s" % (accessor , doc )
86
+ doc = "{%s.select_from()} %s" % (accessor , doc )
82
87
if doc .upper ().startswith ("INSERT INTO (" ):
83
88
doc = "INSERT INTO {%s.table} %s" % (accessor , doc [12 :])
84
89
if doc .upper ().startswith ("UPDATE SET" ):
@@ -87,21 +92,26 @@ def make_sqlfunc_from_method(func, decorator, model_registry):
87
92
doc = "DELETE FROM {%s.table} %s" % (accessor , doc [7 :])
88
93
if "WHERE SELF" in doc .upper ():
89
94
doc = doc .replace ("WHERE SELF" , "WHERE {self.__mapper__.primary_key_condition(self)}" )
95
+ func .__doc__ = doc
90
96
if not getattr (func , "query_decorator" , None ) and ".select_from(" in doc :
91
97
# because the statement does not start with SELECT, it would default to execute when using .select_from()
92
98
func = fetchall (func )
93
- # the model registry is passed as template locals to sql func methods
94
- # so model classes are available in the evaluation scope of SQLTemplate
95
- method = sqlfunc (func , is_method = True , template_locals = model_registry )
99
+ method = sqlfunc (func , is_method = True , template_locals = template_locals )
96
100
return decorator (method ) if decorator else method
97
101
98
102
@staticmethod
99
- def post_process_model_class (cls ):
100
- mapped_attrs = cls . __mapper__
103
+ def create_mapper (cls , mapped_attrs = None ):
104
+ cls . table = SQL . Id ( getattr ( cls , "__table__" , getattr ( cls , "table" , cls . __name__ . lower ())))
101
105
cls .__mapper__ = ModelMapper (
102
106
cls , cls .table .name , allow_unknown_columns = cls .Meta .allow_unknown_columns
103
107
)
104
- cls .__mapper__ .map (mapped_attrs )
108
+
109
+ for attr_name in dir (cls ):
110
+ if isinstance (getattr (cls , attr_name ), (Column , Relationship )) and attr_name not in mapped_attrs :
111
+ cls .__mapper__ .map (attr_name , getattr (cls , attr_name ))
112
+ if mapped_attrs :
113
+ cls .__mapper__ .map (mapped_attrs )
114
+
105
115
cls .c = cls .__mapper__ .columns # handy shortcut
106
116
107
117
auto_primary_key = cls .Meta .auto_primary_key
@@ -110,14 +120,11 @@ def post_process_model_class(cls):
110
120
# we force the usage of SELECT * as we auto add a primary key without any other mapped columns
111
121
# without doing this, only the primary key would be selected
112
122
cls .__mapper__ .force_select_wildcard = True
113
- cls .__mapper__ .map (auto_primary_key , Column (auto_primary_key , primary_key = True ))
114
-
115
- cls .__model_registry__ .register (cls )
116
- return cls
123
+ cls .__mapper__ .map (auto_primary_key , Column (auto_primary_key , type = cls .Meta .auto_primary_key_type , primary_key = True ))
117
124
118
125
@staticmethod
119
126
def process_meta_inheritance (cls ):
120
- if getattr (cls .Meta , "__inherit__" , True ):
127
+ if hasattr ( cls , "Meta" ) and getattr (cls .Meta , "__inherit__" , True ):
121
128
bases_meta = ModelMetaclass .aggregate_bases_meta_attrs (cls )
122
129
for key , value in bases_meta .items ():
123
130
if not hasattr (cls .Meta , key ):
@@ -130,7 +137,7 @@ def process_meta_inheritance(cls):
130
137
def aggregate_bases_meta_attrs (cls ):
131
138
meta = {}
132
139
for base in cls .__bases__ :
133
- if issubclass (base , BaseModel ):
140
+ if hasattr (base , "Meta" ):
134
141
if getattr (base .Meta , "__inherit__" , True ):
135
142
meta .update (ModelMetaclass .aggregate_bases_meta_attrs (base ))
136
143
meta .update (
@@ -331,6 +338,7 @@ class Meta:
331
338
auto_primary_key : t .Optional [str ] = (
332
339
"id" # auto generate a primary key with this name if no primary key are declared
333
340
)
341
+ auto_primary_key_type : SQLType = Integer
334
342
allow_unknown_columns : bool = True # hydrate() will set attributes for unknown columns
335
343
336
344
@classmethod
0 commit comments