Skip to content
This repository was archived by the owner on Apr 3, 2019. It is now read-only.

Commit 19275de

Browse files
author
Eddie Mishelevich
committed
Better binary values handling in python3
1 parent 5745430 commit 19275de

11 files changed

+225
-193
lines changed

tornadoredis/client.py

+55-24
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,28 @@
2727
PY3 = sys.version > '3'
2828

2929

30+
if not PY3:
31+
from itertools import imap
32+
basestring = basestring
33+
unicode = unicode
34+
bytes = str
35+
long = long
36+
b = lambda x: x
37+
else:
38+
imap = map
39+
40+
basestring = str
41+
unicode = str
42+
bytes = bytes
43+
long = int
44+
b = lambda x: x.encode('latin-1') if not isinstance(x, bytes) else x
45+
46+
SYM_STAR = b('*')
47+
SYM_DOLLAR = b('$')
48+
SYM_CRLF = b('\r\n')
49+
SYM_EMPTY = b('')
50+
51+
3052
class CmdLine(object):
3153
def __init__(self, cmd, *args, **kwargs):
3254
self.cmd = cmd
@@ -136,7 +158,7 @@ def get_value(value):
136158
sub_dict[k] = v
137159
return sub_dict
138160
for line in response.splitlines():
139-
line = line.strip()
161+
line = line.strip().decode()
140162
if line and not line.startswith('#'):
141163
key, value = line.split(':')
142164
try:
@@ -159,7 +181,7 @@ def reply_fn(r, *args, **kwargs):
159181

160182

161183
def to_list(source):
162-
if isinstance(source, str):
184+
if isinstance(source, (bytes, str)):
163185
return [source]
164186
else:
165187
return list(source)
@@ -225,7 +247,7 @@ class Client(object):
225247

