diff --git a/sys/apps/network/transport.lua b/sys/apps/network/transport.lua index 9717e4e..2732d12 100644 --- a/sys/apps/network/transport.lua +++ b/sys/apps/network/transport.lua @@ -15,6 +15,7 @@ local computerId = os.getComputerID() local transport = { timers = { }, sockets = { }, + encryptQueue = { }, UID = 0, } @@ -40,8 +41,15 @@ function transport.read(socket) end end -function transport.write(socket, data) - socket.transmit(socket.dport, socket.dhost, data) +function transport.write(socket, msg) + if socket.options.ENCRYPT then + if #transport.encryptQueue == 0 then + os.queueEvent('transport_encrypt') + end + table.insert(transport.encryptQueue, { socket.sport, msg }) + else + socket.transmit(socket.dport, socket.dhost, msg) + end socket.wseq = socket.wrng:nextInt(5) end @@ -63,6 +71,19 @@ function transport.close(socket) transport.sockets[socket.sport] = nil end +Event.on('transport_encrypt', function() + while #transport.encryptQueue > 0 do + local entry = table.remove(transport.encryptQueue, 1) + local socket = transport.sockets[entry[1]] + + if socket and socket.connected then + local msg = entry[2] + msg.data = Crypto.encrypt({ msg.data }, socket.enckey) + socket.transmit(socket.dport, socket.dhost, msg) + end + end +end) + Event.on('timer', function(_, timerId) local socket = transport.timers[timerId] @@ -110,7 +131,7 @@ Event.on('modem_message', function(_, _, dport, dhost, msg, distance) elseif msg.type == 'DATA' and msg.data then if msg.seq ~= socket.rseq then - print('transport seq error - closing socket ' .. socket.sport) + print('transport seq error ' .. socket.sport) _syslog(msg.data) _syslog('expected ' .. socket.rseq) _syslog('got ' .. msg.seq) diff --git a/sys/modules/opus/crypto/chacha20.lua b/sys/modules/opus/crypto/chacha20.lua index f28f8dc..c82763f 100644 --- a/sys/modules/opus/crypto/chacha20.lua +++ b/sys/modules/opus/crypto/chacha20.lua @@ -1,7 +1,9 @@ -- Chacha20 cipher in ComputerCraft -- By Anavrins +local LZW = require('opus.crypto.lualzw') local sha2 = require('opus.crypto.sha2') +local Serializer = require('opus.crypto.serializer') local Util = require('opus.util') local ROUNDS = 20 -- Adjust this for speed tradeoff @@ -116,7 +118,7 @@ local function crypt(data, key, nonce, cntr, round) cntr = tonumber(cntr) or 1 round = tonumber(round) or 20 - local throttle = Util.throttle() + local throttle = Util.throttle(function() _syslog('throttle') end) local out = {} local state = initState(key, nonce, cntr) local blockAmt = math.floor(#data/64) @@ -132,11 +134,11 @@ local function crypt(data, key, nonce, cntr, round) out[#out+1] = bxor(block[j], ks[j]) end - if i % 1000 == 0 then + --if i % 1000 == 0 then throttle() --os.queueEvent("") --os.pullEvent("") - end + --end end return setmetatable(out, mt) end @@ -151,7 +153,8 @@ end local function encrypt(data, key) local nonce = genNonce(12) - data = textutils.serialise(data) + data = Serializer.serialize(data) + data = LZW.compress(data) key = sha2.digest(key) local ctx = crypt(data, key, nonce, 1, ROUNDS) return { nonce:toHex(), ctx:toHex() } @@ -162,7 +165,7 @@ local function decrypt(data, key) data = Util.hexToByteArray(data[2]) key = sha2.digest(key) local ptx = crypt(data, key, nonce, 1, ROUNDS) - return textutils.unserialise(tostring(ptx)) + return textutils.unserialise(LZW.decompress(tostring(ptx))) end local obj = {} diff --git a/sys/modules/opus/crypto/lualzw.lua b/sys/modules/opus/crypto/lualzw.lua new file mode 100644 index 0000000..7bdfa21 --- /dev/null +++ b/sys/modules/opus/crypto/lualzw.lua @@ -0,0 +1,164 @@ +-- see: https://github.com/Rochet2/lualzw +--[[ +MIT License + +Copyright (c) 2016 Rochet2 + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. +]] + +local char = string.char +local type = type +local sub = string.sub +local tconcat = table.concat + +local basedictcompress = {} +local basedictdecompress = {} +for i = 0, 255 do + local ic, iic = char(i), char(i, 0) + basedictcompress[ic] = iic + basedictdecompress[iic] = ic +end + +local function dictAddA(str, dict, a, b) + if a >= 256 then + a, b = 0, b+1 + if b >= 256 then + dict = {} + b = 1 + end + end + dict[str] = char(a,b) + a = a+1 + return dict, a, b +end + +local function compress(input) + if type(input) ~= "string" then + error ("string expected, got "..type(input)) + end + local len = #input + if len <= 1 then + return "u"..input + end + + local dict = {} + local a, b = 0, 1 + + local result = {"c"} + local resultlen = 1 + local n = 2 + local word = "" + for i = 1, len do + local c = sub(input, i, i) + local wc = word..c + if not (basedictcompress[wc] or dict[wc]) then + local write = basedictcompress[word] or dict[word] + if not write then + error "algorithm error, could not fetch word" + end + result[n] = write + resultlen = resultlen + #write + n = n+1 + if len <= resultlen then + return "u"..input + end + dict, a, b = dictAddA(wc, dict, a, b) + word = c + else + word = wc + end + end + result[n] = basedictcompress[word] or dict[word] + resultlen = resultlen+#result[n] + if len <= resultlen then + return "u"..input + end + return tconcat(result) +end + +local function dictAddB(str, dict, a, b) + if a >= 256 then + a, b = 0, b+1 + if b >= 256 then + dict = {} + b = 1 + end + end + dict[char(a,b)] = str + a = a+1 + return dict, a, b +end + +local function decompress(input) + if type(input) ~= "string" then + error( "string expected, got "..type(input)) + end + + if #input < 1 then + error("invalid input - not a compressed string") + end + + local control = sub(input, 1, 1) + if control == "u" then + return sub(input, 2) + elseif control ~= "c" then + error( "invalid input - not a compressed string") + end + input = sub(input, 2) + local len = #input + + if len < 2 then + error( "invalid input - not a compressed string") + end + + local dict = {} + local a, b = 0, 1 + + local result = {} + local n = 1 + local last = sub(input, 1, 2) + result[n] = basedictdecompress[last] or dict[last] + n = n+1 + for i = 3, len, 2 do + local code = sub(input, i, i+1) + local lastStr = basedictdecompress[last] or dict[last] + if not lastStr then + error( "could not find last from dict. Invalid input?") + end + local toAdd = basedictdecompress[code] or dict[code] + if toAdd then + result[n] = toAdd + n = n+1 + dict, a, b = dictAddB(lastStr..sub(toAdd, 1, 1), dict, a, b) + else + local tmp = lastStr..sub(lastStr, 1, 1) + result[n] = tmp + n = n+1 + dict, a, b = dictAddB(tmp, dict, a, b) + end + last = code + end + return tconcat(result) +end + +return { + compress = compress, + decompress = decompress, +} \ No newline at end of file diff --git a/sys/modules/opus/crypto/serializer.lua b/sys/modules/opus/crypto/serializer.lua new file mode 100644 index 0000000..db19336 --- /dev/null +++ b/sys/modules/opus/crypto/serializer.lua @@ -0,0 +1,50 @@ +local Serializer = { } + +local insert = table.insert +local format = string.format + +function Serializer.serialize(tbl) + local output = { } + + local function recurse(t) + local sType = type(t) + if sType == 'table' then + if next(t) == nil then + insert(output, '{}') + else + insert(output, '{') + local tSeen = {} + for k, v in ipairs(t) do + tSeen[k] = true + recurse(v) + insert(output, ',') + end + for k, v in pairs(t) do + if not tSeen[k] then + if type(k) == 'string' and string.match(k, '^[%a_][%a%d_]*$') then + insert(output, k .. '=') + recurse(v) + insert(output, ',') + else + insert(output, '[') + recurse(k) + insert(output, ']=') + recurse(v) + insert(output, ',') + end + end + end + insert(output, '}') + end + elseif sType == 'string' then + insert(output, format('%q', t)) + else + insert(output, tostring(t)) + end + end + + recurse(tbl) + return table.concat(output) +end + +return Serializer diff --git a/sys/modules/opus/socket.lua b/sys/modules/opus/socket.lua index 9f77f99..3eadac7 100644 --- a/sys/modules/opus/socket.lua +++ b/sys/modules/opus/socket.lua @@ -47,9 +47,6 @@ end function socketClass:write(data) if self.connected then - if self.options.ENCRYPT then - data = Crypto.encrypt({ data }, self.enckey) - end network.getTransport().write(self, { type = 'DATA', seq = self.wseq, @@ -66,19 +63,6 @@ function socketClass:ping() end end -function socketClass:setupEncryption(x) - self.rrng = Crypto.newRNG( - SHA.pbkdf2(self.sharedKey, x and "3rseed" or "4sseed", 1)) - self.wrng = Crypto.newRNG( - SHA.pbkdf2(self.sharedKey, x and "4sseed" or "3rseed", 1)) - - self.sharedKey = ECC.exchange(self.privKey, self.remotePubKey) - self.enckey = SHA.pbkdf2(self.sharedKey, "1enc", 1) - --self.hmackey = SHA.pbkdf2(self.sharedKey, "2hmac", 1) - self.rseq = self.rrng:nextInt(5) - self.wseq = self.wrng:nextInt(5) -end - function socketClass:close() if self.connected then self.transmit(self.dport, self.dhost, { @@ -121,6 +105,19 @@ local function newSocket(isLoopback) error('No ports available') end +local function setupCrypto(socket, isClient) + socket.rrng = Crypto.newRNG( + SHA.pbkdf2(socket.sharedKey, isClient and "3rseed" or "4sseed", 1)) + socket.wrng = Crypto.newRNG( + SHA.pbkdf2(socket.sharedKey, isClient and "4sseed" or "3rseed", 1)) + + socket.sharedKey = ECC.exchange(socket.privKey, socket.remotePubKey) + socket.enckey = SHA.pbkdf2(socket.sharedKey, "1enc", 1) + --self.hmackey = SHA.pbkdf2(self.sharedKey, "2hmac", 1) + socket.rseq = socket.rrng:nextInt(5) + socket.wseq = socket.wrng:nextInt(5) +end + function Socket.connect(host, port, options) if not device.wireless_modem then return false, 'Wireless modem not found', 'NOMODEM' @@ -156,7 +153,7 @@ function Socket.connect(host, port, options) socket.connected = true socket.remotePubKey = Util.hexToByteArray(msg.pk) socket.options = msg.options or { } - socket:setupEncryption(true) + setupCrypto(socket, true) network.getTransport().open(socket) return socket @@ -190,7 +187,7 @@ local function trusted(socket, msg, options) if math.abs(os.epoch('utc') - data.ts) < 4096 then socket.remotePubKey = Util.hexToByteArray(data.pk) socket.privKey, socket.pubKey = network.getKeyPair() - socket:setupEncryption() + setupCrypto(socket) return true end _G._syslog('time diff ' .. math.abs(os.epoch('utc') - data.ts)) diff --git a/sys/modules/opus/util.lua b/sys/modules/opus/util.lua index 9e2d63d..e670857 100644 --- a/sys/modules/opus/util.lua +++ b/sys/modules/opus/util.lua @@ -56,7 +56,7 @@ Util.Timer = Util.timer -- deprecate function Util.throttle(fn) local ts = os.clock() - local timeout = .095 + local timeout = .295 return function(...) local nts = os.clock() if nts > ts + timeout then