4
4
import logging
5
5
import inspect
6
6
import urllib .parse
7
+ import functools
7
8
from blinker import Namespace
8
9
from .sql import render , ParametrizedStmt
9
10
from .resultset import ResultSet , CompositeResultSet , CompositionMap
@@ -32,6 +33,8 @@ class Engine:
32
33
"""
33
34
34
35
connected = _signals .signal ("connected" )
36
+ pool_checkin = _signals .signal ("pool-checkin" )
37
+ pool_checkout = _signals .signal ("pool-checkout" )
35
38
disconnected = _signals .signal ("disconnected" )
36
39
37
40
@classmethod
@@ -82,46 +85,45 @@ def __init__(
82
85
83
86
def connect (self , from_pool = True ):
84
87
if not from_pool or self .pool is False :
85
- if self .logger :
86
- getattr (self .logger , self .logger_level )("New connection established" )
87
- return self .connection_factory (self .dbapi )
88
+ return self ._connect ()
88
89
89
90
if self .pool :
90
91
conn = self .pool .pop (0 )
91
92
if self .logger :
92
93
getattr (self .logger , self .logger_level )("Re-using connection from pool" )
93
- self .connected .send (self , conn = conn , from_pool = True )
94
+ self .pool_checkout .send (self , conn = conn )
94
95
elif not self .max_pool_conns or len (self .active_conns ) < self .max_pool_conns :
95
- if self .logger :
96
- getattr (self .logger , self .logger_level )("New connection established" )
97
- conn = self .connection_factory (self .dbapi )
98
- self .connected .send (self , conn = conn , from_pool = False )
96
+ conn = self ._connect ()
99
97
else :
100
98
raise EngineError ("Max number of connections reached" )
101
99
102
100
self .active_conns .append (conn )
103
101
return conn
102
+
103
+ def _connect (self ):
104
+ if self .logger :
105
+ getattr (self .logger , self .logger_level )("Creating new connection" )
106
+ conn = self .connection_factory (self .dbapi )
107
+ self .connected .send (self , conn = conn )
108
+ return conn
104
109
105
110
def disconnect (self , conn , force = False ):
106
- if conn in self .active_conns :
111
+ if force or self .pool is False :
112
+ self ._close (conn )
113
+ elif conn in self .active_conns :
107
114
self .active_conns .remove (conn )
108
- if force :
109
- if self .logger :
110
- getattr (self .logger , self .logger_level )("Closing connection (forced)" )
111
- conn .close ()
112
- self .disconnected .send (self , conn = conn , close_conn = True )
113
- else :
114
- if self .logger :
115
- getattr (self .logger , self .logger_level )("Connection returned to pool" )
116
- self .pool .append (conn )
117
- self .disconnected .send (self , conn = conn , close_conn = False )
118
- elif self .pool is False or force :
119
115
if self .logger :
120
- getattr (self .logger , self .logger_level )("Closing connection" )
121
- conn . close ( )
122
- self .disconnected .send (self , conn = conn , close_conn = True )
116
+ getattr (self .logger , self .logger_level )("Returning connection to pool " )
117
+ self . pool . append ( conn )
118
+ self .pool_checkin .send (self , conn = conn )
123
119
else :
124
120
raise EngineError ("Cannot close connection which is not part of pool" )
121
+
122
+ def _close (self , conn ):
123
+ if self .logger :
124
+ getattr (self .logger , self .logger_level )("Closing connection" )
125
+ conn .close ()
126
+ self .disconnected .send (self , conn = conn )
125
127
126
128
def disconnect_all (self ):
127
129
if self .pool is False :
@@ -130,7 +132,7 @@ def disconnect_all(self):
130
132
getattr (self .logger , self .logger_level )("Closing all connections from pool" )
131
133
for conn in self .pool + self .active_conns :
132
134
conn .close ()
133
- self .disconnected .send (self , conn = conn , close_conn = True )
135
+ self .disconnected .send (self , conn = conn )
134
136
self .pool = []
135
137
self .active_conns = []
136
138
@@ -375,7 +377,8 @@ class Transaction:
375
377
default_composite_separator = "__"
376
378
377
379
before_execute = _signals .signal ("before-execute" )
378
- before_executemany = _signals .signal ("before-executemany" )
380
+ after_execute = _signals .signal ("after-execute" )
381
+ handle_error = _signals .signal ("handle-error" )
379
382
380
383
def __init__ (self , session , virtual = False ):
381
384
self .session = session
@@ -403,8 +406,10 @@ def cursor(self, stmt=None, params=None):
403
406
return self .session .connect ().cursor ()
404
407
stmt , params = render (stmt , params )
405
408
406
- rv = _signal_rv (self .before_execute .send (self , stmt = stmt , params = params ))
407
- if rv :
409
+ rv = _signal_rv (self .before_execute .send (self , stmt = stmt , params = params , many = False ))
410
+ if isinstance (rv , tuple ):
411
+ stmt , params = rv
412
+ elif rv :
408
413
return rv
409
414
410
415
if self .session and self .session .logger :
@@ -413,10 +418,19 @@ def cursor(self, stmt=None, params=None):
413
418
)
414
419
415
420
cur = self .session .connect ().cursor ()
416
- if params :
417
- cur .execute (stmt , params )
418
- else :
419
- cur .execute (stmt )
421
+ try :
422
+ # because the default value of params may depend on some engine
423
+ if params :
424
+ cur .execute (stmt , params )
425
+ else :
426
+ cur .execute (stmt )
427
+ except Exception as e :
428
+ rv = _signal_rv (self .handle_error .send (self , cursor = cur , stmt = stmt , params = params , exc = e , many = False ))
429
+ if rv :
430
+ return rv
431
+ raise
432
+
433
+ self .after_execute .send (self , cursor = cur , stmt = stmt , params = params , many = False )
420
434
return cur
421
435
422
436
def execute (self , stmt , params = None ):
@@ -428,9 +442,11 @@ def execute(self, stmt, params=None):
428
442
429
443
def executemany (self , stmt , seq_of_parameters ):
430
444
rv = _signal_rv (
431
- self .before_executemany .send (self , stmt = stmt , seq_of_parameters = seq_of_parameters )
445
+ self .before_execute .send (self , stmt = stmt , params = seq_of_parameters , many = True )
432
446
)
433
- if rv is False :
447
+ if isinstance (rv , tuple ):
448
+ stmt , seq_of_parameters = rv
449
+ elif rv is False :
434
450
return
435
451
436
452
if self .session and self .session .logger :
@@ -439,7 +455,16 @@ def executemany(self, stmt, seq_of_parameters):
439
455
)
440
456
441
457
cur = self .cursor ()
442
- cur .executemany (str (stmt ), seq_of_parameters )
458
+
459
+ try :
460
+ cur .executemany (str (stmt ), seq_of_parameters )
461
+ except Exception as e :
462
+ if not _signal_rv (
463
+ self .handle_error .send (self , cursor = cur , stmt = stmt , params = seq_of_parameters , exc = e , many = True )
464
+ ):
465
+ raise
466
+
467
+ self .after_execute .send (self , cursor = cur , stmt = stmt , params = seq_of_parameters , many = True )
443
468
cur .close ()
444
469
445
470
def fetch (self , stmt , params = None , model = None , obj = None , loader = None ):
@@ -555,3 +580,21 @@ def _signal_rv(signal_rv):
555
580
if rv :
556
581
final_rv = rv
557
582
return final_rv
583
+
584
+
585
+ def connect_via_engine (engine , signal , func = None ):
586
+ def decorator (func ):
587
+ @functools .wraps (func )
588
+ def wrapper (sender , ** kw ):
589
+ matches = False
590
+ if isinstance (sender , Engine ):
591
+ matches = sender is engine
592
+ elif isinstance (sender , Session ):
593
+ matches = sender .engine is engine
594
+ elif isinstance (sender , Transaction ):
595
+ matches = sender .session .engine is engine
596
+ if matches :
597
+ return func (sender , ** kw )
598
+ signal .connect (wrapper , weak = False )
599
+ return wrapper
600
+ return decorator (func ) if func else decorator
0 commit comments