1
0
mirror of https://github.com/SquidDev-CC/CC-Tweaked synced 2025-02-22 22:10:09 +00:00

Propagate exceptions from parallel where possible (#2095)

In the original implementation of our prettier runtime errors (#1320), we
wrapped the errors thrown within parallel functions into an exception
object. This means the call-stack is available to the catching-code, and
so is able to report a pretty exception message.

Unfortunately, this was a breaking change, and so we had to roll that
back. Some people were pcalling the parallel function, and matching on
the result of the error.

This is a second attempt at this, using a technique I've affectionately
dubbed "magic throws". The parallel API is now aware of whether it is
being pcalled or not, and thus able to decide whether to wrap the error
into an exception or not:

 - Add a new `cc.internal.tiny_require` module. This is a tiny
   reimplementation of require, for use in our global APIs.

 - Add a new (global, in the debug registry) `cc_try_barrier` function.
   This acts as a marker function, and is used to store additional
   information about the current coroutine.

   Currently this stores the parent coroutine (used to walk the full call
   stack) and a cache of whether any `pcall`-like function is on the
   stack.

   Both `parallel` and `cc.internal.exception.try` add this function to
   the root of the call stack.

 - When an error occurs within `parallel`, we walk up the call stack,
   using `cc_try_barrier` to traverse up the parent coroutine's stack
   too. If we do not find any `pcall`-like functions, then we know the
   error is never intercepted by user code, and so its safe to throw a
   full exception.
This commit is contained in:
Jonathan Coates 2025-02-13 17:38:57 +00:00 committed by GitHub
parent 2e2f308ff3
commit 051c70a731
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 201 additions and 15 deletions

View File

@ -39,7 +39,11 @@ the other.
@since 1.2
]]
local exception = dofile("rom/modules/main/cc/internal/tiny_require.lua")("cc.internal.exception")
local function create(...)
local barrier_ctx = { co = coroutine.running() }
local functions = table.pack(...)
local threads = {}
for i = 1, functions.n, 1 do
@ -48,7 +52,7 @@ local function create(...)
error("bad argument #" .. i .. " (function expected, got " .. type(fn) .. ")", 3)
end
threads[i] = { co = coroutine.create(fn), filter = nil }
threads[i] = { co = coroutine.create(function() return exception.try_barrier(barrier_ctx, fn) end), filter = nil }
end
return threads
@ -65,11 +69,14 @@ local function runUntilLimit(threads, limit)
local thread = threads[i]
if thread and (thread.filter == nil or thread.filter == event[1] or event[1] == "terminate") then
local ok, param = coroutine.resume(thread.co, table.unpack(event, 1, event.n))
if not ok then
error(param, 0)
else
if ok then
thread.filter = param
elseif type(param) == "string" and exception.can_wrap_errors() then
error(exception.make_exception(param, thread.co))
else
error(param, 0)
end
if coroutine.status(thread.co) == "dead" then
threads[i] = false
living = living - 1

View File

@ -7,9 +7,7 @@
-- @module textutils
-- @since 1.2
local pgk_env = setmetatable({}, { __index = _ENV })
pgk_env.require = dofile("rom/modules/main/cc/require.lua").make(pgk_env, "rom/modules/main")
local require = pgk_env.require
local require = dofile("rom/modules/main/cc/internal/tiny_require.lua")
local expect = require("cc.expect")
local expect, field = expect.expect, expect.field

View File

