-- From https://www.zash.se/lua-cbor.html --[[ Copyright (c) 2014-2015 Kim Alvefur Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ]] -- Concise Binary Object Representation (CBOR) -- RFC 7049 local function softreq(pkg, field) local ok, mod = pcall(require, pkg); if not ok then return end if field then return mod[field]; end return mod; end local dostring = function (s) local ok, f = pcall(loadstring or load, s); -- luacheck: read globals loadstring if ok and f then return f(); end end local setmetatable = setmetatable; local getmetatable = getmetatable; local dbg_getmetatable if debug then dbg_getmetatable = debug.getmetatable else dbg_getmetatable = getmetatable end local assert = assert; local error = error; local type = type; local pairs = pairs; local ipairs = ipairs; local tostring = tostring; local s_char = string.char; local t_concat = table.concat; local t_sort = table.sort; local m_floor = math.floor; local m_abs = math.abs; local m_huge = math.huge; local m_max = math.max; local maxint = math.maxinteger or 9007199254740992; local minint = math.mininteger or -9007199254740992; local NaN = 0/0; local m_frexp = math.frexp; local m_ldexp = math.ldexp or function (x, exp) return x * 2.0 ^ exp; end; local m_type = math.type or function (n) return n % 1 == 0 and n <= maxint and n >= minint and "integer" or "float" end; local s_pack = string.pack or softreq("struct", "pack"); local s_unpack = string.unpack or softreq("struct", "unpack"); local b_rshift = softreq("bit32", "rshift") or softreq("bit", "rshift") or dostring "return function(a,b) return a >> b end" or function (a, b) return m_max(0, m_floor(a / (2 ^ b))); end; -- sanity check if s_pack and s_pack(">I2", 0) ~= "\0\0" then s_pack = nil; end if s_unpack and s_unpack(">I2", "\1\2\3\4") ~= 0x102 then s_unpack = nil; end local _ENV = nil; -- luacheck: ignore 211 local encoder = {}; local function encode(obj, opts) return encoder[type(obj)](obj, opts); end -- Major types 0, 1 and length encoding for others local function integer(num, m) if m == 0 and num < 0 then -- negative integer, major type 1 num, m = - num - 1, 32; end if num < 24 then return s_char(m + num); elseif num < 2 ^ 8 then return s_char(m + 24, num); elseif num < 2 ^ 16 then return s_char(m + 25, b_rshift(num, 8), num % 0x100); elseif num < 2 ^ 32 then return s_char(m + 26, b_rshift(num, 24) % 0x100, b_rshift(num, 16) % 0x100, b_rshift(num, 8) % 0x100, num % 0x100); elseif num < 2 ^ 64 then local high = m_floor(num / 2 ^ 32); num = num % 2 ^ 32; return s_char(m + 27, b_rshift(high, 24) % 0x100, b_rshift(high, 16) % 0x100, b_rshift(high, 8) % 0x100, high % 0x100, b_rshift(num, 24) % 0x100, b_rshift(num, 16) % 0x100, b_rshift(num, 8) % 0x100, num % 0x100); end error "int too large"; end if s_pack then function integer(num, m) local fmt; m = m or 0; if num < 24 then fmt, m = ">B", m + num; elseif num < 256 then fmt, m = ">BB", m + 24; elseif num < 65536 then fmt, m = ">BI2", m + 25; elseif num < 4294967296 then fmt, m = ">BI4", m + 26; else fmt, m = ">BI8", m + 27; end return s_pack(fmt, m, num); end end local simple_mt = {}; function simple_mt:__tostring() return self.name or ("simple(%d)"):format(self.value); end function simple_mt:__tocbor() return self.cbor or integer(self.value, 224); end local function simple(value, name, cbor) assert(value >= 0 and value <= 255, "bad argument #1 to 'simple' (integer in range 0..255 expected)"); return setmetatable({ value = value, name = name, cbor = cbor }, simple_mt); end local tagged_mt = {}; function tagged_mt:__tostring() return ("%d(%s)"):format(self.tag, tostring(self.value)); end function tagged_mt:__tocbor() return integer(self.tag, 192) .. encode(self.value); end local function tagged(tag, value) assert(tag >= 0, "bad argument #1 to 'tagged' (positive integer expected)"); return setmetatable({ tag = tag, value = value }, tagged_mt); end local null = simple(22, "null"); -- explicit null local undefined = simple(23, "undefined"); -- undefined or nil local BREAK = simple(31, "break", "\255"); -- Number types dispatch function encoder.number(num) return encoder[m_type(num)](num); end -- Major types 0, 1 function encoder.integer(num) if num < 0 then return integer(-1 - num, 32); end return integer(num, 0); end -- Major type 7 function encoder.float(num) if num ~= num then -- NaN shortcut return "\251\127\255\255\255\255\255\255\255"; end local sign = (num > 0 or 1 / num > 0) and 0 or 1; num = m_abs(num) if num == m_huge then return s_char(251, sign * 128 + 128 - 1) .. "\240\0\0\0\0\0\0"; end local fraction, exponent = m_frexp(num) if fraction == 0 then return s_char(251, sign * 128) .. "\0\0\0\0\0\0\0"; end fraction = fraction * 2; exponent = exponent + 1024 - 2; if exponent <= 0 then fraction = fraction * 2 ^ (exponent - 1) exponent = 0; else fraction = fraction - 1; end return s_char(251, sign * 2 ^ 7 + m_floor(exponent / 2 ^ 4) % 2 ^ 7, exponent % 2 ^ 4 * 2 ^ 4 + m_floor(fraction * 2 ^ 4 % 0x100), m_floor(fraction * 2 ^ 12 % 0x100), m_floor(fraction * 2 ^ 20 % 0x100), m_floor(fraction * 2 ^ 28 % 0x100), m_floor(fraction * 2 ^ 36 % 0x100), m_floor(fraction * 2 ^ 44 % 0x100), m_floor(fraction * 2 ^ 52 % 0x100) ) end if s_pack then function encoder.float(num) return s_pack(">Bd", 251, num); end end -- Major type 2 - byte strings function encoder.bytestring(s) return integer(#s, 64) .. s; end -- Major type 3 - UTF-8 strings function encoder.utf8string(s) return integer(#s, 96) .. s; end function encoder.string(s) if s:match "^[\0-\127]*$" then -- If string is entirely ASCII characters, then treat it as a UTF-8 string return encoder.utf8string(s) else return encoder.bytestring(s) end end function encoder.boolean(bool) return bool and "\245" or "\244"; end encoder["nil"] = function() return "\246"; end function encoder.userdata(ud, opts) local mt = dbg_getmetatable(ud); if mt then local encode_ud = opts and opts[mt] or mt.__tocbor; if encode_ud then return encode_ud(ud, opts); end end error "can't encode userdata"; end function encoder.table(t, opts) local mt = getmetatable(t); if mt then local encode_t = opts and opts[mt] or mt.__tocbor; if encode_t then return encode_t(t, opts); end end -- the table is encoded as an array iff when we iterate over it, -- we see succesive integer keys starting from 1. The lua -- language doesn't actually guarantee that this will be the case -- when we iterate over a table with successive integer keys, but -- due an implementation detail in PUC Rio Lua, this is what we -- usually observe. See the Lua manual regarding the # (length) -- operator. In the case that this does not happen, we will fall -- back to a map with integer keys, which becomes a bit larger. local array, map, i, p = { integer(#t, 128) }, { "\191" }, 1, 2; local is_array = true; for k, v in pairs(t) do is_array = is_array and i == k; i = i + 1; local encoded_v = encode(v, opts); array[i] = encoded_v; map[p], p = encode(k, opts), p + 1; map[p], p = encoded_v, p + 1; end -- map[p] = "\255"; map[1] = integer(i - 1, 160); return t_concat(is_array and array or map); end -- Array or dict-only encoders, which can be set as __tocbor metamethod function encoder.array(t, opts) local array = { }; for i, v in ipairs(t) do array[i] = encode(v, opts); end return integer(#array, 128) .. t_concat(array); end function encoder.map(t, opts) local map, p, len = { "\191" }, 2, 0; for k, v in pairs(t) do map[p], p = encode(k, opts), p + 1; map[p], p = encode(v, opts), p + 1; len = len + 1; end -- map[p] = "\255"; map[1] = integer(len, 160); return t_concat(map); end encoder.dict = encoder.map; -- COMPAT function encoder.ordered_map(t, opts) local map = {}; if not t[1] then -- no predefined order local i = 0; for k in pairs(t) do i = i + 1; map[i] = k; end t_sort(map); end for i, k in ipairs(t[1] and t or map) do map[i] = encode(k, opts) .. encode(t[k], opts); end return integer(#map, 160) .. t_concat(map); end encoder["function"] = function () error "can't encode function"; end -- Decoder -- Reads from a file-handle like object local function read_bytes(fh, len) return fh:read(len); end local function read_byte(fh) return fh:read(1):byte(); end local function read_length(fh, mintyp) if mintyp < 24 then return mintyp; elseif mintyp < 28 then local out = 0; for _ = 1, 2 ^ (mintyp - 24) do out = out * 256 + read_byte(fh); end return out; else error "invalid length"; end end local decoder = {}; local function read_type(fh) local byte = read_byte(fh); return b_rshift(byte, 5), byte % 32; end local function read_object(fh, opts) local typ, mintyp = read_type(fh); return decoder[typ](fh, mintyp, opts); end local function read_integer(fh, mintyp) return read_length(fh, mintyp); end local function read_negative_integer(fh, mintyp) return -1 - read_length(fh, mintyp); end local function read_string(fh, mintyp) if mintyp ~= 31 then return read_bytes(fh, read_length(fh, mintyp)); end local out = {}; local i = 1; local v = read_object(fh); while v ~= BREAK do out[i], i = v, i + 1; v = read_object(fh); end return t_concat(out); end local function read_unicode_string(fh, mintyp) return read_string(fh, mintyp); -- local str = read_string(fh, mintyp); -- if have_utf8 and not utf8.len(str) then -- TODO How to handle this? -- end -- return str; end local function read_array(fh, mintyp, opts) local out = {}; if mintyp == 31 then local i = 1; local v = read_object(fh, opts); while v ~= BREAK do out[i], i = v, i + 1; v = read_object(fh, opts); end else local len = read_length(fh, mintyp); for i = 1, len do out[i] = read_object(fh, opts); end end return out; end local function read_map(fh, mintyp, opts) local out = {}; local k; if mintyp == 31 then local i = 1; k = read_object(fh, opts); while k ~= BREAK do out[k], i = read_object(fh, opts), i + 1; k = read_object(fh, opts); end else local len = read_length(fh, mintyp); for _ = 1, len do k = read_object(fh, opts); out[k] = read_object(fh, opts); end end return out; end local tagged_decoders = {}; local function read_semantic(fh, mintyp, opts) local tag = read_length(fh, mintyp); local value = read_object(fh, opts); local postproc = opts and opts[tag] or tagged_decoders[tag]; if postproc then return postproc(value); end return tagged(tag, value); end local function read_half_float(fh) local exponent = read_byte(fh); local fraction = read_byte(fh); local sign = exponent < 128 and 1 or -1; -- sign is highest bit fraction = fraction + (exponent * 256) % 1024; -- copy two(?) bits from exponent to fraction exponent = b_rshift(exponent, 2) % 32; -- remove sign bit and two low bits from fraction; if exponent == 0 then return sign * m_ldexp(fraction, -24); elseif exponent ~= 31 then return sign * m_ldexp(fraction + 1024, exponent - 25); elseif fraction == 0 then return sign * m_huge; else return NaN; end end local function read_float(fh) local exponent = read_byte(fh); local fraction = read_byte(fh); local sign = exponent < 128 and 1 or -1; -- sign is highest bit exponent = exponent * 2 % 256 + b_rshift(fraction, 7); fraction = fraction % 128; fraction = fraction * 256 + read_byte(fh); fraction = fraction * 256 + read_byte(fh); if exponent == 0 then return sign * m_ldexp(exponent, -149); elseif exponent ~= 0xff then return sign * m_ldexp(fraction + 2 ^ 23, exponent - 150); elseif fraction == 0 then return sign * m_huge; else return NaN; end end local function read_double(fh) local exponent = read_byte(fh); local fraction = read_byte(fh); local sign = exponent < 128 and 1 or -1; -- sign is highest bit exponent = exponent % 128 * 16 + b_rshift(fraction, 4); fraction = fraction % 16; fraction = fraction * 256 + read_byte(fh); fraction = fraction * 256 + read_byte(fh); fraction = fraction * 256 + read_byte(fh); fraction = fraction * 256 + read_byte(fh); fraction = fraction * 256 + read_byte(fh); fraction = fraction * 256 + read_byte(fh); if exponent == 0 then return sign * m_ldexp(exponent, -149); elseif exponent ~= 0xff then return sign * m_ldexp(fraction + 2 ^ 52, exponent - 1075); elseif fraction == 0 then return sign * m_huge; else return NaN; end end if s_unpack then function read_float(fh) return s_unpack(">f", read_bytes(fh, 4)) end function read_double(fh) return s_unpack(">d", read_bytes(fh, 8)) end end local function read_simple(fh, value, opts) if value == 24 then value = read_byte(fh); end if value == 20 then return false; elseif value == 21 then return true; elseif value == 22 then return null; elseif value == 23 then return undefined; elseif value == 25 then return read_half_float(fh); elseif value == 26 then return read_float(fh); elseif value == 27 then return read_double(fh); elseif value == 31 then return BREAK; end if opts and opts.simple then return opts.simple(value); end return simple(value); end decoder[0] = read_integer; decoder[1] = read_negative_integer; decoder[2] = read_string; decoder[3] = read_unicode_string; decoder[4] = read_array; decoder[5] = read_map; decoder[6] = read_semantic; decoder[7] = read_simple; -- opts.more(n) -> want more data -- opts.simple -> decode simple value -- opts[int] -> tagged decoder local function decode(s, opts) local fh = {}; local pos = 1; local more; if type(opts) == "function" then more = opts; elseif type(opts) == "table" then more = opts.more; elseif opts ~= nil then error(("bad argument #2 to 'decode' (function or table expected, got %s)"):format(type(opts))); end if type(more) ~= "function" then function more() error "input too short"; end end function fh:read(bytes) local ret = s:sub(pos, pos + bytes - 1); if #ret < bytes then ret = more(bytes - #ret, fh, opts); if ret then self:write(ret); end return self:read(bytes); end pos = pos + bytes; return ret; end function fh:write(bytes) -- luacheck: no self s = s .. bytes; if pos > 256 then s = s:sub(pos + 1); pos = 1; end return #bytes; end return read_object(fh, opts); end return { -- en-/decoder functions encode = encode; decode = decode; decode_file = read_object; -- tables of per-type en-/decoders type_encoders = encoder; type_decoders = decoder; -- special treatment for tagged values tagged_decoders = tagged_decoders; -- constructors for annotated types simple = simple; tagged = tagged; -- pre-defined simple values null = null; undefined = undefined; };