226248
def __init__(self, host='localhost', port=6379, unix_socket_path=None,
227249
password=None, selected_db=None, io_loop=None,
228-
connection_pool=None):
250+
connection_pool=None, encoding='utf-8', encoding_errors='strict'):
229251
self._io_loop = io_loop or IOLoop.current()
230252
self._connection_pool = connection_pool
231253
self._weak = weakref.proxy(self)
@@ -244,6 +266,8 @@ def __init__(self, host='localhost', port=6379, unix_socket_path=None,
244266
self.password = password
245267
self.selected_db = selected_db or 0
246268
self._pipeline = None
269+
self.encoding = encoding
270+
self.encoding_errors = encoding_errors
247271

248272
def __del__(self):
249273
try:
@@ -351,24 +375,31 @@ def disconnect(self, callback=None):
351375
if callback:
352376
callback(False)
353377

354-
#### formatting
355378
def encode(self, value):
356-
if not isinstance(value, str):
357-
if not PY3 and isinstance(value, unicode):
358-
value = value.encode('utf-8')
359-
else:
360-
value = str(value)
361-
if PY3:
362-
value = value.encode('utf-8')
379+
"Return a bytestring representation of the value"
380+
if isinstance(value, bytes):
381+
return value
382+
elif isinstance(value, (int, long)):
383+
value = b(str(value))
384+
elif isinstance(value, float):
385+
value = b(repr(value))
386+
elif not isinstance(value, basestring):
387+
value = str(value)
388+
if isinstance(value, unicode):
389+
value = value.encode(self.encoding, self.encoding_errors)
363390
return value
364391

365392
def format_command(self, *tokens, **kwargs):
366393
cmds = []
367-
for t in tokens:
368-
e_t = self.encode(t)
369-
e_t_s = to_basestring(e_t)
370-
cmds.append('$%s\r\n%s\r\n' % (len(e_t), e_t_s))
371-
return '*%s\r\n%s' % (len(tokens), ''.join(cmds))
394+
395+
buff = SYM_EMPTY.join((SYM_STAR, b(str(len(tokens))), SYM_CRLF))
396+
397+
for arg in imap(self.encode, tokens):
398+
buff = SYM_EMPTY.join(
399+
(buff, SYM_DOLLAR, b(str(len(arg))), SYM_CRLF, arg, SYM_CRLF))
400+
401+
return buff
402+
372403

373404
def format_reply(self, cmd_line, data):
374405
if cmd_line.cmd not in REPLY_MAP:
@@ -458,7 +489,6 @@ def _consume_bulk(self, tail, callback=None):
458489
if not response:
459490
raise ResponseError('EmptyResponse')
460491
else:
461-
response = to_basestring(response)
462492
response = response[:-2]
463493
callback(response)
464494

@@ -486,6 +516,7 @@ def process_data(self, data, cmd_line):
486516
tail = tail[4:]
487517
response = ResponseError(tail, cmd_line)
488518
else:
519+
print("data: {} cmd_line: {}".format(data, cmd_line))
489520
raise ResponseError('Unknown response type %s' % head,
490521
cmd_line)
491522
return response
@@ -1033,14 +1064,14 @@ def psubscribe(self, channels, callback=None):
10331064
self._subscribe('PSUBSCRIBE', channels, callback=callback)
10341065

10351066
def _subscribe(self, cmd, channels, callback=None):
1036-
if isinstance(channels, str) or (not PY3 and isinstance(channels, unicode)):
1067+
if isinstance(channels, (bytes, str)) or (not PY3 and isinstance(channels, unicode)):
10371068
channels = [channels]
10381069
if not self.subscribed:
10391070
listen_callback = None
10401071
original_cb = stack_context.wrap(callback) if callback else None
10411072

10421073
def _cb(*args, **kwargs):
1043-
self.on_subscribed(Message(kind='subscribe',
1074+
self.on_subscribed(Message(kind=b'subscribe',
10441075
channel=channels[0],
10451076
body=None,
10461077
pattern=None))
@@ -1076,7 +1107,7 @@ def punsubscribe(self, channels, callback=None):
10761107
self._unsubscribe('PUNSUBSCRIBE', channels, callback=callback)
10771108

10781109
def _unsubscribe(self, cmd, channels, callback=None):
1079-
if isinstance(channels, str) or (not PY3 and isinstance(channels, unicode)):
1110+
if isinstance(channels, (bytes, str)) or (not PY3 and isinstance(channels, unicode)):
10801111
channels = [channels]
10811112
if callback:
10821113
cb = stack_context.wrap(callback)
@@ -1137,7 +1168,7 @@ def error_wrapper(e):
11371168
self.subscribed = set()
11381169
# send a message to caller:
11391170
# Message(kind='disconnect', channel=set(channel1, ...))
1140-
callback(reply_pubsub_message(('disconnect', channels)))
1171+
callback(reply_pubsub_message((b'disconnect', channels)))
11411172
return
11421173

11431174
response = self.process_data(data, cmd_listen)
@@ -1149,7 +1180,7 @@ def error_wrapper(e):
11491180

11501181
result = self.format_reply(cmd_listen, response)
11511182

1152-
if result and result.kind in ('subscribe', 'psubscribe'):
1183+
if result and result.kind in (b'subscribe', b'psubscribe'):
11531184
self.on_subscribed(result)
11541185
try:
11551186
__, cb = self.subscribe_callbacks.popleft()
@@ -1158,7 +1189,7 @@ def error_wrapper(e):
11581189
if cb:
11591190
cb(True)
11601191

1161-
if result and result.kind in ('unsubscribe', 'punsubscribe'):
1192+
if result and result.kind in (b'unsubscribe', b'punsubscribe'):
11621193
self.on_unsubscribed([result.channel])
11631194

11641195
callback(result)
@@ -1275,7 +1306,7 @@ def format_replies(self, cmd_lines, responses):
12751306
return results
12761307

12771308
def format_pipeline_request(self, command_stack):
1278-
return ''.join(self.format_command(c.cmd, *c.args, **c.kwargs)
1309+
return SYM_EMPTY.join(self.format_command(c.cmd, *c.args, **c.kwargs)
12791310
for c in command_stack)
12801311

12811312
@gen.engine

tornadoredis/connection.py

-2
Original file line numberDiff line numberDiff line change
@@ -120,8 +120,6 @@ def write(self, data, callback=None):
120120
else:
121121
cb = None
122122
try:
123-
if PY3:
124-
data = bytes(data, encoding='utf-8')
125123
self._stream.write(data, callback=cb)
126124
except IOError as e:
127125
self.disconnect()

tornadoredis/pubsub.py

+8-3
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ def subscribe(self, channel_name, subscriber, callback=None):
3939
callback - a callback function
4040
"""
4141
if isinstance(channel_name, list) or isinstance(channel_name, tuple):
42+
channel_name = [self.redis.encode(c) for c in channel_name]
4243
if len(channel_name) > 1:
4344
_cb = lambda *args, **kwargs: self.subscribe(channel_name[1:],
4445
subscriber,
@@ -47,6 +48,7 @@ def subscribe(self, channel_name, subscriber, callback=None):
4748
_cb = callback
4849
self.subscribe(channel_name[0], subscriber, callback=_cb)
4950
else:
51+
channel_name = self.redis.encode(channel_name)
5052
self.subscribers[channel_name][subscriber] += 1
5153
self.subscriber_count[channel_name] += 1
5254
if self.subscriber_count[channel_name] == 1:
@@ -73,6 +75,9 @@ def unsubscribe(self, channel_name, subscriber):
7375
Unsubscribes the redis client from the channel
7476
if there are no subscribers left.
7577
"""
78+
79+
channel_name = self.redis.encode(channel_name)
80+
7681
self.subscribers[channel_name][subscriber] -= 1
7782
if self.subscribers[channel_name][subscriber] <= 0:
7883
del self.subscribers[channel_name][subscriber]
@@ -93,7 +98,7 @@ def on_message(self, msg):
9398
if not msg:
9499
return
95100

96-
if msg.kind == 'disconnect':
101+
if msg.kind == b'disconnect':
97102
# Disconnected from the Redis server
98103
# Close the redis connection
99104
self.close()
@@ -139,7 +144,7 @@ class SockJSSubscriber(BaseSubscriber):
139144
def on_message(self, msg):
140145
if not msg:
141146
return
142-
if msg.kind == 'message' and msg.body:
147+
if msg.kind == b'message' and msg.body:
143148
# Get the list of subscribers for this channel
144149
subscribers = list(self.subscribers[msg.channel].keys())
145150
if subscribers:
@@ -160,7 +165,7 @@ class SocketIOSubscriber(BaseSubscriber):
160165
def on_message(self, msg):
161166
if not msg:
162167
return
163-
if msg.kind == 'message' and msg.body:
168+
if msg.kind == b'message' and msg.body:
164169
# Get the list of subscribers for this channel
165170
subscribers = list(self.subscribers[msg.channel].keys())
166171
if subscribers:

tornadoredis/tests/__init__.py

+8-8
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
1-
from .test_leaks import *
2-
from .test_commands import *
1+
#from .test_leaks import *
2+
#from .test_commands import *
33
from .test_pubsub import *
4-
from .test_pipeline import *
5-
from .test_scripting import *
6-
from .test_reconnect import *
7-
from .test_pool import *
8-
from .test_locks import *
9-
from .test_ipv6 import *
4+
#from .test_pipeline import *
5+
#from .test_scripting import *
6+
#from .test_reconnect import *
7+
#from .test_pool import *
8+
#from .test_locks import *
9+
#from .test_ipv6 import *

0 commit comments

Comments
 (0)