diff --git a/API/skynet.lua b/API/skynet.lua new file mode 100644 index 0000000..f4f6e43 --- /dev/null +++ b/API/skynet.lua @@ -0,0 +1,704 @@ +-- Skynet was created by Gollark +-- This version was modified to include CBOR in the same file + +local CBOR = (function() + -- Concise Binary Object Representation (CBOR) + -- RFC 7049 + + local function softreq(pkg, field) + local ok, mod = pcall(require, pkg); + if not ok then return end + if field then return mod[field]; end + return mod; + end + local dostring = function (s) + local ok, f = pcall(loadstring or load, s); -- luacheck: read globals loadstring + if ok and f then return f(); end + end + + local setmetatable, getmetatable = setmetatable, getmetatable; + local dbg_getmetatable + if debug then dbg_getmetatable = debug.getmetatable else dbg_getmetatable = getmetatable end + local assert = assert; + local error = error; + local type = type; + local pairs = pairs; + local ipairs = ipairs; + local tostring = tostring; + local s_char = string.char; + local t_concat = table.concat; + local t_sort = table.sort; + local m_floor = math.floor; + local m_abs = math.abs; + local m_huge = math.huge; + local m_max = math.max; + local maxint = math.maxinteger or 9007199254740992; + local minint = math.mininteger or -9007199254740992; + local NaN = 0/0; + local m_frexp = math.frexp; + local m_ldexp = math.ldexp or function (x, exp) return x * 2.0 ^ exp; end; + local m_type = math.type or function (n) return n % 1 == 0 and n <= maxint and n >= minint and "integer" or "float" end; + local s_pack = string.pack or softreq("struct", "pack"); + local s_unpack = string.unpack or softreq("struct", "unpack"); + local b_rshift = softreq("bit32", "rshift") or softreq("bit", "rshift") or + dostring "return function(a,b) return a >> b end" or + function (a, b) return m_max(0, m_floor(a / (2 ^ b))); end; + + -- sanity check + if s_pack and s_pack(">I2", 0) ~= "\0\0" then + s_pack = nil; + end + if s_unpack and s_unpack(">I2", "\1\2\3\4") ~= 0x102 then + s_unpack = nil; + end + + local _ENV = nil; -- luacheck: ignore 211 + + local encoder = {}; + + local function encode(obj, opts) + return encoder[type(obj)](obj, opts); + end + + -- Major types 0, 1 and length encoding for others + local function integer(num, m) + if m == 0 and num < 0 then + -- negative integer, major type 1 + num, m = - num - 1, 32; + end + if num < 24 then + return s_char(m + num); + elseif num < 2 ^ 8 then + return s_char(m + 24, num); + elseif num < 2 ^ 16 then + return s_char(m + 25, b_rshift(num, 8), num % 0x100); + elseif num < 2 ^ 32 then + return s_char(m + 26, + b_rshift(num, 24) % 0x100, + b_rshift(num, 16) % 0x100, + b_rshift(num, 8) % 0x100, + num % 0x100); + elseif num < 2 ^ 64 then + local high = m_floor(num / 2 ^ 32); + num = num % 2 ^ 32; + return s_char(m + 27, + b_rshift(high, 24) % 0x100, + b_rshift(high, 16) % 0x100, + b_rshift(high, 8) % 0x100, + high % 0x100, + b_rshift(num, 24) % 0x100, + b_rshift(num, 16) % 0x100, + b_rshift(num, 8) % 0x100, + num % 0x100); + end + error "int too large"; + end + + if s_pack then + function integer(num, m) + local fmt; + m = m or 0; + if num < 24 then + fmt, m = ">B", m + num; + elseif num < 256 then + fmt, m = ">BB", m + 24; + elseif num < 65536 then + fmt, m = ">BI2", m + 25; + elseif num < 4294967296 then + fmt, m = ">BI4", m + 26; + else + fmt, m = ">BI8", m + 27; + end + return s_pack(fmt, m, num); + end + end + + local simple_mt = {}; + function simple_mt:__tostring() return self.name or ("simple(%d)"):format(self.value); end + function simple_mt:__tocbor() return self.cbor or integer(self.value, 224); end + + local function simple(value, name, cbor) + assert(value >= 0 and value <= 255, "bad argument #1 to 'simple' (integer in range 0..255 expected)"); + return setmetatable({ value = value, name = name, cbor = cbor }, simple_mt); + end + + local tagged_mt = {}; + function tagged_mt:__tostring() return ("%d(%s)"):format(self.tag, tostring(self.value)); end + function tagged_mt:__tocbor() return integer(self.tag, 192) .. encode(self.value); end + + local function tagged(tag, value) + assert(tag >= 0, "bad argument #1 to 'tagged' (positive integer expected)"); + return setmetatable({ tag = tag, value = value }, tagged_mt); + end + + local null = simple(22, "null"); -- explicit null + local undefined = simple(23, "undefined"); -- undefined or nil + local BREAK = simple(31, "break", "\255"); + + -- Number types dispatch + function encoder.number(num) + return encoder[m_type(num)](num); + end + + -- Major types 0, 1 + function encoder.integer(num) + if num < 0 then + return integer(-1 - num, 32); + end + return integer(num, 0); + end + + -- Major type 7 + function encoder.float(num) + if num ~= num then -- NaN shortcut + return "\251\127\255\255\255\255\255\255\255"; + end + local sign = (num > 0 or 1 / num > 0) and 0 or 1; + num = m_abs(num) + if num == m_huge then + return s_char(251, sign * 128 + 128 - 1) .. "\240\0\0\0\0\0\0"; + end + local fraction, exponent = m_frexp(num) + if fraction == 0 then + return s_char(251, sign * 128) .. "\0\0\0\0\0\0\0"; + end + fraction = fraction * 2; + exponent = exponent + 1024 - 2; + if exponent <= 0 then + fraction = fraction * 2 ^ (exponent - 1) + exponent = 0; + else + fraction = fraction - 1; + end + return s_char(251, + sign * 2 ^ 7 + m_floor(exponent / 2 ^ 4) % 2 ^ 7, + exponent % 2 ^ 4 * 2 ^ 4 + + m_floor(fraction * 2 ^ 4 % 0x100), + m_floor(fraction * 2 ^ 12 % 0x100), + m_floor(fraction * 2 ^ 20 % 0x100), + m_floor(fraction * 2 ^ 28 % 0x100), + m_floor(fraction * 2 ^ 36 % 0x100), + m_floor(fraction * 2 ^ 44 % 0x100), + m_floor(fraction * 2 ^ 52 % 0x100) + ) + end + + if s_pack then + function encoder.float(num) + return s_pack(">Bd", 251, num); + end + end + + + -- Major type 2 - byte strings + function encoder.bytestring(s) + return integer(#s, 64) .. s; + end + + -- Major type 3 - UTF-8 strings + function encoder.utf8string(s) + return integer(#s, 96) .. s; + end + + function encoder.string(s) + if s:match "^[\0-\127]*$" then -- If string is entirely ASCII characters, then treat it as a UTF-8 string + return encoder.utf8string(s) + else + return encoder.bytestring(s) + end + end + + function encoder.boolean(bool) + return bool and "\245" or "\244"; + end + + encoder["nil"] = function() return "\246"; end + + function encoder.userdata(ud, opts) + local mt = dbg_getmetatable(ud); + if mt then + local encode_ud = opts and opts[mt] or mt.__tocbor; + if encode_ud then + return encode_ud(ud, opts); + end + end + error "can't encode userdata"; + end + + function encoder.table(t, opts) + local mt = getmetatable(t); + if mt then + local encode_t = opts and opts[mt] or mt.__tocbor; + if encode_t then + return encode_t(t, opts); + end + end + -- the table is encoded as an array iff when we iterate over it, + -- we see succesive integer keys starting from 1. The lua + -- language doesn't actually guarantee that this will be the case + -- when we iterate over a table with successive integer keys, but + -- due an implementation detail in PUC Rio Lua, this is what we + -- usually observe. See the Lua manual regarding the # (length) + -- operator. In the case that this does not happen, we will fall + -- back to a map with integer keys, which becomes a bit larger. + local array, map, i, p = { integer(#t, 128) }, { "\191" }, 1, 2; + local is_array = true; + for k, v in pairs(t) do + is_array = is_array and i == k; + i = i + 1; + + local encoded_v = encode(v, opts); + array[i] = encoded_v; + + map[p], p = encode(k, opts), p + 1; + map[p], p = encoded_v, p + 1; + end + -- map[p] = "\255"; + map[1] = integer(i - 1, 160); + return t_concat(is_array and array or map); + end + + -- Array or dict-only encoders, which can be set as __tocbor metamethod + function encoder.array(t, opts) + local array = { }; + for i, v in ipairs(t) do + array[i] = encode(v, opts); + end + return integer(#array, 128) .. t_concat(array); + end + + function encoder.map(t, opts) + local map, p, len = { "\191" }, 2, 0; + for k, v in pairs(t) do + map[p], p = encode(k, opts), p + 1; + map[p], p = encode(v, opts), p + 1; + len = len + 1; + end + -- map[p] = "\255"; + map[1] = integer(len, 160); + return t_concat(map); + end + encoder.dict = encoder.map; -- COMPAT + + function encoder.ordered_map(t, opts) + local map = {}; + if not t[1] then -- no predefined order + local i = 0; + for k in pairs(t) do + i = i + 1; + map[i] = k; + end + t_sort(map); + end + for i, k in ipairs(t[1] and t or map) do + map[i] = encode(k, opts) .. encode(t[k], opts); + end + return integer(#map, 160) .. t_concat(map); + end + + encoder["function"] = function () + error "can't encode function"; + end + + -- Decoder + -- Reads from a file-handle like object + local function read_bytes(fh, len) + return fh:read(len); + end + + local function read_byte(fh) + return fh:read(1):byte(); + end + + local function read_length(fh, mintyp) + if mintyp < 24 then + return mintyp; + elseif mintyp < 28 then + local out = 0; + for _ = 1, 2 ^ (mintyp - 24) do + out = out * 256 + read_byte(fh); + end + return out; + else + error "invalid length"; + end + end + + local decoder = {}; + + local function read_type(fh) + local byte = read_byte(fh); + return b_rshift(byte, 5), byte % 32; + end + + local function read_object(fh, opts) + local typ, mintyp = read_type(fh); + return decoder[typ](fh, mintyp, opts); + end + + local function read_integer(fh, mintyp) + return read_length(fh, mintyp); + end + + local function read_negative_integer(fh, mintyp) + return -1 - read_length(fh, mintyp); + end + + local function read_string(fh, mintyp) + if mintyp ~= 31 then + return read_bytes(fh, read_length(fh, mintyp)); + end + local out = {}; + local i = 1; + local v = read_object(fh); + while v ~= BREAK do + out[i], i = v, i + 1; + v = read_object(fh); + end + return t_concat(out); + end + + local function read_unicode_string(fh, mintyp) + return read_string(fh, mintyp); + end + + local function read_array(fh, mintyp, opts) + local out = {}; + if mintyp == 31 then + local i = 1; + local v = read_object(fh, opts); + while v ~= BREAK do + out[i], i = v, i + 1; + v = read_object(fh, opts); + end + else + local len = read_length(fh, mintyp); + for i = 1, len do + out[i] = read_object(fh, opts); + end + end + return out; + end + + local function read_map(fh, mintyp, opts) + local out = {}; + local k; + if mintyp == 31 then + local i = 1; + k = read_object(fh, opts); + while k ~= BREAK do + out[k], i = read_object(fh, opts), i + 1; + k = read_object(fh, opts); + end + else + local len = read_length(fh, mintyp); + for _ = 1, len do + k = read_object(fh, opts); + out[k] = read_object(fh, opts); + end + end + return out; + end + + local tagged_decoders = {}; + + local function read_semantic(fh, mintyp, opts) + local tag = read_length(fh, mintyp); + local value = read_object(fh, opts); + local postproc = opts and opts[tag] or tagged_decoders[tag]; + if postproc then + return postproc(value); + end + return tagged(tag, value); + end + + local function read_half_float(fh) + local exponent = read_byte(fh); + local fraction = read_byte(fh); + local sign = exponent < 128 and 1 or -1; -- sign is highest bit + + fraction = fraction + (exponent * 256) % 1024; -- copy two(?) bits from exponent to fraction + exponent = b_rshift(exponent, 2) % 32; -- remove sign bit and two low bits from fraction; + + if exponent == 0 then + return sign * m_ldexp(fraction, -24); + elseif exponent ~= 31 then + return sign * m_ldexp(fraction + 1024, exponent - 25); + elseif fraction == 0 then + return sign * m_huge; + else + return NaN; + end + end + + local function read_float(fh) + local exponent = read_byte(fh); + local fraction = read_byte(fh); + local sign = exponent < 128 and 1 or -1; -- sign is highest bit + exponent = exponent * 2 % 256 + b_rshift(fraction, 7); + fraction = fraction % 128; + fraction = fraction * 256 + read_byte(fh); + fraction = fraction * 256 + read_byte(fh); + + if exponent == 0 then + return sign * m_ldexp(exponent, -149); + elseif exponent ~= 0xff then + return sign * m_ldexp(fraction + 2 ^ 23, exponent - 150); + elseif fraction == 0 then + return sign * m_huge; + else + return NaN; + end + end + + local function read_double(fh) + local exponent = read_byte(fh); + local fraction = read_byte(fh); + local sign = exponent < 128 and 1 or -1; -- sign is highest bit + + exponent = exponent % 128 * 16 + b_rshift(fraction, 4); + fraction = fraction % 16; + fraction = fraction * 256 + read_byte(fh); + fraction = fraction * 256 + read_byte(fh); + fraction = fraction * 256 + read_byte(fh); + fraction = fraction * 256 + read_byte(fh); + fraction = fraction * 256 + read_byte(fh); + fraction = fraction * 256 + read_byte(fh); + + if exponent == 0 then + return sign * m_ldexp(exponent, -149); + elseif exponent ~= 0xff then + return sign * m_ldexp(fraction + 2 ^ 52, exponent - 1075); + elseif fraction == 0 then + return sign * m_huge; + else + return NaN; + end + end + + + if s_unpack then + function read_float(fh) return s_unpack(">f", read_bytes(fh, 4)) end + function read_double(fh) return s_unpack(">d", read_bytes(fh, 8)) end + end + + local function read_simple(fh, value, opts) + if value == 24 then + value = read_byte(fh); + end + if value == 20 then + return false; + elseif value == 21 then + return true; + elseif value == 22 then + return null; + elseif value == 23 then + return undefined; + elseif value == 25 then + return read_half_float(fh); + elseif value == 26 then + return read_float(fh); + elseif value == 27 then + return read_double(fh); + elseif value == 31 then + return BREAK; + end + if opts and opts.simple then + return opts.simple(value); + end + return simple(value); + end + + decoder[0] = read_integer; + decoder[1] = read_negative_integer; + decoder[2] = read_string; + decoder[3] = read_unicode_string; + decoder[4] = read_array; + decoder[5] = read_map; + decoder[6] = read_semantic; + decoder[7] = read_simple; + + -- opts.more(n) -> want more data + -- opts.simple -> decode simple value + -- opts[int] -> tagged decoder + local function decode(s, opts) + local fh = {}; + local pos = 1; + + local more; + if type(opts) == "function" then + more = opts; + elseif type(opts) == "table" then + more = opts.more; + elseif opts ~= nil then + error(("bad argument #2 to 'decode' (function or table expected, got %s)"):format(type(opts))); + end + if type(more) ~= "function" then + function more() + error "input too short"; + end + end + + function fh:read(bytes) + local ret = s:sub(pos, pos + bytes - 1); + if #ret < bytes then + ret = more(bytes - #ret, fh, opts); + if ret then self:write(ret); end + return self:read(bytes); + end + pos = pos + bytes; + return ret; + end + + function fh:write(bytes) -- luacheck: no self + s = s .. bytes; + if pos > 256 then + s = s:sub(pos + 1); + pos = 1; + end + return #bytes; + end + + return read_object(fh, opts); + end + + return { + -- en-/decoder functions + encode = encode; + decode = decode; + decode_file = read_object; + + -- tables of per-type en-/decoders + type_encoders = encoder; + type_decoders = decoder; + + -- special treatment for tagged values + tagged_decoders = tagged_decoders; + + -- constructors for annotated types + simple = simple; + tagged = tagged; + + -- pre-defined simple values + null = null; + undefined = undefined; + }; +end)() + +local skynet = { + server = "wss://osmarks.tk/skynet2/connect/", + socket = nil, + open_channels = {}, + CBOR = CBOR +} + +function skynet.connect(force) + if not skynet.socket or force then + -- If we already have a socket and are throwing it away, close old one. + if skynet.socket then skynet.socket.close() end + local sock = http.websocket(skynet.server) + if not sock then error "Skynet server unavailable, broken or running newer protocol version." end + skynet.socket = sock + + for _, c in pairs(skynet.open_channels) do + skynet.open(c) + end + end +end + +function skynet.disconnect() + if skynet.socket then skynet.socket.close() end +end + +local function value_in_table(t, v) + for k, tv in pairs(t) do if tv == v then return true, k end end + return false +end + +local function send_raw(data, tries) + local tries = tries or 0 + skynet.connect() + local ok, err = pcall(skynet.socket.send, CBOR.encode(data), true) -- send in binary mode + if not ok then + if tries > 0 then sleep(tries) end + if tries > 5 then error("Max reconnection attempts exceeded. " .. err) end + pcall(skynet.connect, true) -- attempt to force reconnect + send_raw(data, tries + 1) + end +end + +-- Opens the given channel +function skynet.open(channel) + -- Don't send unnecessary channel-open messages + if not value_in_table(skynet.open_channels, channel) then + send_raw { + "open", + channel + } + table.insert(skynet.open_channels, channel) + end +end + +local function recv_one(filter) + skynet.connect() + while true do + local contents = skynet.socket.receive() + local result = CBOR.decode(contents) + if type(result) == "table" then + if result[1] == "error" then error(result[2] .. ": " .. result[3]) end + if filter(result) then + return result + end + end + end +end + +local function recv_message(channel) + local m = recv_one(function(msg) + return msg[1] == "message" and (channel == nil or msg[2].channel == channel) + end) + return m[2].channel, m[2].message, m[2] +end + +function skynet.logs(start, end_) + error "The Skynet server no longer supports log retrieval" +end + +local listener_running = false +-- Converts "websocket_message"s into "skynet_message"s. +function skynet.listen(force_run) + local function run() + while true do + os.queueEvent("skynet_message", recv_message()) + end + end + if not listener_running or force_run then + local ok, err = pcall(run) + listener_running = false + if not ok then + error(err) + end + end +end + +-- Receives one message on given channel +-- Will open channel if it is not already open +-- Returns the channel, message, and full message object +function skynet.receive(channel) + if channel then skynet.open(channel) end + return recv_message(channel) +end + +-- Send given data on given channel +-- Can accept a third argument - an object of extra metadata to send +function skynet.send(channel, data, full) + local obj = full or {} + obj.message = data + obj.channel = channel + send_raw { + "message", + obj + } +end + +return skynet