diff --git a/projects/core/src/main/java/dan200/computercraft/core/apis/http/websocket/Websocket.java b/projects/core/src/main/java/dan200/computercraft/core/apis/http/websocket/Websocket.java index e2ff84c0c..5695000ef 100644 --- a/projects/core/src/main/java/dan200/computercraft/core/apis/http/websocket/Websocket.java +++ b/projects/core/src/main/java/dan200/computercraft/core/apis/http/websocket/Websocket.java @@ -5,14 +5,13 @@ package dan200.computercraft.core.apis.http.websocket; import com.google.common.base.Strings; +import dan200.computercraft.api.lua.LuaException; import dan200.computercraft.core.Logging; import dan200.computercraft.core.apis.IAPIEnvironment; -import dan200.computercraft.core.apis.http.HTTPRequestException; -import dan200.computercraft.core.apis.http.NetworkUtils; -import dan200.computercraft.core.apis.http.Resource; -import dan200.computercraft.core.apis.http.ResourceGroup; +import dan200.computercraft.core.apis.http.*; import dan200.computercraft.core.apis.http.options.Options; import dan200.computercraft.core.metrics.Metrics; +import dan200.computercraft.core.util.AtomicHelpers; import io.netty.bootstrap.Bootstrap; import io.netty.buffer.Unpooled; import io.netty.channel.Channel; @@ -24,10 +23,8 @@ import io.netty.handler.codec.http.HttpClientCodec; import io.netty.handler.codec.http.HttpHeaderNames; import io.netty.handler.codec.http.HttpHeaders; import io.netty.handler.codec.http.HttpObjectAggregator; -import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame; -import io.netty.handler.codec.http.websocketx.TextWebSocketFrame; -import io.netty.handler.codec.http.websocketx.WebSocketClientProtocolHandler; -import io.netty.handler.codec.http.websocketx.WebSocketVersion; +import io.netty.handler.codec.http.websocketx.*; +import io.netty.util.concurrent.GenericFutureListener; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -35,6 +32,7 @@ import javax.annotation.Nullable; import java.net.URI; import java.nio.ByteBuffer; import java.util.concurrent.Future; +import java.util.concurrent.atomic.AtomicInteger; /** * Provides functionality to verify and connect to a remote websocket. @@ -57,6 +55,9 @@ public class Websocket extends Resource implements WebsocketClient { private final HttpHeaders headers; private final int timeout; + private final AtomicInteger inFlight = new AtomicInteger(0); + private final GenericFutureListener> onSend = f -> inFlight.decrementAndGet(); + public Websocket(ResourceGroup limiter, IAPIEnvironment environment, URI uri, String address, HttpHeaders headers, int timeout) { super(limiter); this.environment = environment; @@ -170,18 +171,27 @@ public class Websocket extends Resource implements WebsocketClient { } @Override - public void sendText(String message) { - environment.observe(Metrics.WEBSOCKET_OUTGOING, message.length()); - - var channel = channel(); - if (channel != null) channel.writeAndFlush(new TextWebSocketFrame(message)); + public void sendText(String message) throws LuaException { + sendMessage(new TextWebSocketFrame(message), message.length()); } @Override - public void sendBinary(ByteBuffer message) { - environment.observe(Metrics.WEBSOCKET_OUTGOING, message.remaining()); + public void sendBinary(ByteBuffer message) throws LuaException { + long size = message.remaining(); + sendMessage(new BinaryWebSocketFrame(Unpooled.wrappedBuffer(message)), size); + } + private void sendMessage(WebSocketFrame frame, long size) throws LuaException { var channel = channel(); - if (channel != null) channel.writeAndFlush(new BinaryWebSocketFrame(Unpooled.wrappedBuffer(message))); + if (channel == null) return; + + // Grow the number of in-flight requests, aborting if we've hit the limit. This is then decremented when the + // promise finishes. + if (!AtomicHelpers.incrementToLimit(inFlight, ResourceQueue.DEFAULT_LIMIT)) { + throw new LuaException("Too many ongoing websocket messages"); + } + + environment.observe(Metrics.WEBSOCKET_OUTGOING, size); + channel.writeAndFlush(frame).addListener(onSend); } } diff --git a/projects/core/src/main/java/dan200/computercraft/core/apis/http/websocket/WebsocketClient.java b/projects/core/src/main/java/dan200/computercraft/core/apis/http/websocket/WebsocketClient.java index 02cf3f43c..55c2a3e3a 100644 --- a/projects/core/src/main/java/dan200/computercraft/core/apis/http/websocket/WebsocketClient.java +++ b/projects/core/src/main/java/dan200/computercraft/core/apis/http/websocket/WebsocketClient.java @@ -4,6 +4,7 @@ package dan200.computercraft.core.apis.http.websocket; +import dan200.computercraft.api.lua.LuaException; import dan200.computercraft.core.apis.http.HTTPRequestException; import java.io.Closeable; @@ -39,15 +40,17 @@ public interface WebsocketClient extends Closeable { * Send a text websocket frame. * * @param message The message to send. + * @throws LuaException If the message could not be sent. */ - void sendText(String message); + void sendText(String message) throws LuaException; /** * Send a binary websocket frame. * * @param message The message to send. + * @throws LuaException If the message could not be sent. */ - void sendBinary(ByteBuffer message); + void sendBinary(ByteBuffer message) throws LuaException; /** * Parse an address, ensuring it is a valid websocket URI. diff --git a/projects/core/src/main/java/dan200/computercraft/core/util/AtomicHelpers.java b/projects/core/src/main/java/dan200/computercraft/core/util/AtomicHelpers.java new file mode 100644 index 000000000..02e745e50 --- /dev/null +++ b/projects/core/src/main/java/dan200/computercraft/core/util/AtomicHelpers.java @@ -0,0 +1,29 @@ +// SPDX-FileCopyrightText: 2023 The CC: Tweaked Developers +// +// SPDX-License-Identifier: MPL-2.0 + +package dan200.computercraft.core.util; + +import java.util.concurrent.atomic.AtomicInteger; + +public final class AtomicHelpers { + private AtomicHelpers() { + } + + /** + * A version of {@link AtomicInteger#getAndIncrement()}, which increments until a limit is reached. + * + * @param atomic The atomic to increment. + * @param limit The maximum value of {@code value}. + * @return Whether the value was sucessfully incremented. + */ + public static boolean incrementToLimit(AtomicInteger atomic, int limit) { + int value; + do { + value = atomic.get(); + if (value >= limit) return false; + } while (!atomic.compareAndSet(value, value + 1)); + + return true; + } +} diff --git a/projects/core/src/test/kotlin/dan200/computercraft/core/apis/http/TestHttpApi.kt b/projects/core/src/test/kotlin/dan200/computercraft/core/apis/http/TestHttpApi.kt index 8078da185..d47d67f44 100644 --- a/projects/core/src/test/kotlin/dan200/computercraft/core/apis/http/TestHttpApi.kt +++ b/projects/core/src/test/kotlin/dan200/computercraft/core/apis/http/TestHttpApi.kt @@ -89,6 +89,30 @@ class TestHttpApi { } } + @Test + fun `Errors if too many websocket messages are sent`() { + runServer { + LuaTaskRunner.runTest { + val httpApi = addApi(HTTPAPI(environment)) + assertThat("http.websocket succeeded", httpApi.websocket(ObjectArguments(WS_URL)), array(equalTo(true))) + + val connectEvent = pullEvent() + assertThat(connectEvent, array(equalTo("websocket_success"), equalTo(WS_URL), isA(WebsocketHandle::class.java))) + + val websocket = connectEvent[2] as WebsocketHandle + val error = assertThrows { + for (i in 0 until 10_000) { + websocket.send(Coerced(LuaValues.encode("Hello")), Optional.of(false)) + } + } + + websocket.close() + + assertThat(error.message, equalTo("Too many ongoing websocket messages")) + } + } + } + @Test fun `Queues an event when the socket is externally closed`() { runServer { stop ->