From 728029210c61b0135bfe35690a3cfa035ad252cb Mon Sep 17 00:00:00 2001 From: Alexey Romanoff Date: Mon, 25 Apr 2022 13:46:10 +0300 Subject: [PATCH 1/5] src/websocket: copy {client_sync, sync}.lua to {client_poller, poller}.lua --- src/websocket/client_poller.lua | 38 ++++++ src/websocket/poller.lua | 199 ++++++++++++++++++++++++++++++++ 2 files changed, 237 insertions(+) create mode 100644 src/websocket/client_poller.lua create mode 100644 src/websocket/poller.lua diff --git a/src/websocket/client_poller.lua b/src/websocket/client_poller.lua new file mode 100644 index 0000000..71b89e0 --- /dev/null +++ b/src/websocket/client_poller.lua @@ -0,0 +1,38 @@ +local socket = require'socket' +local sync = require'websocket.sync' +local tools = require'websocket.tools' + +local new = function(ws) + ws = ws or {} + local self = {} + + self.sock_connect = function(self,host,port) + self.sock = socket.tcp() + if ws.timeout ~= nil then + self.sock:settimeout(ws.timeout) + end + local _,err = self.sock:connect(host,port) + if err then + self.sock:close() + return nil,err + end + end + + self.sock_send = function(self,...) + return self.sock:send(...) + end + + self.sock_receive = function(self,...) + return self.sock:receive(...) + end + + self.sock_close = function(self) + self.sock:shutdown() + self.sock:close() + end + + self = sync.extend(self) + return self +end + +return new diff --git a/src/websocket/poller.lua b/src/websocket/poller.lua new file mode 100644 index 0000000..17ed2ca --- /dev/null +++ b/src/websocket/poller.lua @@ -0,0 +1,199 @@ +local frame = require'websocket.frame' +local handshake = require'websocket.handshake' +local tools = require'websocket.tools' +local tinsert = table.insert +local tconcat = table.concat + + +local receive = function(self) + if self.state ~= 'OPEN' and not self.is_closing then + return nil,nil,false,1006,'wrong state' + end + local first_opcode + local frames + local bytes = 3 + local encoded = '' + local clean = function(was_clean,code,reason) + self.state = 'CLOSED' + self:sock_close() + if self.on_close then + self:on_close() + end + return nil,nil,was_clean,code,reason or 'closed' + end + while true do + local chunk,err = self:sock_receive(bytes) + if err then + return clean(false,1006,err) + end + encoded = encoded..chunk + local decoded,fin,opcode,_,masked = frame.decode(encoded) + if not self.is_server and masked then + return clean(false,1006,'Websocket receive failed: frame was not masked') + end + if decoded then + if opcode == frame.CLOSE then + if not self.is_closing then + local code,reason = frame.decode_close(decoded) + -- echo code + local msg = frame.encode_close(code) + local encoded = frame.encode(msg,frame.CLOSE,not self.is_server) + local n,err = self:sock_send(encoded) + if n == #encoded then + return clean(true,code,reason) + else + return clean(false,code,err) + end + else + return decoded,opcode + end + end + if not first_opcode then + first_opcode = opcode + end + if not fin then + if not frames then + frames = {} + elseif opcode ~= frame.CONTINUATION then + return clean(false,1002,'protocol error') + end + bytes = 3 + encoded = '' + tinsert(frames,decoded) + elseif not frames then + return decoded,first_opcode + else + tinsert(frames,decoded) + return tconcat(frames),first_opcode + end + else + assert(type(fin) == 'number' and fin > 0) + bytes = fin + end + end + assert(false,'never reach here') +end + +local send = function(self,data,opcode) + if self.state ~= 'OPEN' then + return nil,false,1006,'wrong state' + end + local encoded = frame.encode(data,opcode or frame.TEXT,not self.is_server) + local n,err = self:sock_send(encoded) + if n ~= #encoded then + return nil,self:close(1006,err) + end + return true +end + +local close = function(self,code,reason) + if self.state ~= 'OPEN' then + return false,1006,'wrong state' + end + if self.state == 'CLOSED' then + return false,1006,'wrong state' + end + local msg = frame.encode_close(code or 1000,reason) + local encoded = frame.encode(msg,frame.CLOSE,not self.is_server) + local n,err = self:sock_send(encoded) + local was_clean = false + local code = 1005 + local reason = '' + if n == #encoded then + self.is_closing = true + local rmsg,opcode = self:receive() + if rmsg and opcode == frame.CLOSE then + code,reason = frame.decode_close(rmsg) + was_clean = true + end + else + reason = err + end + self:sock_close() + if self.on_close then + self:on_close() + end + self.state = 'CLOSED' + return was_clean,code,reason or '' +end + +local connect = function(self,ws_url,ws_protocol) + if self.state ~= 'CLOSED' then + return nil,'wrong state' + end + local protocol,host,port,uri = tools.parse_url(ws_url) + if protocol ~= 'ws' then + return nil,'bad protocol' + end + local _,err = self:sock_connect(host,port) + if err then + return nil,err + end + local ws_protocols_tbl = {''} + if type(ws_protocol) == 'string' then + ws_protocols_tbl = {ws_protocol} + elseif type(ws_protocol) == 'table' then + ws_protocols_tbl = ws_protocol + end + local key = tools.generate_key() + local req = handshake.upgrade_request + { + key = key, + host = host, + port = port, + protocols = ws_protocols_tbl, + uri = uri + } + local n,err = self:sock_send(req) + if n ~= #req then + return nil,err + end + local resp = {} + repeat + local line,err = self:sock_receive('*l') + resp[#resp+1] = line + if err then + return nil,err + end + until line == '' + local response = table.concat(resp,'\r\n') + local headers = handshake.http_headers(response) + local expected_accept = handshake.sec_websocket_accept(key) + if headers['sec-websocket-accept'] ~= expected_accept then + local msg = 'Websocket Handshake failed: Invalid Sec-Websocket-Accept (expected %s got %s)' + return nil,msg:format(expected_accept,headers['sec-websocket-accept'] or 'nil') + end + self.state = 'OPEN' + return true +end + +local extend = function(obj) + assert(obj.sock_send) + assert(obj.sock_receive) + assert(obj.sock_close) + + assert(obj.is_closing == nil) + assert(obj.receive == nil) + assert(obj.send == nil) + assert(obj.close == nil) + assert(obj.connect == nil) + + if not obj.is_server then + assert(obj.sock_connect) + end + + if not obj.state then + obj.state = 'CLOSED' + end + + obj.receive = receive + obj.send = send + obj.close = close + obj.connect = connect + + return obj +end + +return { + extend = extend +} From d8ddb86b82da51da5dd69931f7db701f3cfa867d Mon Sep 17 00:00:00 2001 From: Alexey Romanoff Date: Mon, 25 Apr 2022 13:48:12 +0300 Subject: [PATCH 2/5] src/websocket/client_poller.lua: fix the extending module to `poller` --- src/websocket/client_poller.lua | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/websocket/client_poller.lua b/src/websocket/client_poller.lua index 71b89e0..bf29a85 100644 --- a/src/websocket/client_poller.lua +++ b/src/websocket/client_poller.lua @@ -1,6 +1,5 @@ local socket = require'socket' -local sync = require'websocket.sync' -local tools = require'websocket.tools' +local poller = require 'websocket.poller' local new = function(ws) ws = ws or {} @@ -31,7 +30,7 @@ local new = function(ws) self.sock:close() end - self = sync.extend(self) + self = poller.extend(self) return self end From 5f6a694583056c898e59ce47674eef5ae7699f35 Mon Sep 17 00:00:00 2001 From: Alexey Romanoff Date: Mon, 25 Apr 2022 13:48:46 +0300 Subject: [PATCH 3/5] src/websocket/client_poller.lua: sock_connect: set zero timeout after connect --- src/websocket/client_poller.lua | 1 + 1 file changed, 1 insertion(+) diff --git a/src/websocket/client_poller.lua b/src/websocket/client_poller.lua index bf29a85..2c3f347 100644 --- a/src/websocket/client_poller.lua +++ b/src/websocket/client_poller.lua @@ -15,6 +15,7 @@ local new = function(ws) self.sock:close() return nil,err end + self.sock:settimeout(0) end self.sock_send = function(self,...) From 50f3ecb88c114f4626e22ca6b37798173a6039a2 Mon Sep 17 00:00:00 2001 From: Alexey Romanoff Date: Mon, 25 Apr 2022 13:58:13 +0300 Subject: [PATCH 4/5] src/websocket/poller.lua: replace sync mechanics with poller mechanics --- src/websocket/poller.lua | 127 +++++++++++++++++++++------------------ 1 file changed, 69 insertions(+), 58 deletions(-) diff --git a/src/websocket/poller.lua b/src/websocket/poller.lua index 17ed2ca..7e9fdc6 100644 --- a/src/websocket/poller.lua +++ b/src/websocket/poller.lua @@ -4,74 +4,83 @@ local tools = require'websocket.tools' local tinsert = table.insert local tconcat = table.concat +local clean = function(self, was_clean,code,reason) + self.state = 'CLOSED' + self:sock_close() + if self.on_close then + self:on_close() + end + return nil,nil,was_clean,code,reason or 'closed' +end -local receive = function(self) +local receive_init = function(self) + self.first_opcode = nil + self.frames = nil + self.bytes = 3 + self.encoded = '' +end + +local poll = function(self) if self.state ~= 'OPEN' and not self.is_closing then + receive_init(self) return nil,nil,false,1006,'wrong state' end - local first_opcode - local frames - local bytes = 3 - local encoded = '' - local clean = function(was_clean,code,reason) - self.state = 'CLOSED' - self:sock_close() - if self.on_close then - self:on_close() - end - return nil,nil,was_clean,code,reason or 'closed' - end - while true do - local chunk,err = self:sock_receive(bytes) - if err then - return clean(false,1006,err) - end - encoded = encoded..chunk - local decoded,fin,opcode,_,masked = frame.decode(encoded) - if not self.is_server and masked then - return clean(false,1006,'Websocket receive failed: frame was not masked') - end - if decoded then - if opcode == frame.CLOSE then - if not self.is_closing then - local code,reason = frame.decode_close(decoded) - -- echo code - local msg = frame.encode_close(code) - local encoded = frame.encode(msg,frame.CLOSE,not self.is_server) - local n,err = self:sock_send(encoded) - if n == #encoded then - return clean(true,code,reason) - else - return clean(false,code,err) - end + + local chunk,err = self:sock_receive(self.bytes) + if err and err ~= 'timeout' then + receive_init(self) + return clean(self, false,1006,err) + end + self.encoded = self.encoded..(chunk or '') + local decoded,fin,opcode,_,masked = frame.decode(self.encoded) + if not self.is_server and masked then + receive_init(self) + return clean(self, false,1006,'Websocket receive failed: frame was not masked') + end + if decoded then + if opcode == frame.CLOSE then + if not self.is_closing then + local code,reason = frame.decode_close(decoded) + -- echo code + local msg = frame.encode_close(code) + local encoded = frame.encode(msg,frame.CLOSE,not self.is_server) + local n,err = self:sock_send(encoded) + receive_init(self) + if n == #encoded then + return clean(self, true,code,reason) else - return decoded,opcode - end - end - if not first_opcode then - first_opcode = opcode - end - if not fin then - if not frames then - frames = {} - elseif opcode ~= frame.CONTINUATION then - return clean(false,1002,'protocol error') + return clean(self, false,code,err) end - bytes = 3 - encoded = '' - tinsert(frames,decoded) - elseif not frames then - return decoded,first_opcode else - tinsert(frames,decoded) - return tconcat(frames),first_opcode + receive_init(self) + return decoded,opcode end + end + if not self.first_opcode then + self.first_opcode = opcode + end + if not fin then + if not self.frames then + self.frames = {} + elseif opcode ~= frame.CONTINUATION then + receive_init(self) + return clean(self, false,1002,'protocol error') + end + self.bytes = 3 + self.encoded = '' + tinsert(self.frames,decoded) + elseif not self.frames then + receive_init(self) + return decoded,self.first_opcode else - assert(type(fin) == 'number' and fin > 0) - bytes = fin + tinsert(self.frames,decoded) + receive_init(self) + return tconcat(self.frames),self.first_opcode end + else + assert(type(fin) == 'number' and fin > 0) + self.bytes = fin end - assert(false,'never reach here') end local send = function(self,data,opcode) @@ -186,11 +195,13 @@ local extend = function(obj) obj.state = 'CLOSED' end - obj.receive = receive + obj.poll = poll obj.send = send obj.close = close obj.connect = connect + receive_init(obj) + return obj end From b7480161330c60cc00762879b52f49b55d3f64da Mon Sep 17 00:00:00 2001 From: Alexey Romanoff Date: Mon, 25 Apr 2022 13:58:49 +0300 Subject: [PATCH 5/5] src/websocket/poller.lua: add ssl --- src/websocket/poller.lua | 25 +++++++++++++++---------- 1 file changed, 15 insertions(+), 10 deletions(-) diff --git a/src/websocket/poller.lua b/src/websocket/poller.lua index 7e9fdc6..ca68f5c 100644 --- a/src/websocket/poller.lua +++ b/src/websocket/poller.lua @@ -1,6 +1,7 @@ local frame = require'websocket.frame' local handshake = require'websocket.handshake' local tools = require'websocket.tools' +local ssl = require'ssl' local tinsert = table.insert local tconcat = table.concat @@ -126,17 +127,21 @@ local close = function(self,code,reason) return was_clean,code,reason or '' end -local connect = function(self,ws_url,ws_protocol) +local connect = function(self,ws_url,ws_protocol,ssl_params) if self.state ~= 'CLOSED' then - return nil,'wrong state' + return nil,'wrong state',nil end local protocol,host,port,uri = tools.parse_url(ws_url) - if protocol ~= 'ws' then - return nil,'bad protocol' - end + -- Preconnect (for SSL if needed) local _,err = self:sock_connect(host,port) if err then - return nil,err + return nil,err,nil + end + if protocol == 'wss' then + self.sock = ssl.wrap(self.sock, ssl_params) + self.sock:dohandshake() + elseif protocol ~= "ws" then + return nil, 'bad protocol' end local ws_protocols_tbl = {''} if type(ws_protocol) == 'string' then @@ -155,14 +160,14 @@ local connect = function(self,ws_url,ws_protocol) } local n,err = self:sock_send(req) if n ~= #req then - return nil,err + return nil,err,nil end local resp = {} repeat local line,err = self:sock_receive('*l') resp[#resp+1] = line if err then - return nil,err + return nil,err,nil end until line == '' local response = table.concat(resp,'\r\n') @@ -170,10 +175,10 @@ local connect = function(self,ws_url,ws_protocol) local expected_accept = handshake.sec_websocket_accept(key) if headers['sec-websocket-accept'] ~= expected_accept then local msg = 'Websocket Handshake failed: Invalid Sec-Websocket-Accept (expected %s got %s)' - return nil,msg:format(expected_accept,headers['sec-websocket-accept'] or 'nil') + return nil,msg:format(expected_accept,headers['sec-websocket-accept'] or 'nil'),headers end self.state = 'OPEN' - return true + return true,headers['sec-websocket-protocol'],headers end local extend = function(obj)