mirror of
https://github.com/SquidDev-CC/CC-Tweaked
synced 2025-02-22 14:00:09 +00:00
Propagate exceptions from parallel where possible (#2095)
In the original implementation of our prettier runtime errors (#1320), we wrapped the errors thrown within parallel functions into an exception object. This means the call-stack is available to the catching-code, and so is able to report a pretty exception message. Unfortunately, this was a breaking change, and so we had to roll that back. Some people were pcalling the parallel function, and matching on the result of the error. This is a second attempt at this, using a technique I've affectionately dubbed "magic throws". The parallel API is now aware of whether it is being pcalled or not, and thus able to decide whether to wrap the error into an exception or not: - Add a new `cc.internal.tiny_require` module. This is a tiny reimplementation of require, for use in our global APIs. - Add a new (global, in the debug registry) `cc_try_barrier` function. This acts as a marker function, and is used to store additional information about the current coroutine. Currently this stores the parent coroutine (used to walk the full call stack) and a cache of whether any `pcall`-like function is on the stack. Both `parallel` and `cc.internal.exception.try` add this function to the root of the call stack. - When an error occurs within `parallel`, we walk up the call stack, using `cc_try_barrier` to traverse up the parent coroutine's stack too. If we do not find any `pcall`-like functions, then we know the error is never intercepted by user code, and so its safe to throw a full exception.
This commit is contained in:
parent
2e2f308ff3
commit
051c70a731
@ -39,7 +39,11 @@ the other.
|
||||
@since 1.2
|
||||
]]
|
||||
|
||||
local exception = dofile("rom/modules/main/cc/internal/tiny_require.lua")("cc.internal.exception")
|
||||
|
||||
local function create(...)
|
||||
local barrier_ctx = { co = coroutine.running() }
|
||||
|
||||
local functions = table.pack(...)
|
||||
local threads = {}
|
||||
for i = 1, functions.n, 1 do
|
||||
@ -48,7 +52,7 @@ local function create(...)
|
||||
error("bad argument #" .. i .. " (function expected, got " .. type(fn) .. ")", 3)
|
||||
end
|
||||
|
||||
threads[i] = { co = coroutine.create(fn), filter = nil }
|
||||
threads[i] = { co = coroutine.create(function() return exception.try_barrier(barrier_ctx, fn) end), filter = nil }
|
||||
end
|
||||
|
||||
return threads
|
||||
@ -65,11 +69,14 @@ local function runUntilLimit(threads, limit)
|
||||
local thread = threads[i]
|
||||
if thread and (thread.filter == nil or thread.filter == event[1] or event[1] == "terminate") then
|
||||
local ok, param = coroutine.resume(thread.co, table.unpack(event, 1, event.n))
|
||||
if not ok then
|
||||
error(param, 0)
|
||||
else
|
||||
if ok then
|
||||
thread.filter = param
|
||||
elseif type(param) == "string" and exception.can_wrap_errors() then
|
||||
error(exception.make_exception(param, thread.co))
|
||||
else
|
||||
error(param, 0)
|
||||
end
|
||||
|
||||
if coroutine.status(thread.co) == "dead" then
|
||||
threads[i] = false
|
||||
living = living - 1
|
||||
|
@ -7,9 +7,7 @@
|
||||
-- @module textutils
|
||||
-- @since 1.2
|
||||
|
||||
local pgk_env = setmetatable({}, { __index = _ENV })
|
||||
pgk_env.require = dofile("rom/modules/main/cc/require.lua").make(pgk_env, "rom/modules/main")
|
||||
local require = pgk_env.require
|
||||
local require = dofile("rom/modules/main/cc/internal/tiny_require.lua")
|
||||
|
||||
local expect = require("cc.expect")
|
||||
local expect, field = expect.expect, expect.field
|
||||
|
@ -12,7 +12,7 @@
|
||||
]]
|
||||
|
||||
local expect = require "cc.expect".expect
|
||||
local error_printer = require "cc.internal.error_printer"
|
||||
local type, debug, coroutine = type, debug, coroutine
|
||||
|
||||
local function find_frame(thread, file, line)
|
||||
-- Scan the first 16 frames for something interesting.
|
||||
@ -28,7 +28,7 @@ end
|
||||
|
||||
--[[- Check whether this error is an exception.
|
||||
|
||||
Currently we don't provide a stable API for throwing (and propogating) rich
|
||||
Currently we don't provide a stable API for throwing (and propagating) rich
|
||||
errors, like those supported by this module. In lieu of that, we describe the
|
||||
exception protocol, which may be used by user-written coroutine managers to
|
||||
throw exceptions which are pretty-printed by the shell:
|
||||
@ -64,6 +64,86 @@ local function is_exception(exn)
|
||||
return mt and mt.__name == "exception" and type(rawget(exn, "message")) == "string" and type(rawget(exn, "thread")) == "thread"
|
||||
end
|
||||
|
||||
local exn_mt = {
|
||||
__name = "exception",
|
||||
__tostring = function(self) return self.message end,
|
||||
}
|
||||
|
||||
--[[- Create a new exception from a message and thread.
|
||||
|
||||
@tparam string message The exception message.
|
||||
@tparam coroutine thread The coroutine the error occurred on.
|
||||
@return The constructed exception.
|
||||
]]
|
||||
local function make_exception(message, thread)
|
||||
return setmetatable({ message = message, thread = thread }, exn_mt)
|
||||
end
|
||||
|
||||
--[[- A marker function for [`try`] and the wider exception machinery.
|
||||
|
||||
This function is typically the first function on the call stack. It acts as both
|
||||
a signifier that this function is exception aware, and allows us to store
|
||||
additional information for the exception machinery on the call stack.
|
||||
|
||||
@see can_wrap_errors
|
||||
]]
|
||||
local try_barrier = debug.getregistry().cc_try_barrier
|
||||
if not try_barrier then
|
||||
-- We define an extra "bounce" function to prevent f(...) being treated as a
|
||||
-- tail call, and so ensure the barrier remains on the stack.
|
||||
local function bounce(...) return ... end
|
||||
|
||||
--- @tparam { co = coroutine, can_wrap ?= boolean } parent The parent coroutine.
|
||||
-- @tparam function f The function to call.
|
||||
-- @param ... The arguments to this function.
|
||||
try_barrier = function(parent, f, ...) return bounce(f(...)) end
|
||||
|
||||
debug.getregistry().cc_try_barrier = try_barrier
|
||||
end
|
||||
|
||||
-- Functions that act as a barrier for exceptions.
|
||||
local pcall_functions = { [pcall] = true, [xpcall] = true, [load] = true }
|
||||
|
||||
--[[- Check to see whether we can wrap errors into an exception.
|
||||
|
||||
This scans the current thread (up to a limit), and any parent threads, to
|
||||
determine if there is a pcall anywhere on the callstack. If not, then we know
|
||||
the error message is not observed by user code, and so may be wrapped into an
|
||||
exception.
|
||||
|
||||
@tparam[opt] coroutine The thread to check. Defaults to the current thread.
|
||||
@treturn boolean Whether we can wrap errors into exceptions.
|
||||
]]
|
||||
local function can_wrap_errors(thread)
|
||||
if not thread then thread = coroutine.running() end
|
||||
|
||||
for offset = 0, 31 do
|
||||
local frame = debug.getinfo(thread, offset, "f")
|
||||
if not frame then return false end
|
||||
|
||||
local func = frame.func
|
||||
if func == try_barrier then
|
||||
-- If we've a try barrier, then extract the parent coroutine and
|
||||
-- check if it can wrap errors.
|
||||
local _, parent = debug.getlocal(thread, offset, 1)
|
||||
if type(parent) ~= "table" or type(parent.co) ~= "thread" then return false end
|
||||
|
||||
local result = parent.can_wrap
|
||||
if result == nil then
|
||||
result = can_wrap_errors(parent.co)
|
||||
parent.can_wrap = result
|
||||
end
|
||||
|
||||
return result
|
||||
elseif pcall_functions[func] then
|
||||
-- If we're a pcall, then abort.
|
||||
return false
|
||||
end
|
||||
end
|
||||
|
||||
return false
|
||||
end
|
||||
|
||||
--[[- Attempt to call the provided function `func` with the provided arguments.
|
||||
|
||||
@tparam function func The function to call.
|
||||
@ -79,8 +159,8 @@ end
|
||||
local function try(func, ...)
|
||||
expect(1, func, "function")
|
||||
|
||||
local co = coroutine.create(func)
|
||||
local result = table.pack(coroutine.resume(co, ...))
|
||||
local co = coroutine.create(try_barrier)
|
||||
local result = table.pack(coroutine.resume(co, { co = co, can_wrap = true }, func, ...))
|
||||
|
||||
while coroutine.status(co) ~= "dead" do
|
||||
local event = table.pack(os.pullEventRaw(result[2]))
|
||||
@ -152,7 +232,7 @@ local function report(err, thread, source_map)
|
||||
-- Could not determine the line. Bail.
|
||||
if not line_contents or #line_contents == "" then return end
|
||||
|
||||
error_printer({
|
||||
require("cc.internal.error_printer")({
|
||||
get_pos = function() return line, column end,
|
||||
get_line = function() return line_contents end,
|
||||
}, {
|
||||
@ -162,6 +242,11 @@ end
|
||||
|
||||
|
||||
return {
|
||||
make_exception = make_exception,
|
||||
|
||||
try_barrier = try_barrier,
|
||||
can_wrap_errors = can_wrap_errors,
|
||||
|
||||
try = try,
|
||||
report = report,
|
||||
}
|
||||
|
@ -0,0 +1,37 @@
|
||||
-- SPDX-FileCopyrightText: 2025 The CC: Tweaked Developers
|
||||
--
|
||||
-- SPDX-License-Identifier: MPL-2.0
|
||||
|
||||
--[[- A minimal implementation of require.
|
||||
|
||||
This is intended for use with APIs, and other internal code which is not run in
|
||||
the [`shell`] environment. This allows us to avoid some of the overhead of
|
||||
loading the full [`cc.require`] module.
|
||||
|
||||
> [!DANGER]
|
||||
> This is an internal module and SHOULD NOT be used in your own code. It may
|
||||
> be removed or changed at any time.
|
||||
|
||||
@local
|
||||
|
||||
@tparam string name The module to require.
|
||||
@return The required module.
|
||||
]]
|
||||
|
||||
local loaded = {}
|
||||
local env = setmetatable({}, { __index = _G })
|
||||
local function require(name)
|
||||
local result = loaded[name]
|
||||
if result then return result end
|
||||
|
||||
local path = "rom/modules/main/" .. name:gsub("%.", "/")
|
||||
if fs.exists(path .. ".lua") then
|
||||
result = assert(loadfile(path .. ".lua", nil, env))()
|
||||
else
|
||||
result = assert(loadfile(path .. "/init.lua", nil, env))()
|
||||
end
|
||||
loaded[name] = result
|
||||
return result
|
||||
end
|
||||
env.require = require
|
||||
return require
|
@ -413,7 +413,7 @@ public class ComputerTestDelegate {
|
||||
var wholeMessage = new StringBuilder();
|
||||
if (message != null) wholeMessage.append(message);
|
||||
if (trace != null) {
|
||||
if (wholeMessage.length() != 0) wholeMessage.append('\n');
|
||||
if (!wholeMessage.isEmpty()) wholeMessage.append('\n');
|
||||
wholeMessage.append(trace);
|
||||
}
|
||||
|
||||
|
@ -193,7 +193,7 @@ local function format(value)
|
||||
return "\"" .. escaped .. "\""
|
||||
else
|
||||
local ok, res = pcall(textutils.serialise, value)
|
||||
if ok then return res else return tostring(value) end
|
||||
if ok then return (res:gsub("\\\n", "\\n")) else return tostring(value) end
|
||||
end
|
||||
end
|
||||
|
||||
@ -379,7 +379,7 @@ end
|
||||
function expect_mt:str_match(pattern)
|
||||
local actual_type = type(self.value)
|
||||
if actual_type ~= "string" then
|
||||
self:_fail(("Expected value of type string\nbut got %s"):format(actual_type))
|
||||
self:_fail(("Expected value of type string\nbut got %s (of type %s)"):format(format(self.value), actual_type))
|
||||
end
|
||||
if not self.value:find(pattern) then
|
||||
self:_fail(("Expected %q\n to match pattern %q"):format(self.value, pattern))
|
||||
|
@ -129,4 +129,63 @@ describe("The parallel library", function()
|
||||
expect(exitCount):eq(3)
|
||||
end)
|
||||
end)
|
||||
|
||||
describe("exceptions", function()
|
||||
local try = require "cc.internal.exception".try
|
||||
local function check_failure(fn, ...)
|
||||
local ok, message, thread = try(fn, ...)
|
||||
expect(ok):eq(false)
|
||||
expect(message):str_match("/parallel_spec.lua:%d+: Oh no$")
|
||||
return thread
|
||||
end
|
||||
|
||||
it("throws an exception when within a try", function()
|
||||
local expected_thread
|
||||
local thread = check_failure(parallel.waitForAny, function()
|
||||
expected_thread = coroutine.running()
|
||||
error("Oh no")
|
||||
end)
|
||||
|
||||
expect(thread):eq(expected_thread)
|
||||
end)
|
||||
|
||||
it("throws an exception when within a try (nested)", function()
|
||||
local expected_thread
|
||||
local thread = check_failure(parallel.waitForAny, function()
|
||||
parallel.waitForAny(function()
|
||||
expected_thread = coroutine.running()
|
||||
error("Oh no")
|
||||
end)
|
||||
end)
|
||||
expect(thread):eq(expected_thread)
|
||||
end)
|
||||
|
||||
it("throws the raw error when within a pcall", function()
|
||||
local expected_thread
|
||||
local thread = check_failure(function()
|
||||
expected_thread = coroutine.running()
|
||||
|
||||
local ok, err = pcall(parallel.waitForAny, function() error("Oh no") end)
|
||||
expect(ok):eq(false)
|
||||
expect(err):str_match("/parallel_spec.lua:%d+: Oh no$")
|
||||
error(err, 0)
|
||||
end)
|
||||
expect(thread):eq(expected_thread)
|
||||
end)
|
||||
|
||||
it("throws the raw error when within a pcall (nested)", function()
|
||||
local expected_thread
|
||||
local thread = check_failure(function()
|
||||
expected_thread = coroutine.running()
|
||||
|
||||
local ok, err = pcall(parallel.waitForAny, function()
|
||||
parallel.waitForAny(function() error("Oh no") end)
|
||||
end)
|
||||
expect(ok):eq(false)
|
||||
expect(err):str_match("/parallel_spec.lua:%d+: Oh no$")
|
||||
error(err, 0)
|
||||
end)
|
||||
expect(thread):eq(expected_thread)
|
||||
end)
|
||||
end)
|
||||
end)
|
||||
|
Loading…
x
Reference in New Issue
Block a user