Skip to content
This repository was archived by the owner on Apr 3, 2019. It is now read-only.

Don't assume values are unicode, they could be just binary #78

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 54 additions & 24 deletions tornadoredis/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,28 @@
PY3 = sys.version > '3'


if not PY3:
from itertools import imap
basestring = basestring
unicode = unicode
bytes = str
long = long
b = lambda x: x
else:
imap = map

basestring = str
unicode = str
bytes = bytes
long = int
b = lambda x: x.encode('latin-1') if not isinstance(x, bytes) else x

SYM_STAR = b('*')
SYM_DOLLAR = b('$')
SYM_CRLF = b('\r\n')
SYM_EMPTY = b('')


class CmdLine(object):
def __init__(self, cmd, *args, **kwargs):
self.cmd = cmd
Expand Down Expand Up @@ -136,7 +158,7 @@ def get_value(value):
sub_dict[k] = v
return sub_dict
for line in response.splitlines():
line = line.strip()
line = line.strip().decode()
if line and not line.startswith('#'):
key, value = line.split(':')
try:
Expand All @@ -159,7 +181,7 @@ def reply_fn(r, *args, **kwargs):


def to_list(source):
if isinstance(source, str):
if isinstance(source, (bytes, str)):
return [source]
else:
return list(source)
Expand Down Expand Up @@ -225,7 +247,7 @@ class Client(object):

