forked from swist/django-more
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathdjango_db_models_sql_compiler.py
373 lines (314 loc) · 13.7 KB
/
django_db_models_sql_compiler.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
from collections import OrderedDict
from types import MethodType
from django.db.models import Field
from django.db.models import sql
from django.db.models.sql.compiler import *
class SQLCompiler:
def get_return_fields(self):
cols, params = [], []
col_idx = 1
for _, (s_sql, s_params), alias in self.select + extra_select:
if alias:
s_sql = self.connection.ops.quote_name(alias)
elif with_col_aliases:
s_sql = 'Col%d' % col_idx
col_idx += 1
params.extend(s_params)
cols.append(s_sql)
return cols, params
class SQLUpdateCompiler:
def as_sql(self):
"""
Create the SQL for this query. Return the SQL string and list of
parameters.
"""
self.pre_sql_setup()
if not self.query.values:
return '', ()
qn = self.quote_name_unless_alias
values, update_params = [], []
for field, model, val in self.query.values:
if hasattr(val, 'resolve_expression'):
val = val.resolve_expression(self.query, allow_joins=False, for_save=True)
if val.contains_aggregate:
raise FieldError("Aggregate functions are not allowed in this query")
elif hasattr(val, 'prepare_database_save'):
if field.remote_field:
val = field.get_db_prep_save(
val.prepare_database_save(field),
connection=self.connection,
)
else:
raise TypeError(
"Tried to update field %s with a model instance, %r. "
"Use a value compatible with %s."
% (field, val, field.__class__.__name__)
)
else:
val = field.get_db_prep_save(val, connection=self.connection)
# Getting the placeholder for the field.
if hasattr(field, 'get_placeholder'):
placeholder = field.get_placeholder(val, self, self.connection)
else:
placeholder = '%s'
name = field.column
if hasattr(val, 'as_sql'):
sql, params = self.compile(val)
values.append('%s = %s' % (qn(name), sql))
update_params.extend(params)
elif val is not None:
values.append('%s = %s' % (qn(name), placeholder))
update_params.append(val)
else:
values.append('%s = NULL' % qn(name))
table = self.query.tables[0]
result = [
'UPDATE %s SET' % qn(table),
', '.join(values),
]
where, params = self.compile(self.query.where)
if self.query.extra_tables:
from_, f_params = self.get_from_clause()
if from_:
result.append(" FROM {tables}".format(
tables=", ".join(from_)))
params.extend(f_params)
if where:
result.append('WHERE %s' % where)
return ' '.join(result), tuple(update_params + params)
def get_from_clause(self):
# Kludgy syntax because method declared here has wrong context for direct super()
result, params = super(sql.compiler.SQLUpdateCompiler, self).get_from_clause()
if len(result) <= 1:
# ur silly ;)
return [], tuple()
# Strip silliness from clauses, it's ridic
result = [clause.strip(", ") for clause in result]
# Remove first FROM clause, UPDATE must not specify own table
return result[1:], params
class SQLReturningMixin:
select_setup = sql.compiler.SQLCompiler.pre_sql_setup
def get_returning_clause(self):
extra_select, order_by, group_by = self.select_setup()
cols, params = [], []
col_idx = 1
for _, (s_sql, s_params), alias in self.select + extra_select:
if alias:
s_sql = '%s AS %s' % (s_sql, self.connection.ops.quote_name(alias))
params.extend(s_params)
cols.append(s_sql)
return cols, tuple(params)
class SQLInsertReturningCompiler(sql.compiler.SQLInsertCompiler, SQLReturningMixin):
def as_sql(self):
i_sql, i_params = super().as_sql()[0]
# Needs aliases and colnames
if self.query.values_select:
fields = self.query.values_select
elif self.query.default_cols:
fields = self.get_default_columns()
fields = [
sql for sql, _ in
(self.compile(col, select_format=True) for col in fields)]
return_, r_params = self.get_returning_clause()
if return_:
i_sql += " RETURNING ({fields})".format(
fields=", ".join(return_))
i_params += r_params
return i_sql, i_params
class SQLInsertSelectCompiler(sql.compiler.SQLCompiler):
def as_sql(self):
i_sql, i_params = super().as_sql()[0]
# Needs aliases and colnames
if self.query.values_select:
fields = self.query.values_select
elif self.query.default_cols:
fields = self.get_default_columns()
fields = [
sql for sql, _ in
(self.compile(col, select_format=True) for col in fields)]
i_sql = "INSERT INTO {insert_table} {sql} RETURNING ({fields})".format(
into_table=self.query.into_table or self.query.model.db_table,
sql=i_sql,
fields=", ".join(fields)
)
return i_sql, i_params
class SQLUpdateReturningCompiler(sql.compiler.SQLUpdateCompiler, SQLReturningMixin):
execute_sql = sql.compiler.SQLCompiler.execute_sql
def as_sql(self):
i_sql, i_params = super().as_sql()
return_, r_params = self.get_returning_clause()
if return_:
i_sql += " RETURNING ({fields})".format(
fields=", ".join(return_))
i_params += r_params
return i_sql, i_params
def pre_sql_setup(self):
pass
class SQLCTESelectCompiler:
def __init__(self):
self.select = []
self.extra_select = []
def pre_sql_setup(self):
pass
def get_from_clause(self):
result, params = super().get_from_clause()
# chop off first from
return result[1:], params
class SQLWithCompiler:
def __new__(cls, query, connection, using, *args, **kwargs):
# Retype as base query
base_compiler = query.base_query.get_compiler(using, connection).__class__
kls = type(cls.__name__ + base_compiler.__name__, (cls, base_compiler), {})
return object.__new__(kls)
def __init__(self, query, connection, using):
self.query = query
self.connection = connection
self.using = using
# The select, klass_info, and annotations are needed by QuerySet.iterator()
# these are set as a side-effect of executing the query. Note that we calculate
# separately a list of extra select columns needed for grammatical correctness
# of the query, but these columns are not included in self.select.
self.get_base_compiler = self.query.base_query.get_compiler
def as_sql(self):
# Collect all with queries to compile
self.query.prepare_queries()
result, params = [], []
for query in self.query.queries:
# Compile only the base of nested queries instead of their tree
if hasattr(query, "get_base_compiler"):
compiler = query.get_base_compiler(using=self.using, connection=self.connection)
else:
compiler = query.get_compiler(using=self.using, connection=self.connection)
w_sql, w_params = compiler.as_sql()
# Needs aliases and colnames
if hasattr(query, "get_columns"):
fields = query.get_columns()
fields = " ({fields})".format(
fields=", ".join(field for field in fields))
else:
fields = ""
w_sql = "{alias}{fields} AS ({sql})".format(
alias=query.with_alias,
fields=fields,
sql=w_sql
)
result.append(w_sql)
params.extend(w_params)
b_sql, b_params = self.get_base_compiler(using=self.using, connection=self.connection).as_sql()
params.extend(b_params)
return "WITH {withs} {base}".format(
withs=", ".join(result),
base=b_sql
), tuple(params)
def __getattr__(self, attr):
# Pretend to be the compiler of the base query unless it's specific to this
base_attr = getattr(self.base_compiler, attr)
if callable(base_attr):
return MethodType(getattr(self.base_compiler.__class__, attr), self)
return base_attr
class SQLLiteralCompiler(sql.compiler.SQLCompiler):
# Lambdas to do field savvy value conversions
field_lambdas = {
"val": lambda f, v: f.get_prep_value(v),
"key": lambda f, v: v,
"rel": lambda f, v: v.id,
}
# Lambdas to get values from objects
value_lambdas = {
"dict": lambda f, o: o.get(f),
"attr": lambda f, o: getattr(o, f),
"list": lambda f, o: o[f],
}
def pre_sql_setup(self, fields, obj):
""" Based on the model supplied, build a database value prep_map for fields
This dict of lambdas will apply model database prep conversions as per
Django norm.
As literal sets are simplistic, this can be generated once instead of
checking per object.
"""
# Determine appropriate mapping set
if isinstance(obj, dict):
self.get_value = self.value_lambdas["dict"]
elif hasattr(obj, "_fields") or all((hasattr(obj, field) for field in fields)):
self.get_value = self.value_lambdas["attr"]
else:
self.get_value = self.value_lambdas["list"]
if all(isinstance(field, Field) for field in fields):
self.prep_mapping = OrderedDict()
# Create type savvy field conversion list
for field in fields:
if not field.is_relation:
self.prep_mapping[field] = self.field_lambdas["val"]
elif field.is_relation and field.name.endswith("_id"):
self.prep_mapping[field] = self.field_lambdas["key"]
elif field.is_relation:
self.prep_mapping[field] = self.field_lambdas["rel"]
# Promote more complex conversion iter
self.obj_values = self.obj_values_prepped
def obj_values(self, obj, fields):
return [
self.get_value(field, obj)
for field in fields]
def obj_values_prepped(self, obj, fields):
return [
self.prep_mapping[field](field, self.get_value(field.name, obj))
for field in fields]
def assemble_params(self, fields, objs, enum_field=None):
if enum_field:
yield from (value
for row, obj in enumerate(objs, 1)
for value in [row] + self.obj_values(obj, fields))
else:
yield from (value
for obj in objs
for value in self.obj_values(obj, fields))
def assemble_as_sql(self, fields, objs):
"""
Take a sequence of N fields and a sequence of M rows of values, and
generate placeholder SQL and parameters for each field and value.
Return a pair containing:
* a sequence of M rows of N SQL placeholder strings, and
* a sequence of M rows of corresponding parameter values.
Each placeholder string may contain any number of '%s' interpolation
strings, and each parameter row will contain exactly as many params
as the total number of '%s's in the corresponding placeholder row.
"""
if not objs:
return "", []
self.pre_sql_setup(fields, self.query.sample_obj)
params = list(self.assemble_params(fields, objs, self.query.enum_field))
enum_ph = ['%s'] if self.query.enum_field else []
row_ph = "({})".format(
", ".join(enum_ph + [self.get_field_placeholder(field, params[i])
for i, field in enumerate(fields)]))
if not isinstance(fields, range):
header_ph = "({})".format(
", ".join(enum_ph + ["{ph}::{type}".format(
ph=self.get_field_placeholder(field, params[i]),
type=field.db_type(self.connection))
for i, field in enumerate(fields)]))
else:
header_ph = row_ph
sql = ", ".join([header_ph] + [row_ph] * (len(objs) - 1))
assert (len(fields) + len(enum_ph)) * len(objs) == len(params), "Values set params not matched"
return sql, params
def get_field_placeholder(self, field, val):
if hasattr(field, 'get_placeholder'):
# Some fields (e.g. geo fields) need special munging before
# they can be inserted.
sql, _ = field.get_placeholder(val, self, self.connection), [val]
else:
# Return the common case for the placeholder
sql = '%s'
return sql
def as_sql(self):
""" Create the SQL for a set of literal values used as a CTE """
fields = self.query.fields
if fields and not self.query.values_select:
# Ensure values argument for WITH compilation
self.query.values_select = [field.name for field in fields]
if not fields:
fields = range(len(self.query.sample_obj))
values_sql, params = self.assemble_as_sql(
fields=fields, objs=self.query.objs)
return "VALUES {}".format(values_sql), tuple(params)