mirror of
https://github.com/osmarks/autobotrobot
synced 2025-01-23 13:46:52 +00:00
Unified bridge system
ABR can now bridge to IRC, because of course. It can also bridge Discord to Discord. Bridging works transitively because I have a fairly elegant (if I do say so myself) way of handling it: links are internally point-to-point, and when something is sent in a channel with links configured the bot traverses the graph of links to work out where to send to. It is planned to expose a private websocket API to synchronize with other servers providing virtual channels. This system is now used for telephone calls. There may be issues in certain situations due to the lack of (meaningful) transaction support in aiosqlite. The telephone command has been extended with (un)link commands, currently only for me as they can link anywhere.
This commit is contained in:
parent
d449d5356b
commit
d8344d9759
14
src/db.py
14
src/db.py
@ -76,6 +76,20 @@ CREATE TABLE user_data (
|
||||
value TEXT NOT NULL,
|
||||
UNIQUE (user_id, guild_id, key)
|
||||
);
|
||||
""",
|
||||
"""
|
||||
CREATE TABLE links (
|
||||
to_type TEXT NOT NULL,
|
||||
to_id BLOB NOT NULL,
|
||||
from_type TEXT NOT NULL,
|
||||
from_id BLOB NOT NULL,
|
||||
established_at INTEGER NOT NULL,
|
||||
UNIQUE (to_type, to_id, from_type, from_id)
|
||||
);
|
||||
CREATE TABLE discord_webhooks (
|
||||
channel_id INTEGER PRIMARY KEY,
|
||||
webhook TEXT NOT NULL
|
||||
);
|
||||
"""
|
||||
]
|
||||
|
||||
|
@ -1,9 +1,10 @@
|
||||
import util
|
||||
import asyncio
|
||||
import traceback
|
||||
import re
|
||||
from discord.ext import commands
|
||||
|
||||
import util
|
||||
import eventbus
|
||||
|
||||
def setup(bot):
|
||||
@bot.group(help="Debug/random messing around utilities. Owner-only.")
|
||||
@ -26,7 +27,9 @@ def setup(bot):
|
||||
**locals(),
|
||||
"bot": bot,
|
||||
"ctx": ctx,
|
||||
"db": bot.database
|
||||
"db": bot.database,
|
||||
"util": util,
|
||||
"eventbus": eventbus
|
||||
}
|
||||
|
||||
def check(re, u): return str(re.emoji) == "❌" and u == ctx.author
|
||||
|
100
src/eventbus.py
Normal file
100
src/eventbus.py
Normal file
@ -0,0 +1,100 @@
|
||||
import asyncio
|
||||
import prometheus_client
|
||||
import dataclasses
|
||||
import typing
|
||||
import collections
|
||||
import logging
|
||||
|
||||
import util
|
||||
|
||||
@dataclasses.dataclass
|
||||
class AuthorInfo:
|
||||
name: str
|
||||
id: any
|
||||
avatar_url: str = None
|
||||
deprioritize: bool = False
|
||||
|
||||
def unpack_dataclass_without(d, without):
|
||||
dct = dict([(field, getattr(d, field)) for field in type(d).__dataclass_fields__])
|
||||
del dct[without]
|
||||
return dct
|
||||
|
||||
@dataclasses.dataclass
|
||||
class Message:
|
||||
author: AuthorInfo
|
||||
message: str
|
||||
source: (str, any)
|
||||
id: int
|
||||
|
||||
evbus_messages = prometheus_client.Counter("abr_evbus_messages", "Messages processed by event bus", ["source_type"])
|
||||
evbus_messages_dropped = prometheus_client.Counter("abr_evbus_messages_dropped", "Messages received by event bus but dropped by rate limits", ["source_type"])
|
||||
|
||||
# maps each bridge destination type (discord/APIONET/etc) to the listeners for it
|
||||
listeners = collections.defaultdict(set)
|
||||
|
||||
# maintains a list of all the unidirectional links between channels - key is source, values are targets
|
||||
links = collections.defaultdict(set)
|
||||
|
||||
def find_all_destinations(source):
|
||||
visited = set()
|
||||
targets = set(links[source])
|
||||
while len(targets) > 0:
|
||||
current = targets.pop()
|
||||
targets.update(adjacent for adjacent in links[current] if not adjacent in visited)
|
||||
visited.add(current)
|
||||
return visited
|
||||
|
||||
RATE = 10.0
|
||||
PER = 5000000.0 # µs
|
||||
|
||||
RLData = collections.namedtuple("RLData", ["allowance", "last_check"])
|
||||
rate_limiting = collections.defaultdict(lambda: RLData(RATE, util.timestamp()))
|
||||
|
||||
async def push(msg: Message):
|
||||
destinations = find_all_destinations(msg.source)
|
||||
if len(destinations) > 0:
|
||||
# "token bucket" rate limiting algorithm - max 10 messages per 5 seconds (half that for bots)
|
||||
# TODO: maybe separate buckets for bot and unbot?
|
||||
current = util.timestamp_µs()
|
||||
time_passed = current - rate_limiting[msg.source].last_check
|
||||
allowance = rate_limiting[msg.source].allowance
|
||||
allowance += time_passed * (RATE / PER)
|
||||
if allowance > RATE:
|
||||
allowance = RATE
|
||||
rate_limiting[msg.source] = RLData(allowance, current)
|
||||
if allowance < 1:
|
||||
evbus_messages_dropped.labels(msg.source[0]).inc()
|
||||
return
|
||||
allowance -= 2.0 if msg.author.deprioritize else 1.0
|
||||
rate_limiting[msg.source] = RLData(allowance, current)
|
||||
|
||||
evbus_messages.labels(msg.source[0]).inc()
|
||||
for dest in destinations:
|
||||
if dest == msg.source: continue
|
||||
dest_type, dest_channel = dest
|
||||
for listener in listeners[dest_type]:
|
||||
asyncio.ensure_future(listener(dest_channel, msg))
|
||||
|
||||
def add_listener(s, l): listeners[s].add(l)
|
||||
|
||||
async def add_bridge_link(db, c1, c2):
|
||||
logging.info("Bridging %s and %s", repr(c1), repr(c2))
|
||||
links[c1].add(c2)
|
||||
links[c2].add(c1)
|
||||
await db.execute("INSERT INTO links VALUES (?, ?, ?, ?, ?) ON CONFLICT DO NOTHING", (c1[0], c1[1], c2[0], c2[1], util.timestamp()))
|
||||
await db.execute("INSERT INTO links VALUES (?, ?, ?, ?, ?) ON CONFLICT DO NOTHING", (c2[0], c2[1], c1[0], c1[1], util.timestamp()))
|
||||
await db.commit()
|
||||
|
||||
async def remove_bridge_link(db, c1, c2):
|
||||
logging.info("Unbridging %s and %s", repr(c1), repr(c2))
|
||||
links[c1].remove(c2)
|
||||
links[c2].remove(c1)
|
||||
await db.execute("DELETE FROM links WHERE (to_type = ? AND to_id = ?) AND (from_type = ? AND from_id = ?)", (c1[0], c1[1], c2[0], c2[1]))
|
||||
await db.execute("DELETE FROM links WHERE (to_type = ? AND to_id = ?) AND (from_type = ? AND from_id = ?)", (c2[0], c2[1], c1[0], c1[1]))
|
||||
await db.commit()
|
||||
|
||||
async def initial_load(db):
|
||||
rows = await db.execute_fetchall("SELECT * FROM links")
|
||||
for row in rows:
|
||||
links[(row["from_type"], row["from_id"])].add((row["to_type"], row["to_id"]))
|
||||
logging.info("Loaded %d links", len(rows))
|
53
src/irc_link.py
Normal file
53
src/irc_link.py
Normal file
@ -0,0 +1,53 @@
|
||||
import eventbus
|
||||
import asyncio
|
||||
import irc.client_aio
|
||||
import random
|
||||
import util
|
||||
import logging
|
||||
import hashlib
|
||||
|
||||
def scramble(text):
|
||||
n = list(text)
|
||||
random.shuffle(n)
|
||||
return "".join(n)
|
||||
|
||||
def color_code(x):
|
||||
return f"\x03{x}"
|
||||
def random_color(id): return color_code(hashlib.blake2b(str(id).encode("utf-8")).digest()[0] % 13 + 2)
|
||||
|
||||
async def initialize():
|
||||
joined = set()
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
reactor = irc.client_aio.AioReactor(loop=loop)
|
||||
conn = await reactor.server().connect(util.config["irc"]["server"], util.config["irc"]["port"], util.config["irc"]["nick"])
|
||||
|
||||
def inuse(conn, event):
|
||||
conn.nick(scramble(conn.get_nickname()))
|
||||
|
||||
def pubmsg(conn, event):
|
||||
msg = eventbus.Message(eventbus.AuthorInfo(event.source.nick, str(event.source), None), " ".join(event.arguments), (util.config["irc"]["name"], event.target), util.random_id())
|
||||
asyncio.create_task(eventbus.push(msg))
|
||||
|
||||
async def on_bridge_message(channel_name, msg):
|
||||
if channel_name in util.config["irc"]["channels"]:
|
||||
if channel_name not in joined: conn.join(channel_name)
|
||||
line = msg.message.replace("\n", " ")
|
||||
line = f"<{random_color(msg.author.id)}{msg.author.name}{color_code('')}> " + line.strip()[:400]
|
||||
conn.privmsg(channel_name, line)
|
||||
else:
|
||||
logging.warning("IRC channel %s not allowed", channel_name)
|
||||
|
||||
def connect(conn, event):
|
||||
for channel in util.config["irc"]["channels"]:
|
||||
conn.join(channel)
|
||||
logging.info("connected to %s", channel)
|
||||
joined.add(channel)
|
||||
|
||||
# TODO: do better thing
|
||||
conn.add_global_handler("welcome", connect)
|
||||
conn.add_global_handler("disconnect", lambda conn, event: logging.warn("disconnected from IRC, oh no"))
|
||||
conn.add_global_handler("nicknameinuse", inuse)
|
||||
conn.add_global_handler("pubmsg", pubmsg)
|
||||
|
||||
eventbus.add_listener(util.config["irc"]["name"], on_bridge_message)
|
45
src/main.py
45
src/main.py
@ -12,10 +12,13 @@ import collections
|
||||
import prometheus_client
|
||||
import prometheus_async.aio
|
||||
import typing
|
||||
import sys
|
||||
|
||||
import tio
|
||||
import db
|
||||
import util
|
||||
import eventbus
|
||||
import irc_link
|
||||
import achievement
|
||||
|
||||
config = util.config
|
||||
@ -47,7 +50,6 @@ async def on_message(message):
|
||||
command_errors = prometheus_client.Counter("abr_errors", "Count of errors encountered in executing commands.")
|
||||
@bot.event
|
||||
async def on_command_error(ctx, err):
|
||||
#print(ctx, err)
|
||||
if isinstance(err, (commands.CommandNotFound, commands.CheckFailure)): return
|
||||
if isinstance(err, commands.CommandInvokeError) and isinstance(err.original, ValueError): return await ctx.send(embed=util.error_embed(str(err.original)))
|
||||
if isinstance(err, commands.MissingRequiredArgument): return await ctx.send(embed=util.error_embed(str(err)))
|
||||
@ -69,15 +71,52 @@ async def on_ready():
|
||||
await bot.change_presence(status=discord.Status.online,
|
||||
activity=discord.Activity(name=f"{bot.command_prefix}help", type=discord.ActivityType.listening))
|
||||
|
||||
webhooks = {}
|
||||
|
||||
async def initial_load_webhooks(db):
|
||||
for row in await db.execute_fetchall("SELECT * FROM discord_webhooks"):
|
||||
webhooks[row["channel_id"]] = row["webhook"]
|
||||
|
||||
@bot.listen("on_message")
|
||||
async def send_to_bridge(msg):
|
||||
if msg.author == bot.user or msg.author.discriminator == "0000": return
|
||||
if msg.content == "": return
|
||||
channel_id = msg.channel.id
|
||||
msg = eventbus.Message(eventbus.AuthorInfo(msg.author.name, msg.author.id, str(msg.author.avatar_url), msg.author.bot), msg.content, ("discord", channel_id), msg.id)
|
||||
await eventbus.push(msg)
|
||||
|
||||
async def on_bridge_message(channel_id, msg):
|
||||
channel = bot.get_channel(channel_id)
|
||||
if channel:
|
||||
webhook = webhooks.get(channel_id)
|
||||
if webhook:
|
||||
wh_obj = discord.Webhook.from_url(webhook, adapter=discord.AsyncWebhookAdapter(bot.http._HTTPClient__session))
|
||||
await wh_obj.send(
|
||||
content=msg.message, username=msg.author.name, avatar_url=msg.author.avatar_url,
|
||||
allowed_mentions=discord.AllowedMentions(everyone=False, roles=False, users=False))
|
||||
else:
|
||||
text = f"<{msg.author.name}> {msg.message}"
|
||||
await channel.send(text[:2000], allowed_mentions=discord.AllowedMentions(everyone=False, roles=False, users=False))
|
||||
else:
|
||||
logging.warning("channel %d not found", channel_id)
|
||||
|
||||
eventbus.add_listener("discord", on_bridge_message)
|
||||
|
||||
visible_users = prometheus_client.Gauge("abr_visible_users", "Users the bot can see")
|
||||
def get_visible_users():
|
||||
return len(bot.users)
|
||||
visible_users.set_function(get_visible_users)
|
||||
|
||||
heavserver_members = prometheus_client.Gauge("abr_heavserver_members", "Current member count of heavserver")
|
||||
heavserver_bots = prometheus_client.Gauge("abr_heavserver_bots", "Current bot count of heavserver")
|
||||
def get_heavserver_members():
|
||||
if not bot.get_guild(util.config["heavserver"]["id"]): return 0
|
||||
return len(bot.get_guild(util.config["heavserver"]["id"]).members)
|
||||
def get_heavserver_bots():
|
||||
if not bot.get_guild(util.config["heavserver"]["id"]): return 0
|
||||
return len([ None for member in bot.get_guild(util.config["heavserver"]["id"]).members if member.bot ])
|
||||
heavserver_members.set_function(get_heavserver_members)
|
||||
heavserver_bots.set_function(get_heavserver_bots)
|
||||
|
||||
guild_count = prometheus_client.Gauge("abr_guilds", "Guilds the bot is in")
|
||||
def get_guild_count():
|
||||
@ -86,6 +125,9 @@ guild_count.set_function(get_guild_count)
|
||||
|
||||
async def run_bot():
|
||||
bot.database = await db.init(config["database"])
|
||||
await eventbus.initial_load(bot.database)
|
||||
await initial_load_webhooks(bot.database)
|
||||
await irc_link.initialize()
|
||||
for ext in util.extensions:
|
||||
logging.info("loaded %s", ext)
|
||||
bot.load_extension(ext)
|
||||
@ -98,5 +140,6 @@ if __name__ == '__main__':
|
||||
loop.run_until_complete(run_bot())
|
||||
except KeyboardInterrupt:
|
||||
loop.run_until_complete(bot.logout())
|
||||
sys.exit(0)
|
||||
finally:
|
||||
loop.close()
|
||||
|
@ -6,6 +6,7 @@ import hashlib
|
||||
from datetime import datetime
|
||||
|
||||
import util
|
||||
import eventbus
|
||||
|
||||
# Generate a "phone" address
|
||||
# Not actually for phones
|
||||
@ -32,38 +33,35 @@ def setup(bot):
|
||||
async def get_addr_config(addr):
|
||||
return await bot.database.execute_fetchone("SELECT * FROM telephone_config WHERE id = ?", (addr,))
|
||||
|
||||
@bot.listen("on_message")
|
||||
async def forward_call_messages(message):
|
||||
channel = message.channel.id
|
||||
if (message.author.discriminator == "0000" and message.author.bot) or message.author == bot.user or message.content == "": # check if webhook, from itself, or only has embeds
|
||||
return
|
||||
calls = await bot.database.execute_fetchall("""SELECT tcf.channel_id AS from_channel, tct.channel_id AS to_channel,
|
||||
tcf.webhook AS from_webhook, tct.webhook AS to_webhook FROM calls
|
||||
JOIN telephone_config AS tcf ON tcf.id = calls.from_id JOIN telephone_config AS tct ON tct.id = calls.to_id
|
||||
WHERE from_channel = ? OR to_channel = ?""", (channel, channel))
|
||||
if calls == []: return
|
||||
async def send_to(call):
|
||||
if call["from_channel"] == channel:
|
||||
other_channel, other_webhook = call["to_channel"], call["to_webhook"]
|
||||
else:
|
||||
other_channel, other_webhook = call["from_channel"], call["from_webhook"]
|
||||
|
||||
async def send_normal_message():
|
||||
m = f"**{message.author.name}**: "
|
||||
m += message.content[:2000 - len(m)]
|
||||
await bot.get_channel(other_channel).send(m)
|
||||
|
||||
if other_webhook:
|
||||
@telephone.command(brief="Link to other channels", help="""Connect to another channel on Discord or any supported bridges.
|
||||
Virtual channels also exist.
|
||||
""")
|
||||
@commands.check(util.admin_check)
|
||||
async def link(ctx, target_type, target_id):
|
||||
try:
|
||||
await discord.Webhook.from_url(other_webhook, adapter=discord.AsyncWebhookAdapter(bot.http._HTTPClient__session)).send(
|
||||
content=message.content, username=message.author.name, avatar_url=message.author.avatar_url,
|
||||
allowed_mentions=discord.AllowedMentions(everyone=False, roles=False, users=False))
|
||||
except discord.errors.NotFound:
|
||||
logging.warn("channel %d webhook missing", other_channel)
|
||||
await send_normal_message()
|
||||
else: await send_normal_message()
|
||||
target_id = int(target_id)
|
||||
except ValueError: pass
|
||||
await eventbus.add_bridge_link(bot.database, ("discord", ctx.channel.id), (target_type, util.extract_codeblock(target_id)))
|
||||
await ctx.send(f"Link established.")
|
||||
pass
|
||||
|
||||
await asyncio.gather(*map(send_to, calls))
|
||||
@telephone.command(brief="Undo link commands.")
|
||||
@commands.check(util.admin_check)
|
||||
async def unlink(ctx, target_type, target_id):
|
||||
try:
|
||||
target_id = int(target_id)
|
||||
except ValueError: pass
|
||||
await eventbus.remove_bridge_link(bot.database, ("discord", ctx.channel.id), (target_type, util.extract_codeblock(target_id)))
|
||||
await ctx.send(f"Successfully deleted.")
|
||||
pass
|
||||
|
||||
@telephone.command(brief="Generate a webhook")
|
||||
@commands.check(util.admin_check)
|
||||
async def init_webhook(ctx):
|
||||
webhook = (await ctx.channel.create_webhook(name="ABR webhook", reason=f"requested by {ctx.author.name}")).url
|
||||
await bot.database.execute("INSERT OR REPLACE INTO discord_webhooks VALUES (?, ?)", (ctx.channel.id, webhook))
|
||||
await bot.database.commit()
|
||||
await ctx.send("Done.")
|
||||
|
||||
@telephone.command()
|
||||
@commands.check(util.server_mod_check)
|
||||
@ -76,6 +74,7 @@ def setup(bot):
|
||||
if not info or not webhook:
|
||||
try:
|
||||
webhook = (await ctx.channel.create_webhook(name="incoming message display", reason="configure for apiotelephone")).url
|
||||
await bot.database.execute("INSERT OR REPLACE INTO discord_webhooks VALUES (?, ?)", (ctx.channel.id, webhook))
|
||||
await ctx.send("Created webhook.")
|
||||
except discord.Forbidden as f:
|
||||
logging.warn("could not create webhook in #%s %s", ctx.channel.name, ctx.guild.name, exc_info=f)
|
||||
@ -84,7 +83,7 @@ def setup(bot):
|
||||
await bot.database.commit()
|
||||
await ctx.send("Configured.")
|
||||
|
||||
@telephone.command(aliases=["call"])
|
||||
@telephone.command(aliases=["call"], brief="Dial another telephone channel.")
|
||||
async def dial(ctx, address):
|
||||
# basic checks - ensure this is a phone channel and has no other open calls
|
||||
channel_info = await get_channel_config(ctx.channel.id)
|
||||
@ -127,6 +126,7 @@ def setup(bot):
|
||||
if em == "✅": # accept call
|
||||
await bot.database.execute("INSERT INTO calls VALUES (?, ?, ?)", (originating_address, address, util.timestamp()))
|
||||
await bot.database.commit()
|
||||
await eventbus.add_bridge_link(bot.database, ("discord", ctx.channel.id), ("discord", recv_channel.id))
|
||||
await asyncio.gather(
|
||||
ctx.send(embed=util.info_embed("Outgoing call", "Call accepted and connected.")),
|
||||
recv_channel.send(embed=util.info_embed("Incoming call", "Call accepted and connected."))
|
||||
@ -134,10 +134,7 @@ def setup(bot):
|
||||
elif em == "❎": # drop call
|
||||
await ctx.send(embed=util.error_embed("Your call was declined.", "Call declined"))
|
||||
|
||||
async def get_calls(addr):
|
||||
pass
|
||||
|
||||
@telephone.command(aliases=["disconnect", "quit"])
|
||||
@telephone.command(aliases=["disconnect", "quit"], brief="Disconnect latest call.")
|
||||
async def hangup(ctx):
|
||||
channel_info = await get_channel_config(ctx.channel.id)
|
||||
addr = channel_info["id"]
|
||||
@ -155,13 +152,14 @@ def setup(bot):
|
||||
await bot.database.execute("DELETE FROM calls WHERE to_id = ? AND from_id = ?", (addr, other))
|
||||
await bot.database.commit()
|
||||
other_channel = (await get_addr_config(other))["channel_id"]
|
||||
await eventbus.remove_bridge_link(bot.database, ("discord", other_channel), ("discord", ctx.channel.id))
|
||||
|
||||
await asyncio.gather(
|
||||
ctx.send(embed=util.info_embed("Hung up", f"Call to {other} disconnected.")),
|
||||
bot.get_channel(other_channel).send(embed=util.info_embed("Hung up", f"Call to {addr} disconnected."))
|
||||
)
|
||||
|
||||
@telephone.command(aliases=["status"])
|
||||
@telephone.command(aliases=["status"], brief="List inbound/outbound calls.")
|
||||
async def info(ctx):
|
||||
channel_info = await get_channel_config(ctx.channel.id)
|
||||
if not channel_info: return await ctx.send(embed=util.info_embed("Phone status", "Not a phone channel"))
|
||||
|
19
src/util.py
19
src/util.py
@ -11,6 +11,7 @@ import toml
|
||||
import os.path
|
||||
from discord.ext import commands
|
||||
import hashlib
|
||||
import time
|
||||
|
||||
config = {}
|
||||
|
||||
@ -21,6 +22,7 @@ def load_config():
|
||||
load_config()
|
||||
|
||||
def timestamp(): return int(datetime.datetime.now(tz=datetime.timezone.utc).timestamp())
|
||||
def timestamp_µs(): return int(datetime.datetime.now(tz=datetime.timezone.utc).timestamp() * 1e6)
|
||||
|
||||
prefixes = {
|
||||
# big SI prefixes
|
||||
@ -263,3 +265,20 @@ extensions = (
|
||||
"commands",
|
||||
"userdata"
|
||||
)
|
||||
|
||||
# https://github.com/SawdustSoftware/simpleflake/blob/master/simpleflake/simpleflake.py
|
||||
|
||||
SIMPLEFLAKE_EPOCH = 946702800
|
||||
#field lengths in bits
|
||||
SIMPLEFLAKE_TIMESTAMP_LENGTH = 43
|
||||
SIMPLEFLAKE_RANDOM_LENGTH = 21
|
||||
#left shift amounts
|
||||
SIMPLEFLAKE_RANDOM_SHIFT = 0
|
||||
SIMPLEFLAKE_TIMESTAMP_SHIFT = 21
|
||||
|
||||
def random_id():
|
||||
second_time = time.time()
|
||||
second_time -= SIMPLEFLAKE_EPOCH
|
||||
millisecond_time = int(second_time * 1000)
|
||||
randomness = random.getrandbits(SIMPLEFLAKE_RANDOM_LENGTH)
|
||||
return (millisecond_time << SIMPLEFLAKE_TIMESTAMP_SHIFT) + randomness
|
Loading…
Reference in New Issue
Block a user