From 0bd0f4d313d8b9df36834bf335c558a2a4a5f584 Mon Sep 17 00:00:00 2001 From: SquidDev Date: Mon, 1 May 2017 18:57:10 +0100 Subject: [PATCH] Prefix all loaded strings with "=" Whilst this is not consistent with normal Lua, this is required in order to remain compatible with LuaJ. --- .../core/lua/CobaltLuaMachine.java | 92 +++++++++++++++++++ .../assets/computercraft/lua/bios.lua | 2 +- 2 files changed, 93 insertions(+), 1 deletion(-) diff --git a/src/main/java/dan200/computercraft/core/lua/CobaltLuaMachine.java b/src/main/java/dan200/computercraft/core/lua/CobaltLuaMachine.java index b250ea392..36d4e7983 100644 --- a/src/main/java/dan200/computercraft/core/lua/CobaltLuaMachine.java +++ b/src/main/java/dan200/computercraft/core/lua/CobaltLuaMachine.java @@ -21,6 +21,7 @@ import org.squiddev.cobalt.compiler.LoadState; import org.squiddev.cobalt.debug.DebugFrame; import org.squiddev.cobalt.debug.DebugHandler; import org.squiddev.cobalt.debug.DebugState; +import org.squiddev.cobalt.function.LibFunction; import org.squiddev.cobalt.function.LuaFunction; import org.squiddev.cobalt.function.VarArgFunction; import org.squiddev.cobalt.lib.*; @@ -121,6 +122,9 @@ public class CobaltLuaMachine implements ILuaMachine m_globals.load( state, new MathLib() ); m_globals.load( state, new CoroutineLib() ); + // Register custom load/loadstring provider which automatically adds prefixes. + LibFunction.bind( state, m_globals, PrefixLoader.class, new String[]{ "load", "loadstring" } ); + // Remove globals we don't want to expose m_globals.rawset( "collectgarbage", Constants.NIL ); m_globals.rawset( "dofile", Constants.NIL ); @@ -641,4 +645,92 @@ public class CobaltLuaMachine implements ILuaMachine } return objects; } + + private static class PrefixLoader extends VarArgFunction + { + private static final LuaString FUNCTION_STR = valueOf( "function" ); + private static final LuaString EQ_STR = valueOf( "=" ); + + @Override + public Varargs invoke( LuaState state, Varargs args ) throws LuaError + { + switch (opcode) + { + case 0: // "load", // ( func [,chunkname] ) -> chunk | nil, msg + { + LuaValue func = args.arg( 1 ).checkFunction(); + LuaString chunkname = args.arg( 2 ).optLuaString( FUNCTION_STR ); + if( !chunkname.startsWith( '@' ) && !chunkname.startsWith( '=' ) ) + { + chunkname = OperationHelper.concat( EQ_STR, chunkname ); + } + return BaseLib.loadStream( state, new StringInputStream( state, func ), chunkname ); + } + case 1: // "loadstring", // ( string [,chunkname] ) -> chunk | nil, msg + { + LuaString script = args.arg( 1 ).checkLuaString(); + LuaString chunkname = args.arg( 2 ).optLuaString( script ); + if( !chunkname.startsWith( '@' ) && !chunkname.startsWith( '=' ) ) + { + chunkname = OperationHelper.concat( EQ_STR, chunkname ); + } + return BaseLib.loadStream( state, script.toInputStream(), chunkname ); + } + } + + return NONE; + } + } + + private static class StringInputStream extends InputStream + { + private final LuaState state; + private final LuaValue func; + private byte[] bytes; + private int offset, remaining = 0; + + public StringInputStream( LuaState state, LuaValue func ) + { + this.state = state; + this.func = func; + } + + @Override + public int read() throws IOException + { + if( remaining <= 0 ) + { + LuaValue s; + try + { + s = OperationHelper.call( state, func ); + } catch (LuaError e) + { + throw new IOException( e ); + } + + if( s.isNil() ) + { + return -1; + } + LuaString ls; + try + { + ls = s.strvalue(); + } catch (LuaError e) + { + throw new IOException( e ); + } + bytes = ls.bytes; + offset = ls.offset; + remaining = ls.length; + if( remaining <= 0 ) + { + return -1; + } + } + --remaining; + return bytes[offset++]; + } + } } diff --git a/src/main/resources/assets/computercraft/lua/bios.lua b/src/main/resources/assets/computercraft/lua/bios.lua index e0ef79190..e98a34fbe 100644 --- a/src/main/resources/assets/computercraft/lua/bios.lua +++ b/src/main/resources/assets/computercraft/lua/bios.lua @@ -541,7 +541,7 @@ loadfile = function( _sFile, _tEnv ) end local file = fs.open( _sFile, "r" ) if file then - local func, err = load( file.readAll(), fs.getName( _sFile ), "t", _tEnv ) + local func, err = load( file.readAll(), "@" .. fs.getName( _sFile ), "t", _tEnv ) file.close() return func, err end