From f866d2bd5809c927881dbcc07893d1658b70cb26 Mon Sep 17 00:00:00 2001 From: "kepler155c@gmail.com" Date: Tue, 9 May 2017 01:57:00 -0400 Subject: [PATCH] socket improvement --- apps/mirrorHost.lua | 2 +- apps/telnet.lua | 12 ++-- apps/trust.lua | 2 +- sys/apis/blocks.lua | 1 + sys/apis/socket.lua | 115 ++++++++++++-------------------------- sys/apis/util.lua | 13 +++-- sys/network/snmp.lua | 2 +- sys/network/telnet.lua | 8 ++- sys/network/trust.lua | 4 +- sys/services/transort.lua | 97 ++++++++++++++++++++++++++++++++ 10 files changed, 158 insertions(+), 98 deletions(-) create mode 100644 sys/services/transort.lua diff --git a/apps/mirrorHost.lua b/apps/mirrorHost.lua index 74c0536..ea5802a 100644 --- a/apps/mirrorHost.lua +++ b/apps/mirrorHost.lua @@ -16,7 +16,7 @@ mon.setBackgroundColor(colors.black) mon.clear() while true do - local socket = Socket.server(5901, true) + local socket = Socket.server(5901) print('mirror: connection from ' .. socket.dhost) diff --git a/apps/telnet.lua b/apps/telnet.lua index a1af509..39ef850 100644 --- a/apps/telnet.lua +++ b/apps/telnet.lua @@ -46,13 +46,14 @@ process:newThread('telnet_read', function() ct[v.f](unpack(v.args)) end end + print('telnet_read exiting') end) ct.clear() ct.setCursorPos(1, 1) while true do - local e = { process:pullEvent() } + local e = { process:pullEvent(nil, true) } local event = e[1] if not socket.connected then @@ -71,10 +72,11 @@ while true do event == 'mouse_click' or event == 'mouse_drag' then - socket:write({ - type = 'shellRemote', - event = e, - }) + if not socket:write({ type = 'shellRemote', event = e }) then + socket:close() + break + end + elseif event == 'terminate' then socket:close() break diff --git a/apps/trust.lua b/apps/trust.lua index c04e7d3..f474724 100644 --- a/apps/trust.lua +++ b/apps/trust.lua @@ -53,7 +53,7 @@ local secretKey = os.getSecretKey() local publicKey = modexp(exchange.base, secretKey, exchange.primeMod) local password = SHA1.sha1(password) -socket:write(Crypto.encrypt({ pk = publicKey }, password)) +socket:write(Crypto.encrypt({ pk = publicKey, dh = os.getComputerID() }, password)) print(socket:read(2) or 'No response') diff --git a/sys/apis/blocks.lua b/sys/apis/blocks.lua index dbf3683..1f47309 100644 --- a/sys/apis/blocks.lua +++ b/sys/apis/blocks.lua @@ -298,6 +298,7 @@ function standardBlockDB:seedDB() [ '86:0' ] = 'pumpkin', [ '90:0' ] = 'air', [ '91:0' ] = 'pumpkin', + [ '92:0' ] = 'flatten', -- cake [ '93:0' ] = 'repeater', [ '94:0' ] = 'repeater', [ '96:0' ] = 'trapdoor', diff --git a/sys/apis/socket.lua b/sys/apis/socket.lua index 27fbfe9..6a9671e 100644 --- a/sys/apis/socket.lua +++ b/sys/apis/socket.lua @@ -31,41 +31,29 @@ function socketClass:read(timeout) return end - local timerId - local filter - - if timeout then - timerId = os.startTimer(timeout) - elseif self.keepAlive then - timerId = os.startTimer(3) - else - filter = 'modem_message' + local data, distance = transport.read(self) + if data then + return data, distance end - while true do - local e, s, dport, dhost, msg, distance = os.pullEvent(filter) - if e == 'modem_message' and - dport == self.sport and dhost == self.shost and - msg then + local timerId = os.startTimer(timeout or 5) - if msg.type == 'DISC' then - -- received disconnect from other end - self.connected = false - self:close() - return - elseif msg.type == 'DATA' then - if msg.data then - if timerId then - os.cancelTimer(timerId) - end - return msg.data, distance - end + while true do + local e, id = os.pullEvent() + + if e == 'transport_' .. self.dport then + + data, distance = transport.read(self) + if data then + os.cancelTimer(timerId) + return data, distance end - elseif e == 'timer' and s == timerId then + + elseif e == 'timer' and id == timerId then if timeout or not self.connected then break end - timerId = os.startTimer(3) + timerId = os.startTimer(5) end end end @@ -75,8 +63,9 @@ function socketClass:write(data) Logger.log('socket', 'write: No connection') return false end - self.transmit(self.dport, self.dhost, { + transport.write(self, { type = 'DATA', + seq = self.wseq, data = data, }) return true @@ -91,45 +80,7 @@ function socketClass:close() self.connected = false end device.wireless_modem.close(self.sport) -end - --- write a ping every second (too much traffic!) -local function pinger(socket) - - local process = require('process') - - socket.keepAlive = true - - Logger.log('socket', 'keepAlive enabled') - - process:newThread('socket_ping', function() - local timerId = os.startTimer(1) - local timeStamp = os.clock() - - while true do - local e, id, dport, dhost, msg = os.pullEvent() - - if e == 'modem_message' then - if dport == socket.sport and - dhost == socket.shost and - msg and - msg.type == 'PING' then - - timeStamp = os.clock() - end - elseif e == 'timer' and id == timerId then - if os.clock() - timeStamp > 3 then - Logger.log('socket', 'Connection timed out') - socket:close() - break - end - timerId = os.startTimer(1) - socket.transmit(socket.dport, socket.dhost, { - type = 'PING', - }) - end - end - end) + transport.close(self) end local Socket = { } @@ -139,12 +90,16 @@ local function loopback(port, sport, msg) end local function newSocket(isLoopback) - for i = 16384, 32768 do + for i = 16384, 32767 do if not device.wireless_modem.isOpen(i) then local socket = { shost = os.getComputerID(), sport = i, transmit = device.wireless_modem.transmit, + wseq = math.random(100, 100000), + rseq = math.random(100, 100000), + timers = { }, + messages = { }, } setmetatable(socket, { __index = socketClass }) @@ -169,7 +124,9 @@ function Socket.connect(host, port) type = 'OPEN', shost = socket.shost, dhost = socket.dhost, - t = Crypto.encrypt({ ts = os.time() }, exchange.publicKey), + t = Crypto.encrypt({ ts = os.time(), seq = socket.seq }, exchange.publicKey), + rseq = socket.wseq, + wseq = socket.rseq, }) local timerId = os.startTimer(3) @@ -185,11 +142,10 @@ function Socket.connect(host, port) Logger.log('socket', 'connection established to %d %d->%d', host, socket.sport, socket.dport) - if msg.keepAlive then - pinger(socket) - end os.cancelTimer(timerId) + transport.open(socket) + return socket end until e == 'timer' and id == timerId @@ -199,7 +155,8 @@ end function trusted(msg, port) - if port == 19 then -- no auth for trust server + if port == 19 or msg.shost == os.getComputerID() then + -- no auth for trust server or loopback return true end @@ -214,7 +171,7 @@ function trusted(msg, port) end end -function Socket.server(port, keepAlive) +function Socket.server(port) device.wireless_modem.open(port) Logger.log('socket', 'Waiting for connections on port ' .. port) @@ -232,18 +189,16 @@ function Socket.server(port, keepAlive) socket.dport = dport socket.dhost = msg.shost socket.connected = true - + socket.wseq = msg.wseq + socket.rseq = msg.rseq socket.transmit(socket.dport, socket.sport, { type = 'CONN', dhost = socket.dhost, shost = socket.shost, - keepAlive = keepAlive, }) Logger.log('socket', 'Connection established %d->%d', socket.sport, socket.dport) - if keepAlive then - pinger(socket) - end + transport.open(socket) return socket end end diff --git a/sys/apis/util.lua b/sys/apis/util.lua index 46979aa..7ba934d 100644 --- a/sys/apis/util.lua +++ b/sys/apis/util.lua @@ -21,7 +21,7 @@ function Util.tryTimes(attempts, f, ...) return unpack(result) end -function Util.print(pattern, ...) +function Util.tostring(pattern, ...) local function serialize(tbl, width) local str = '{\n' @@ -43,12 +43,15 @@ function Util.print(pattern, ...) end if type(pattern) == 'string' then - print(string.format(pattern, ...)) + return string.format(pattern, ...) elseif type(pattern) == 'table' then - print(serialize(pattern, term.current().getSize())) - else - print(tostring(pattern)) + return serialize(pattern, term.current().getSize()) end + return tostring(pattern) +end + +function Util.print(pattern, ...) + print(Util.tostring(pattern, ...)) end function Util.runFunction(env, fn, ...) diff --git a/sys/network/snmp.lua b/sys/network/snmp.lua index 43b083c..ef9f74c 100644 --- a/sys/network/snmp.lua +++ b/sys/network/snmp.lua @@ -101,7 +101,7 @@ process:newThread('discovery_server', function() while true do local e, s, sport, id, info, distance = os.pullEvent('modem_message') - if sport == 999 then + if sport == 999 and tonumber(id) and type(info) == 'table' then if not network[id] then network[id] = { } end diff --git a/sys/network/telnet.lua b/sys/network/telnet.lua index a5fa35e..859bbf3 100644 --- a/sys/network/telnet.lua +++ b/sys/network/telnet.lua @@ -40,7 +40,7 @@ local function telnetHost(socket, termInfo) socket:close() end) - local queueThread = process:newThread('telnet_read', function() + process:newThread('telnet_read', function() while true do local data = socket:read() if not data then @@ -49,7 +49,6 @@ local function telnetHost(socket, termInfo) if data.type == 'shellRemote' then local event = table.remove(data.event, 1) - shellThread:resume(event, unpack(data.event)) end end @@ -65,7 +64,10 @@ local function telnetHost(socket, termInfo) break end if socket.queue then - socket:write(socket.queue) + if not socket:write(socket.queue) then + print('telnet: connection lost to ' .. socket.dhost) + break + end socket.queue = nil end end diff --git a/sys/network/trust.lua b/sys/network/trust.lua index 282b56f..6a981da 100644 --- a/sys/network/trust.lua +++ b/sys/network/trust.lua @@ -17,9 +17,9 @@ process:newThread('trust_server', function() socket:write('No password has been set') else data = Crypto.decrypt(data, password) - if data and data.pk then + if data and data.pk and data.dh then local trustList = Util.readTable('.known_hosts') or { } - trustList[socket.dhost] = data.pk + trustList[data.dh] = data.pk Util.writeTable('.known_hosts', trustList) socket:write('Trust accepted') diff --git a/sys/services/transort.lua b/sys/services/transort.lua new file mode 100644 index 0000000..53bff02 --- /dev/null +++ b/sys/services/transort.lua @@ -0,0 +1,97 @@ +--[[ + Low level socket protocol implementation. + + * sequencing + * write acknowledgements + * background read buffering +]]-- + +multishell.setTitle(multishell.getCurrent(), 'Net transport') + +local computerId = os.getComputerID() + +_G.transport = { + timers = { }, + sockets = { }, +} + +function transport.open(socket) + transport.sockets[socket.sport] = socket +end + +function transport.read(socket) + local data = table.remove(socket.messages, 1) + if data then + return unpack(data) + end +end + +function transport.write(socket, data) + + --debug('>> ' .. Util.tostring({ type = 'DATA', seq = socket.wseq })) + socket.transmit(socket.dport, socket.dhost, data) + + local timerId = os.startTimer(2) + + transport.timers[timerId] = socket + socket.timers[socket.wseq] = timerId + + socket.wseq = socket.wseq + 1 +end + +function transport.close(socket) + transport.sockets[socket.sport] = nil +end + +print('Net transport started') + +while true do + local e, timerId, dport, dhost, msg, distance = os.pullEvent() + + if e == 'timer' then + local socket = transport.timers[timerId] + if socket and socket.connected then + socket:close() + transport.timers[timerId] = nil + end + + elseif e == 'modem_message' and dhost == computerId and msg then + local socket = transport.sockets[dport] + if socket and socket.connected then + + --if msg.type then debug('<< ' .. Util.tostring(msg)) end + + if msg.type == 'DISC' then + -- received disconnect from other end + socket.connected = false + socket:close() + + elseif msg.type == 'ACK' then + local timerId = socket.timers[msg.seq] + + os.cancelTimer(timerId) + socket.timers[msg.seq] = nil + transport.timers[timerId] = nil + + elseif msg.type == 'DATA' and msg.data then + if msg.seq ~= socket.rseq then + socket:close() + else + socket.rseq = socket.rseq + 1 + table.insert(socket.messages, { msg.data, distance }) + + -- use resume instead ?? + if not socket.messages[2] then -- table size is 1 + os.queueEvent('transport_' .. dport) + end + + --debug('>> ' .. Util.tostring({ type = 'ACK', seq = msg.seq })) + socket.transmit(socket.dport, socket.dhost, { + type = 'ACK', + seq = msg.seq, + }) + end + end + end + end +end