diff --git a/src/achievement.py b/src/achievement.py index e09c73d..8fcb76c 100644 --- a/src/achievement.py +++ b/src/achievement.py @@ -44,7 +44,7 @@ async def achieve(bot: commands.Bot, message: discord.Message, achievement): metrics.achievements_achieved.inc() await bot.database.execute("INSERT INTO achievements VALUES (?, ?, ?)", (uid, achievement, util.timestamp())) await bot.database.commit() - logging.info("awarded achievement %s to %s", message.author.name, achievement) + logging.info("Awarded achievement %s to %s", message.author.name, achievement) def setup(bot): @bot.group(name="achievements", aliases=["ach", "achieve", "achievement"], brief="Achieve a wide variety of fun achievements!", help=f""" diff --git a/src/debug.py b/src/debug.py index 11db314..d2f9107 100644 --- a/src/debug.py +++ b/src/debug.py @@ -71,7 +71,7 @@ def setup(bot): @magic.command(help="Reload configuration file.") async def reload_config(ctx): util.load_config() - ctx.send("Done!") + await ctx.send("Done!") @magic.command(help="Reload extensions (all or the specified one).") async def reload_ext(ctx, ext="all"): diff --git a/src/discord_link.py b/src/discord_link.py new file mode 100644 index 0000000..a75702d --- /dev/null +++ b/src/discord_link.py @@ -0,0 +1,90 @@ +import eventbus +import discord +import asyncio +import logging +import re +import discord.ext.commands as commands + +def parse_formatting(bot, text): + def parse_match(m): + try: + target = int(m.group(2)) + except ValueError: return m.string + if m.group(1) == "@": # user ping + user = bot.get_user(target) + if user: return { "type": "user_mention", "name": user.name, "id": target } + return f"@{target}" + else: # channel "ping" + channel = bot.get_channel(target) + if channel: return { "type": "channel_mention", "name": channel.name, "id": target } + return f"#{target}" + remaining = text + out = [] + while match := re.search(r"<([@#])!?([0-9]+)>", remaining): + start, end = match.span() + out.append(remaining[:start]) + out.append(parse_match(match)) + remaining = remaining[end:] + out.append(remaining) + return list(filter(lambda x: x != "", out)) + +def render_formatting(dest_channel, message): + out = "" + for seg in message: + if isinstance(seg, str): + out += seg + else: + kind = seg["type"] + # TODO: use python 3.10 pattern matching + if kind == "user_mention": + member = dest_channel.guild.get_member(seg["id"]) + if member: out += f"<@{member.id}>" + else: out += f"@{seg['name']}" + elif kind == "channel_mention": # these appear to be clickable across servers/guilds + out += f"<#{seg['id']}>" + else: logging.warn("Unrecognized message seg %s", kind) + return out + +class DiscordLink(commands.Cog): + def __init__(self, bot): + self.webhooks = {} + self.bot = bot + self.unlisten = eventbus.add_listener("discord", self.on_bridge_message) + + async def initial_load_webhooks(self): + rows = await self.bot.database.execute_fetchall("SELECT * FROM discord_webhooks") + for row in rows: + self.webhooks[row["channel_id"]] = row["webhook"] + logging.info("Loaded %d webhooks", len(rows)) + + async def on_bridge_message(self, channel_id, msg): + channel = self.bot.get_channel(channel_id) + if channel: + webhook = self.webhooks.get(channel_id) + if webhook: + wh_obj = discord.Webhook.from_url(webhook, adapter=discord.AsyncWebhookAdapter(self.bot.http._HTTPClient__session)) + await wh_obj.send( + content=render_formatting(channel, msg.message)[:2000], username=msg.author.name, avatar_url=msg.author.avatar_url, + allowed_mentions=discord.AllowedMentions(everyone=False, roles=False, users=False)) + else: + text = f"<{msg.author.name}> {render_formatting(channel, msg.message)}" + await channel.send(text[:2000], allowed_mentions=discord.AllowedMentions(everyone=False, roles=False, users=False)) + else: + logging.warning("Channel %d not found", channel_id) + + @commands.Cog.listener("on_message") + async def send_to_bridge(self, msg): + # discard webhooks and bridge messages (hackily, admittedly, not sure how else to do this) + if msg.content == "": return + if (msg.author == self.bot.user and msg.content[0] == "<") or msg.author.discriminator == "0000": return + channel_id = msg.channel.id + msg = eventbus.Message(eventbus.AuthorInfo(msg.author.name, msg.author.id, str(msg.author.avatar_url), msg.author.bot), parse_formatting(self.bot, msg.content), ("discord", channel_id), msg.id) + await eventbus.push(msg) + + def cog_unload(self): + self.unlisten() + +def setup(bot): + cog = DiscordLink(bot) + bot.add_cog(cog) + asyncio.create_task(cog.initial_load_webhooks()) \ No newline at end of file diff --git a/src/eventbus.py b/src/eventbus.py index ba0d336..9e63ee1 100644 --- a/src/eventbus.py +++ b/src/eventbus.py @@ -22,7 +22,7 @@ def unpack_dataclass_without(d, without): @dataclasses.dataclass class Message: author: AuthorInfo - message: str + message: list[typing.Union[str, dict]] source: (str, any) id: int @@ -75,7 +75,9 @@ async def push(msg: Message): for listener in listeners[dest_type]: asyncio.ensure_future(listener(dest_channel, msg)) -def add_listener(s, l): listeners[s].add(l) +def add_listener(s, l): + listeners[s].add(l) + return lambda: listeners[s].remove(l) async def add_bridge_link(db, c1, c2): logging.info("Bridging %s and %s", repr(c1), repr(c2)) diff --git a/src/irc_link.py b/src/irc_link.py index 4e3815a..c3b061c 100644 --- a/src/irc_link.py +++ b/src/irc_link.py @@ -5,6 +5,7 @@ import random import util import logging import hashlib +import discord.ext.commands as commands def scramble(text): n = list(text) @@ -15,7 +16,23 @@ def color_code(x): return f"\x03{x}" def random_color(id): return color_code(hashlib.blake2b(str(id).encode("utf-8")).digest()[0] % 13 + 2) +def render_formatting(message): + out = "" + for seg in message: + if isinstance(seg, str): + out += seg.replace("\n", " ") + else: + kind = seg["type"] + # TODO: check if user exists on both ends, and possibly drop if so + if kind == "user_mention": + out += f"@{random_color(seg['id'])}{seg['name']}{color_code('')}" + elif kind == "channel_mention": # these appear to be clickable across servers/guilds + out += f"#{seg['name']}" + else: logging.warn("Unrecognized message seg %s", kind) + return out.strip() + global_conn = None +unlisten = None async def initialize(): logging.info("Initializing IRC link") @@ -32,15 +49,14 @@ async def initialize(): conn.nick(scramble(conn.get_nickname())) def pubmsg(conn, event): - msg = eventbus.Message(eventbus.AuthorInfo(event.source.nick, str(event.source), None), " ".join(event.arguments), (util.config["irc"]["name"], event.target), util.random_id()) + msg = eventbus.Message(eventbus.AuthorInfo(event.source.nick, str(event.source), None), [" ".join(event.arguments)], (util.config["irc"]["name"], event.target), util.random_id()) asyncio.create_task(eventbus.push(msg)) async def on_bridge_message(channel_name, msg): if channel_name in util.config["irc"]["channels"]: if channel_name not in joined: conn.join(channel_name) - line = msg.message.replace("\n", " ") # ping fix - zero width space embedded in messages - line = f"<{random_color(msg.author.id)}{msg.author.name[0]}\u200B{msg.author.name[1:]}{color_code('')}> " + line.strip()[:400] + line = f"<{random_color(msg.author.id)}{msg.author.name[0]}\u200B{msg.author.name[1:]}{color_code('')}> " + render_formatting(msg.message)[:400] conn.privmsg(channel_name, line) else: logging.warning("IRC channel %s not allowed", channel_name) @@ -48,19 +64,21 @@ async def initialize(): def connect(conn, event): for channel in util.config["irc"]["channels"]: conn.join(channel) - logging.info("connected to %s", channel) + logging.info("Connected to %s on IRC", channel) joined.add(channel) # TODO: do better thing conn.add_global_handler("welcome", connect) - conn.add_global_handler("disconnect", lambda conn, event: logging.warn("disconnected from IRC, oh no")) + conn.add_global_handler("disconnect", lambda conn, event: logging.warn("Disconnected from IRC")) conn.add_global_handler("nicknameinuse", inuse) conn.add_global_handler("pubmsg", pubmsg) - eventbus.add_listener(util.config["irc"]["name"], on_bridge_message) + global unlisten + unlisten = eventbus.add_listener(util.config["irc"]["name"], on_bridge_message) def setup(bot): asyncio.create_task(initialize()) def teardown(bot): - if global_conn: global_conn.disconnect() \ No newline at end of file + if global_conn: global_conn.disconnect() + if unlisten: unlisten() \ No newline at end of file diff --git a/src/main.py b/src/main.py index 3cba1de..fa9a60e 100644 --- a/src/main.py +++ b/src/main.py @@ -51,14 +51,16 @@ command_errors = prometheus_client.Counter("abr_errors", "Count of errors encoun @bot.event async def on_command_error(ctx, err): if isinstance(err, (commands.CommandNotFound, commands.CheckFailure)): return - if isinstance(err, commands.CommandInvokeError) and isinstance(err.original, ValueError): return await ctx.send(embed=util.error_embed(str(err.original))) + if isinstance(err, commands.CommandInvokeError) and isinstance(err.original, ValueError): + return await ctx.send(embed=util.error_embed(str(err.original), title=f"Error in {ctx.invoked_with}")) # TODO: really should find a way to detect ALL user errors here? - if isinstance(err, (commands.UserInputError)): return await ctx.send(embed=util.error_embed(str(err))) + if isinstance(err, (commands.UserInputError)): + return await ctx.send(embed=util.error_embed(str(err), title=f"Error in {ctx.invoked_with}")) try: command_errors.inc() trace = re.sub("\n\n+", "\n", "\n".join(traceback.format_exception(err, err, err.__traceback__))) - logging.error("command error occured (in %s)", ctx.invoked_with, exc_info=err) - await ctx.send(embed=util.error_embed(util.gen_codeblock(trace), title="Internal error")) + logging.error("Command error occured (in %s)", ctx.invoked_with, exc_info=err) + await ctx.send(embed=util.error_embed(util.gen_codeblock(trace), title=f"Internal error in {ctx.invoked_with}")) await achievement.achieve(ctx.bot, ctx.message, "error") except Exception as e: print("meta-error:", e) @@ -72,38 +74,6 @@ async def on_ready(): await bot.change_presence(status=discord.Status.online, activity=discord.Activity(name=f"{bot.command_prefix}help", type=discord.ActivityType.listening)) -webhooks = {} - -async def initial_load_webhooks(db): - for row in await db.execute_fetchall("SELECT * FROM discord_webhooks"): - webhooks[row["channel_id"]] = row["webhook"] - -@bot.listen("on_message") -async def send_to_bridge(msg): - # discard webhooks and bridge messages (hackily, admittedly, not sure how else to do this) - if (msg.author == bot.user and msg.content[0] == "<") or msg.author.discriminator == "0000": return - if msg.content == "": return - channel_id = msg.channel.id - msg = eventbus.Message(eventbus.AuthorInfo(msg.author.name, msg.author.id, str(msg.author.avatar_url), msg.author.bot), msg.content, ("discord", channel_id), msg.id) - await eventbus.push(msg) - -async def on_bridge_message(channel_id, msg): - channel = bot.get_channel(channel_id) - if channel: - webhook = webhooks.get(channel_id) - if webhook: - wh_obj = discord.Webhook.from_url(webhook, adapter=discord.AsyncWebhookAdapter(bot.http._HTTPClient__session)) - await wh_obj.send( - content=msg.message, username=msg.author.name, avatar_url=msg.author.avatar_url, - allowed_mentions=discord.AllowedMentions(everyone=False, roles=False, users=False)) - else: - text = f"<{msg.author.name}> {msg.message}" - await channel.send(text[:2000], allowed_mentions=discord.AllowedMentions(everyone=False, roles=False, users=False)) - else: - logging.warning("channel %d not found", channel_id) - -eventbus.add_listener("discord", on_bridge_message) - visible_users = prometheus_client.Gauge("abr_visible_users", "Users the bot can see") def get_visible_users(): return len(bot.users) @@ -128,9 +98,8 @@ guild_count.set_function(get_guild_count) async def run_bot(): bot.database = await db.init(config["database"]) await eventbus.initial_load(bot.database) - await initial_load_webhooks(bot.database) for ext in util.extensions: - logging.info("loaded %s", ext) + logging.info("Loaded %s", ext) bot.load_extension(ext) await bot.start(config["token"]) diff --git a/src/reminders.py b/src/reminders.py index 0ca7619..72769e0 100644 --- a/src/reminders.py +++ b/src/reminders.py @@ -90,7 +90,7 @@ def setup(bot): metrics.reminders_fired.inc() to_expire.append((1, rid)) # 1 = expired normally break - except Exception as e: logging.warning("failed to send %d to %s", rid, method_name, exc_info=e) + except Exception as e: logging.warning("Failed to send %d to %s", rid, method_name, exc_info=e) except Exception as e: logging.warning("Could not send reminder %d", rid, exc_info=e) to_expire.append((2, rid)) # 2 = errored diff --git a/src/telephone.py b/src/telephone.py index 49474c5..b7d1e84 100644 --- a/src/telephone.py +++ b/src/telephone.py @@ -38,20 +38,22 @@ def setup(bot): """) @commands.check(util.admin_check) async def link(ctx, target_type, target_id): + target_id = util.extract_codeblock(target_id) try: target_id = int(target_id) except ValueError: pass - await eventbus.add_bridge_link(bot.database, ("discord", ctx.channel.id), (target_type, util.extract_codeblock(target_id))) + await eventbus.add_bridge_link(bot.database, ("discord", ctx.channel.id), (target_type, target_id)) await ctx.send(f"Link established.") pass @telephone.command(brief="Undo link commands.") @commands.check(util.admin_check) async def unlink(ctx, target_type, target_id): + target_id = util.extract_codeblock(target_id) try: target_id = int(target_id) except ValueError: pass - await eventbus.remove_bridge_link(bot.database, ("discord", ctx.channel.id), (target_type, util.extract_codeblock(target_id))) + await eventbus.remove_bridge_link(bot.database, ("discord", ctx.channel.id), (target_type, target_id)) await ctx.send(f"Successfully deleted.") pass @@ -77,7 +79,7 @@ def setup(bot): await bot.database.execute("INSERT OR REPLACE INTO discord_webhooks VALUES (?, ?)", (ctx.channel.id, webhook)) await ctx.send("Created webhook.") except discord.Forbidden as f: - logging.warn("could not create webhook in #%s %s", ctx.channel.name, ctx.guild.name, exc_info=f) + logging.warn("Could not create webhook in #%s %s", ctx.channel.name, ctx.guild.name, exc_info=f) await ctx.send("Webhook creation failed - please ensure permissions are available. This is not necessary but is recommended.") await bot.database.execute("INSERT OR REPLACE INTO telephone_config VALUES (?, ?, ?, ?)", (num, ctx.guild.id, ctx.channel.id, webhook)) await bot.database.commit() diff --git a/src/util.py b/src/util.py index 8b94279..82829d2 100644 --- a/src/util.py +++ b/src/util.py @@ -264,7 +264,8 @@ extensions = ( "voice", "commands", "userdata", - "irc_link" + "irc_link", + "discord_link" ) # https://github.com/SawdustSoftware/simpleflake/blob/master/simpleflake/simpleflake.py