Reduce coupling in websocket code

- Add a new WebsocketClient interface, which WebsocketHandle uses for
   sending messages and closing. This reduces coupling between Websocket
   and WebsocketHandle, which is nice, though admitedly only use for
   copy-cat :).

 - WebsocketHandle now uses Websocket(Client).isClosed(), rather than
   tracking the closed state itself - this makes the class mostly a thin
   Lua wrapper over the client, which is nice.

 - Convert Options into a record.

 - Clarify the behaviour of ws.close() and the websocket_closed event.
   Our previous test was incorrect as it called WebsocketHandle.close
   (rather than WebsocketHandle.doClose), which had slightly different
   semantics in whether the event is queued.
This commit is contained in:
Jonathan Coates 2023-09-21 18:59:15 +01:00
parent 3188197447
commit ae71eb3cae
No known key found for this signature in database
GPG Key ID: B9E431FF07C98D06
15 changed files with 207 additions and 133 deletions

View File

@ -12,6 +12,7 @@
import dan200.computercraft.core.apis.http.*;
import dan200.computercraft.core.apis.http.request.HttpRequest;
import dan200.computercraft.core.apis.http.websocket.Websocket;
import dan200.computercraft.core.apis.http.websocket.WebsocketClient;
import io.netty.handler.codec.http.DefaultHttpHeaders;
import io.netty.handler.codec.http.HttpHeaderNames;
import io.netty.handler.codec.http.HttpHeaders;
@ -165,7 +166,7 @@ public final Object[] websocket(IArguments args) throws LuaException {
var timeout = getTimeout(timeoutArg);
try {
var uri = Websocket.checkUri(address);
var uri = WebsocketClient.parseUri(address);
if (!new Websocket(websockets, apiEnvironment, uri, address, headers, timeout).queue(Websocket::connect)) {
throw new LuaException("Too many websockets already open");
}

View File

@ -134,7 +134,7 @@ public static InetSocketAddress getAddress(String host, int port, boolean ssl) t
*/
public static Options getOptions(String host, InetSocketAddress address) throws HTTPRequestException {
var options = AddressRule.apply(CoreConfig.httpRules, host, address);
if (options.action == Action.DENY) throw new HTTPRequestException("Domain not permitted");
if (options.action() == Action.DENY) throw new HTTPRequestException("Domain not permitted");
return options;
}
@ -150,7 +150,7 @@ public static Options getOptions(String host, InetSocketAddress address) throws
* @throws HTTPRequestException If a proxy is required but not configured correctly.
*/
public static @Nullable Consumer<SocketChannel> getProxyHandler(Options options, int timeout) throws HTTPRequestException {
if (!options.useProxy) return null;
if (!options.useProxy()) return null;
var type = CoreConfig.httpProxyType;
var host = CoreConfig.httpProxyHost;

View File

@ -6,20 +6,13 @@
/**
* Options about a specific domain.
* Options for a given HTTP request or websocket, which control its resource constraints.
*
* @param action Whether to {@link Action#ALLOW} or {@link Action#DENY} this request.
* @param maxUpload The maximum size of the HTTP request.
* @param maxDownload The maximum size of the HTTP response.
* @param websocketMessage The maximum size of a websocket message (outgoing and incoming).
* @param useProxy Whether to use the configured proxy.
*/
public final class Options {
public final Action action;
public final long maxUpload;
public final long maxDownload;
public final int websocketMessage;
public final boolean useProxy;
Options(Action action, long maxUpload, long maxDownload, int websocketMessage, boolean useProxy) {
this.action = action;
this.maxUpload = maxUpload;
this.maxDownload = maxDownload;
this.websocketMessage = websocketMessage;
this.useProxy = useProxy;
}
public record Options(Action action, long maxUpload, long maxDownload, int websocketMessage, boolean useProxy) {
}

View File

@ -34,7 +34,7 @@ public PartialOptions(@Nullable Action action, OptionalLong maxUpload, OptionalL
this.useProxy = useProxy;
}
Options toOptions() {
public Options toOptions() {
if (options != null) return options;
return options = new Options(

View File

@ -130,7 +130,7 @@ private void doRequest(URI uri, HttpMethod method) {
if (isClosed()) return;
var requestBody = getHeaderSize(headers) + postBuffer.capacity();
if (options.maxUpload != 0 && requestBody > options.maxUpload) {
if (options.maxUpload() != 0 && requestBody > options.maxUpload()) {
failure("Request body is too large");
return;
}

View File

@ -136,7 +136,7 @@ public void channelRead0(ChannelHandlerContext ctx, HttpObject message) {
var partial = content.content();
if (partial.isReadable()) {
// If we've read more than we're allowed to handle, abort as soon as possible.
if (options.maxDownload != 0 && responseBody.readableBytes() + partial.readableBytes() > options.maxDownload) {
if (options.maxDownload() != 0 && responseBody.readableBytes() + partial.readableBytes() > options.maxDownload()) {
closed = true;
ctx.close();

View File

@ -16,8 +16,8 @@
* A version of {@link WebSocketClientHandshaker13} which doesn't add the {@link HttpHeaderNames#ORIGIN} header to the
* original HTTP request.
*/
public class NoOriginWebSocketHandshaker extends WebSocketClientHandshaker13 {
public NoOriginWebSocketHandshaker(URI webSocketURL, WebSocketVersion version, String subprotocol, boolean allowExtensions, HttpHeaders customHeaders, int maxFramePayloadLength) {
class NoOriginWebSocketHandshaker extends WebSocketClientHandshaker13 {
NoOriginWebSocketHandshaker(URI webSocketURL, WebSocketVersion version, String subprotocol, boolean allowExtensions, HttpHeaders customHeaders, int maxFramePayloadLength) {
super(webSocketURL, version, subprotocol, allowExtensions, customHeaders, maxFramePayloadLength);
}

View File

@ -12,8 +12,9 @@
import dan200.computercraft.core.apis.http.Resource;
import dan200.computercraft.core.apis.http.ResourceGroup;
import dan200.computercraft.core.apis.http.options.Options;
import dan200.computercraft.core.util.IoUtil;
import dan200.computercraft.core.metrics.Metrics;
import io.netty.bootstrap.Bootstrap;
import io.netty.buffer.Unpooled;
import io.netty.channel.Channel;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelInitializer;
@ -23,21 +24,22 @@
import io.netty.handler.codec.http.HttpHeaderNames;
import io.netty.handler.codec.http.HttpHeaders;
import io.netty.handler.codec.http.HttpObjectAggregator;
import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame;
import io.netty.handler.codec.http.websocketx.TextWebSocketFrame;
import io.netty.handler.codec.http.websocketx.WebSocketClientProtocolHandler;
import io.netty.handler.codec.http.websocketx.WebSocketVersion;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import javax.annotation.Nullable;
import java.lang.ref.WeakReference;
import java.net.URI;
import java.net.URISyntaxException;
import java.nio.ByteBuffer;
import java.util.concurrent.Future;
/**
* Provides functionality to verify and connect to a remote websocket.
*/
public class Websocket extends Resource<Websocket> {
public class Websocket extends Resource<Websocket> implements WebsocketClient {
private static final Logger LOG = LoggerFactory.getLogger(Websocket.class);
/**
@ -46,14 +48,8 @@ public class Websocket extends Resource<Websocket> {
*/
public static final int MAX_MESSAGE_SIZE = 1 << 30;
static final String SUCCESS_EVENT = "websocket_success";
static final String FAILURE_EVENT = "websocket_failure";
static final String CLOSE_EVENT = "websocket_closed";
static final String MESSAGE_EVENT = "websocket_message";
private @Nullable Future<?> executorFuture;
private @Nullable ChannelFuture connectFuture;
private @Nullable WeakReference<WebsocketHandle> websocketHandle;
private @Nullable ChannelFuture channelFuture;
private final IAPIEnvironment environment;
private final URI uri;
@ -70,38 +66,6 @@ public Websocket(ResourceGroup<Websocket> limiter, IAPIEnvironment environment,
this.timeout = timeout;
}
public static URI checkUri(String address) throws HTTPRequestException {
URI uri = null;
try {
uri = new URI(address);
} catch (URISyntaxException ignored) {
// Fall through to the case below
}
if (uri == null || uri.getHost() == null) {
try {
uri = new URI("ws://" + address);
} catch (URISyntaxException ignored) {
// Fall through to the case below
}
}
if (uri == null || uri.getHost() == null) throw new HTTPRequestException("URL malformed");
var scheme = uri.getScheme();
if (scheme == null) {
try {
uri = new URI("ws://" + uri);
} catch (URISyntaxException e) {
throw new HTTPRequestException("URL malformed");
}
} else if (!scheme.equalsIgnoreCase("wss") && !scheme.equalsIgnoreCase("ws")) {
throw new HTTPRequestException("Invalid scheme '" + scheme + "'");
}
return uri;
}
public void connect() {
if (isClosed()) return;
executorFuture = NetworkUtils.EXECUTOR.submit(this::doConnect);
@ -122,7 +86,7 @@ private void doConnect() {
// getAddress may have a slight delay, so let's perform another cancellation check.
if (isClosed()) return;
connectFuture = new Bootstrap()
channelFuture = new Bootstrap()
.group(NetworkUtils.LOOP_GROUP)
.channel(NioSocketChannel.class)
.handler(new ChannelInitializer<SocketChannel>() {
@ -133,7 +97,7 @@ protected void initChannel(SocketChannel ch) {
var subprotocol = headers.get(HttpHeaderNames.SEC_WEBSOCKET_PROTOCOL);
var handshaker = new NoOriginWebSocketHandshaker(
uri, WebSocketVersion.V13, subprotocol, true, headers,
options.websocketMessage <= 0 ? MAX_MESSAGE_SIZE : options.websocketMessage
options.websocketMessage() <= 0 ? MAX_MESSAGE_SIZE : options.websocketMessage()
);
var p = ch.pipeline();
@ -162,12 +126,12 @@ protected void initChannel(SocketChannel ch) {
}
}
void success(Channel channel, Options options) {
void success(Options options) {
if (isClosed()) return;
var handle = new WebsocketHandle(this, options, channel);
var handle = new WebsocketHandle(environment, address, this, options);
environment().queueEvent(SUCCESS_EVENT, address, handle);
websocketHandle = createOwnerReference(handle);
createOwnerReference(handle);
checkClosed();
}
@ -189,19 +153,35 @@ protected void dispose() {
super.dispose();
executorFuture = closeFuture(executorFuture);
connectFuture = closeChannel(connectFuture);
var websocketHandleRef = websocketHandle;
var websocketHandle = websocketHandleRef == null ? null : websocketHandleRef.get();
IoUtil.closeQuietly(websocketHandle);
this.websocketHandle = null;
channelFuture = closeChannel(channelFuture);
}
public IAPIEnvironment environment() {
IAPIEnvironment environment() {
return environment;
}
public String address() {
String address() {
return address;
}
private @Nullable Channel channel() {
var channel = channelFuture;
return channel == null ? null : channel.channel();
}
@Override
public void sendText(String message) {
environment.observe(Metrics.WEBSOCKET_OUTGOING, message.length());
var channel = channel();
if (channel != null) channel.writeAndFlush(new TextWebSocketFrame(message));
}
@Override
public void sendBinary(ByteBuffer message) {
environment.observe(Metrics.WEBSOCKET_OUTGOING, message.remaining());
var channel = channel();
if (channel != null) channel.writeAndFlush(new BinaryWebSocketFrame(Unpooled.wrappedBuffer(message)));
}
}

View File

@ -0,0 +1,90 @@
// SPDX-FileCopyrightText: 2023 The CC: Tweaked Developers
//
// SPDX-License-Identifier: MPL-2.0
package dan200.computercraft.core.apis.http.websocket;
import dan200.computercraft.core.apis.http.HTTPRequestException;
import java.io.Closeable;
import java.net.URI;
import java.net.URISyntaxException;
import java.nio.ByteBuffer;
/**
* A client-side websocket, which can be used to send messages to a remote server.
* <p>
* {@link WebsocketHandle} wraps this into a Lua-compatible interface.
*/
public interface WebsocketClient extends Closeable {
String SUCCESS_EVENT = "websocket_success";
String FAILURE_EVENT = "websocket_failure";
String CLOSE_EVENT = "websocket_closed";
String MESSAGE_EVENT = "websocket_message";
/**
* Determine whether this websocket is closed.
*
* @return Whether this websocket is closed.
*/
boolean isClosed();
/**
* Close this websocket.
*/
@Override
void close();
/**
* Send a text websocket frame.
*
* @param message The message to send.
*/
void sendText(String message);
/**
* Send a binary websocket frame.
*
* @param message The message to send.
*/
void sendBinary(ByteBuffer message);
/**
* Parse an address, ensuring it is a valid websocket URI.
*
* @param address The address to parse.
* @return The parsed URI.
* @throws HTTPRequestException If the address is not valid.
*/
static URI parseUri(String address) throws HTTPRequestException {
URI uri = null;
try {
uri = new URI(address);
} catch (URISyntaxException ignored) {
// Fall through to the case below
}
if (uri == null || uri.getHost() == null) {
try {
uri = new URI("ws://" + address);
} catch (URISyntaxException ignored) {
// Fall through to the case below
}
}
if (uri == null || uri.getHost() == null) throw new HTTPRequestException("URL malformed");
var scheme = uri.getScheme();
if (scheme == null) {
try {
uri = new URI("ws://" + uri);
} catch (URISyntaxException e) {
throw new HTTPRequestException("URL malformed");
}
} else if (!scheme.equalsIgnoreCase("wss") && !scheme.equalsIgnoreCase("ws")) {
throw new HTTPRequestException("Invalid scheme '" + scheme + "'");
}
return uri;
}
}

View File

@ -6,40 +6,34 @@
import com.google.common.base.Objects;
import dan200.computercraft.api.lua.*;
import dan200.computercraft.core.apis.IAPIEnvironment;
import dan200.computercraft.core.apis.http.options.Options;
import dan200.computercraft.core.metrics.Metrics;
import io.netty.buffer.Unpooled;
import io.netty.channel.Channel;
import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame;
import io.netty.handler.codec.http.websocketx.TextWebSocketFrame;
import javax.annotation.Nullable;
import java.io.Closeable;
import java.util.Arrays;
import java.util.Optional;
import static dan200.computercraft.api.lua.LuaValues.checkFinite;
import static dan200.computercraft.core.apis.IAPIEnvironment.TIMER_EVENT;
import static dan200.computercraft.core.apis.http.websocket.Websocket.CLOSE_EVENT;
import static dan200.computercraft.core.apis.http.websocket.Websocket.MESSAGE_EVENT;
import static dan200.computercraft.core.apis.http.websocket.WebsocketClient.CLOSE_EVENT;
import static dan200.computercraft.core.apis.http.websocket.WebsocketClient.MESSAGE_EVENT;
/**
* A websocket, which can be used to send an receive messages with a web server.
* A websocket, which can be used to send and receive messages with a web server.
*
* @cc.module http.Websocket
* @see dan200.computercraft.core.apis.HTTPAPI#websocket On how to open a websocket.
*/
public class WebsocketHandle implements Closeable {
private final Websocket websocket;
public class WebsocketHandle {
private final IAPIEnvironment environment;
private final String address;
private final WebsocketClient websocket;
private final Options options;
private boolean closed = false;
private @Nullable Channel channel;
public WebsocketHandle(Websocket websocket, Options options, Channel channel) {
public WebsocketHandle(IAPIEnvironment environment, String address, WebsocketClient websocket, Options options) {
this.environment = environment;
this.address = address;
this.websocket = websocket;
this.options = options;
this.channel = channel;
}
/**
@ -58,7 +52,7 @@ public WebsocketHandle(Websocket websocket, Options options, Channel channel) {
public final MethodResult receive(Optional<Double> timeout) throws LuaException {
checkOpen();
var timeoutId = timeout.isPresent()
? websocket.environment().startTimer(Math.round(checkFinite(0, timeout.get()) / 0.05))
? environment.startTimer(Math.round(checkFinite(0, timeout.get()) / 0.05))
: -1;
return new ReceiveCallback(timeoutId).pull;
@ -78,17 +72,14 @@ public final void send(Coerced<String> message, Optional<Boolean> binary) throws
checkOpen();
var text = message.value();
if (options.websocketMessage != 0 && text.length() > options.websocketMessage) {
if (options.websocketMessage() != 0 && text.length() > options.websocketMessage()) {
throw new LuaException("Message is too large");
}
websocket.environment().observe(Metrics.WEBSOCKET_OUTGOING, text.length());
var channel = this.channel;
if (channel != null) {
channel.writeAndFlush(binary.orElse(false)
? new BinaryWebSocketFrame(Unpooled.wrappedBuffer(LuaValues.encode(text)))
: new TextWebSocketFrame(text));
if (binary.orElse(false)) {
websocket.sendBinary(LuaValues.encode(text));
} else {
websocket.sendText(text);
}
}
@ -96,25 +87,13 @@ public final void send(Coerced<String> message, Optional<Boolean> binary) throws
* Close this websocket. This will terminate the connection, meaning messages can no longer be sent or received
* along it.
*/
@LuaFunction("close")
public final void doClose() {
close();
@LuaFunction
public final void close() {
websocket.close();
}
private void checkOpen() throws LuaException {
if (closed) throw new LuaException("attempt to use a closed file");
}
@Override
public void close() {
closed = true;
var channel = this.channel;
if (channel != null) {
channel.close();
this.channel = null;
}
if (websocket.isClosed()) throw new LuaException("attempt to use a closed file");
}
private final class ReceiveCallback implements ILuaCallback {
@ -127,9 +106,9 @@ private final class ReceiveCallback implements ILuaCallback {
@Override
public MethodResult resume(Object[] event) {
if (event.length >= 3 && Objects.equal(event[0], MESSAGE_EVENT) && Objects.equal(event[1], websocket.address())) {
if (event.length >= 3 && Objects.equal(event[0], MESSAGE_EVENT) && Objects.equal(event[1], address)) {
return MethodResult.of(Arrays.copyOfRange(event, 2, event.length));
} else if (event.length >= 2 && Objects.equal(event[0], CLOSE_EVENT) && Objects.equal(event[1], websocket.address()) && closed) {
} else if (event.length >= 2 && Objects.equal(event[0], CLOSE_EVENT) && Objects.equal(event[1], address) && websocket.isClosed()) {
// If the socket is closed abort.
return MethodResult.of();
} else if (event.length >= 2 && timeoutId != -1 && Objects.equal(event[0], TIMER_EVENT)

View File

@ -13,14 +13,14 @@
import io.netty.handler.codec.http.websocketx.*;
import io.netty.util.CharsetUtil;
import static dan200.computercraft.core.apis.http.websocket.Websocket.MESSAGE_EVENT;
import static dan200.computercraft.core.apis.http.websocket.WebsocketClient.MESSAGE_EVENT;
public class WebsocketHandler extends SimpleChannelInboundHandler<Object> {
class WebsocketHandler extends SimpleChannelInboundHandler<Object> {
private final Websocket websocket;
private final Options options;
private boolean handshakeComplete = false;
public WebsocketHandler(Websocket websocket, Options options) {
WebsocketHandler(Websocket websocket, Options options) {
this.websocket = websocket;
this.options = options;
}
@ -32,9 +32,9 @@ public void channelInactive(ChannelHandlerContext ctx) throws Exception {
}
@Override
public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception {
public void userEventTriggered(ChannelHandlerContext ctx, Object evt) {
if (evt == WebSocketClientProtocolHandler.ClientHandshakeStateEvent.HANDSHAKE_COMPLETE) {
websocket.success(ctx.channel(), options);
websocket.success(options);
handshakeComplete = true;
} else if (evt == WebSocketClientProtocolHandler.ClientHandshakeStateEvent.HANDSHAKE_TIMEOUT) {
websocket.failure("Timed out");

View File

@ -23,8 +23,8 @@ public void matchesPort() {
Action.ALLOW.toPartial()
));
assertEquals(apply(rules, "localhost", 8080).action, Action.ALLOW);
assertEquals(apply(rules, "localhost", 8081).action, Action.DENY);
assertEquals(apply(rules, "localhost", 8080).action(), Action.ALLOW);
assertEquals(apply(rules, "localhost", 8081).action(), Action.DENY);
}
@ParameterizedTest
@ -43,7 +43,7 @@ public void matchesPort() {
"169.254.169.254", // AWS, Digital Ocean, GCP, etc..
})
public void blocksLocalDomains(String domain) {
assertEquals(apply(CoreConfig.httpRules, domain, 80).action, Action.DENY);
assertEquals(apply(CoreConfig.httpRules, domain, 80).action(), Action.DENY);
}
@ParameterizedTest
@ -52,7 +52,7 @@ public void blocksLocalDomains(String domain) {
"100.63.255.255", "100.128.0.0"
})
public void allowsNonLocalDomains(String domain) {
assertEquals(apply(CoreConfig.httpRules, domain, 80).action, Action.ALLOW);
assertEquals(apply(CoreConfig.httpRules, domain, 80).action(), Action.ALLOW);
}
private Options apply(Iterable<AddressRule> rules, String host, int port) {

View File

@ -27,7 +27,7 @@
const val URL: String = "http://127.0.0.1:$PORT"
const val WS_URL: String = "ws://127.0.0.1:$PORT/ws"
fun runServer(run: () -> Unit) {
fun runServer(run: (stop: () -> Unit) -> Unit) {
val workerGroup: EventLoopGroup = NioEventLoopGroup(2)
try {
val ch = ServerBootstrap()
@ -48,7 +48,7 @@ override fun initChannel(ch: SocketChannel) {
},
).bind(PORT).sync().channel()
try {
run()
run { workerGroup.shutdownGracefully() }
} finally {
ch.close().sync()
}

View File

@ -5,6 +5,7 @@
package dan200.computercraft.core.apis.http
import dan200.computercraft.api.lua.Coerced
import dan200.computercraft.api.lua.LuaException
import dan200.computercraft.api.lua.ObjectArguments
import dan200.computercraft.core.CoreConfig
import dan200.computercraft.core.apis.HTTPAPI
@ -22,7 +23,9 @@
import org.junit.jupiter.api.AfterAll
import org.junit.jupiter.api.BeforeAll
import org.junit.jupiter.api.Test
import org.junit.jupiter.api.assertThrows
import java.util.*
import kotlin.time.Duration.Companion.milliseconds
class TestHttpApi {
companion object {
@ -79,12 +82,36 @@ fun after() {
websocket.close()
val closeEvent = pullEventOrTimeout(500.milliseconds, "websocket_closed")
assertThat("No event was queued", closeEvent, equalTo(null))
}
}
}
@Test
fun `Queues an event when the socket is externally closed`() {
runServer { stop ->
LuaTaskRunner.runTest {
val httpApi = addApi(HTTPAPI(environment))
assertThat("http.websocket succeeded", httpApi.websocket(ObjectArguments(WS_URL)), array(equalTo(true)))
val connectEvent = pullEvent()
assertThat(connectEvent, array(equalTo("websocket_success"), equalTo(WS_URL), isA(WebsocketHandle::class.java)))
val websocket = connectEvent[2] as WebsocketHandle
stop()
val closeEvent = pullEvent("websocket_closed")
assertThat(
"Websocket was closed",
closeEvent,
array(equalTo("websocket_closed"), equalTo(WS_URL), equalTo("Connection closed"), equalTo(null)),
)
assertThrows<LuaException>("Throws an exception when sending") {
websocket.send(Coerced("hello"), Optional.of(false))
}
}
}
}

View File

@ -12,6 +12,7 @@
import dan200.computercraft.core.apis.PeripheralAPI
import kotlinx.coroutines.CancellableContinuation
import kotlinx.coroutines.suspendCancellableCoroutine
import kotlinx.coroutines.withTimeoutOrNull
import kotlin.time.Duration
/**
@ -29,6 +30,9 @@ interface LuaTaskContext {
/** Pull a Lua event */
suspend fun pullEvent(event: String? = null): Array<out Any?>
suspend fun pullEventOrTimeout(timeout: Duration, event: String? = null): Array<out Any?>? =
withTimeoutOrNull(timeout) { pullEvent(event) }
/** Resolve a [MethodResult] until completion, returning the resulting values. */
suspend fun MethodResult.await(): Array<out Any?>? {
var result = this