autobotrobot/src/eventbus.py

106 lines
4.2 KiB
Python

import asyncio
import prometheus_client
import dataclasses
import typing
import collections
import logging
import discord
import util
@dataclasses.dataclass
class AuthorInfo:
name: str
id: any
avatar_url: str = None
deprioritize: bool = False
def unpack_dataclass_without(d, without):
dct = dict([(field, getattr(d, field)) for field in type(d).__dataclass_fields__])
del dct[without]
return dct
@dataclasses.dataclass
class Message:
author: AuthorInfo
message: list[typing.Union[str, dict]]
source: (str, any)
id: int
attachments: list[discord.Attachment]
reply: (AuthorInfo, str) = None
evbus_messages = prometheus_client.Counter("abr_evbus_messages", "Messages processed by event bus", ["source_type"])
evbus_messages_dropped = prometheus_client.Counter("abr_evbus_messages_dropped", "Messages received by event bus but dropped by rate limits", ["source_type"])
# maps each bridge destination type (discord/APIONET/etc) to the listeners for it
listeners = collections.defaultdict(set)
# maintains a list of all the unidirectional links between channels - key is source, values are targets
links = collections.defaultdict(set)
def find_all_destinations(source):
visited = set()
targets = set(links[source])
while len(targets) > 0:
current = targets.pop()
targets.update(adjacent for adjacent in links[current] if not adjacent in visited)
visited.add(current)
return visited
# 5 messages per 5 seconds from each input channel
RATE = 10.0
PER = 5000000.0 # µs
RLData = collections.namedtuple("RLData", ["allowance", "last_check"])
rate_limiting = collections.defaultdict(lambda: RLData(RATE, util.timestamp()))
async def push(msg: Message):
destinations = find_all_destinations(msg.source)
if len(destinations) > 0:
# "token bucket" rate limiting algorithm - max 10 messages per 5 seconds (half that for bots)
# TODO: maybe separate buckets for bot and unbot?
current = util.timestamp_µs()
time_passed = current - rate_limiting[msg.source].last_check
allowance = rate_limiting[msg.source].allowance
allowance += time_passed * (RATE / PER)
if allowance > RATE:
allowance = RATE
rate_limiting[msg.source] = RLData(allowance, current)
if allowance < 1:
evbus_messages_dropped.labels(msg.source[0]).inc()
return
allowance -= 2.0 if msg.author.deprioritize else 1.0
rate_limiting[msg.source] = RLData(allowance, current)
evbus_messages.labels(msg.source[0]).inc()
for dest in destinations:
if dest == msg.source: continue
dest_type, dest_channel = dest
for listener in listeners[dest_type]:
asyncio.ensure_future(listener(dest_channel, msg))
def add_listener(s, l):
listeners[s].add(l)
return lambda: listeners[s].remove(l)
async def add_bridge_link(db, c1, c2, cause=None, bidirectional=True):
logging.info("Bridging %s and %s (bidirectional: %s)", repr(c1), repr(c2), bidirectional)
links[c1].add(c2)
if bidirectional: links[c2].add(c1)
await db.execute("INSERT INTO links VALUES (?, ?, ?, ?, ?, ?) ON CONFLICT DO NOTHING", (c1[0], c1[1], c2[0], c2[1], util.timestamp(), cause))
if bidirectional: await db.execute("INSERT INTO links VALUES (?, ?, ?, ?, ?, ?) ON CONFLICT DO NOTHING", (c2[0], c2[1], c1[0], c1[1], util.timestamp(), cause))
await db.commit()
async def remove_bridge_link(db, c1, c2, bidirectional=True):
logging.info("Unbridging %s and %s (bidirectional: %s)", repr(c1), repr(c2), bidirectional)
links[c1].remove(c2)
if bidirectional: links[c2].remove(c1)
await db.execute("DELETE FROM links WHERE (to_type = ? AND to_id = ?) AND (from_type = ? AND from_id = ?)", (c1[0], c1[1], c2[0], c2[1]))
if bidirectional: await db.execute("DELETE FROM links WHERE (to_type = ? AND to_id = ?) AND (from_type = ? AND from_id = ?)", (c2[0], c2[1], c1[0], c1[1]))
await db.commit()
async def initial_load(db):
rows = await db.execute_fetchall("SELECT * FROM links")
for row in rows:
links[(row["from_type"], row["from_id"])].add((row["to_type"], row["to_id"]))
logging.info("Loaded %d links", len(rows))