diff --git a/.gitignore b/.gitignore index 96b40ce93..675c66cf6 100644 --- a/.gitignore +++ b/.gitignore @@ -9,6 +9,7 @@ /projects/*/logs /projects/fabric/fabricloader.log /projects/*/build +/projects/*/src/test/generated_tests/ /buildSrc/build /out /buildSrc/out diff --git a/gradle/libs.versions.toml b/gradle/libs.versions.toml index ad95ddf27..0a865a702 100644 --- a/gradle/libs.versions.toml +++ b/gradle/libs.versions.toml @@ -51,6 +51,7 @@ sodium = "mc1.20-0.4.10" hamcrest = "2.2" jqwik = "1.8.2" junit = "5.10.1" +jmh = "1.37" # Build tools cctJavadoc = "1.8.2" @@ -127,6 +128,8 @@ junit-jupiter-api = { module = "org.junit.jupiter:junit-jupiter-api", version.re junit-jupiter-engine = { module = "org.junit.jupiter:junit-jupiter-engine", version.ref = "junit" } junit-jupiter-params = { module = "org.junit.jupiter:junit-jupiter-params", version.ref = "junit" } slf4j-simple = { module = "org.slf4j:slf4j-simple", version.ref = "slf4j" } +jmh = { module = "org.openjdk.jmh:jmh-core", version.ref = "jmh" } +jmh-processor = { module = "org.openjdk.jmh:jmh-generator-annprocess", version.ref = "jmh" } # LWJGL lwjgl-bom = { module = "org.lwjgl:lwjgl-bom", version.ref = "lwjgl" } diff --git a/projects/common/build.gradle.kts b/projects/common/build.gradle.kts index 2b4810975..2cf545c63 100644 --- a/projects/common/build.gradle.kts +++ b/projects/common/build.gradle.kts @@ -46,6 +46,9 @@ dependencies { testImplementation(libs.bundles.test) testRuntimeOnly(libs.bundles.testRuntime) + testImplementation(libs.jmh) + testAnnotationProcessor(libs.jmh.processor) + testModCompileOnly(libs.mixin) testModImplementation(testFixtures(project(":core"))) testModImplementation(testFixtures(project(":common"))) diff --git a/projects/common/src/test/java/dan200/computercraft/impl/network/wired/NetworkBenchmark.java b/projects/common/src/test/java/dan200/computercraft/impl/network/wired/NetworkBenchmark.java new file mode 100644 index 000000000..ad3e02040 --- /dev/null +++ b/projects/common/src/test/java/dan200/computercraft/impl/network/wired/NetworkBenchmark.java @@ -0,0 +1,183 @@ +// SPDX-FileCopyrightText: 2024 The CC: Tweaked Developers +// +// SPDX-License-Identifier: MPL-2.0 + +package dan200.computercraft.impl.network.wired; + +import dan200.computercraft.api.network.wired.WiredNetwork; +import dan200.computercraft.impl.network.wired.NetworkTest.NetworkElement; +import dan200.computercraft.shared.util.DirectionUtil; +import it.unimi.dsi.fastutil.objects.Object2IntMap; +import it.unimi.dsi.fastutil.objects.Object2IntOpenHashMap; +import net.minecraft.core.BlockPos; +import org.openjdk.jmh.annotations.*; +import org.openjdk.jmh.runner.Runner; +import org.openjdk.jmh.runner.RunnerException; +import org.openjdk.jmh.runner.options.OptionsBuilder; + +import java.util.Objects; +import java.util.concurrent.TimeUnit; +import java.util.function.BiConsumer; +import java.util.function.BiFunction; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotEquals; + +@BenchmarkMode(Mode.AverageTime) +@OutputTimeUnit(TimeUnit.MILLISECONDS) +public class NetworkBenchmark { + private static final int BRUTE_SIZE = 16; + + public static void main(String[] args) throws RunnerException { + var opts = new OptionsBuilder() + .include(NetworkBenchmark.class.getName() + "\\..*") + .warmupIterations(2) + .measurementIterations(5) + .forks(1) + .build(); + new Runner(opts).run(); + } + + @Benchmark + @Warmup(time = 1, timeUnit = TimeUnit.SECONDS) + @Measurement(time = 2, timeUnit = TimeUnit.SECONDS) + public void removeEveryNode(ConnectedGrid grid) { + grid.grid.forEach((node, pos) -> node.remove()); + } + + @Benchmark + @Warmup(time = 1, timeUnit = TimeUnit.SECONDS) + @Measurement(time = 2, timeUnit = TimeUnit.SECONDS) + public void connectAndDisconnect(SplitGrid connectedGrid) { + WiredNodeImpl left = connectedGrid.left, right = connectedGrid.right; + + assertNotEquals(left.getNetwork(), right.getNetwork()); + left.connectTo(right); + assertEquals(left.getNetwork(), right.getNetwork()); + left.disconnectFrom(right); + assertNotEquals(left.getNetwork(), right.getNetwork()); + } + + @Benchmark + @Warmup(time = 1, timeUnit = TimeUnit.SECONDS) + @Measurement(time = 2, timeUnit = TimeUnit.SECONDS) + public void connectAndRemove(SplitGrid connectedGrid) { + WiredNodeImpl left = connectedGrid.left, right = connectedGrid.right, centre = connectedGrid.centre; + + assertNotEquals(left.getNetwork(), right.getNetwork()); + centre.connectTo(left); + centre.connectTo(right); + assertEquals(left.getNetwork(), right.getNetwork()); + centre.remove(); + assertNotEquals(left.getNetwork(), right.getNetwork()); + } + + /** + * Create a grid where all nodes are connected to their neighbours. + */ + @State(Scope.Thread) + public static class ConnectedGrid { + Grid grid; + + @Setup(Level.Invocation) + public void setup() { + var grid = this.grid = new Grid<>(BRUTE_SIZE); + grid.map((existing, pos) -> new NetworkElement("n_" + pos, pos.getX() == pos.getY() && pos.getY() == pos.getZ()).getNode()); + + // Connect every node + grid.forEach((node, pos) -> { + for (var facing : DirectionUtil.FACINGS) { + var other = grid.get(pos.relative(facing)); + if (other != null) node.connectTo(other); + } + }); + + var networks = countNetworks(grid); + if (networks.size() != 1) throw new AssertionError("Expected exactly one network."); + } + } + + /** + * Create a grid where the nodes at {@code x < BRUTE_SIZE/2} and {@code x >= BRUTE_SIZE/2} are in separate networks, + * but otherwise connected to their neighbours. + */ + @State(Scope.Thread) + public static class SplitGrid { + Grid grid; + WiredNodeImpl left, right, centre; + + @Setup + public void setup() { + var grid = this.grid = new Grid<>(BRUTE_SIZE); + grid.map((existing, pos) -> new NetworkElement("n_" + pos, pos.getX() == pos.getY() && pos.getY() == pos.getZ()).getNode()); + + // Connect every node + grid.forEach((node, pos) -> { + for (var facing : DirectionUtil.FACINGS) { + var offset = pos.relative(facing); + if (offset.getX() >= BRUTE_SIZE / 2 == pos.getX() >= BRUTE_SIZE / 2) { + var other = grid.get(offset); + if (other != null) node.connectTo(other); + } + } + }); + + var networks = countNetworks(grid); + if (networks.size() != 2) throw new AssertionError("Expected exactly two networks."); + for (var network : networks.object2IntEntrySet()) { + if (network.getIntValue() != BRUTE_SIZE * BRUTE_SIZE * (BRUTE_SIZE / 2)) { + throw new AssertionError("Network is the wrong size"); + } + } + + left = Objects.requireNonNull(grid.get(new BlockPos(BRUTE_SIZE / 2 - 1, 0, 0))); + right = Objects.requireNonNull(grid.get(new BlockPos(BRUTE_SIZE / 2, 0, 0))); + centre = new NetworkElement("c", false).getNode(); + } + } + + private static Object2IntMap countNetworks(Grid grid) { + Object2IntMap networks = new Object2IntOpenHashMap<>(); + grid.forEach((node, pos) -> networks.put(node.network, networks.getOrDefault(node.network, 0) + 1)); + return networks; + } + + private static class Grid { + private final int size; + private final T[] box; + + @SuppressWarnings("unchecked") + Grid(int size) { + this.size = size; + this.box = (T[]) new Object[size * size * size]; + } + + public T get(BlockPos pos) { + int x = pos.getX(), y = pos.getY(), z = pos.getZ(); + + return x >= 0 && x < size && y >= 0 && y < size && z >= 0 && z < size + ? box[x * size * size + y * size + z] + : null; + } + + public void forEach(BiConsumer transform) { + for (var x = 0; x < size; x++) { + for (var y = 0; y < size; y++) { + for (var z = 0; z < size; z++) { + transform.accept(box[x * size * size + y * size + z], new BlockPos(x, y, z)); + } + } + } + } + + public void map(BiFunction transform) { + for (var x = 0; x < size; x++) { + for (var y = 0; y < size; y++) { + for (var z = 0; z < size; z++) { + box[x * size * size + y * size + z] = transform.apply(box[x * size * size + y * size + z], new BlockPos(x, y, z)); + } + } + } + } + } +} diff --git a/projects/common/src/test/java/dan200/computercraft/impl/network/wired/NetworkTest.java b/projects/common/src/test/java/dan200/computercraft/impl/network/wired/NetworkTest.java index 090a5d1b1..0c76f697e 100644 --- a/projects/common/src/test/java/dan200/computercraft/impl/network/wired/NetworkTest.java +++ b/projects/common/src/test/java/dan200/computercraft/impl/network/wired/NetworkTest.java @@ -9,19 +9,14 @@ import dan200.computercraft.api.network.wired.WiredNetworkChange; import dan200.computercraft.api.network.wired.WiredNode; import dan200.computercraft.api.peripheral.IPeripheral; -import dan200.computercraft.shared.util.DirectionUtil; -import net.minecraft.core.BlockPos; import net.minecraft.world.level.Level; import net.minecraft.world.phys.Vec3; -import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Test; import javax.annotation.Nullable; import java.util.HashMap; import java.util.Map; import java.util.Set; -import java.util.function.BiConsumer; -import java.util.function.BiFunction; import static org.junit.jupiter.api.Assertions.*; @@ -29,11 +24,11 @@ public class NetworkTest { @Test public void testConnect() { NetworkElement - aE = new NetworkElement(null, null, "a"), - bE = new NetworkElement(null, null, "b"), - cE = new NetworkElement(null, null, "c"); + aE = new NetworkElement("a"), + bE = new NetworkElement("b"), + cE = new NetworkElement("c"); - WiredNode + WiredNodeImpl aN = aE.getNode(), bN = bE.getNode(), cN = cE.getNode(); @@ -42,8 +37,8 @@ public void testConnect() { assertNotEquals(aN.getNetwork(), cN.getNetwork(), "A's and C's network must be different"); assertNotEquals(bN.getNetwork(), cN.getNetwork(), "B's and C's network must be different"); - assertTrue(aN.getNetwork().connect(aN, bN), "Must be able to add connection"); - assertFalse(aN.getNetwork().connect(aN, bN), "Cannot add connection twice"); + assertTrue(aN.connectTo(bN), "Must be able to add connection"); + assertFalse(aN.connectTo(bN), "Cannot add connection twice"); assertEquals(aN.getNetwork(), bN.getNetwork(), "A's and B's network must be equal"); assertEquals(Set.of(aN, bN), nodes(aN.getNetwork()), "A's network should be A and B"); @@ -51,7 +46,7 @@ public void testConnect() { assertEquals(Set.of("a", "b"), aE.allPeripherals().keySet(), "A's peripheral set should be A, B"); assertEquals(Set.of("a", "b"), bE.allPeripherals().keySet(), "B's peripheral set should be A, B"); - aN.getNetwork().connect(aN, cN); + aN.connectTo(cN); assertEquals(aN.getNetwork(), bN.getNetwork(), "A's and B's network must be equal"); assertEquals(aN.getNetwork(), cN.getNetwork(), "A's and C's network must be equal"); @@ -69,20 +64,20 @@ public void testConnect() { @Test public void testDisconnectNoChange() { NetworkElement - aE = new NetworkElement(null, null, "a"), - bE = new NetworkElement(null, null, "b"), - cE = new NetworkElement(null, null, "c"); + aE = new NetworkElement("a"), + bE = new NetworkElement("b"), + cE = new NetworkElement("c"); - WiredNode + WiredNodeImpl aN = aE.getNode(), bN = bE.getNode(), cN = cE.getNode(); - aN.getNetwork().connect(aN, bN); - aN.getNetwork().connect(aN, cN); - aN.getNetwork().connect(bN, cN); + aN.connectTo(bN); + aN.connectTo(cN); + bN.connectTo(cN); - aN.getNetwork().disconnect(aN, bN); + aN.disconnectFrom(bN); assertEquals(aN.getNetwork(), bN.getNetwork(), "A's and B's network must be equal"); assertEquals(aN.getNetwork(), cN.getNetwork(), "A's and C's network must be equal"); @@ -96,19 +91,19 @@ public void testDisconnectNoChange() { @Test public void testDisconnectLeaf() { NetworkElement - aE = new NetworkElement(null, null, "a"), - bE = new NetworkElement(null, null, "b"), - cE = new NetworkElement(null, null, "c"); + aE = new NetworkElement("a"), + bE = new NetworkElement("b"), + cE = new NetworkElement("c"); - WiredNode + WiredNodeImpl aN = aE.getNode(), bN = bE.getNode(), cN = cE.getNode(); - aN.getNetwork().connect(aN, bN); - aN.getNetwork().connect(aN, cN); + aN.connectTo(bN); + aN.connectTo(cN); - aN.getNetwork().disconnect(aN, bN); + aN.disconnectFrom(bN); assertNotEquals(aN.getNetwork(), bN.getNetwork(), "A's and B's network must not be equal"); assertEquals(aN.getNetwork(), cN.getNetwork(), "A's and C's network must be equal"); @@ -123,23 +118,23 @@ public void testDisconnectLeaf() { @Test public void testDisconnectSplit() { NetworkElement - aE = new NetworkElement(null, null, "a"), - aaE = new NetworkElement(null, null, "a_"), - bE = new NetworkElement(null, null, "b"), - bbE = new NetworkElement(null, null, "b_"); + aE = new NetworkElement("a"), + aaE = new NetworkElement("a_"), + bE = new NetworkElement("b"), + bbE = new NetworkElement("b_"); - WiredNode + WiredNodeImpl aN = aE.getNode(), aaN = aaE.getNode(), bN = bE.getNode(), bbN = bbE.getNode(); - aN.getNetwork().connect(aN, aaN); - bN.getNetwork().connect(bN, bbN); + aN.connectTo(aaN); + bN.connectTo(bbN); - aN.getNetwork().connect(aN, bN); + aN.connectTo(bN); - aN.getNetwork().disconnect(aN, bN); + aN.disconnectFrom(bN); assertNotEquals(aN.getNetwork(), bN.getNetwork(), "A's and B's network must not be equal"); assertEquals(aN.getNetwork(), aaN.getNetwork(), "A's and A_'s network must be equal"); @@ -154,7 +149,7 @@ public void testDisconnectSplit() { @Test public void testRemoveSingle() { - var aE = new NetworkElement(null, null, "a"); + var aE = new NetworkElement("a"); var aN = aE.getNode(); var network = aN.getNetwork(); @@ -165,20 +160,20 @@ public void testRemoveSingle() { @Test public void testRemoveLeaf() { NetworkElement - aE = new NetworkElement(null, null, "a"), - bE = new NetworkElement(null, null, "b"), - cE = new NetworkElement(null, null, "c"); + aE = new NetworkElement("a"), + bE = new NetworkElement("b"), + cE = new NetworkElement("c"); - WiredNode + WiredNodeImpl aN = aE.getNode(), bN = bE.getNode(), cN = cE.getNode(); - aN.getNetwork().connect(aN, bN); - aN.getNetwork().connect(aN, cN); + aN.connectTo(bN); + aN.connectTo(cN); - assertTrue(aN.getNetwork().remove(bN), "Must be able to remove node"); - assertFalse(aN.getNetwork().remove(bN), "Cannot remove a second time"); + assertTrue(bN.remove(), "Must be able to remove node"); + assertFalse(bN.remove(), "Cannot remove a second time"); assertNotEquals(aN.getNetwork(), bN.getNetwork(), "A's and B's network must not be equal"); assertEquals(aN.getNetwork(), cN.getNetwork(), "A's and C's network must be equal"); @@ -194,26 +189,26 @@ public void testRemoveLeaf() { @Test public void testRemoveSplit() { NetworkElement - aE = new NetworkElement(null, null, "a"), - aaE = new NetworkElement(null, null, "a_"), - bE = new NetworkElement(null, null, "b"), - bbE = new NetworkElement(null, null, "b_"), - cE = new NetworkElement(null, null, "c"); + aE = new NetworkElement("a"), + aaE = new NetworkElement("a_"), + bE = new NetworkElement("b"), + bbE = new NetworkElement("b_"), + cE = new NetworkElement("c"); - WiredNode + WiredNodeImpl aN = aE.getNode(), aaN = aaE.getNode(), bN = bE.getNode(), bbN = bbE.getNode(), cN = cE.getNode(); - aN.getNetwork().connect(aN, aaN); - bN.getNetwork().connect(bN, bbN); + aN.connectTo(aaN); + bN.connectTo(bbN); - cN.getNetwork().connect(aN, cN); - cN.getNetwork().connect(bN, cN); + cN.connectTo(aN); + cN.connectTo(bN); - cN.getNetwork().remove(cN); + cN.remove(); assertNotEquals(aN.getNetwork(), bN.getNetwork(), "A's and B's network must not be equal"); assertEquals(aN.getNetwork(), aaN.getNetwork(), "A's and A_'s network must be equal"); @@ -228,96 +223,30 @@ public void testRemoveSplit() { assertEquals(Set.of(), cE.allPeripherals().keySet(), "C's peripheral set should be empty"); } - private static final int BRUTE_SIZE = 16; - private static final int TOGGLE_CONNECTION_TIMES = 5; - private static final int TOGGLE_NODE_TIMES = 5; - - @Test - @Disabled("Takes a long time to run, mostly for stress testing") - public void testLarge() { - var grid = new Grid(BRUTE_SIZE); - grid.map((existing, pos) -> new NetworkElement(null, null, "n_" + pos).getNode()); - - // Test connecting - { - var start = System.nanoTime(); - - grid.forEach((existing, pos) -> { - for (var facing : DirectionUtil.FACINGS) { - var offset = pos.relative(facing); - if (offset.getX() > BRUTE_SIZE / 2 == pos.getX() > BRUTE_SIZE / 2) { - var other = grid.get(offset); - if (other != null) existing.getNetwork().connect(existing, other); - } - } - }); - - var end = System.nanoTime(); - - System.out.printf("Connecting %s³ nodes took %s seconds\n", BRUTE_SIZE, (end - start) * 1e-9); - } - - // Test toggling - { - var left = grid.get(new BlockPos(BRUTE_SIZE / 2, 0, 0)); - var right = grid.get(new BlockPos(BRUTE_SIZE / 2 + 1, 0, 0)); - assertNotEquals(left.getNetwork(), right.getNetwork()); - - var start = System.nanoTime(); - for (var i = 0; i < TOGGLE_CONNECTION_TIMES; i++) { - left.getNetwork().connect(left, right); - left.getNetwork().disconnect(left, right); - } - - var end = System.nanoTime(); - - System.out.printf("Toggling connection %s times took %s seconds\n", TOGGLE_CONNECTION_TIMES, (end - start) * 1e-9); - } - - { - var left = grid.get(new BlockPos(BRUTE_SIZE / 2, 0, 0)); - var right = grid.get(new BlockPos(BRUTE_SIZE / 2 + 1, 0, 0)); - var centre = new NetworkElement(null, null, "c").getNode(); - assertNotEquals(left.getNetwork(), right.getNetwork()); - - var start = System.nanoTime(); - for (var i = 0; i < TOGGLE_NODE_TIMES; i++) { - left.getNetwork().connect(left, centre); - right.getNetwork().connect(right, centre); - - left.getNetwork().remove(centre); - } - - var end = System.nanoTime(); - - System.out.printf("Toggling node %s times took %s seconds\n", TOGGLE_NODE_TIMES, (end - start) * 1e-9); - } - } - - private static final class NetworkElement implements WiredElement { - private final Level world; - private final Vec3 position; + static final class NetworkElement implements WiredElement { private final String id; - private final WiredNode node; + private final WiredNodeImpl node; private final Map localPeripherals = new HashMap<>(); private final Map remotePeripherals = new HashMap<>(); - private NetworkElement(Level world, Vec3 position, String id) { - this.world = world; - this.position = position; + NetworkElement(String id) { + this(id, true); + } + + NetworkElement(String id, boolean peripheral) { this.id = id; this.node = new WiredNodeImpl(this); - this.addPeripheral(id); + if (peripheral) addPeripheral(id); } @Override public Level getLevel() { - return world; + throw new IllegalStateException("Unexpected call to getLevel()"); } @Override public Vec3 getPosition() { - return position; + throw new IllegalStateException("Unexpected call to getPosition()"); } @Override @@ -331,7 +260,7 @@ public String toString() { } @Override - public WiredNode getNode() { + public WiredNodeImpl getNode() { return node; } @@ -364,45 +293,6 @@ public boolean equals(@Nullable IPeripheral other) { } } - private static class Grid { - private final int size; - private final T[] box; - - @SuppressWarnings("unchecked") - Grid(int size) { - this.size = size; - this.box = (T[]) new Object[size * size * size]; - } - - public T get(BlockPos pos) { - int x = pos.getX(), y = pos.getY(), z = pos.getZ(); - - return x >= 0 && x < size && y >= 0 && y < size && z >= 0 && z < size - ? box[x * size * size + y * size + z] - : null; - } - - public void forEach(BiConsumer transform) { - for (var x = 0; x < size; x++) { - for (var y = 0; y < size; y++) { - for (var z = 0; z < size; z++) { - transform.accept(box[x * size * size + y * size + z], new BlockPos(x, y, z)); - } - } - } - } - - public void map(BiFunction transform) { - for (var x = 0; x < size; x++) { - for (var y = 0; y < size; y++) { - for (var z = 0; z < size; z++) { - box[x * size * size + y * size + z] = transform.apply(box[x * size * size + y * size + z], new BlockPos(x, y, z)); - } - } - } - } - } - private static Set nodes(WiredNetwork network) { return ((WiredNetworkImpl) network).nodes; }