Use method handlees in our generated Lua methods

When the target method is in a different class loader to CC, our
generated method fails, as it cannot find the target class. To get
around that, we create a MethodHandle to the target method, and then
inject that into the generated class (with Java's new dynamic constant
system). We can then invoke the MethodHandle in our generated code,
avoiding any references to the target class/method.
This commit is contained in:
Jonathan Coates 2023-09-03 16:38:00 +01:00
parent 48bd75faac
commit 8896e30ad6
No known key found for this signature in database
GPG Key ID: B9E431FF07C98D06
3 changed files with 97 additions and 52 deletions

View File

@ -1,19 +0,0 @@
// SPDX-FileCopyrightText: 2020 The CC: Tweaked Developers
//
// SPDX-License-Identifier: MPL-2.0
package dan200.computercraft.core.asm;
import java.security.ProtectionDomain;
final class DeclaringClassLoader extends ClassLoader {
static final DeclaringClassLoader INSTANCE = new DeclaringClassLoader();
private DeclaringClassLoader() {
super(DeclaringClassLoader.class.getClassLoader());
}
Class<?> define(String name, byte[] bytes, ProtectionDomain protectionDomain) throws ClassFormatError {
return defineClass(name, bytes, 0, bytes.length, protectionDomain);
}
}

View File

