-
Notifications
You must be signed in to change notification settings - Fork 11
/
active_alchemy.py
476 lines (392 loc) · 13.7 KB
/
active_alchemy.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
# -*- coding: utf-8 -*-
"""
==================
Active-Alchemy
==================
A framework agnostic wrapper for SQLAlchemy that makes it really easy
to use by implementing a simple active record like api, while it still uses the db.session underneath
:copyright: © 2014/2016 by `Mardix`.
:license: MIT, see LICENSE for more details.
"""
NAME = "Active-Alchemy"
# ------------------------------------------------------------------------------
import threading
import json
import datetime
import sqlalchemy
from sqlalchemy import *
from sqlalchemy.orm import scoped_session, sessionmaker, Query
from sqlalchemy.engine.url import make_url
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.schema import MetaData
from paginator import Paginator
import inflection
import sqlalchemy_utils as sa_utils
import arrow
DEFAULT_PER_PAGE = 10
utcnow = arrow.utcnow
def _create_scoped_session(db, query_cls):
session = sessionmaker(autoflush=True, autocommit=False,
bind=db.engine, query_cls=query_cls)
return scoped_session(session)
def _tablemaker(db):
def make_sa_table(*args, **kwargs):
if len(args) > 1 and isinstance(args[1], db.Column):
args = (args[0], db.metadata) + args[1:]
kwargs.setdefault('bind_key', None)
info = kwargs.pop('info', None) or {}
info.setdefault('bind_key', None)
kwargs['info'] = info
return sqlalchemy.Table(*args, **kwargs)
return make_sa_table
def _include_sqlalchemy(db):
for module in sqlalchemy, sqlalchemy.orm:
for key in module.__all__:
if not hasattr(db, key):
setattr(db, key, getattr(module, key))
db.Table = _tablemaker(db)
db.event = sqlalchemy.event
db.utils = sa_utils
db.arrow = arrow
db.utcnow = utcnow
db.SADateTime = db.DateTime
db.DateTime = sa_utils.ArrowType
db.JSONType = sa_utils.JSONType
db.EmailType = sa_utils.EmailType
class BaseQuery(Query):
def get_or_error(self, uid, error):
"""Like :meth:`get` but raises an error if not found instead of
returning `None`.
"""
rv = self.get(uid)
if rv is None:
if isinstance(error, Exception):
raise error
return error()
return rv
def first_or_error(self, error):
"""Like :meth:`first` but raises an error if not found instead of
returning `None`.
"""
rv = self.first()
if rv is None:
if isinstance(error, Exception):
raise error
return error()
return rv
def paginate(self, **kwargs):
"""Paginate this results.
Returns an :class:`Paginator` object.
"""
return Paginator(self, **kwargs)
class ModelTableNameDescriptor(object):
"""
Create the table name if it doesn't exist.
"""
def __get__(self, obj, type):
tablename = type.__dict__.get('__tablename__')
if not tablename:
tablename = inflection.underscore(type.__name__)
setattr(type, '__tablename__', tablename)
return tablename
class EngineConnector(object):
def __init__(self, sa_obj):
self._sa_obj = sa_obj
self._engine = None
self._connected_for = None
self._lock = threading.Lock()
def get_engine(self):
with self._lock:
uri = self._sa_obj.uri
info = self._sa_obj.info
options = self._sa_obj.options
echo = options.get('echo')
if (uri, echo) == self._connected_for:
return self._engine
self._engine = engine = sqlalchemy.create_engine(info, **options)
self._connected_for = (uri, echo)
return engine
class BaseModel(object):
"""
Baseclass for custom user models.
"""
__tablename__ = ModelTableNameDescriptor()
__primary_key__ = "id" # String
def __iter__(self):
"""Returns an iterable that supports .next()
so we can do dict(sa_instance).
"""
for k in self.__dict__.keys():
if not k.startswith('_'):
yield (k, getattr(self, k))
def __repr__(self):
return '<%s>' % self.__class__.__name__
def to_dict(self):
"""
Return an entity as dict
:returns dict:
"""
return {c.name: getattr(self, c.name) for c in self.__table__.columns}
def to_json(self):
"""
Convert the entity to JSON
:returns str:
"""
data = {}
for k, v in self.to_dict().items():
if isinstance(v, (datetime.datetime, sa_utils.ArrowType, arrow.Arrow)):
v = v.isoformat()
data[k] = v
return json.dumps(data)
@classmethod
def get(cls, pk):
"""
Select entry by its primary key. It must be define as
__primary_key__ (string)
"""
return cls._query(cls).filter(getattr(cls, cls.__primary_key__) == pk).first()
@classmethod
def create(cls, **kwargs):
"""
To create a new record
:returns object: The new record
"""
record = cls(**kwargs).save()
return record
def update(self, **kwargs):
"""
Update an entry
"""
for k, v in kwargs.items():
setattr(self, k, v)
self.save()
return self
@classmethod
def query(cls, *args):
"""
:returns query:
"""
if not args:
query = cls._query(cls)
else:
query = cls._query(*args)
return query
def save(self):
"""
Shortcut to add and save + rollback
"""
try:
self.db.add(self)
self.db.commit()
return self
except Exception as e:
self.db.rollback()
raise
def delete(self, delete=True, hard_delete=False):
"""
Soft delete a record
:param delete: Bool - To soft-delete/soft-undelete a record
:param hard_delete: Bool - *** Not applicable under BaseModel
"""
try:
self.db.session.delete(self)
return self.db.commit()
except Exception as e:
self.db.rollback()
raise
class Model(BaseModel):
"""
Model create
"""
id = Column(Integer, primary_key=True)
created_at = Column(sa_utils.ArrowType, default=utcnow)
updated_at = Column(sa_utils.ArrowType, default=utcnow, onupdate=utcnow)
is_deleted = Column(Boolean, default=False, index=True)
deleted_at = Column(sa_utils.ArrowType, default=None)
@classmethod
def query(cls, *args, **kwargs):
"""
:returns query:
:**kwargs:
- include_deleted bool: True To filter in deleted records.
By default it is set to False
"""
if not args:
query = cls._query(cls)
else:
query = cls._query(*args)
if "include_deleted" not in kwargs or kwargs["include_deleted"] is False:
query = query.filter(cls.is_deleted != True)
return query
@classmethod
def get(cls, id, include_deleted=False):
"""
Select entry by id
:param id: The id of the entry
:param include_deleted: It should not query deleted record. Set to True to get all
"""
return cls.query(include_deleted=include_deleted)\
.filter(cls.id == id)\
.first()
def delete(self, delete=True, hard_delete=False):
"""
Soft delete a record
:param delete: Bool - To soft-delete/soft-undelete a record
:param hard_delete: Bool - If true it will completely delete the record
"""
# Hard delete
if hard_delete:
try:
self.db.session.delete(self)
return self.db.commit()
except:
self.db.rollback()
raise
else:
data = {
"is_deleted": delete,
"deleted_at": utcnow() if delete else None
}
self.update(**data)
return self
class ActiveAlchemy(object):
"""This class is used to instantiate a SQLAlchemy connection to
a database.
db = ActiveAlchemy(_uri_to_database_)
The class also provides access to all the SQLAlchemy
functions from the :mod:`sqlalchemy` and :mod:`sqlalchemy.orm` modules.
So you can declare models like this::
class User(db.Model):
login = db.Column(db.String(80), unique=True)
passw_hash = db.Column(db.String(80))
In a web application you need to call `db.session.remove()`
after each response, and `db.session.rollback()` if an error occurs.
If your application object has a `after_request` and `on_exception
decorators, just pass that object at creation::
app = Flask(__name__)
db = ActiveAlchemy('sqlite://', app=app)
or later::
db = ActiveAlchemy()
app = Flask(__name__)
db.init_app(app)
.. admonition:: Check types carefully
Don't perform type or `isinstance` checks against `db.Table`, which
emulates `Table` behavior but is not a class. `db.Table` exposes the
`Table` interface, but is a function which allows omission of metadata.
"""
def __init__(self, uri='sqlite://',
app=None,
echo=False,
pool_size=None,
pool_timeout=None,
pool_recycle=None,
convert_unicode=True,
query_cls=BaseQuery):
self.uri = uri
self.info = make_url(uri)
self.options = self._cleanup_options(
echo=echo,
pool_size=pool_size,
pool_timeout=pool_timeout,
pool_recycle=pool_recycle,
convert_unicode=convert_unicode,
)
self.connector = None
self._engine_lock = threading.Lock()
self.session = _create_scoped_session(self, query_cls=query_cls)
self.Model = declarative_base(cls=Model, name='Model')
self.BaseModel = declarative_base(cls=BaseModel, name='BaseModel')
self.Model.db, self.BaseModel.db = self, self
self.Model._query, self.BaseModel._query = self.session.query, self.session.query
if app is not None:
self.init_app(app)
_include_sqlalchemy(self)
def _cleanup_options(self, **kwargs):
options = dict([
(key, val)
for key, val in kwargs.items()
if val is not None
])
return self._apply_driver_hacks(options)
def _apply_driver_hacks(self, options):
if "mysql" in self.info.drivername:
self.info.query.setdefault('charset', 'utf8')
options.setdefault('pool_size', 10)
options.setdefault('pool_recycle', 7200)
elif self.info.drivername == 'sqlite':
no_pool = options.get('pool_size') == 0
memory_based = self.info.database in (None, '', ':memory:')
if memory_based and no_pool:
raise ValueError(
'SQLite in-memory database with an empty queue'
' (pool_size = 0) is not possible due to data loss.'
)
return options
def init_app(self, app):
"""This callback can be used to initialize an application for the
use with this database setup. In a web application or a multithreaded
environment, never use a database without initialize it first,
or connections will leak.
"""
if not hasattr(app, 'databases'):
app.databases = []
if isinstance(app.databases, list):
if self in app.databases:
return
app.databases.append(self)
def shutdown(response=None):
self.session.remove()
return response
def rollback(error=None):
try:
self.session.rollback()
except Exception:
pass
self.set_flask_hooks(app, shutdown, rollback)
def set_flask_hooks(self, app, shutdown, rollback):
if hasattr(app, 'after_request'):
app.after_request(shutdown)
if hasattr(app, 'on_exception'):
app.on_exception(rollback)
@property
def engine(self):
"""Gives access to the engine. """
with self._engine_lock:
connector = self.connector
if connector is None:
connector = EngineConnector(self)
self.connector = connector
return connector.get_engine()
@property
def metadata(self):
"""Proxy for Model.metadata"""
return self.Model.metadata
@property
def query(self):
"""Proxy for session.query"""
return self.session.query
def add(self, *args, **kwargs):
"""Proxy for session.add"""
return self.session.add(*args, **kwargs)
def flush(self, *args, **kwargs):
"""Proxy for session.flush"""
return self.session.flush(*args, **kwargs)
def commit(self):
"""Proxy for session.commit"""
return self.session.commit()
def rollback(self):
"""Proxy for session.rollback"""
return self.session.rollback()
def create_all(self):
"""Creates all tables. """
self.Model.metadata.create_all(bind=self.engine)
def drop_all(self):
"""Drops all tables. """
self.Model.metadata.drop_all(bind=self.engine)
def reflect(self, meta=None):
"""Reflects tables from the database. """
meta = meta or MetaData()
meta.reflect(bind=self.engine)
return meta
def __repr__(self):
return "<SQLAlchemy('{0}')>".format(self.uri)