mirror of
https://github.com/SquidDev-CC/CC-Tweaked
synced 2025-11-13 03:43:08 +00:00
Add a function for type-checking arguments (#207)
- Define an expect(index, actual_value, types...) helper function which takes an argument index, value and list of permissable types and ensures the value is of one of those types. If not, it will produce an error message with the expected and actual type, as well as the argument number and (if available) the function name. - Expose expect in the global scope as _G["~expect"], hopefully making it clear it is internal. - Replace most manual type checks with this helper method. - Write tests to ensure this argument validation works as expected Also fix a couple of bugs exposed by this refactor and the subsequent tests: - Make rednet checks a little more strict - rednet.close(false) is no longer valid. - Error when attempting to redirect the terminal to itself (term.redirect(term)).
This commit is contained in:
@@ -1,5 +1,48 @@
|
||||
local native_select, native_type = select, type
|
||||
|
||||
--- Expect an argument to have a specific type.
|
||||
--
|
||||
-- @tparam int index The 1-based argument index.
|
||||
-- @param value The argument's value.
|
||||
-- @tparam string ... The allowed types of the argument.
|
||||
-- @throws If the value is not one of the allowed types.
|
||||
local function expect(index, value, ...)
|
||||
local t = native_type(value)
|
||||
for i = 1, native_select("#", ...) do
|
||||
if t == native_select(i, ...) then return true end
|
||||
end
|
||||
|
||||
local types = table.pack(...)
|
||||
for i = types.n, 1, -1 do
|
||||
if types[i] == "nil" then table.remove(types, i) end
|
||||
end
|
||||
|
||||
local type_names
|
||||
if #types <= 1 then
|
||||
type_names = tostring(...)
|
||||
else
|
||||
type_names = table.concat(types, ", ", 1, #types - 1) .. " or " .. types[#types]
|
||||
end
|
||||
|
||||
-- If we can determine the function name with a high level of confidence, try to include it.
|
||||
local name
|
||||
if native_type(debug) == "table" and native_type(debug.getinfo) == "function" then
|
||||
local ok, info = pcall(debug.getinfo, 3, "nS")
|
||||
if ok and info.name and #info.name ~= "" and info.what ~= "C" then name = info.name end
|
||||
end
|
||||
|
||||
if name then
|
||||
error( ("bad argument #%d to '%s' (expected %s, got %s)"):format(index, name, type_names, t), 3 )
|
||||
else
|
||||
error( ("bad argument #%d (expected %s, got %s)"):format(index, type_names, t), 3 )
|
||||
end
|
||||
end
|
||||
|
||||
-- We expose expect in the global table as APIs need to access it, but give it
|
||||
-- a non-identifier name - meaning it does not show up in auto-completion.
|
||||
-- expect is an internal function, and should not be used by users.
|
||||
_G["~expect"] = expect
|
||||
|
||||
local nativegetfenv = getfenv
|
||||
if _VERSION == "Lua 5.1" then
|
||||
-- If we're on Lua 5.1, install parts of the Lua 5.2/5.3 API so that programs can be written against it
|
||||
local type = type
|
||||
@@ -20,18 +63,11 @@ if _VERSION == "Lua 5.1" then
|
||||
end
|
||||
|
||||
function load( x, name, mode, env )
|
||||
if type( x ) ~= "string" and type( x ) ~= "function" then
|
||||
error( "bad argument #1 (expected string or function, got " .. type( x ) .. ")", 2 )
|
||||
end
|
||||
if name ~= nil and type( name ) ~= "string" then
|
||||
error( "bad argument #2 (expected string, got " .. type( name ) .. ")", 2 )
|
||||
end
|
||||
if mode ~= nil and type( mode ) ~= "string" then
|
||||
error( "bad argument #3 (expected string, got " .. type( mode ) .. ")", 2 )
|
||||
end
|
||||
if env ~= nil and type( env) ~= "table" then
|
||||
error( "bad argument #4 (expected table, got " .. type( env ) .. ")", 2 )
|
||||
end
|
||||
expect(1, x, "function", "string")
|
||||
expect(2, name, "string", "nil")
|
||||
expect(3, mode, "string", "nil")
|
||||
expect(4, env, "table", "nil")
|
||||
|
||||
local ok, p1, p2 = pcall( function()
|
||||
if type(x) == "string" then
|
||||
local result, err = nativeloadstring( x, name )
|
||||
@@ -76,10 +112,9 @@ if _VERSION == "Lua 5.1" then
|
||||
math.log10 = nil
|
||||
table.maxn = nil
|
||||
else
|
||||
loadstring = function(string, chunkname) return nativeloadstring(string, prefix( chunkname ))
|
||||
loadstring = function(string, chunkname) return nativeloadstring(string, prefix( chunkname )) end
|
||||
|
||||
-- Inject a stub for the old bit library
|
||||
end
|
||||
_G.bit = {
|
||||
bnot = bit32.bnot,
|
||||
band = bit32.band,
|
||||
@@ -175,9 +210,7 @@ end
|
||||
|
||||
-- Install globals
|
||||
function sleep( nTime )
|
||||
if nTime ~= nil and type( nTime ) ~= "number" then
|
||||
error( "bad argument #1 (expected number, got " .. type( nTime ) .. ")", 2 )
|
||||
end
|
||||
expect(1, nTime, "number", "nil")
|
||||
local timer = os.startTimer( nTime or 0 )
|
||||
repeat
|
||||
local sEvent, param = os.pullEvent( "timer" )
|
||||
@@ -185,9 +218,7 @@ function sleep( nTime )
|
||||
end
|
||||
|
||||
function write( sText )
|
||||
if type( sText ) ~= "string" and type( sText ) ~= "number" then
|
||||
error( "bad argument #1 (expected string or number, got " .. type( sText ) .. ")", 2 )
|
||||
end
|
||||
expect(1, sText, "string", "number")
|
||||
|
||||
local w,h = term.getSize()
|
||||
local x,y = term.getCursorPos()
|
||||
@@ -275,18 +306,11 @@ function printError( ... )
|
||||
end
|
||||
|
||||
function read( _sReplaceChar, _tHistory, _fnComplete, _sDefault )
|
||||
if _sReplaceChar ~= nil and type( _sReplaceChar ) ~= "string" then
|
||||
error( "bad argument #1 (expected string, got " .. type( _sReplaceChar ) .. ")", 2 )
|
||||
end
|
||||
if _tHistory ~= nil and type( _tHistory ) ~= "table" then
|
||||
error( "bad argument #2 (expected table, got " .. type( _tHistory ) .. ")", 2 )
|
||||
end
|
||||
if _fnComplete ~= nil and type( _fnComplete ) ~= "function" then
|
||||
error( "bad argument #3 (expected function, got " .. type( _fnComplete ) .. ")", 2 )
|
||||
end
|
||||
if _sDefault ~= nil and type( _sDefault ) ~= "string" then
|
||||
error( "bad argument #4 (expected string, got " .. type( _sDefault ) .. ")", 2 )
|
||||
end
|
||||
expect(1, _sReplaceChar, "string", "nil")
|
||||
expect(2, _tHistory, "table", "nil")
|
||||
expect(3, _fnComplete, "function", "nil")
|
||||
expect(4, _sDefault, "string", "nil")
|
||||
|
||||
term.setCursorBlink( true )
|
||||
|
||||
local sLine
|
||||
@@ -544,13 +568,10 @@ function read( _sReplaceChar, _tHistory, _fnComplete, _sDefault )
|
||||
return sLine
|
||||
end
|
||||
|
||||
loadfile = function( _sFile, _tEnv )
|
||||
if type( _sFile ) ~= "string" then
|
||||
error( "bad argument #1 (expected string, got " .. type( _sFile ) .. ")", 2 )
|
||||
end
|
||||
if _tEnv ~= nil and type( _tEnv ) ~= "table" then
|
||||
error( "bad argument #2 (expected table, got " .. type( _tEnv ) .. ")", 2 )
|
||||
end
|
||||
function loadfile( _sFile, _tEnv )
|
||||
expect(1, _sFile, "string")
|
||||
expect(2, _tEnv, "table", "nil")
|
||||
|
||||
local file = fs.open( _sFile, "r" )
|
||||
if file then
|
||||
local func, err = load( file.readAll(), "@" .. fs.getName( _sFile ), "t", _tEnv )
|
||||
@@ -560,10 +581,9 @@ loadfile = function( _sFile, _tEnv )
|
||||
return nil, "File not found"
|
||||
end
|
||||
|
||||
dofile = function( _sFile )
|
||||
if type( _sFile ) ~= "string" then
|
||||
error( "bad argument #1 (expected string, got " .. type( _sFile ) .. ")", 2 )
|
||||
end
|
||||
function dofile( _sFile )
|
||||
expect(1, _sFile, "string")
|
||||
|
||||
local fnFile, e = loadfile( _sFile, _G )
|
||||
if fnFile then
|
||||
return fnFile()
|
||||
@@ -574,12 +594,9 @@ end
|
||||
|
||||
-- Install the rest of the OS api
|
||||
function os.run( _tEnv, _sPath, ... )
|
||||
if type( _tEnv ) ~= "table" then
|
||||
error( "bad argument #1 (expected table, got " .. type( _tEnv ) .. ")", 2 )
|
||||
end
|
||||
if type( _sPath ) ~= "string" then
|
||||
error( "bad argument #2 (expected string, got " .. type( _sPath ) .. ")", 2 )
|
||||
end
|
||||
expect(1, _tEnv, "table")
|
||||
expect(2, _sPath, "string")
|
||||
|
||||
local tArgs = table.pack( ... )
|
||||
local tEnv = _tEnv
|
||||
setmetatable( tEnv, { __index = _G } )
|
||||
@@ -604,9 +621,7 @@ end
|
||||
|
||||
local tAPIsLoading = {}
|
||||
function os.loadAPI( _sPath )
|
||||
if type( _sPath ) ~= "string" then
|
||||
error( "bad argument #1 (expected string, got " .. type( _sPath ) .. ")", 2 )
|
||||
end
|
||||
expect(1, _sPath, "string")
|
||||
local sName = fs.getName( _sPath )
|
||||
if sName:sub(-4) == ".lua" then
|
||||
sName = sName:sub(1,-5)
|
||||
@@ -644,9 +659,7 @@ function os.loadAPI( _sPath )
|
||||
end
|
||||
|
||||
function os.unloadAPI( _sName )
|
||||
if type( _sName ) ~= "string" then
|
||||
error( "bad argument #1 (expected string, got " .. type( _sName ) .. ")", 2 )
|
||||
end
|
||||
expect(1, _sName, "string")
|
||||
if _sName ~= "_G" and type(_G[_sName]) == "table" then
|
||||
_G[_sName] = nil
|
||||
end
|
||||
@@ -692,9 +705,11 @@ if http then
|
||||
|
||||
local function checkOptions( options, body )
|
||||
checkKey( options, "url", "string")
|
||||
if body == false
|
||||
then checkKey( options, "body", "nil" )
|
||||
else checkKey( options, "body", "string", not body ) end
|
||||
if body == false then
|
||||
checkKey( options, "body", "nil" )
|
||||
else
|
||||
checkKey( options, "body", "string", not body )
|
||||
end
|
||||
checkKey( options, "headers", "table", true )
|
||||
checkKey( options, "method", "string", true )
|
||||
checkKey( options, "redirect", "boolean", true )
|
||||
@@ -725,15 +740,9 @@ if http then
|
||||
return wrapRequest( _url.url, _url )
|
||||
end
|
||||
|
||||
if type( _url ) ~= "string" then
|
||||
error( "bad argument #1 (expected string, got " .. type( _url ) .. ")", 2 )
|
||||
end
|
||||
if _headers ~= nil and type( _headers ) ~= "table" then
|
||||
error( "bad argument #2 (expected table, got " .. type( _headers ) .. ")", 2 )
|
||||
end
|
||||
if _binary ~= nil and type( _binary ) ~= "boolean" then
|
||||
error( "bad argument #3 (expected boolean, got " .. type( _binary ) .. ")", 2 )
|
||||
end
|
||||
expect(1, _url, "string")
|
||||
expect(2, _headers, "table", "nil")
|
||||
expect(3, _binary, "boolean", "nil")
|
||||
return wrapRequest( _url, _url, nil, _headers, _binary )
|
||||
end
|
||||
|
||||
@@ -743,18 +752,10 @@ if http then
|
||||
return wrapRequest( _url.url, _url )
|
||||
end
|
||||
|
||||
if type( _url ) ~= "string" then
|
||||
error( "bad argument #1 (expected string, got " .. type( _url ) .. ")", 2 )
|
||||
end
|
||||
if type( _post ) ~= "string" then
|
||||
error( "bad argument #2 (expected string, got " .. type( _post ) .. ")", 2 )
|
||||
end
|
||||
if _headers ~= nil and type( _headers ) ~= "table" then
|
||||
error( "bad argument #3 (expected table, got " .. type( _headers ) .. ")", 2 )
|
||||
end
|
||||
if _binary ~= nil and type( _binary ) ~= "boolean" then
|
||||
error( "bad argument #4 (expected boolean, got " .. type( _binary ) .. ")", 2 )
|
||||
end
|
||||
expect(1, _url, "string")
|
||||
expect(2, _post, "string")
|
||||
expect(3, _headers, "table", "nil")
|
||||
expect(4, _binary, "boolean", "nil")
|
||||
return wrapRequest( _url, _url, _post, _headers, _binary )
|
||||
end
|
||||
|
||||
@@ -764,19 +765,10 @@ if http then
|
||||
checkOptions( _url )
|
||||
url = _url.url
|
||||
else
|
||||
if type( _url ) ~= "string" then
|
||||
error( "bad argument #1 (expected string, got " .. type( _url ) .. ")", 2 )
|
||||
end
|
||||
if _post ~= nil and type( _post ) ~= "string" then
|
||||
error( "bad argument #2 (expected string, got " .. type( _post ) .. ")", 2 )
|
||||
end
|
||||
if _headers ~= nil and type( _headers ) ~= "table" then
|
||||
error( "bad argument #3 (expected table, got " .. type( _headers ) .. ")", 2 )
|
||||
end
|
||||
if _binary ~= nil and type( _binary ) ~= "boolean" then
|
||||
error( "bad argument #4 (expected boolean, got " .. type( _binary ) .. ")", 2 )
|
||||
end
|
||||
|
||||
expect(1, _url, "string")
|
||||
expect(2, _post, "string", "nil")
|
||||
expect(3, _headers, "table", "nil")
|
||||
expect(4, _binary, "boolean", "nil")
|
||||
url = _url.url
|
||||
end
|
||||
|
||||
@@ -802,12 +794,9 @@ if http then
|
||||
local nativeWebsocket = http.websocket
|
||||
http.websocketAsync = nativeWebsocket
|
||||
http.websocket = function( _url, _headers )
|
||||
if type( _url ) ~= "string" then
|
||||
error( "bad argument #1 (expected string, got " .. type( _url ) .. ")", 2 )
|
||||
end
|
||||
if _headers ~= nil and type( _headers ) ~= "table" then
|
||||
error( "bad argument #2 (expected table, got " .. type( _headers ) .. ")", 2 )
|
||||
end
|
||||
expect(1, _url, "string")
|
||||
expect(2, _headers, "table", "nil")
|
||||
|
||||
local ok, err = nativeWebsocket( _url, _headers )
|
||||
if not ok then return ok, err end
|
||||
|
||||
@@ -825,18 +814,11 @@ end
|
||||
-- Install the lua part of the FS api
|
||||
local tEmpty = {}
|
||||
function fs.complete( sPath, sLocation, bIncludeFiles, bIncludeDirs )
|
||||
if type( sPath ) ~= "string" then
|
||||
error( "bad argument #1 (expected string, got " .. type( sPath ) .. ")", 2 )
|
||||
end
|
||||
if type( sLocation ) ~= "string" then
|
||||
error( "bad argument #2 (expected string, got " .. type( sLocation ) .. ")", 2 )
|
||||
end
|
||||
if bIncludeFiles ~= nil and type( bIncludeFiles ) ~= "boolean" then
|
||||
error( "bad argument #3 (expected boolean, got " .. type( bIncludeFiles ) .. ")", 2 )
|
||||
end
|
||||
if bIncludeDirs ~= nil and type( bIncludeDirs ) ~= "boolean" then
|
||||
error( "bad argument #4 (expected boolean, got " .. type( bIncludeDirs ) .. ")", 2 )
|
||||
end
|
||||
expect(1, sPath, "string")
|
||||
expect(2, sLocation, "string")
|
||||
expect(3, bIncludeFiles, "boolean", "nil")
|
||||
expect(4, bIncludeDirs, "boolean", "nil")
|
||||
|
||||
bIncludeFiles = (bIncludeFiles ~= false)
|
||||
bIncludeDirs = (bIncludeDirs ~= false)
|
||||
local sDir = sLocation
|
||||
|
||||
Reference in New Issue
Block a user