1
0
mirror of https://github.com/SquidDev-CC/CC-Tweaked synced 2024-11-05 09:36:19 +00:00

Add a couple of tests for HTTP

This commit is contained in:
Jonathan Coates 2023-06-02 20:57:45 +01:00
parent 9cca908bff
commit 12ca8583f4
No known key found for this signature in database
GPG Key ID: B9E431FF07C98D06
4 changed files with 228 additions and 73 deletions

View File

@ -0,0 +1,135 @@
// SPDX-FileCopyrightText: 2023 The CC: Tweaked Developers
//
// SPDX-License-Identifier: MPL-2.0
package dan200.computercraft.core.apis.http
import io.netty.bootstrap.ServerBootstrap
import io.netty.buffer.ByteBufUtil
import io.netty.buffer.Unpooled
import io.netty.channel.*
import io.netty.channel.nio.NioEventLoopGroup
import io.netty.channel.socket.SocketChannel
import io.netty.channel.socket.nio.NioServerSocketChannel
import io.netty.handler.codec.http.*
import io.netty.handler.codec.http.websocketx.TextWebSocketFrame
import io.netty.handler.codec.http.websocketx.WebSocketFrame
import io.netty.handler.codec.http.websocketx.WebSocketServerProtocolHandler
import io.netty.handler.codec.http.websocketx.WebSocketServerProtocolHandler.HandshakeComplete
import io.netty.handler.codec.http.websocketx.extensions.compression.WebSocketServerCompressionHandler
import java.nio.charset.StandardCharsets
/**
* Runs a small HTTP server to run alongside [TestHttpApi]
*/
object HttpServer {
const val PORT: Int = 8378
const val URL: String = "http://127.0.0.1:$PORT"
const val WS_URL: String = "ws://127.0.0.1:$PORT/ws"
fun runServer(run: () -> Unit) {
val workerGroup: EventLoopGroup = NioEventLoopGroup(2)
try {
val ch = ServerBootstrap()
.group(workerGroup)
.channel(NioServerSocketChannel::class.java)
.childHandler(
object : ChannelInitializer<SocketChannel>() {
override fun initChannel(ch: SocketChannel) {
val p: ChannelPipeline = ch.pipeline()
p.addLast(HttpServerCodec())
p.addLast(HttpContentCompressor())
p.addLast(HttpObjectAggregator(8192))
p.addLast(HttpServerHandler())
p.addLast(WebSocketServerCompressionHandler())
p.addLast(WebSocketServerProtocolHandler("/ws", null, true))
p.addLast(WebSocketFrameHandler())
}
},
).bind(PORT).sync().channel()
try {
run()
} finally {
ch.close().sync()
}
} finally {
workerGroup.shutdownGracefully()
}
}
}
/**
* A HTTP handler which hosts `/` (a simple static page) and `/ws` (see [WebSocketFrameHandler])
*/
private class HttpServerHandler : SimpleChannelInboundHandler<FullHttpRequest>() {
companion object {
private val CONTENT = "Hello, world!".toByteArray(StandardCharsets.UTF_8)
}
override fun channelReadComplete(ctx: ChannelHandlerContext) {
ctx.flush()
}
public override fun channelRead0(ctx: ChannelHandlerContext, request: FullHttpRequest) {
when (request.uri()) {
"/", "/index.html" -> handleIndex(ctx, request)
"/ws" -> handleWebsocket(ctx, request)
else -> sendHttpResponse(ctx, request, DefaultFullHttpResponse(request.protocolVersion(), HttpResponseStatus.NOT_FOUND))
}
}
private fun handleIndex(ctx: ChannelHandlerContext, request: FullHttpRequest) {
sendHttpResponse(
ctx,
request,
DefaultFullHttpResponse(request.protocolVersion(), HttpResponseStatus.OK, Unpooled.wrappedBuffer(CONTENT)),
)
}
private fun handleWebsocket(ctx: ChannelHandlerContext, request: FullHttpRequest) {
if (!request.headers().contains(HttpHeaderNames.UPGRADE, HttpHeaderValues.WEBSOCKET, true)) {
return sendHttpResponse(ctx, request, DefaultFullHttpResponse(request.protocolVersion(), HttpResponseStatus.BAD_REQUEST))
}
ctx.fireChannelRead(request.retain())
}
private fun sendHttpResponse(ctx: ChannelHandlerContext, request: FullHttpRequest, response: FullHttpResponse) {
// Generate an error page if response getStatus code is not OK (200).
val responseStatus = response.status()
if (responseStatus.code() != 200) {
ByteBufUtil.writeUtf8(response.content(), responseStatus.toString())
HttpUtil.setContentLength(response, response.content().readableBytes().toLong())
}
// Send the response and close the connection if necessary.
val keepAlive = HttpUtil.isKeepAlive(request) && responseStatus.code() == 200
HttpUtil.setKeepAlive(response, keepAlive)
val future = ctx.writeAndFlush(response)
if (!keepAlive) future.addListener(ChannelFutureListener.CLOSE)
}
}
/**
* A basic WS server which just sends back the original message.
*/
private class WebSocketFrameHandler : SimpleChannelInboundHandler<WebSocketFrame>() {
override fun channelRead0(ctx: ChannelHandlerContext, frame: WebSocketFrame) {
if (frame is TextWebSocketFrame) {
// Send the uppercase string back.
val request = frame.text()
ctx.channel().writeAndFlush(TextWebSocketFrame(request.uppercase()))
} else {
throw UnsupportedOperationException("unsupported frame type: ${frame.javaClass.name}")
}
}
override fun userEventTriggered(ctx: ChannelHandlerContext, evt: Any) {
if (evt is HandshakeComplete) {
// Channel upgrade to websocket, remove WebSocketIndexPageHandler.
ctx.pipeline().remove(HttpServerHandler::class.java)
} else {
super.userEventTriggered(ctx, evt)
}
}
}

