1
0
mirror of https://github.com/SquidDev-CC/CC-Tweaked synced 2025-02-11 00:20:05 +00:00

Cancel no-longer-needed timers

Several functions accept a "timeout" argument, which is implemented by
starting a timer, and then racing the desired output against the timer
event.

However, if the timer never wins, we weren't cancelling the timer, and
so it was still queued. This is especially problematic if dozens or
hundreds of rednet (or websocket) messages are received in quick
succession, as we could fill the entire event queue, and stall the
computer.

See #1995
This commit is contained in:
Jonathan Coates 2024-11-10 21:07:44 +00:00
parent 3293639adf
commit b742745854
No known key found for this signature in database
GPG Key ID: B9E431FF07C98D06
5 changed files with 54 additions and 3 deletions

View File

@ -62,7 +62,7 @@ public class WebsocketHandle {
? environment.startTimer(Math.round(checkFinite(0, timeout.get()) / 0.05)) ? environment.startTimer(Math.round(checkFinite(0, timeout.get()) / 0.05))
: -1; : -1;
return new ReceiveCallback(timeoutId).pull; return new ReceiveCallback(environment, timeoutId).pull;
} }
/** /**
@ -110,18 +110,22 @@ public class WebsocketHandle {
private final class ReceiveCallback implements ILuaCallback { private final class ReceiveCallback implements ILuaCallback {
final MethodResult pull = MethodResult.pullEvent(null, this); final MethodResult pull = MethodResult.pullEvent(null, this);
private final IAPIEnvironment environment;
private final int timeoutId; private final int timeoutId;
ReceiveCallback(int timeoutId) { ReceiveCallback(IAPIEnvironment environment, int timeoutId) {
this.timeoutId = timeoutId; this.timeoutId = timeoutId;
this.environment = environment;
} }
@Override @Override
public MethodResult resume(Object[] event) { public MethodResult resume(Object[] event) {
if (event.length >= 3 && Objects.equals(event[0], MESSAGE_EVENT) && Objects.equals(event[1], address)) { if (event.length >= 3 && Objects.equals(event[0], MESSAGE_EVENT) && Objects.equals(event[1], address)) {
environment.cancelTimer(timeoutId);
return MethodResult.of(Arrays.copyOfRange(event, 2, event.length)); return MethodResult.of(Arrays.copyOfRange(event, 2, event.length));
} else if (event.length >= 2 && Objects.equals(event[0], CLOSE_EVENT) && Objects.equals(event[1], address) && websocket.isClosed()) { } else if (event.length >= 2 && Objects.equals(event[0], CLOSE_EVENT) && Objects.equals(event[1], address) && websocket.isClosed()) {
// If the socket is closed abort. // If the socket is closed abort.
environment.cancelTimer(timeoutId);
return MethodResult.of(); return MethodResult.of();
} else if (event.length >= 2 && timeoutId != -1 && Objects.equals(event[0], TIMER_EVENT) } else if (event.length >= 2 && timeoutId != -1 && Objects.equals(event[0], TIMER_EVENT)
&& event[1] instanceof Number id && id.intValue() == timeoutId) { && event[1] instanceof Number id && id.intValue() == timeoutId) {

View File

@ -196,6 +196,8 @@ function locate(_nTimeout, _bDebug)
modem.close(CHANNEL_GPS) modem.close(CHANNEL_GPS)
end end
os.cancelTimer(timeout)
-- Return the response -- Return the response
if pos1 and pos2 then if pos1 and pos2 then
if _bDebug then if _bDebug then

View File

@ -298,6 +298,7 @@ function receive(protocol_filter, timeout)
-- Return the first matching rednet_message -- Return the first matching rednet_message
local sender_id, message, protocol = p1, p2, p3 local sender_id, message, protocol = p1, p2, p3
if protocol_filter == nil or protocol == protocol_filter then if protocol_filter == nil or protocol == protocol_filter then
if timer then os.cancelTimer(timer) end
return sender_id, message, protocol return sender_id, message, protocol
end end
elseif event == "timer" then elseif event == "timer" then
@ -431,6 +432,7 @@ function lookup(protocol, hostname)
if hostname == nil then if hostname == nil then
table.insert(results, sender_id) table.insert(results, sender_id)
elseif message.sHostname == hostname then elseif message.sHostname == hostname then
os.cancelTimer(timer)
return sender_id return sender_id
end end
end end
@ -440,6 +442,9 @@ function lookup(protocol, hostname)
break break
end end
end end
os.cancelTimer(timer)
if results then if results then
return table.unpack(results) return table.unpack(results)
end end

View File

@ -171,6 +171,36 @@ describe("The rednet library", function()
fake_computer.run_all { computer_1, computer_2 } fake_computer.run_all { computer_1, computer_2 }
end) end)
describe("timeouts", function()
it("can time out messages", function()
local computer = computer_with_rednet(1, function(rednet)
local id = rednet.receive(1)
expect(id):eq(nil)
end, { open = true })
local computers = { computer }
fake_computer.run_all(computers, false)
fake_computer.advance_all(computers, 1)
fake_computer.run_all(computers)
end)
it("cancels the pending timer", function()
local computer = computer_with_rednet(1, function(rednet)
-- Send a message to ourselves with a timer
rednet.send(1, "hello")
local id = rednet.receive(1)
expect(id):eq(1)
end, { open = true })
fake_computer.run_all({ computer })
-- Our pending timer list only contains the rednet.run timer.
expect(computer.pending_timers):same {
[debugx.getupvalue(computer.env.rednet.run, "prune_received_timer")] = 10,
}
end)
end)
it("repeats messages between computers", function() it("repeats messages between computers", function()
local computer_1, modem_1 = computer_with_rednet(1, function(rednet) local computer_1, modem_1 = computer_with_rednet(1, function(rednet)
rednet.send(3, "Hello") rednet.send(3, "Hello")

View File

@ -50,6 +50,7 @@ local function make_computer(id, fn)
pending_timers[t], next_timer = clock + delay, next_timer + 1 pending_timers[t], next_timer = clock + delay, next_timer + 1
return t return t
end, end,
cancelTimer = function(id) pending_timers[id] = nil end,
clock = function() return clock end, clock = function() return clock end,
sleep = function(time) sleep = function(time)
local timer = env.os.startTimer(time or 0) local timer = env.os.startTimer(time or 0)
@ -98,7 +99,16 @@ local function make_computer(id, fn)
local position = vector.new(0, 0, 0) local position = vector.new(0, 0, 0)
return { env = env, peripherals = peripherals, queue_event = queue_event, step = step, co = co, advance = advance, position = position } return {
env = env,
peripherals = peripherals,
queue_event = queue_event,
step = step,
co = co,
advance = advance,
position = position,
pending_timers = pending_timers,
}
end end
local function parse_channel(c) local function parse_channel(c)