def __init__(self, host='localhost', port=6379, unix_socket_path=None,
password=None, selected_db=None, io_loop=None,
connection_pool=None):
connection_pool=None, encoding='utf-8', encoding_errors='strict'):
self._io_loop = io_loop or IOLoop.current()
self._connection_pool = connection_pool
self._weak = weakref.proxy(self)
Expand All @@ -244,6 +266,8 @@ def __init__(self, host='localhost', port=6379, unix_socket_path=None,
self.password = password
self.selected_db = selected_db or 0
self._pipeline = None
self.encoding = encoding
self.encoding_errors = encoding_errors

def __del__(self):
try:
Expand Down Expand Up @@ -351,24 +375,31 @@ def disconnect(self, callback=None):
if callback:
callback(False)

#### formatting
def encode(self, value):
if not isinstance(value, str):
if not PY3 and isinstance(value, unicode):
value = value.encode('utf-8')
else:
value = str(value)
if PY3:
value = value.encode('utf-8')
"Return a bytestring representation of the value"
if isinstance(value, bytes):
return value
elif isinstance(value, (int, long)):
value = b(str(value))
elif isinstance(value, float):
value = b(repr(value))
elif not isinstance(value, basestring):
value = str(value)
if isinstance(value, unicode):
value = value.encode(self.encoding, self.encoding_errors)
return value

def format_command(self, *tokens, **kwargs):
cmds = []
for t in tokens:
e_t = self.encode(t)
e_t_s = to_basestring(e_t)
cmds.append('$%s\r\n%s\r\n' % (len(e_t), e_t_s))
return '*%s\r\n%s' % (len(tokens), ''.join(cmds))

buff = SYM_EMPTY.join((SYM_STAR, b(str(len(tokens))), SYM_CRLF))

for arg in imap(self.encode, tokens):
buff = SYM_EMPTY.join(
(buff, SYM_DOLLAR, b(str(len(arg))), SYM_CRLF, arg, SYM_CRLF))

return buff


def format_reply(self, cmd_line, data):
if cmd_line.cmd not in REPLY_MAP:
Expand Down Expand Up @@ -458,7 +489,6 @@ def _consume_bulk(self, tail, callback=None):
if not response:
raise ResponseError('EmptyResponse')
else:
response = to_unicode(response)
response = response[:-2]
callback(response)

Expand Down Expand Up @@ -1033,14 +1063,14 @@ def psubscribe(self, channels, callback=None):
self._subscribe('PSUBSCRIBE', channels, callback=callback)

def _subscribe(self, cmd, channels, callback=None):
if isinstance(channels, str) or (not PY3 and isinstance(channels, unicode)):
if isinstance(channels, (bytes, str)) or (not PY3 and isinstance(channels, unicode)):
channels = [channels]
if not self.subscribed:
listen_callback = None
original_cb = stack_context.wrap(callback) if callback else None

def _cb(*args, **kwargs):
self.on_subscribed(Message(kind='subscribe',
self.on_subscribed(Message(kind=b'subscribe',
channel=channels[0],
body=None,
pattern=None))
Expand Down Expand Up @@ -1076,7 +1106,7 @@ def punsubscribe(self, channels, callback=None):
self._unsubscribe('PUNSUBSCRIBE', channels, callback=callback)

def _unsubscribe(self, cmd, channels, callback=None):
if isinstance(channels, str) or (not PY3 and isinstance(channels, unicode)):
if isinstance(channels, (bytes, str)) or (not PY3 and isinstance(channels, unicode)):
channels = [channels]
if callback:
cb = stack_context.wrap(callback)
Expand Down Expand Up @@ -1137,7 +1167,7 @@ def error_wrapper(e):
self.subscribed = set()
# send a message to caller:
# Message(kind='disconnect', channel=set(channel1, ...))
callback(reply_pubsub_message(('disconnect', channels)))
callback(reply_pubsub_message((b'disconnect', channels)))
return

response = self.process_data(data, cmd_listen)
Expand All @@ -1149,7 +1179,7 @@ def error_wrapper(e):

result = self.format_reply(cmd_listen, response)

if result and result.kind in ('subscribe', 'psubscribe'):
if result and result.kind in (b'subscribe', b'psubscribe'):
self.on_subscribed(result)
try:
__, cb = self.subscribe_callbacks.popleft()
Expand All @@ -1158,7 +1188,7 @@ def error_wrapper(e):
if cb:
cb(True)

if result and result.kind in ('unsubscribe', 'punsubscribe'):
if result and result.kind in (b'unsubscribe', b'punsubscribe'):
self.on_unsubscribed([result.channel])

callback(result)
Expand Down Expand Up @@ -1275,7 +1305,7 @@ def format_replies(self, cmd_lines, responses):
return results

def format_pipeline_request(self, command_stack):
return ''.join(self.format_command(c.cmd, *c.args, **c.kwargs)
return SYM_EMPTY.join(self.format_command(c.cmd, *c.args, **c.kwargs)
for c in command_stack)

@gen.engine
Expand Down
2 changes: 0 additions & 2 deletions tornadoredis/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,8 +120,6 @@ def write(self, data, callback=None):
else:
cb = None
try:
if PY3:
data = bytes(data, encoding='utf-8')
self._stream.write(data, callback=cb)
except IOError as e:
self.disconnect()
Expand Down
11 changes: 8 additions & 3 deletions tornadoredis/pubsub.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def subscribe(self, channel_name, subscriber, callback=None):
callback - a callback function
"""
if isinstance(channel_name, list) or isinstance(channel_name, tuple):
channel_name = [self.redis.encode(c) for c in channel_name]
if len(channel_name) > 1:
_cb = lambda *args, **kwargs: self.subscribe(channel_name[1:],
subscriber,
Expand All @@ -47,6 +48,7 @@ def subscribe(self, channel_name, subscriber, callback=None):
_cb = callback
self.subscribe(channel_name[0], subscriber, callback=_cb)
else:
channel_name = self.redis.encode(channel_name)
self.subscribers[channel_name][subscriber] += 1
self.subscriber_count[channel_name] += 1
if self.subscriber_count[channel_name] == 1:
Expand All @@ -73,6 +75,9 @@ def unsubscribe(self, channel_name, subscriber):
Unsubscribes the redis client from the channel
if there are no subscribers left.
"""

channel_name = self.redis.encode(channel_name)

self.subscribers[channel_name][subscriber] -= 1
if self.subscribers[channel_name][subscriber] <= 0:
del self.subscribers[channel_name][subscriber]
Expand All @@ -93,7 +98,7 @@ def on_message(self, msg):
if not msg:
return

if msg.kind == 'disconnect':
if msg.kind == b'disconnect':
# Disconnected from the Redis server
# Close the redis connection
self.close()
Expand Down Expand Up @@ -139,7 +144,7 @@ class SockJSSubscriber(BaseSubscriber):
def on_message(self, msg):
if not msg:
return
if msg.kind == 'message' and msg.body:
if msg.kind == b'message' and msg.body:
# Get the list of subscribers for this channel
subscribers = list(self.subscribers[msg.channel].keys())
if subscribers:
Expand All @@ -160,7 +165,7 @@ class SocketIOSubscriber(BaseSubscriber):
def on_message(self, msg):
if not msg:
return
if msg.kind == 'message' and msg.body:
if msg.kind == b'message' and msg.body:
# Get the list of subscribers for this channel
subscribers = list(self.subscribers[msg.channel].keys())
if subscribers:
Expand Down
16 changes: 8 additions & 8 deletions tornadoredis/tests/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from .test_leaks import *
from .test_commands import *
#from .test_leaks import *
#from .test_commands import *
from .test_pubsub import *
from .test_pipeline import *
from .test_scripting import *
from .test_reconnect import *
from .test_pool import *
from .test_locks import *
from .test_ipv6 import *
#from .test_pipeline import *
#from .test_scripting import *
#from .test_reconnect import *
#from .test_pool import *
#from .test_locks import *
#from .test_ipv6 import *
Loading