View File

@ -0,0 +1,91 @@
// SPDX-FileCopyrightText: 2022 The CC: Tweaked Developers
//
// SPDX-License-Identifier: MPL-2.0
package dan200.computercraft.core.apis.http
import dan200.computercraft.api.lua.Coerced
import dan200.computercraft.api.lua.ObjectArguments
import dan200.computercraft.core.CoreConfig
import dan200.computercraft.core.apis.HTTPAPI
import dan200.computercraft.core.apis.handles.EncodedReadableHandle
import dan200.computercraft.core.apis.http.HttpServer.URL
import dan200.computercraft.core.apis.http.HttpServer.WS_URL
import dan200.computercraft.core.apis.http.HttpServer.runServer
import dan200.computercraft.core.apis.http.options.Action
import dan200.computercraft.core.apis.http.options.AddressRule
import dan200.computercraft.core.apis.http.request.HttpResponseHandle
import dan200.computercraft.core.apis.http.websocket.WebsocketHandle
import dan200.computercraft.test.core.computer.LuaTaskRunner
import org.hamcrest.MatcherAssert.assertThat
import org.hamcrest.Matchers.*
import org.junit.jupiter.api.AfterAll
import org.junit.jupiter.api.BeforeAll
import org.junit.jupiter.api.Test
import java.util.*
class TestHttpApi {
companion object {
@JvmStatic
@BeforeAll
fun before() {
CoreConfig.httpRules = listOf(AddressRule.parse("*", OptionalInt.empty(), Action.ALLOW.toPartial()))
}
@JvmStatic
@AfterAll
fun after() {
CoreConfig.httpRules = Collections.unmodifiableList(
listOf(
AddressRule.parse("\$private", OptionalInt.empty(), Action.DENY.toPartial()),
AddressRule.parse("*", OptionalInt.empty(), Action.ALLOW.toPartial()),
),
)
}
}
@Test
fun `Connects to a HTTP server`() {
runServer {
LuaTaskRunner.runTest {
val httpApi = addApi(HTTPAPI(environment))
assertThat("http.request succeeded", httpApi.request(ObjectArguments(URL)), array(equalTo(true)))
val result = pullEvent("http_success")
assertThat(result, array(equalTo("http_success"), equalTo(URL), isA(HttpResponseHandle::class.java)))
val handle = result[2] as HttpResponseHandle
val reader = handle.extra.iterator().next() as EncodedReadableHandle
assertThat(reader.readAll(), array(equalTo("Hello, world!")))
}
}
}
@Test
fun `Connects to websocket`() {
runServer {
LuaTaskRunner.runTest {
val httpApi = addApi(HTTPAPI(environment))
assertThat("http.websocket succeeded", httpApi.websocket(ObjectArguments(WS_URL)), array(equalTo(true)))
val connectEvent = pullEvent()
assertThat(connectEvent, array(equalTo("websocket_success"), equalTo(WS_URL), isA(WebsocketHandle::class.java)))
val websocket = connectEvent[2] as WebsocketHandle
websocket.send(Coerced("Hello"), Optional.of(false))
val message = websocket.receive(Optional.empty()).await()
assertThat("Received a return message", message, array(equalTo("HELLO"), equalTo(false)))
websocket.close()
val closeEvent = pullEvent("websocket_closed")
assertThat(
"Websocket was closed",
closeEvent,
array(equalTo("websocket_closed"), equalTo(WS_URL), equalTo("Connection closed"), equalTo(null)),
)
}
}
}
}

