mirror of
https://github.com/osmarks/autobotrobot
synced 2024-10-30 05:26:17 +00:00
autogollark refactor to new bot
This commit is contained in:
parent
7f06436252
commit
a311f30602
103
src/autogollark.py
Normal file
103
src/autogollark.py
Normal file
@ -0,0 +1,103 @@
|
||||
import discord
|
||||
import logging
|
||||
import asyncio
|
||||
import aiohttp
|
||||
import random
|
||||
import prometheus_client
|
||||
from datetime import datetime
|
||||
import discord.ext.commands as commands
|
||||
|
||||
import util
|
||||
|
||||
config = util.config
|
||||
|
||||
intents = discord.Intents.default()
|
||||
intents.members = True
|
||||
intents.presences = True
|
||||
intents.message_content = True
|
||||
|
||||
bot = discord.Client(allowed_mentions=discord.AllowedMentions(everyone=False, users=True, roles=True), intents=intents)
|
||||
|
||||
cleaner = commands.clean_content()
|
||||
def clean(ctx, text):
|
||||
return cleaner.convert(ctx, text)
|
||||
|
||||
AUTOGOLLARK_MARKER = "\u200b"
|
||||
|
||||
async def serialize_history(ctx, n=20):
|
||||
prompt = []
|
||||
seen = set()
|
||||
async for message in ctx.channel.history(limit=n):
|
||||
display_name = message.author.display_name
|
||||
content = message.content
|
||||
if not content and message.embeds:
|
||||
content = message.embeds[0].title
|
||||
elif not content and message.attachments:
|
||||
content = "[attachments]"
|
||||
if not content:
|
||||
continue
|
||||
if message.content.startswith(AUTOGOLLARK_MARKER):
|
||||
content = message.content.removeprefix(AUTOGOLLARK_MARKER)
|
||||
if message.author == bot.user:
|
||||
display_name = config["autogollark"]["name"]
|
||||
|
||||
if content in seen: continue
|
||||
seen.add(content)
|
||||
prompt.append(f"[{util.render_time(message.created_at)}] {display_name}: {content}\n")
|
||||
if sum(len(x) for x in prompt) > util.config["ai"]["max_len"]:
|
||||
break
|
||||
prompt.reverse()
|
||||
return prompt
|
||||
|
||||
async def autogollark(ctx, session):
|
||||
display_name = config["autogollark"]["name"]
|
||||
prompt = await serialize_history(ctx, n=20)
|
||||
prompt.append(f"[{util.render_time(datetime.utcnow())}] {display_name}:")
|
||||
conversation = "".join(prompt)
|
||||
# retrieve gollark data from backend
|
||||
gollark_chunks = []
|
||||
async with session.post(util.config["autogollark"]["api"], json={"query": conversation}) as res:
|
||||
for chunk in (await res.json()):
|
||||
gollark_chunk = []
|
||||
if sum(len(y) for x in gollark_chunks for y in x) > util.config["autogollark"]["max_context_chars"]: gollark_chunks.pop(0)
|
||||
for message in chunk:
|
||||
dt = datetime.fromisoformat(message["timestamp"])
|
||||
line = f"[{util.render_time(dt)}] {message['author'] or display_name}: {await clean(ctx, message['contents'])}\n"
|
||||
gollark_chunk.append(line)
|
||||
|
||||
# ugly hack to remove duplicates
|
||||
ds = []
|
||||
for chunk in gollark_chunks:
|
||||
if line in chunk and line != "---\n": ds.append(chunk)
|
||||
for d in ds:
|
||||
try:
|
||||
gollark_chunks.remove(d)
|
||||
except ValueError: pass
|
||||
|
||||
gollark_chunk.append("---\n")
|
||||
gollark_chunks.append(gollark_chunk)
|
||||
|
||||
gollark_data = "".join("".join(x) for x in gollark_chunks)
|
||||
|
||||
print(gollark_data + conversation)
|
||||
|
||||
# generate response
|
||||
generation = await util.generate(session, gollark_data + conversation, stop=["\n["])
|
||||
while True:
|
||||
new_generation = generation.strip().strip("[\n ")
|
||||
new_generation = new_generation.removesuffix("---")
|
||||
if new_generation == generation:
|
||||
break
|
||||
generation = new_generation
|
||||
if generation:
|
||||
await ctx.send(generation)
|
||||
|
||||
@bot.event
|
||||
async def on_message(message):
|
||||
if message.channel.id in util.config["autogollark"]["channels"] and not message.author == bot.user:
|
||||
await autogollark(commands.Context(bot=bot, message=message, prefix="", view=None), bot.session)
|
||||
|
||||
async def run_bot():
|
||||
bot.session = aiohttp.ClientSession()
|
||||
logging.info("Autogollark starting")
|
||||
await bot.start(config["autogollark"]["token"])
|
@ -20,6 +20,7 @@ import util
|
||||
import eventbus
|
||||
import irc_link
|
||||
import achievement
|
||||
import autogollark
|
||||
|
||||
config = util.config
|
||||
|
||||
@ -30,7 +31,7 @@ intents.members = True
|
||||
intents.presences = True
|
||||
intents.message_content = True
|
||||
|
||||
bot = commands.Bot(command_prefix=commands.when_mentioned_or(config["prefix"]), description="AutoBotRobot, the omniscient, omnipotent, omnibenevolent Discord bot by gollark." + util.config.get("description_suffix", ""),
|
||||
bot = commands.Bot(command_prefix=commands.when_mentioned_or(config["prefix"]), description="AutoBotRobot, the omniscient, omnipotent, omnibenevolent Discord bot by gollark." + util.config.get("description_suffix", ""),
|
||||
case_insensitive=True, allowed_mentions=discord.AllowedMentions(everyone=False, users=True, roles=True), intents=intents)
|
||||
bot._skip_check = lambda x, y: False
|
||||
|
||||
@ -78,7 +79,7 @@ async def andrew_bad(ctx):
|
||||
@bot.event
|
||||
async def on_ready():
|
||||
logging.info("Connected as " + bot.user.name)
|
||||
await bot.change_presence(status=discord.Status.online,
|
||||
await bot.change_presence(status=discord.Status.online,
|
||||
activity=discord.Activity(name=f"{config['prefix']}help", type=discord.ActivityType.listening))
|
||||
|
||||
visible_users = prometheus_client.Gauge("abr_visible_users", "Users the bot can see")
|
||||
@ -108,6 +109,7 @@ async def run_bot():
|
||||
for ext in util.extensions:
|
||||
logging.info("Loaded %s", ext)
|
||||
bot.load_extension(ext)
|
||||
asyncio.create_task(autogollark.run_bot())
|
||||
await bot.start(config["token"])
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -13,16 +13,10 @@ from pathlib import Path
|
||||
import tio
|
||||
import util
|
||||
|
||||
def render(dt: datetime):
|
||||
return f"{dt.hour:02}:{dt.minute:02}"
|
||||
|
||||
cleaner = commands.clean_content()
|
||||
def clean(ctx, text):
|
||||
return cleaner.convert(ctx, text)
|
||||
|
||||
AUTOGOLLARK_MARKER = "\u200b"
|
||||
AUTOGOLLARK_GOLLARK = "autogollark"
|
||||
|
||||
class Sentience(commands.Cog):
|
||||
def __init__(self, bot):
|
||||
self.bot = bot
|
||||
@ -48,13 +42,9 @@ class Sentience(commands.Cog):
|
||||
if not content:
|
||||
continue
|
||||
if message.author == self.bot.user:
|
||||
if message.content.startswith(AUTOGOLLARK_MARKER):
|
||||
content = message.content.removeprefix(AUTOGOLLARK_MARKER)
|
||||
display_name = AUTOGOLLARK_GOLLARK
|
||||
|
||||
if content in seen: continue
|
||||
seen.add(content)
|
||||
prompt.append(f"[{render(message.created_at)}] {display_name}: {content}\n")
|
||||
prompt.append(f"[{util.render_time(message.created_at)}] {display_name}: {content}\n")
|
||||
if sum(len(x) for x in prompt) > util.config["ai"]["max_len"]:
|
||||
break
|
||||
prompt.reverse()
|
||||
@ -63,60 +53,11 @@ class Sentience(commands.Cog):
|
||||
@commands.command(help="Highly advanced AI Assistant.")
|
||||
async def ai(self, ctx, *, query=None):
|
||||
prompt = await self.serialize_history(ctx)
|
||||
prompt.append(f'[{render(datetime.utcnow())}] {util.config["ai"]["own_name"]}:')
|
||||
prompt.append(f'[{util.render_time(datetime.utcnow())}] {util.config["ai"]["own_name"]}:')
|
||||
generation = await util.generate(self.session, util.config["ai"]["prompt_start"] + "".join(prompt))
|
||||
if generation.strip():
|
||||
await ctx.send(generation.strip())
|
||||
|
||||
@commands.command(help="Emulated gollark instance.", aliases=["gollark", "ag"])
|
||||
async def autogollark(self, ctx):
|
||||
prompt = await self.serialize_history(ctx, n=50)
|
||||
prompt.append(f"[{render(datetime.utcnow())}] {AUTOGOLLARK_GOLLARK}:")
|
||||
conversation = "".join(prompt)
|
||||
# retrieve gollark data from backend
|
||||
gollark_chunks = []
|
||||
async with self.session.post(util.config["ai"]["autogollark_server"], json={"query": conversation}) as res:
|
||||
for chunk in (await res.json()):
|
||||
gollark_chunk = []
|
||||
if sum(len(y) for x in gollark_chunks for y in x) > util.config["ai"]["max_gollark_len"]: gollark_chunks.pop(0)
|
||||
for message in chunk:
|
||||
dt = datetime.fromisoformat(message["timestamp"])
|
||||
line = f"[{render(dt)}] {message['author'] or AUTOGOLLARK_GOLLARK}: {await clean(ctx, message['contents'])}\n"
|
||||
gollark_chunk.append(line)
|
||||
|
||||
# ugly hack to remove duplicates
|
||||
ds = []
|
||||
for chunk in gollark_chunks:
|
||||
if line in chunk and line != "---\n": ds.append(chunk)
|
||||
for d in ds:
|
||||
print("delete chunk", d)
|
||||
try:
|
||||
gollark_chunks.remove(d)
|
||||
except ValueError: pass
|
||||
|
||||
gollark_chunk.append("---\n")
|
||||
gollark_chunks.append(gollark_chunk)
|
||||
|
||||
gollark_data = "".join("".join(x) for x in gollark_chunks)
|
||||
|
||||
print(gollark_data + conversation)
|
||||
|
||||
# generate response
|
||||
generation = await util.generate(self.session, gollark_data + conversation, stop=["\n["])
|
||||
while True:
|
||||
new_generation = generation.strip().strip("[\n ")
|
||||
new_generation = new_generation.removesuffix("---")
|
||||
if new_generation == generation:
|
||||
break
|
||||
generation = new_generation
|
||||
if generation:
|
||||
await ctx.send(AUTOGOLLARK_MARKER + generation)
|
||||
|
||||
@commands.Cog.listener("on_message")
|
||||
async def autogollark_listener(self, message):
|
||||
if message.channel.id in util.config["ai"]["autogollark_channels"] and not message.content.startswith(AUTOGOLLARK_MARKER):
|
||||
await self.autogollark(commands.Context(bot=self.bot, message=message, prefix="", view=None))
|
||||
|
||||
@commands.command(help="Search meme library.", aliases=["memes"])
|
||||
async def meme(self, ctx, *, query=None):
|
||||
search_many = ctx.invoked_with == "memes"
|
||||
|
65
src/util.py
65
src/util.py
@ -18,12 +18,18 @@ import collections
|
||||
import aiohttp
|
||||
import string
|
||||
from pathlib import Path
|
||||
import sys
|
||||
import dataclasses
|
||||
import logging
|
||||
import prometheus_client
|
||||
|
||||
config = {}
|
||||
|
||||
config_path = os.path.join(os.path.dirname(__file__), "../config.toml")
|
||||
|
||||
# update in place for runtime config reload
|
||||
def load_config():
|
||||
for k, v in toml.load(open(os.path.join(os.path.dirname(__file__), "../config.toml"), "r")).items(): config[k] = v
|
||||
for k, v in toml.load(open(config_path, "r")).items(): config[k] = v
|
||||
|
||||
load_config()
|
||||
|
||||
@ -74,7 +80,7 @@ fractional_tu_mappings = {
|
||||
}
|
||||
|
||||
def rpartfor(u):
|
||||
if u[0][-1] == "s":
|
||||
if u[0][-1] == "s":
|
||||
l = [u[0] + "?"]
|
||||
l.extend(u[1:])
|
||||
else: l = u
|
||||
@ -203,7 +209,7 @@ lyrictable_raw = {
|
||||
}
|
||||
lyrictable = str.maketrans({v: k for k, v in lyrictable_raw.items()})
|
||||
|
||||
apioinfixes = ["cryo", "pyro", "chrono", "meta", "anarcho", "arachno", "aqua", "accelero", "hydro", "radio", "xeno", "morto", "thanato", "memeto",
|
||||
apioinfixes = ["cryo", "pyro", "chrono", "meta", "anarcho", "arachno", "aqua", "accelero", "hydro", "radio", "xeno", "morto", "thanato", "memeto",
|
||||
"contra", "umbra", "macrono", "acantho", "acousto", "aceto", "acro", "aeolo", "hexa", "aero", "aesthio", "agro", "ferro", "alumino",
|
||||
"ammonio", "anti", "ankylo", "aniso", "annulo", "apo", "abio", "archeo", "argento", "arseno", "arithmo", "astro", "atlo", "auto", "axo",
|
||||
"azido", "bacillo", "bario", "balneo", "baryo", "basi", "benzo", "bismuto", "boreo", "biblio", "spatio", "boro", "bromo", "brachio",
|
||||
@ -229,7 +235,7 @@ apioinfixes = ["cryo", "pyro", "chrono", "meta", "anarcho", "arachno", "aqua", "
|
||||
"temporo", "tera", "tetra", "thalasso", "thaumato", "thermo", "tephro", "tessera", "thio", "titano", "tomo", "topo", "tono", "tungsto",
|
||||
"turbo", "tyranno", "ultra", "undeca", "tribo", "trito", "tropho", "tropo", "uni", "urano", "video", "viro", "visuo", "xantho", "xenna",
|
||||
"xeri", "xipho", "xylo", "xyro", "yocto", "yttro", "zepto", "zetta", "zinco", "zirco", "zoo", "zono", "zygo", "templateo", "rustaceo", "mnesto",
|
||||
"amnesto", "cetaceo", "anthropo", "ioctlo", "crustaceo", "citrono", "apeiro", "Ægypto", "equi", "anglo", "atto", "ortho", "macro", "micro", "auro",
|
||||
"amnesto", "cetaceo", "anthropo", "ioctlo", "crustaceo", "citrono", "apeiro", "Ægypto", "equi", "anglo", "atto", "ortho", "macro", "micro", "auro",
|
||||
"Australo", "dys", "eu", "giga", "Inver", "omni", "semi", "Scando", "sub", "super", "trans", "ur-", "un", "mid", "mis", "ante", "intra"]
|
||||
apiosuffixes = ["hazard", "form"]
|
||||
|
||||
@ -339,17 +345,55 @@ def chunks(source, length):
|
||||
for i in range(0, len(source), length):
|
||||
yield source[i : i+length]
|
||||
|
||||
async def generate(sess: aiohttp.ClientSession, prompt, stop=["\n"]):
|
||||
async with sess.post(config["ai"]["llm_backend"], json={
|
||||
@dataclasses.dataclass
|
||||
class BackendStatus:
|
||||
consecutive_failures: int = 0
|
||||
avoid_until: datetime.datetime | None = None
|
||||
|
||||
last_failures = {}
|
||||
|
||||
backend_successes = prometheus_client.Counter("abr_llm_backend_success", "Number of successful requests to LLM backends", labelnames=["backend"])
|
||||
backend_failures = prometheus_client.Counter("abr_llm_backend_failure", "Number of failed requests to LLM backends", labelnames=["backend"])
|
||||
|
||||
async def generate_raw(sess: aiohttp.ClientSession, backend, prompt, stop):
|
||||
async with sess.post(backend["url"], json={
|
||||
"prompt": prompt,
|
||||
"max_tokens": 200,
|
||||
"stop": stop,
|
||||
"client": "abr",
|
||||
**config["ai"].get("params", {})
|
||||
}, headers=config["ai"].get("headers", {})) as res:
|
||||
**backend.get("params", {})
|
||||
}, headers=backend.get("headers", {}), timeout=aiohttp.ClientTimeout(total=30)) as res:
|
||||
data = await res.json()
|
||||
return data["choices"][0]["text"]
|
||||
|
||||
async def generate(sess: aiohttp.ClientSession, prompt, stop=["\n"]):
|
||||
backends = config["ai"]["llm_backend"]
|
||||
for backend in backends:
|
||||
if backend["url"] not in last_failures:
|
||||
last_failures[backend["url"]] = BackendStatus()
|
||||
status = last_failures[backend["url"]]
|
||||
|
||||
now = datetime.datetime.now(datetime.UTC)
|
||||
|
||||
# high to low
|
||||
def sort_key(backend):
|
||||
failure_stats = last_failures[backend["url"]]
|
||||
return (failure_stats.avoid_until is None or failure_stats.avoid_until < now), -failure_stats.consecutive_failures, backend["priority"]
|
||||
|
||||
backends = sorted(backends, key=sort_key, reverse=True)
|
||||
|
||||
for backend in backends:
|
||||
try:
|
||||
result = await generate_raw(sess, backend, prompt, stop)
|
||||
backend_successes.labels(backend["url"]).inc()
|
||||
return result
|
||||
except Exception as e:
|
||||
backend_failures.labels(backend["url"]).inc()
|
||||
logging.warning("LLM backend %s failed: %s", backend["url"], e)
|
||||
failure_stats = last_failures[backend["url"]]
|
||||
failure_stats.avoid_until = now + datetime.timedelta(seconds=2 ** failure_stats.consecutive_failures)
|
||||
failure_stats.consecutive_failures += 1
|
||||
|
||||
filesafe_charset = string.ascii_letters + string.digits + "-"
|
||||
|
||||
TARGET_FORMAT = "jpegh"
|
||||
@ -366,4 +410,7 @@ def meme_thumbnail(results, result):
|
||||
if result[3] & format_id != 0:
|
||||
return Path(config["memetics"]["thumb_base"]) / f"{result[2]}{TARGET_FORMAT}.{results['extensions'][TARGET_FORMAT]}"
|
||||
else:
|
||||
return Path(config["memetics"]["meme_base"]) / result[1]
|
||||
return Path(config["memetics"]["meme_base"]) / result[1]
|
||||
|
||||
def render_time(dt: datetime.datetime):
|
||||
return f"{dt.hour:02}:{dt.minute:02}"
|
||||
|
Loading…
Reference in New Issue
Block a user