-
Notifications
You must be signed in to change notification settings - Fork 2
/
transfer.lua
94 lines (87 loc) · 2.68 KB
/
transfer.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
local transfer = {}
local socket = require 'socket'
function transfer.hostname()
return socket.dns.gethostname()
end
function transfer.bind(port, backlog, ipv6)
local master
if ipv6 then
master = assert(socket.tcp6())
else
master = assert(socket.tcp())
end
assert(master:bind("*", port))
assert(master:listen(backlog))
return master
end
function transfer.connect(hostname, port, ipv6, retry)
retry = retry or math.huge
local function connect()
local master
if ipv6 then
master = assert(socket.tcp6())
else
master = assert(socket.tcp())
end
assert(master:connect(hostname, port))
return master
end
local status, client
repeat
status, client = pcall(connect)
retry = retry - 1
if not status then
print(
string.format("Failed to connect <%s:%s> (%s). [Retrying]",
hostname, port, client)
)
socket.sleep(10)
end
until status or retry == 0
if not status then
error(
string.format("Failed to connect <%s:%s> (%s)", hostname, port, client)
)
end
return client
end
-- split transfer into several blocks to circumvent luajit issues
local BLOCKSZ = 2^24 -- 16MB
function transfer.send(c, data)
data = torch.CharTensor(torch.serializeToStorage(data))
local size = data:size(1)
c:send(string.format("0x%0.16x", size))
local n = 0
local buffer
while n < math.floor(size/BLOCKSZ) do
buffer = buffer or torch.CharTensor(BLOCKSZ)
buffer:copy(data:narrow(1, n*BLOCKSZ+1, BLOCKSZ))
local subdata = buffer:storage():string()
assert(c:send(subdata) == BLOCKSZ, 'send error')
n = n + 1
end
local subdata =
data:narrow(1, n*BLOCKSZ+1, size % BLOCKSZ):clone():storage():string()
assert(c:send(subdata) == size % BLOCKSZ, 'send error')
end
function transfer.receive(c)
local size = assert(c:receive(18), 'receive error')
size = tonumber(size)
local data = torch.CharTensor(size)
local n = 0
local buffer
while n < math.floor(size/BLOCKSZ) do
buffer = buffer or torch.CharTensor(BLOCKSZ)
local subdata = assert(c:receive(BLOCKSZ), 'receive error')
assert(#subdata == BLOCKSZ, 'receive error')
buffer:storage():string(subdata)
data:narrow(1, n*BLOCKSZ+1, BLOCKSZ):copy(buffer)
n = n + 1
end
local subdata = assert(c:receive(size % BLOCKSZ), 'receive error')
assert(#subdata == size % BLOCKSZ, 'receive error')
subdata = torch.CharTensor(torch.CharStorage():string(subdata))
data:narrow(1, n*BLOCKSZ+1, size % BLOCKSZ):copy(subdata)
return torch.deserializeFromStorage(data:storage())
end
return transfer