diff --git a/src/irc_link.py b/src/irc_link.py index 5bab220..4e3815a 100644 --- a/src/irc_link.py +++ b/src/irc_link.py @@ -15,12 +15,18 @@ def color_code(x): return f"\x03{x}" def random_color(id): return color_code(hashlib.blake2b(str(id).encode("utf-8")).digest()[0] % 13 + 2) +global_conn = None + async def initialize(): + logging.info("Initializing IRC link") + joined = set() loop = asyncio.get_event_loop() reactor = irc.client_aio.AioReactor(loop=loop) conn = await reactor.server().connect(util.config["irc"]["server"], util.config["irc"]["port"], util.config["irc"]["nick"]) + global global_conn + global_conn = conn def inuse(conn, event): conn.nick(scramble(conn.get_nickname())) @@ -51,4 +57,10 @@ async def initialize(): conn.add_global_handler("nicknameinuse", inuse) conn.add_global_handler("pubmsg", pubmsg) - eventbus.add_listener(util.config["irc"]["name"], on_bridge_message) \ No newline at end of file + eventbus.add_listener(util.config["irc"]["name"], on_bridge_message) + +def setup(bot): + asyncio.create_task(initialize()) + +def teardown(bot): + if global_conn: global_conn.disconnect() \ No newline at end of file diff --git a/src/main.py b/src/main.py index b5b155f..b12a241 100644 --- a/src/main.py +++ b/src/main.py @@ -52,7 +52,8 @@ command_errors = prometheus_client.Counter("abr_errors", "Count of errors encoun async def on_command_error(ctx, err): if isinstance(err, (commands.CommandNotFound, commands.CheckFailure)): return if isinstance(err, commands.CommandInvokeError) and isinstance(err.original, ValueError): return await ctx.send(embed=util.error_embed(str(err.original))) - if isinstance(err, commands.MissingRequiredArgument): return await ctx.send(embed=util.error_embed(str(err))) + # TODO: really should find a way to detect ALL user errors here? + if isinstance(err, (commands.MissingRequiredArgument, commands.ExpectedClosingQuoteError, commands.InvalidEndOfQuotedStringError)): return await ctx.send(embed=util.error_embed(str(err))) try: command_errors.inc() trace = re.sub("\n\n+", "\n", "\n".join(traceback.format_exception(err, err, err.__traceback__))) @@ -127,7 +128,6 @@ async def run_bot(): bot.database = await db.init(config["database"]) await eventbus.initial_load(bot.database) await initial_load_webhooks(bot.database) - await irc_link.initialize() for ext in util.extensions: logging.info("loaded %s", ext) bot.load_extension(ext) diff --git a/src/util.py b/src/util.py index 5d6a24f..8b94279 100644 --- a/src/util.py +++ b/src/util.py @@ -263,7 +263,8 @@ extensions = ( "heavserver", "voice", "commands", - "userdata" + "userdata", + "irc_link" ) # https://github.com/SawdustSoftware/simpleflake/blob/master/simpleflake/simpleflake.py