@ -12,7 +12,7 @@
]]
local expect = require "cc.expect".expect
local error_printer = require "cc.internal.error_printer"
local type, debug, coroutine = type, debug, coroutine
local function find_frame(thread, file, line)
-- Scan the first 16 frames for something interesting.
@ -28,7 +28,7 @@ end
--[[- Check whether this error is an exception.
Currently we don't provide a stable API for throwing (and propogating) rich
Currently we don't provide a stable API for throwing (and propagating) rich
errors, like those supported by this module. In lieu of that, we describe the
exception protocol, which may be used by user-written coroutine managers to
throw exceptions which are pretty-printed by the shell:
@ -64,6 +64,86 @@ local function is_exception(exn)
return mt and mt.__name == "exception" and type(rawget(exn, "message")) == "string" and type(rawget(exn, "thread")) == "thread"
end
local exn_mt = {
__name = "exception",
__tostring = function(self) return self.message end,
}
--[[- Create a new exception from a message and thread.
@tparam string message The exception message.
@tparam coroutine thread The coroutine the error occurred on.
@return The constructed exception.
]]
local function make_exception(message, thread)
return setmetatable({ message = message, thread = thread }, exn_mt)
end
--[[- A marker function for [`try`] and the wider exception machinery.
This function is typically the first function on the call stack. It acts as both
a signifier that this function is exception aware, and allows us to store
additional information for the exception machinery on the call stack.
@see can_wrap_errors
]]
local try_barrier = debug.getregistry().cc_try_barrier
if not try_barrier then
-- We define an extra "bounce" function to prevent f(...) being treated as a
-- tail call, and so ensure the barrier remains on the stack.
local function bounce(...) return ... end
--- @tparam { co = coroutine, can_wrap ?= boolean } parent The parent coroutine.
-- @tparam function f The function to call.
-- @param ... The arguments to this function.
try_barrier = function(parent, f, ...) return bounce(f(...)) end
debug.getregistry().cc_try_barrier = try_barrier
end
-- Functions that act as a barrier for exceptions.
local pcall_functions = { [pcall] = true, [xpcall] = true, [load] = true }
--[[- Check to see whether we can wrap errors into an exception.
This scans the current thread (up to a limit), and any parent threads, to
determine if there is a pcall anywhere on the callstack. If not, then we know
the error message is not observed by user code, and so may be wrapped into an
exception.
@tparam[opt] coroutine The thread to check. Defaults to the current thread.
@treturn boolean Whether we can wrap errors into exceptions.
]]
local function can_wrap_errors(thread)
if not thread then thread = coroutine.running() end
for offset = 0, 31 do
local frame = debug.getinfo(thread, offset, "f")
if not frame then return false end
local func = frame.func
if func == try_barrier then
-- If we've a try barrier, then extract the parent coroutine and
-- check if it can wrap errors.
local _, parent = debug.getlocal(thread, offset, 1)
if type(parent) ~= "table" or type(parent.co) ~= "thread" then return false end
local result = parent.can_wrap
if result == nil then
result = can_wrap_errors(parent.co)
parent.can_wrap = result
end
return result
elseif pcall_functions[func] then
-- If we're a pcall, then abort.
return false
end
end
return false
end
--[[- Attempt to call the provided function `func` with the provided arguments.
@tparam function func The function to call.
@ -79,8 +159,8 @@ end
local function try(func, ...)
expect(1, func, "function")
local co = coroutine.create(func)
local result = table.pack(coroutine.resume(co, ...))
local co = coroutine.create(try_barrier)
local result = table.pack(coroutine.resume(co, { co = co, can_wrap = true }, func, ...))
while coroutine.status(co) ~= "dead" do
local event = table.pack(os.pullEventRaw(result[2]))
@ -152,7 +232,7 @@ local function report(err, thread, source_map)
-- Could not determine the line. Bail.
if not line_contents or #line_contents == "" then return end
error_printer({
require("cc.internal.error_printer")({
get_pos = function() return line, column end,
get_line = function() return line_contents end,
}, {
@ -162,6 +242,11 @@ end
return {
make_exception = make_exception,
try_barrier = try_barrier,
can_wrap_errors = can_wrap_errors,
try = try,
report = report,
}

View File

@ -0,0 +1,37 @@
-- SPDX-FileCopyrightText: 2025 The CC: Tweaked Developers
--
-- SPDX-License-Identifier: MPL-2.0
--[[- A minimal implementation of require.
This is intended for use with APIs, and other internal code which is not run in
the [`shell`] environment. This allows us to avoid some of the overhead of
loading the full [`cc.require`] module.
> [!DANGER]
> This is an internal module and SHOULD NOT be used in your own code. It may
> be removed or changed at any time.
@local
@tparam string name The module to require.
@return The required module.
]]
local loaded = {}
local env = setmetatable({}, { __index = _G })
local function require(name)
local result = loaded[name]
if result then return result end
local path = "rom/modules/main/" .. name:gsub("%.", "/")
if fs.exists(path .. ".lua") then
result = assert(loadfile(path .. ".lua", nil, env))()
else
result = assert(loadfile(path .. "/init.lua", nil, env))()
end
loaded[name] = result
return result
end
env.require = require
return require

View File

@ -413,7 +413,7 @@ public class ComputerTestDelegate {
var wholeMessage = new StringBuilder();
if (message != null) wholeMessage.append(message);
if (trace != null) {
if (wholeMessage.length() != 0) wholeMessage.append('\n');
if (!wholeMessage.isEmpty()) wholeMessage.append('\n');
wholeMessage.append(trace);
}

View File

@ -193,7 +193,7 @@ local function format(value)
return "\"" .. escaped .. "\""
else
local ok, res = pcall(textutils.serialise, value)
if ok then return res else return tostring(value) end
if ok then return (res:gsub("\\\n", "\\n")) else return tostring(value) end
end
end
@ -379,7 +379,7 @@ end
function expect_mt:str_match(pattern)
local actual_type = type(self.value)
if actual_type ~= "string" then
self:_fail(("Expected value of type string\nbut got %s"):format(actual_type))
self:_fail(("Expected value of type string\nbut got %s (of type %s)"):format(format(self.value), actual_type))
end
if not self.value:find(pattern) then
self:_fail(("Expected %q\n to match pattern %q"):format(self.value, pattern))

View File

@ -129,4 +129,63 @@ describe("The parallel library", function()
expect(exitCount):eq(3)
end)
end)
describe("exceptions", function()
local try = require "cc.internal.exception".try
local function check_failure(fn, ...)
local ok, message, thread = try(fn, ...)
expect(ok):eq(false)
expect(message):str_match("/parallel_spec.lua:%d+: Oh no$")
return thread
end
it("throws an exception when within a try", function()
local expected_thread
local thread = check_failure(parallel.waitForAny, function()
expected_thread = coroutine.running()
error("Oh no")
end)
expect(thread):eq(expected_thread)
end)
it("throws an exception when within a try (nested)", function()
local expected_thread
local thread = check_failure(parallel.waitForAny, function()
parallel.waitForAny(function()
expected_thread = coroutine.running()
error("Oh no")
end)
end)
expect(thread):eq(expected_thread)
end)
it("throws the raw error when within a pcall", function()
local expected_thread
local thread = check_failure(function()
expected_thread = coroutine.running()
local ok, err = pcall(parallel.waitForAny, function() error("Oh no") end)
expect(ok):eq(false)
expect(err):str_match("/parallel_spec.lua:%d+: Oh no$")
error(err, 0)
end)
expect(thread):eq(expected_thread)
end)
it("throws the raw error when within a pcall (nested)", function()
local expected_thread
local thread = check_failure(function()
expected_thread = coroutine.running()
local ok, err = pcall(parallel.waitForAny, function()
parallel.waitForAny(function() error("Oh no") end)
end)
expect(ok):eq(false)
expect(err):str_match("/parallel_spec.lua:%d+: Oh no$")
error(err, 0)
end)
expect(thread):eq(expected_thread)
end)
end)
end)