@ -10,26 +10,44 @@
import com.google.common.primitives.Primitives;
import com.google.common.reflect.TypeToken;
import dan200.computercraft.api.lua.*;
import org.objectweb.asm.ClassWriter;
import org.objectweb.asm.MethodVisitor;
import org.objectweb.asm.Type;
import dan200.computercraft.core.methods.LuaMethod;
import org.objectweb.asm.*;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import javax.annotation.Nullable;
import java.lang.constant.ConstantDescs;
import java.lang.invoke.MethodHandle;
import java.lang.invoke.MethodHandles;
import java.lang.invoke.MethodType;
import java.lang.reflect.Method;
import java.lang.reflect.Modifier;
import java.util.List;
import java.util.Optional;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Function;
import static org.objectweb.asm.Opcodes.*;
/**
* The underlying generator for {@link LuaFunction}-annotated methods.
* <p>
* The constructor {@link Generator#Generator(Class, List, Function)} takes in the type of interface to generate (i.e.
* {@link LuaMethod}), the context arguments for this function (in the case of {@link LuaMethod}, this will just be
* {@link ILuaContext}) and a "wrapper" function to lift a function to execute on the main thread.
* <p>
* The generated class then implements this interface - the {@code apply} method calls the appropriate methods on
* {@link IArguments} to extract the arguments, and then calls the original method.
* <p>
* As the method is not guaranteed to come from the same classloader, we cannot call the method directly, as that may
* result in linkage errors. We instead inject a {@link MethodHandle} into the class as a dynamic constant, and then
* call the method with {@link MethodHandle#invokeExact(Object...)}. The method handle is constant, and so this has
* equivalent performance to the direct call.
*
* @param <T> The type of the interface the generated classes implement.
*/
final class Generator<T> {
private static final Logger LOG = LoggerFactory.getLogger(Generator.class);
private static final AtomicInteger METHOD_ID = new AtomicInteger();
private static final MethodHandles.Lookup LOOKUP = MethodHandles.lookup();
private static final String METHOD_NAME = "apply";
private static final String[] EXCEPTIONS = new String[]{ Type.getInternalName(LuaException.class) };
@ -42,11 +60,17 @@ final class Generator<T> {
private static final String INTERNAL_COERCED = Type.getInternalName(Coerced.class);
private static final ConstantDynamic METHOD_CONSTANT = new ConstantDynamic(ConstantDescs.DEFAULT_NAME, MethodHandle.class.descriptorString(), new Handle(
H_INVOKESTATIC, Type.getInternalName(MethodHandles.class), "classData",
MethodType.methodType(Object.class, MethodHandles.Lookup.class, String.class, Class.class).descriptorString(), false
));
private final Class<T> base;
private final List<Class<?>> context;
private final String[] interfaces;
private final String methodDesc;
private final String classPrefix;
private final Function<T, T> wrap;
@ -64,6 +88,8 @@ final class Generator<T> {
for (var klass : context) methodDesc.append(Type.getDescriptor(klass));
methodDesc.append(DESC_ARGUMENTS).append(")").append(DESC_METHOD_RESULT);
this.methodDesc = methodDesc.toString();
classPrefix = Generator.class.getPackageName() + "." + base.getSimpleName() + "$";
}
Optional<T> getMethod(Method method) {
@ -110,11 +136,17 @@ private Optional<T> build(Method method) {
var target = Modifier.isStatic(modifiers) ? method.getParameterTypes()[0] : method.getDeclaringClass();
try {
var className = method.getDeclaringClass().getName() + "$cc$" + method.getName() + METHOD_ID.getAndIncrement();
var bytes = generate(className, target, method, annotation.unsafe());
var handle = LOOKUP.unreflect(method);
// Convert the handle from one of the form (target, ...) -> ret type to (Object, ...) -> Object. This both
// handles the boxing of primitives for us, and ensures our bytecode does not reference any external types.
// We could handle the conversion to MethodResult here too, but it doesn't feel worth it.
var widenedHandle = handle.asType(widenMethodType(handle.type(), target));
var bytes = generate(classPrefix + method.getName(), target, method, widenedHandle.type().descriptorString(), annotation.unsafe());
if (bytes == null) return Optional.empty();
var klass = DeclaringClassLoader.INSTANCE.define(className, bytes, method.getDeclaringClass().getProtectionDomain());
var klass = LOOKUP.defineHiddenClassWithClassData(bytes, widenedHandle, true).lookupClass();
var instance = klass.asSubclass(base).getDeclaredConstructor().newInstance();
return Optional.of(annotation.mainThread() ? wrap.apply(instance) : instance);
@ -122,16 +154,29 @@ private Optional<T> build(Method method) {
LOG.error("Error generating wrapper for {}.", name, e);
return Optional.empty();
}
}
private static MethodType widenMethodType(MethodType source, Class<?> target) {
// Treat the target argument as just Object - we'll do the cast in the method handle.
var args = source.parameterArray();
for (var i = 0; i < args.length; i++) {
if (args[i] == target) args[i] = Object.class;
}
// And convert the return value to Object if needed.
var ret = source.returnType();
return ret == void.class || ret == MethodResult.class || ret == Object[].class
? MethodType.methodType(ret, args)
: MethodType.methodType(Object.class, args);
}
@Nullable
private byte[] generate(String className, Class<?> target, Method method, boolean unsafe) {
private byte[] generate(String className, Class<?> target, Method targetMethod, String targetDescriptor, boolean unsafe) {
var internalName = className.replace(".", "/");
// Construct a public final class which extends Object and implements MethodInstance.Delegate
var cw = new ClassWriter(ClassWriter.COMPUTE_FRAMES | ClassWriter.COMPUTE_MAXS);
cw.visit(V1_8, ACC_PUBLIC | ACC_FINAL, internalName, null, "java/lang/Object", interfaces);
cw.visit(V17, ACC_PUBLIC | ACC_FINAL, internalName, null, "java/lang/Object", interfaces);
cw.visitSource("CC generated method", null);
{ // Constructor just invokes super.
@ -148,35 +193,26 @@ private byte[] generate(String className, Class<?> target, Method method, boolea
var mw = cw.visitMethod(ACC_PUBLIC, METHOD_NAME, methodDesc, null, EXCEPTIONS);
mw.visitCode();
// If we're an instance method, load the this parameter.
if (!Modifier.isStatic(method.getModifiers())) {
mw.visitVarInsn(ALOAD, 1);
mw.visitTypeInsn(CHECKCAST, Type.getInternalName(target));
}
mw.visitLdcInsn(METHOD_CONSTANT);
// If we're an instance method, load the target as the first argument.
if (!Modifier.isStatic(targetMethod.getModifiers())) mw.visitVarInsn(ALOAD, 1);
var argIndex = 0;
for (var genericArg : method.getGenericParameterTypes()) {
var loadedArg = loadArg(mw, target, method, unsafe, genericArg, argIndex);
for (var genericArg : targetMethod.getGenericParameterTypes()) {
var loadedArg = loadArg(mw, target, targetMethod, unsafe, genericArg, argIndex);
if (loadedArg == null) return null;
if (loadedArg) argIndex++;
}
mw.visitMethodInsn(
Modifier.isStatic(method.getModifiers()) ? INVOKESTATIC : INVOKEVIRTUAL,
Type.getInternalName(method.getDeclaringClass()), method.getName(),
Type.getMethodDescriptor(method), false
);
mw.visitMethodInsn(INVOKEVIRTUAL, "java/lang/invoke/MethodHandle", "invokeExact", targetDescriptor, false);
// We allow a reasonable amount of flexibility on the return value's type. Alongside the obvious MethodResult,
// we convert basic types into an immediate result.
var ret = method.getReturnType();
var ret = targetMethod.getReturnType();
if (ret != MethodResult.class) {
if (ret == void.class) {
mw.visitMethodInsn(INVOKESTATIC, INTERNAL_METHOD_RESULT, "of", "()" + DESC_METHOD_RESULT, false);
} else if (ret.isPrimitive()) {
var boxed = Primitives.wrap(ret);
mw.visitMethodInsn(INVOKESTATIC, Type.getInternalName(boxed), "valueOf", "(" + Type.getDescriptor(ret) + ")" + Type.getDescriptor(boxed), false);
mw.visitMethodInsn(INVOKESTATIC, INTERNAL_METHOD_RESULT, "of", "(Ljava/lang/Object;)" + DESC_METHOD_RESULT, false);
} else if (ret == Object[].class) {
mw.visitMethodInsn(INVOKESTATIC, INTERNAL_METHOD_RESULT, "of", "([Ljava/lang/Object;)" + DESC_METHOD_RESULT, false);
} else {
@ -199,7 +235,6 @@ private byte[] generate(String className, Class<?> target, Method method, boolea
private Boolean loadArg(MethodVisitor mw, Class<?> target, Method method, boolean unsafe, java.lang.reflect.Type genericArg, int argIndex) {
if (genericArg == target) {
mw.visitVarInsn(ALOAD, 1);
mw.visitTypeInsn(CHECKCAST, Type.getInternalName(target));
return false;
}

View File

@ -11,12 +11,14 @@
import dan200.computercraft.core.methods.NamedMethod;
import org.hamcrest.Matcher;
import org.junit.jupiter.api.Test;
import org.objectweb.asm.ClassReader;
import org.objectweb.asm.ClassVisitor;
import org.objectweb.asm.ClassWriter;
import org.objectweb.asm.Opcodes;
import java.io.IOException;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.lang.invoke.MethodHandles;
import java.util.*;
import static dan200.computercraft.test.core.ContramapMatcher.contramap;
import static org.hamcrest.MatcherAssert.assertThat;
@ -116,6 +118,33 @@ public void testUnsafe() {
assertThat(methods, contains(named("withUnsafe")));
}
@Test
public void testClassNotAccessible() throws IOException, ReflectiveOperationException, LuaException {
var basicName = Basic.class.getName().replace('.', '/');
// Load our Basic class, rewriting it to be a separate (hidden) class which is not part of the same nest as
// the existing Basic.
ClassReader reader;
try (var input = getClass().getClassLoader().getResourceAsStream(basicName + ".class")) {
reader = new ClassReader(Objects.requireNonNull(input, "Cannot find " + basicName));
}
var writer = new ClassWriter(reader, 0);
reader.accept(new ClassVisitor(Opcodes.ASM9, writer) {
@Override
public void visitNestHost(String nestHost) {
}
@Override
public void visitInnerClass(String name, String outerName, String innerName, int access) {
}
}, 0);
var klass = MethodHandles.lookup().defineHiddenClass(writer.toByteArray(), true).lookupClass();
var methods = GENERATOR.getMethods(klass);
assertThat(apply(methods, klass.getConstructor().newInstance(), "go"), equalTo(MethodResult.of()));
}
public static class Basic {
@LuaFunction
public final void go() {