1
0
mirror of https://github.com/osmarks/autobotrobot synced 2025-01-08 14:50:25 +00:00

Unified bridge system

ABR can now bridge to IRC, because of course.
It can also bridge Discord to Discord.
Bridging works transitively because I have a fairly elegant (if I do say so myself) way of handling it:
links are internally point-to-point, and when something is sent in a channel with links configured the bot traverses the graph of links to work out where to send to.
It is planned to expose a private websocket API to synchronize with other servers providing virtual channels.
This system is now used for telephone calls.
There may be issues in certain situations due to the lack of (meaningful) transaction support in aiosqlite.
The telephone command has been extended with (un)link commands, currently only for me as they can link anywhere.
This commit is contained in:
osmarks 2021-02-25 17:48:06 +00:00
parent d449d5356b
commit d8344d9759
7 changed files with 270 additions and 40 deletions

View File

@ -76,6 +76,20 @@ CREATE TABLE user_data (
value TEXT NOT NULL,
UNIQUE (user_id, guild_id, key)
);
""",
"""
CREATE TABLE links (
to_type TEXT NOT NULL,
to_id BLOB NOT NULL,
from_type TEXT NOT NULL,
from_id BLOB NOT NULL,
established_at INTEGER NOT NULL,
UNIQUE (to_type, to_id, from_type, from_id)
);
CREATE TABLE discord_webhooks (
channel_id INTEGER PRIMARY KEY,
webhook TEXT NOT NULL
);
"""
]

View File

@ -1,9 +1,10 @@
import util
import asyncio
import traceback
import re
from discord.ext import commands
import util
import eventbus
def setup(bot):
@bot.group(help="Debug/random messing around utilities. Owner-only.")
@ -26,7 +27,9 @@ def setup(bot):
**locals(),
"bot": bot,
"ctx": ctx,
"db": bot.database
"db": bot.database,
"util": util,
"eventbus": eventbus
}
def check(re, u): return str(re.emoji) == "" and u == ctx.author

100
src/eventbus.py Normal file
View File

@ -0,0 +1,100 @@
import asyncio
import prometheus_client
import dataclasses
import typing
import collections
import logging
import util
@dataclasses.dataclass
class AuthorInfo:
name: str
id: any
avatar_url: str = None
deprioritize: bool = False
def unpack_dataclass_without(d, without):
dct = dict([(field, getattr(d, field)) for field in type(d).__dataclass_fields__])
del dct[without]
return dct
@dataclasses.dataclass
class Message:
author: AuthorInfo
message: str
source: (str, any)
id: int
evbus_messages = prometheus_client.Counter("abr_evbus_messages", "Messages processed by event bus", ["source_type"])
evbus_messages_dropped = prometheus_client.Counter("abr_evbus_messages_dropped", "Messages received by event bus but dropped by rate limits", ["source_type"])
# maps each bridge destination type (discord/APIONET/etc) to the listeners for it
listeners = collections.defaultdict(set)
# maintains a list of all the unidirectional links between channels - key is source, values are targets
links = collections.defaultdict(set)
def find_all_destinations(source):
visited = set()
targets = set(links[source])
while len(targets) > 0:
current = targets.pop()
targets.update(adjacent for adjacent in links[current] if not adjacent in visited)
visited.add(current)
return visited
RATE = 10.0
PER = 5000000.0 # µs
RLData = collections.namedtuple("RLData", ["allowance", "last_check"])
rate_limiting = collections.defaultdict(lambda: RLData(RATE, util.timestamp()))
async def push(msg: Message):
destinations = find_all_destinations(msg.source)
if len(destinations) > 0:
# "token bucket" rate limiting algorithm - max 10 messages per 5 seconds (half that for bots)
# TODO: maybe separate buckets for bot and unbot?
current = util.timestamp_µs()
time_passed = current - rate_limiting[msg.source].last_check
allowance = rate_limiting[msg.source].allowance
allowance += time_passed * (RATE / PER)
if allowance > RATE:
allowance = RATE
rate_limiting[msg.source] = RLData(allowance, current)
if allowance < 1:
evbus_messages_dropped.labels(msg.source[0]).inc()
return
allowance -= 2.0 if msg.author.deprioritize else 1.0
rate_limiting[msg.source] = RLData(allowance, current)
evbus_messages.labels(msg.source[0]).inc()
for dest in destinations:
if dest == msg.source: continue
dest_type, dest_channel = dest
for listener in listeners[dest_type]:
asyncio.ensure_future(listener(dest_channel, msg))
def add_listener(s, l): listeners[s].add(l)
async def add_bridge_link(db, c1, c2):
logging.info("Bridging %s and %s", repr(c1), repr(c2))
links[c1].add(c2)
links[c2].add(c1)
await db.execute("INSERT INTO links VALUES (?, ?, ?, ?, ?) ON CONFLICT DO NOTHING", (c1[0], c1[1], c2[0], c2[1], util.timestamp()))
await db.execute("INSERT INTO links VALUES (?, ?, ?, ?, ?) ON CONFLICT DO NOTHING", (c2[0], c2[1], c1[0], c1[1], util.timestamp()))
await db.commit()
async def remove_bridge_link(db, c1, c2):
logging.info("Unbridging %s and %s", repr(c1), repr(c2))
links[c1].remove(c2)
links[c2].remove(c1)
await db.execute("DELETE FROM links WHERE (to_type = ? AND to_id = ?) AND (from_type = ? AND from_id = ?)", (c1[0], c1[1], c2[0], c2[1]))
await db.execute("DELETE FROM links WHERE (to_type = ? AND to_id = ?) AND (from_type = ? AND from_id = ?)", (c2[0], c2[1], c1[0], c1[1]))
await db.commit()
async def initial_load(db):
rows = await db.execute_fetchall("SELECT * FROM links")
for row in rows:
links[(row["from_type"], row["from_id"])].add((row["to_type"], row["to_id"]))
logging.info("Loaded %d links", len(rows))

53
src/irc_link.py Normal file
View File

@ -0,0 +1,53 @@
import eventbus
import asyncio
import irc.client_aio
import random
import util
import logging
import hashlib
def scramble(text):
n = list(text)
random.shuffle(n)
return "".join(n)
def color_code(x):
return f"\x03{x}"
def random_color(id): return color_code(hashlib.blake2b(str(id).encode("utf-8")).digest()[0] % 13 + 2)
async def initialize():
joined = set()
loop = asyncio.get_event_loop()
reactor = irc.client_aio.AioReactor(loop=loop)
conn = await reactor.server().connect(util.config["irc"]["server"], util.config["irc"]["port"], util.config["irc"]["nick"])
def inuse(conn, event):
conn.nick(scramble(conn.get_nickname()))
def pubmsg(conn, event):
msg = eventbus.Message(eventbus.AuthorInfo(event.source.nick, str(event.source), None), " ".join(event.arguments), (util.config["irc"]["name"], event.target), util.random_id())
asyncio.create_task(eventbus.push(msg))
async def on_bridge_message(channel_name, msg):
if channel_name in util.config["irc"]["channels"]:
if channel_name not in joined: conn.join(channel_name)
line = msg.message.replace("\n", " ")
line = f"<{random_color(msg.author.id)}{msg.author.name}{color_code('')}> " + line.strip()[:400]
conn.privmsg(channel_name, line)
else:
logging.warning("IRC channel %s not allowed", channel_name)
def connect(conn, event):
for channel in util.config["irc"]["channels"]:
conn.join(channel)
logging.info("connected to %s", channel)
joined.add(channel)
# TODO: do better thing
conn.add_global_handler("welcome", connect)
conn.add_global_handler("disconnect", lambda conn, event: logging.warn("disconnected from IRC, oh no"))
conn.add_global_handler("nicknameinuse", inuse)
conn.add_global_handler("pubmsg", pubmsg)
eventbus.add_listener(util.config["irc"]["name"], on_bridge_message)

View File

@ -12,10 +12,13 @@ import collections
import prometheus_client
import prometheus_async.aio
import typing
import sys
import tio
import db
import util
import eventbus
import irc_link
import achievement
config = util.config
@ -47,7 +50,6 @@ async def on_message(message):
command_errors = prometheus_client.Counter("abr_errors", "Count of errors encountered in executing commands.")
@bot.event
async def on_command_error(ctx, err):
#print(ctx, err)
if isinstance(err, (commands.CommandNotFound, commands.CheckFailure)): return
if isinstance(err, commands.CommandInvokeError) and isinstance(err.original, ValueError): return await ctx.send(embed=util.error_embed(str(err.original)))
if isinstance(err, commands.MissingRequiredArgument): return await ctx.send(embed=util.error_embed(str(err)))
@ -69,15 +71,52 @@ async def on_ready():
await bot.change_presence(status=discord.Status.online,
activity=discord.Activity(name=f"{bot.command_prefix}help", type=discord.ActivityType.listening))
webhooks = {}
async def initial_load_webhooks(db):
for row in await db.execute_fetchall("SELECT * FROM discord_webhooks"):
webhooks[row["channel_id"]] = row["webhook"]
@bot.listen("on_message")
async def send_to_bridge(msg):
if msg.author == bot.user or msg.author.discriminator == "0000": return
if msg.content == "": return
channel_id = msg.channel.id
msg = eventbus.Message(eventbus.AuthorInfo(msg.author.name, msg.author.id, str(msg.author.avatar_url), msg.author.bot), msg.content, ("discord", channel_id), msg.id)
await eventbus.push(msg)
async def on_bridge_message(channel_id, msg):
channel = bot.get_channel(channel_id)
if channel:
webhook = webhooks.get(channel_id)
if webhook:
wh_obj = discord.Webhook.from_url(webhook, adapter=discord.AsyncWebhookAdapter(bot.http._HTTPClient__session))
await wh_obj.send(
content=msg.message, username=msg.author.name, avatar_url=msg.author.avatar_url,
allowed_mentions=discord.AllowedMentions(everyone=False, roles=False, users=False))
else:
text = f"<{msg.author.name}> {msg.message}"
await channel.send(text[:2000], allowed_mentions=discord.AllowedMentions(everyone=False, roles=False, users=False))
else:
logging.warning("channel %d not found", channel_id)
eventbus.add_listener("discord", on_bridge_message)
visible_users = prometheus_client.Gauge("abr_visible_users", "Users the bot can see")
def get_visible_users():
return len(bot.users)
visible_users.set_function(get_visible_users)
heavserver_members = prometheus_client.Gauge("abr_heavserver_members", "Current member count of heavserver")
heavserver_bots = prometheus_client.Gauge("abr_heavserver_bots", "Current bot count of heavserver")
def get_heavserver_members():
if not bot.get_guild(util.config["heavserver"]["id"]): return 0
return len(bot.get_guild(util.config["heavserver"]["id"]).members)
def get_heavserver_bots():
if not bot.get_guild(util.config["heavserver"]["id"]): return 0
return len([ None for member in bot.get_guild(util.config["heavserver"]["id"]).members if member.bot ])
heavserver_members.set_function(get_heavserver_members)
heavserver_bots.set_function(get_heavserver_bots)
guild_count = prometheus_client.Gauge("abr_guilds", "Guilds the bot is in")
def get_guild_count():
@ -86,6 +125,9 @@ guild_count.set_function(get_guild_count)
async def run_bot():
bot.database = await db.init(config["database"])
await eventbus.initial_load(bot.database)
await initial_load_webhooks(bot.database)
await irc_link.initialize()
for ext in util.extensions:
logging.info("loaded %s", ext)
bot.load_extension(ext)
@ -98,5 +140,6 @@ if __name__ == '__main__':
loop.run_until_complete(run_bot())
except KeyboardInterrupt:
loop.run_until_complete(bot.logout())
sys.exit(0)
finally:
loop.close()

View File

@ -6,6 +6,7 @@ import hashlib
from datetime import datetime
import util
import eventbus
# Generate a "phone" address
# Not actually for phones
@ -32,38 +33,35 @@ def setup(bot):
async def get_addr_config(addr):
return await bot.database.execute_fetchone("SELECT * FROM telephone_config WHERE id = ?", (addr,))
@bot.listen("on_message")
async def forward_call_messages(message):
channel = message.channel.id
if (message.author.discriminator == "0000" and message.author.bot) or message.author == bot.user or message.content == "": # check if webhook, from itself, or only has embeds
return
calls = await bot.database.execute_fetchall("""SELECT tcf.channel_id AS from_channel, tct.channel_id AS to_channel,
tcf.webhook AS from_webhook, tct.webhook AS to_webhook FROM calls
JOIN telephone_config AS tcf ON tcf.id = calls.from_id JOIN telephone_config AS tct ON tct.id = calls.to_id
WHERE from_channel = ? OR to_channel = ?""", (channel, channel))
if calls == []: return
async def send_to(call):
if call["from_channel"] == channel:
other_channel, other_webhook = call["to_channel"], call["to_webhook"]
else:
other_channel, other_webhook = call["from_channel"], call["from_webhook"]
@telephone.command(brief="Link to other channels", help="""Connect to another channel on Discord or any supported bridges.
Virtual channels also exist.
""")
@commands.check(util.admin_check)
async def link(ctx, target_type, target_id):
try:
target_id = int(target_id)
except ValueError: pass
await eventbus.add_bridge_link(bot.database, ("discord", ctx.channel.id), (target_type, util.extract_codeblock(target_id)))
await ctx.send(f"Link established.")
pass
async def send_normal_message():
m = f"**{message.author.name}**: "
m += message.content[:2000 - len(m)]
await bot.get_channel(other_channel).send(m)
@telephone.command(brief="Undo link commands.")
@commands.check(util.admin_check)
async def unlink(ctx, target_type, target_id):
try:
target_id = int(target_id)
except ValueError: pass
await eventbus.remove_bridge_link(bot.database, ("discord", ctx.channel.id), (target_type, util.extract_codeblock(target_id)))
await ctx.send(f"Successfully deleted.")
pass
if other_webhook:
try:
await discord.Webhook.from_url(other_webhook, adapter=discord.AsyncWebhookAdapter(bot.http._HTTPClient__session)).send(
content=message.content, username=message.author.name, avatar_url=message.author.avatar_url,
allowed_mentions=discord.AllowedMentions(everyone=False, roles=False, users=False))
except discord.errors.NotFound:
logging.warn("channel %d webhook missing", other_channel)
await send_normal_message()
else: await send_normal_message()
await asyncio.gather(*map(send_to, calls))
@telephone.command(brief="Generate a webhook")
@commands.check(util.admin_check)
async def init_webhook(ctx):
webhook = (await ctx.channel.create_webhook(name="ABR webhook", reason=f"requested by {ctx.author.name}")).url
await bot.database.execute("INSERT OR REPLACE INTO discord_webhooks VALUES (?, ?)", (ctx.channel.id, webhook))
await bot.database.commit()
await ctx.send("Done.")
@telephone.command()
@commands.check(util.server_mod_check)
@ -76,6 +74,7 @@ def setup(bot):
if not info or not webhook:
try:
webhook = (await ctx.channel.create_webhook(name="incoming message display", reason="configure for apiotelephone")).url
await bot.database.execute("INSERT OR REPLACE INTO discord_webhooks VALUES (?, ?)", (ctx.channel.id, webhook))
await ctx.send("Created webhook.")
except discord.Forbidden as f:
logging.warn("could not create webhook in #%s %s", ctx.channel.name, ctx.guild.name, exc_info=f)
@ -84,7 +83,7 @@ def setup(bot):
await bot.database.commit()
await ctx.send("Configured.")
@telephone.command(aliases=["call"])
@telephone.command(aliases=["call"], brief="Dial another telephone channel.")
async def dial(ctx, address):
# basic checks - ensure this is a phone channel and has no other open calls
channel_info = await get_channel_config(ctx.channel.id)
@ -127,6 +126,7 @@ def setup(bot):
if em == "": # accept call
await bot.database.execute("INSERT INTO calls VALUES (?, ?, ?)", (originating_address, address, util.timestamp()))
await bot.database.commit()
await eventbus.add_bridge_link(bot.database, ("discord", ctx.channel.id), ("discord", recv_channel.id))
await asyncio.gather(
ctx.send(embed=util.info_embed("Outgoing call", "Call accepted and connected.")),
recv_channel.send(embed=util.info_embed("Incoming call", "Call accepted and connected."))
@ -134,10 +134,7 @@ def setup(bot):
elif em == "": # drop call
await ctx.send(embed=util.error_embed("Your call was declined.", "Call declined"))
async def get_calls(addr):
pass
@telephone.command(aliases=["disconnect", "quit"])
@telephone.command(aliases=["disconnect", "quit"], brief="Disconnect latest call.")
async def hangup(ctx):
channel_info = await get_channel_config(ctx.channel.id)
addr = channel_info["id"]
@ -155,13 +152,14 @@ def setup(bot):
await bot.database.execute("DELETE FROM calls WHERE to_id = ? AND from_id = ?", (addr, other))
await bot.database.commit()
other_channel = (await get_addr_config(other))["channel_id"]
await eventbus.remove_bridge_link(bot.database, ("discord", other_channel), ("discord", ctx.channel.id))
await asyncio.gather(
ctx.send(embed=util.info_embed("Hung up", f"Call to {other} disconnected.")),
bot.get_channel(other_channel).send(embed=util.info_embed("Hung up", f"Call to {addr} disconnected."))
)
@telephone.command(aliases=["status"])
@telephone.command(aliases=["status"], brief="List inbound/outbound calls.")
async def info(ctx):
channel_info = await get_channel_config(ctx.channel.id)
if not channel_info: return await ctx.send(embed=util.info_embed("Phone status", "Not a phone channel"))

View File

@ -11,6 +11,7 @@ import toml
import os.path
from discord.ext import commands
import hashlib
import time
config = {}
@ -21,6 +22,7 @@ def load_config():
load_config()
def timestamp(): return int(datetime.datetime.now(tz=datetime.timezone.utc).timestamp())
def timestamp_µs(): return int(datetime.datetime.now(tz=datetime.timezone.utc).timestamp() * 1e6)
prefixes = {
# big SI prefixes
@ -262,4 +264,21 @@ extensions = (
"voice",
"commands",
"userdata"
)
)
# https://github.com/SawdustSoftware/simpleflake/blob/master/simpleflake/simpleflake.py
SIMPLEFLAKE_EPOCH = 946702800
#field lengths in bits
SIMPLEFLAKE_TIMESTAMP_LENGTH = 43
SIMPLEFLAKE_RANDOM_LENGTH = 21
#left shift amounts
SIMPLEFLAKE_RANDOM_SHIFT = 0
SIMPLEFLAKE_TIMESTAMP_SHIFT = 21
def random_id():
second_time = time.time()
second_time -= SIMPLEFLAKE_EPOCH
millisecond_time = int(second_time * 1000)
randomness = random.getrandbits(SIMPLEFLAKE_RANDOM_LENGTH)
return (millisecond_time << SIMPLEFLAKE_TIMESTAMP_SHIFT) + randomness