potatOS/src/xlib/00_cbor.lua

611 lines
15 KiB
Lua

-- From https://www.zash.se/lua-cbor.html
--[[
Copyright (c) 2014-2015 Kim Alvefur
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
]]
-- Concise Binary Object Representation (CBOR)
-- RFC 7049
local function softreq(pkg, field)
local ok, mod = pcall(require, pkg);
if not ok then return end
if field then return mod[field]; end
return mod;
end
local dostring = function (s)
local ok, f = pcall(loadstring or load, s); -- luacheck: read globals loadstring
if ok and f then return f(); end
end
local setmetatable = setmetatable;
local getmetatable = getmetatable;
local dbg_getmetatable
if debug then dbg_getmetatable = debug.getmetatable else dbg_getmetatable = getmetatable end
local assert = assert;
local error = error;
local type = type;
local pairs = pairs;
local ipairs = ipairs;
local tostring = tostring;
local s_char = string.char;
local t_concat = table.concat;
local t_sort = table.sort;
local m_floor = math.floor;
local m_abs = math.abs;
local m_huge = math.huge;
local m_max = math.max;
local maxint = math.maxinteger or 9007199254740992;
local minint = math.mininteger or -9007199254740992;
local NaN = 0/0;
local m_frexp = math.frexp;
local m_ldexp = math.ldexp or function (x, exp) return x * 2.0 ^ exp; end;
local m_type = math.type or function (n) return n % 1 == 0 and n <= maxint and n >= minint and "integer" or "float" end;
local s_pack = string.pack or softreq("struct", "pack");
local s_unpack = string.unpack or softreq("struct", "unpack");
local b_rshift = softreq("bit32", "rshift") or softreq("bit", "rshift") or
dostring "return function(a,b) return a >> b end" or
function (a, b) return m_max(0, m_floor(a / (2 ^ b))); end;
-- sanity check
if s_pack and s_pack(">I2", 0) ~= "\0\0" then
s_pack = nil;
end
if s_unpack and s_unpack(">I2", "\1\2\3\4") ~= 0x102 then
s_unpack = nil;
end
local _ENV = nil; -- luacheck: ignore 211
local encoder = {};
local function encode(obj, opts)
return encoder[type(obj)](obj, opts);
end
-- Major types 0, 1 and length encoding for others
local function integer(num, m)
if m == 0 and num < 0 then
-- negative integer, major type 1
num, m = - num - 1, 32;
end
if num < 24 then
return s_char(m + num);
elseif num < 2 ^ 8 then
return s_char(m + 24, num);
elseif num < 2 ^ 16 then
return s_char(m + 25, b_rshift(num, 8), num % 0x100);
elseif num < 2 ^ 32 then
return s_char(m + 26,
b_rshift(num, 24) % 0x100,
b_rshift(num, 16) % 0x100,
b_rshift(num, 8) % 0x100,
num % 0x100);
elseif num < 2 ^ 64 then
local high = m_floor(num / 2 ^ 32);
num = num % 2 ^ 32;
return s_char(m + 27,
b_rshift(high, 24) % 0x100,
b_rshift(high, 16) % 0x100,
b_rshift(high, 8) % 0x100,
high % 0x100,
b_rshift(num, 24) % 0x100,
b_rshift(num, 16) % 0x100,
b_rshift(num, 8) % 0x100,
num % 0x100);
end
error "int too large";
end
if s_pack then
function integer(num, m)
local fmt;
m = m or 0;
if num < 24 then
fmt, m = ">B", m + num;
elseif num < 256 then
fmt, m = ">BB", m + 24;
elseif num < 65536 then
fmt, m = ">BI2", m + 25;
elseif num < 4294967296 then
fmt, m = ">BI4", m + 26;
else
fmt, m = ">BI8", m + 27;
end
return s_pack(fmt, m, num);
end
end
local simple_mt = {};
function simple_mt:__tostring() return self.name or ("simple(%d)"):format(self.value); end
function simple_mt:__tocbor() return self.cbor or integer(self.value, 224); end
local function simple(value, name, cbor)
assert(value >= 0 and value <= 255, "bad argument #1 to 'simple' (integer in range 0..255 expected)");
return setmetatable({ value = value, name = name, cbor = cbor }, simple_mt);
end
local tagged_mt = {};
function tagged_mt:__tostring() return ("%d(%s)"):format(self.tag, tostring(self.value)); end
function tagged_mt:__tocbor() return integer(self.tag, 192) .. encode(self.value); end
local function tagged(tag, value)
assert(tag >= 0, "bad argument #1 to 'tagged' (positive integer expected)");
return setmetatable({ tag = tag, value = value }, tagged_mt);
end
local null = simple(22, "null"); -- explicit null
local undefined = simple(23, "undefined"); -- undefined or nil
local BREAK = simple(31, "break", "\255");
-- Number types dispatch
function encoder.number(num)
return encoder[m_type(num)](num);
end
-- Major types 0, 1
function encoder.integer(num)
if num < 0 then
return integer(-1 - num, 32);
end
return integer(num, 0);
end
-- Major type 7
function encoder.float(num)
if num ~= num then -- NaN shortcut
return "\251\127\255\255\255\255\255\255\255";
end
local sign = (num > 0 or 1 / num > 0) and 0 or 1;
num = m_abs(num)
if num == m_huge then
return s_char(251, sign * 128 + 128 - 1) .. "\240\0\0\0\0\0\0";
end
local fraction, exponent = m_frexp(num)
if fraction == 0 then
return s_char(251, sign * 128) .. "\0\0\0\0\0\0\0";
end
fraction = fraction * 2;
exponent = exponent + 1024 - 2;
if exponent <= 0 then
fraction = fraction * 2 ^ (exponent - 1)
exponent = 0;
else
fraction = fraction - 1;
end
return s_char(251,
sign * 2 ^ 7 + m_floor(exponent / 2 ^ 4) % 2 ^ 7,
exponent % 2 ^ 4 * 2 ^ 4 +
m_floor(fraction * 2 ^ 4 % 0x100),
m_floor(fraction * 2 ^ 12 % 0x100),
m_floor(fraction * 2 ^ 20 % 0x100),
m_floor(fraction * 2 ^ 28 % 0x100),
m_floor(fraction * 2 ^ 36 % 0x100),
m_floor(fraction * 2 ^ 44 % 0x100),
m_floor(fraction * 2 ^ 52 % 0x100)
)
end
if s_pack then
function encoder.float(num)
return s_pack(">Bd", 251, num);
end
end
-- Major type 2 - byte strings
function encoder.bytestring(s)
return integer(#s, 64) .. s;
end
-- Major type 3 - UTF-8 strings
function encoder.utf8string(s)
return integer(#s, 96) .. s;
end
function encoder.string(s)
if s:match "^[\0-\127]*$" then -- If string is entirely ASCII characters, then treat it as a UTF-8 string
return encoder.utf8string(s)
else
return encoder.bytestring(s)
end
end
function encoder.boolean(bool)
return bool and "\245" or "\244";
end
encoder["nil"] = function() return "\246"; end
function encoder.userdata(ud, opts)
local mt = dbg_getmetatable(ud);
if mt then
local encode_ud = opts and opts[mt] or mt.__tocbor;
if encode_ud then
return encode_ud(ud, opts);
end
end
error "can't encode userdata";
end
function encoder.table(t, opts)
local mt = getmetatable(t);
if mt then
local encode_t = opts and opts[mt] or mt.__tocbor;
if encode_t then
return encode_t(t, opts);
end
end
-- the table is encoded as an array iff when we iterate over it,
-- we see succesive integer keys starting from 1. The lua
-- language doesn't actually guarantee that this will be the case
-- when we iterate over a table with successive integer keys, but
-- due an implementation detail in PUC Rio Lua, this is what we
-- usually observe. See the Lua manual regarding the # (length)
-- operator. In the case that this does not happen, we will fall
-- back to a map with integer keys, which becomes a bit larger.
local array, map, i, p = { integer(#t, 128) }, { "\191" }, 1, 2;
local is_array = true;
for k, v in pairs(t) do
is_array = is_array and i == k;
i = i + 1;
local encoded_v = encode(v, opts);
array[i] = encoded_v;
map[p], p = encode(k, opts), p + 1;
map[p], p = encoded_v, p + 1;
end
-- map[p] = "\255";
map[1] = integer(i - 1, 160);
return t_concat(is_array and array or map);
end
-- Array or dict-only encoders, which can be set as __tocbor metamethod
function encoder.array(t, opts)
local array = { };
for i, v in ipairs(t) do
array[i] = encode(v, opts);
end
return integer(#array, 128) .. t_concat(array);
end
function encoder.map(t, opts)
local map, p, len = { "\191" }, 2, 0;
for k, v in pairs(t) do
map[p], p = encode(k, opts), p + 1;
map[p], p = encode(v, opts), p + 1;
len = len + 1;
end
-- map[p] = "\255";
map[1] = integer(len, 160);
return t_concat(map);
end
encoder.dict = encoder.map; -- COMPAT
function encoder.ordered_map(t, opts)
local map = {};
if not t[1] then -- no predefined order
local i = 0;
for k in pairs(t) do
i = i + 1;
map[i] = k;
end
t_sort(map);
end
for i, k in ipairs(t[1] and t or map) do
map[i] = encode(k, opts) .. encode(t[k], opts);
end
return integer(#map, 160) .. t_concat(map);
end
encoder["function"] = function ()
error "can't encode function";
end
-- Decoder
-- Reads from a file-handle like object
local function read_bytes(fh, len)
return fh:read(len);
end
local function read_byte(fh)
return fh:read(1):byte();
end
local function read_length(fh, mintyp)
if mintyp < 24 then
return mintyp;
elseif mintyp < 28 then
local out = 0;
for _ = 1, 2 ^ (mintyp - 24) do
out = out * 256 + read_byte(fh);
end
return out;
else
error "invalid length";
end
end
local decoder = {};
local function read_type(fh)
local byte = read_byte(fh);
return b_rshift(byte, 5), byte % 32;
end
local function read_object(fh, opts)
local typ, mintyp = read_type(fh);
return decoder[typ](fh, mintyp, opts);
end
local function read_integer(fh, mintyp)
return read_length(fh, mintyp);
end
local function read_negative_integer(fh, mintyp)
return -1 - read_length(fh, mintyp);
end
local function read_string(fh, mintyp)
if mintyp ~= 31 then
return read_bytes(fh, read_length(fh, mintyp));
end
local out = {};
local i = 1;
local v = read_object(fh);
while v ~= BREAK do
out[i], i = v, i + 1;
v = read_object(fh);
end
return t_concat(out);
end
local function read_unicode_string(fh, mintyp)
return read_string(fh, mintyp);
-- local str = read_string(fh, mintyp);
-- if have_utf8 and not utf8.len(str) then
-- TODO How to handle this?
-- end
-- return str;
end
local function read_array(fh, mintyp, opts)
local out = {};
if mintyp == 31 then
local i = 1;
local v = read_object(fh, opts);
while v ~= BREAK do
out[i], i = v, i + 1;
v = read_object(fh, opts);
end
else
local len = read_length(fh, mintyp);
for i = 1, len do
out[i] = read_object(fh, opts);
end
end
return out;
end
local function read_map(fh, mintyp, opts)
local out = {};
local k;
if mintyp == 31 then
local i = 1;
k = read_object(fh, opts);
while k ~= BREAK do
out[k], i = read_object(fh, opts), i + 1;
k = read_object(fh, opts);
end
else
local len = read_length(fh, mintyp);
for _ = 1, len do
k = read_object(fh, opts);
out[k] = read_object(fh, opts);
end
end
return out;
end
local tagged_decoders = {};
local function read_semantic(fh, mintyp, opts)
local tag = read_length(fh, mintyp);
local value = read_object(fh, opts);
local postproc = opts and opts[tag] or tagged_decoders[tag];
if postproc then
return postproc(value);
end
return tagged(tag, value);
end
local function read_half_float(fh)
local exponent = read_byte(fh);
local fraction = read_byte(fh);
local sign = exponent < 128 and 1 or -1; -- sign is highest bit
fraction = fraction + (exponent * 256) % 1024; -- copy two(?) bits from exponent to fraction
exponent = b_rshift(exponent, 2) % 32; -- remove sign bit and two low bits from fraction;
if exponent == 0 then
return sign * m_ldexp(fraction, -24);
elseif exponent ~= 31 then
return sign * m_ldexp(fraction + 1024, exponent - 25);
elseif fraction == 0 then
return sign * m_huge;
else
return NaN;
end
end
local function read_float(fh)
local exponent = read_byte(fh);
local fraction = read_byte(fh);
local sign = exponent < 128 and 1 or -1; -- sign is highest bit
exponent = exponent * 2 % 256 + b_rshift(fraction, 7);
fraction = fraction % 128;
fraction = fraction * 256 + read_byte(fh);
fraction = fraction * 256 + read_byte(fh);
if exponent == 0 then
return sign * m_ldexp(exponent, -149);
elseif exponent ~= 0xff then
return sign * m_ldexp(fraction + 2 ^ 23, exponent - 150);
elseif fraction == 0 then
return sign * m_huge;
else
return NaN;
end
end
local function read_double(fh)
local exponent = read_byte(fh);
local fraction = read_byte(fh);
local sign = exponent < 128 and 1 or -1; -- sign is highest bit
exponent = exponent % 128 * 16 + b_rshift(fraction, 4);
fraction = fraction % 16;
fraction = fraction * 256 + read_byte(fh);
fraction = fraction * 256 + read_byte(fh);
fraction = fraction * 256 + read_byte(fh);
fraction = fraction * 256 + read_byte(fh);
fraction = fraction * 256 + read_byte(fh);
fraction = fraction * 256 + read_byte(fh);
if exponent == 0 then
return sign * m_ldexp(exponent, -149);
elseif exponent ~= 0xff then
return sign * m_ldexp(fraction + 2 ^ 52, exponent - 1075);
elseif fraction == 0 then
return sign * m_huge;
else
return NaN;
end
end
if s_unpack then
function read_float(fh) return s_unpack(">f", read_bytes(fh, 4)) end
function read_double(fh) return s_unpack(">d", read_bytes(fh, 8)) end
end
local function read_simple(fh, value, opts)
if value == 24 then
value = read_byte(fh);
end
if value == 20 then
return false;
elseif value == 21 then
return true;
elseif value == 22 then
return null;
elseif value == 23 then
return undefined;
elseif value == 25 then
return read_half_float(fh);
elseif value == 26 then
return read_float(fh);
elseif value == 27 then
return read_double(fh);
elseif value == 31 then
return BREAK;
end
if opts and opts.simple then
return opts.simple(value);
end
return simple(value);
end
decoder[0] = read_integer;
decoder[1] = read_negative_integer;
decoder[2] = read_string;
decoder[3] = read_unicode_string;
decoder[4] = read_array;
decoder[5] = read_map;
decoder[6] = read_semantic;
decoder[7] = read_simple;
-- opts.more(n) -> want more data
-- opts.simple -> decode simple value
-- opts[int] -> tagged decoder
local function decode(s, opts)
local fh = {};
local pos = 1;
local more;
if type(opts) == "function" then
more = opts;
elseif type(opts) == "table" then
more = opts.more;
elseif opts ~= nil then
error(("bad argument #2 to 'decode' (function or table expected, got %s)"):format(type(opts)));
end
if type(more) ~= "function" then
function more()
error "input too short";
end
end
function fh:read(bytes)
local ret = s:sub(pos, pos + bytes - 1);
if #ret < bytes then
ret = more(bytes - #ret, fh, opts);
if ret then self:write(ret); end
return self:read(bytes);
end
pos = pos + bytes;
return ret;
end
function fh:write(bytes) -- luacheck: no self
s = s .. bytes;
if pos > 256 then
s = s:sub(pos + 1);
pos = 1;
end
return #bytes;
end
return read_object(fh, opts);
end
return {
-- en-/decoder functions
encode = encode;
decode = decode;
decode_file = read_object;
-- tables of per-type en-/decoders
type_encoders = encoder;
type_decoders = decoder;
-- special treatment for tagged values
tagged_decoders = tagged_decoders;
-- constructors for annotated types
simple = simple;
tagged = tagged;
-- pre-defined simple values
null = null;
undefined = undefined;
};