View File

@ -1,58 +0,0 @@
// SPDX-FileCopyrightText: 2022 The CC: Tweaked Developers
//
// SPDX-License-Identifier: MPL-2.0
package http
import dan200.computercraft.api.lua.ObjectArguments
import dan200.computercraft.core.CoreConfig
import dan200.computercraft.core.apis.HTTPAPI
import dan200.computercraft.core.apis.http.options.Action
import dan200.computercraft.core.apis.http.options.AddressRule
import dan200.computercraft.test.core.computer.LuaTaskRunner
import org.junit.jupiter.api.AfterAll
import org.junit.jupiter.api.Assertions.assertArrayEquals
import org.junit.jupiter.api.Assertions.assertEquals
import org.junit.jupiter.api.BeforeAll
import org.junit.jupiter.api.Disabled
import org.junit.jupiter.api.Test
import java.util.*
@Disabled("Requires some setup locally.")
class TestHttpApi {
companion object {
private const val WS_ADDRESS = "ws://127.0.0.1:8080"
@JvmStatic
@BeforeAll
fun before() {
CoreConfig.httpRules = listOf(AddressRule.parse("*", OptionalInt.empty(), Action.ALLOW.toPartial()))
}
@JvmStatic
@AfterAll
fun after() {
CoreConfig.httpRules = Collections.unmodifiableList(
listOf(
AddressRule.parse("\$private", OptionalInt.empty(), Action.DENY.toPartial()),
AddressRule.parse("*", OptionalInt.empty(), Action.ALLOW.toPartial()),
),
)
}
}
@Test
fun `Connects to websocket`() {
LuaTaskRunner.runTest {
val httpApi = addApi(HTTPAPI(environment))
val result = httpApi.websocket(ObjectArguments(WS_ADDRESS))
assertArrayEquals(arrayOf(true), result, "Should have created websocket")
val event = pullEvent()
assertEquals("websocket_success", event[0]) {
"Websocket failed to connect: ${event.contentToString()}"
}
}
}
}

View File

@ -10,7 +10,6 @@ import dan200.computercraft.api.lua.LuaException
import dan200.computercraft.core.apis.IAPIEnvironment import dan200.computercraft.core.apis.IAPIEnvironment
import dan200.computercraft.test.core.apis.BasicApiEnvironment import dan200.computercraft.test.core.apis.BasicApiEnvironment
import kotlinx.coroutines.channels.Channel import kotlinx.coroutines.channels.Channel
import kotlinx.coroutines.launch
import kotlinx.coroutines.runBlocking import kotlinx.coroutines.runBlocking
import kotlinx.coroutines.withTimeout import kotlinx.coroutines.withTimeout
import kotlin.time.Duration import kotlin.time.Duration
@ -21,11 +20,7 @@ class LuaTaskRunner : AbstractLuaTaskContext() {
private val apis = mutableListOf<ILuaAPI>() private val apis = mutableListOf<ILuaAPI>()
val environment: IAPIEnvironment = object : BasicApiEnvironment(BasicEnvironment()) { val environment: IAPIEnvironment = object : BasicApiEnvironment(BasicEnvironment()) {
override fun queueEvent(event: String?, vararg args: Any?) { override fun queueEvent(event: String?, vararg args: Any?) = this@LuaTaskRunner.queueEvent(event, args)
if (eventStream.trySend(Event(event, args)).isFailure) {
throw IllegalStateException("Queue is full")
}
}
override fun shutdown() { override fun shutdown() {
super.shutdown() super.shutdown()
@ -46,21 +41,13 @@ class LuaTaskRunner : AbstractLuaTaskContext() {
environment.shutdown() environment.shutdown()
} }
private suspend fun run() {
for (event in eventStream) {
queueEvent(event.name, event.args)
}
}
private class Event(val name: String?, val args: Array<out Any?>) private class Event(val name: String?, val args: Array<out Any?>)
companion object { companion object {
fun runTest(timeout: Duration = 5.seconds, fn: suspend LuaTaskRunner.() -> Unit) { fun runTest(timeout: Duration = 5.seconds, fn: suspend LuaTaskRunner.() -> Unit) {
runBlocking { runBlocking {
withTimeout(timeout) { withTimeout(timeout) {
val runner = LuaTaskRunner() LuaTaskRunner().use { fn(it) }
launch { runner.run() }
runner.use { fn(runner) }
} }
} }
} }