mirror of
https://github.com/osmarks/autobotrobot
synced 2025-04-16 07:43:14 +00:00
rearrange a bit, add machine learning for general trendiness
This commit is contained in:
parent
a7f6156727
commit
134a6e6e31
@ -1,44 +0,0 @@
|
||||
import aiohttp
|
||||
import discord
|
||||
import asyncio
|
||||
import logging
|
||||
import discord.ext.commands as commands
|
||||
import html.parser
|
||||
|
||||
class Parser(html.parser.HTMLParser):
|
||||
def __init__(self):
|
||||
self.links = []
|
||||
super().__init__()
|
||||
|
||||
def handle_starttag(self, tag, attrs):
|
||||
attrs = dict(attrs)
|
||||
if tag == "a" and attrs.get("class") == "result__a" and "https://duckduckgo.com/y.js?ad_provider" not in attrs["href"]:
|
||||
self.links.append(attrs["href"])
|
||||
|
||||
class DuckDuckGo(commands.Cog):
|
||||
def __init__(self, bot):
|
||||
self.bot = bot
|
||||
self.session = aiohttp.ClientSession()
|
||||
|
||||
@commands.command()
|
||||
async def search(self, ctx, *, query):
|
||||
async with ctx.typing():
|
||||
async with self.session.post("https://html.duckduckgo.com/html/", data={ "q": query, "d": "" }) as resp:
|
||||
if resp.history:
|
||||
await ctx.send(resp.url, reference=ctx.message)
|
||||
else:
|
||||
p = Parser()
|
||||
txt = await resp.text()
|
||||
p.feed(txt)
|
||||
p.close()
|
||||
try:
|
||||
return await ctx.send(p.links[0], reference=ctx.message)
|
||||
except IndexError:
|
||||
return await ctx.send("No results.", reference=ctx.message)
|
||||
|
||||
def cog_unload(self):
|
||||
asyncio.create_task(self.session.close())
|
||||
|
||||
def setup(bot):
|
||||
cog = DuckDuckGo(bot)
|
||||
bot.add_cog(cog)
|
@ -62,7 +62,8 @@ async def on_command_error(ctx, err):
|
||||
logging.error("Command error occured (in %s)", ctx.invoked_with, exc_info=err)
|
||||
await ctx.send(embed=util.error_embed(util.gen_codeblock(trace), title=f"Internal error in {ctx.invoked_with}"))
|
||||
await achievement.achieve(ctx.bot, ctx.message, "error")
|
||||
except Exception as e: print("meta-error:", e)
|
||||
except Exception as e:
|
||||
logging.exception("Error in error handling!", e)
|
||||
|
||||
@bot.check
|
||||
async def andrew_bad(ctx):
|
||||
@ -103,7 +104,7 @@ async def run_bot():
|
||||
bot.load_extension(ext)
|
||||
await bot.start(config["token"])
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
loop = asyncio.get_event_loop()
|
||||
loop.create_task(prometheus_async.aio.web.start_http_server(port=config["metrics_port"]))
|
||||
try:
|
||||
|
122
src/search.py
Normal file
122
src/search.py
Normal file
@ -0,0 +1,122 @@
|
||||
import aiohttp
|
||||
import discord
|
||||
import asyncio
|
||||
import logging
|
||||
import discord.ext.commands as commands
|
||||
import html.parser
|
||||
import collections
|
||||
import util
|
||||
import io
|
||||
import concurrent.futures
|
||||
|
||||
def pool_load_model(model):
|
||||
from transformers import pipeline
|
||||
qa_pipeline = pipeline("question-answering", model)
|
||||
globals()["qa_pipeline"] = qa_pipeline
|
||||
|
||||
def pool_operate(question, context):
|
||||
qa_pipeline = globals()["qa_pipeline"]
|
||||
return qa_pipeline(question=question, context=context)
|
||||
|
||||
class Parser(html.parser.HTMLParser):
|
||||
def __init__(self):
|
||||
self.links = []
|
||||
super().__init__()
|
||||
|
||||
def handle_starttag(self, tag, attrs):
|
||||
attrs = dict(attrs)
|
||||
if tag == "a" and attrs.get("class") == "result__a" and "https://duckduckgo.com/y.js?ad_provider" not in attrs["href"]:
|
||||
self.links.append(attrs["href"])
|
||||
|
||||
class Search(commands.Cog):
|
||||
def __init__(self, bot):
|
||||
self.bot = bot
|
||||
self.session = aiohttp.ClientSession()
|
||||
self.wp_cache = collections.OrderedDict()
|
||||
self.wp_search_cache = collections.OrderedDict()
|
||||
self.pool = None
|
||||
|
||||
def ensure_pool(self):
|
||||
if self.pool is not None: return
|
||||
self.pool = concurrent.futures.ProcessPoolExecutor(max_workers=1, initializer=pool_load_model, initargs=(util.config["ir"]["model"],))
|
||||
|
||||
@commands.command()
|
||||
async def search(self, ctx, *, query):
|
||||
"Search using DuckDuckGo. Returns the first result as a link."
|
||||
async with ctx.typing():
|
||||
async with self.session.post("https://html.duckduckgo.com/html/", data={ "q": query, "d": "" }) as resp:
|
||||
if resp.history:
|
||||
await ctx.send(resp.url, reference=ctx.message)
|
||||
else:
|
||||
p = Parser()
|
||||
txt = await resp.text()
|
||||
p.feed(txt)
|
||||
p.close()
|
||||
try:
|
||||
return await ctx.send(p.links[0], reference=ctx.message)
|
||||
except IndexError:
|
||||
return await ctx.send("No results.", reference=ctx.message)
|
||||
|
||||
async def wp_search(self, query):
|
||||
async with self.session.get("https://en.wikipedia.org/w/api.php",
|
||||
params={ "action": "query", "list": "search", "srsearch": query, "utf8": "1", "format": "json", "srlimit": 1 }) as resp:
|
||||
data = (await resp.json())["query"]["search"]
|
||||
if len(data) > 0: return data[0]["title"]
|
||||
else: return None
|
||||
|
||||
async def wp_fetch(self, page, *, fallback=True):
|
||||
async def fallback_to_search():
|
||||
if fallback:
|
||||
new_page = await self.wp_search(page)
|
||||
if len(self.wp_search_cache) > util.config["ir"]["cache_size"]:
|
||||
self.wp_search_cache.popitem(last=False)
|
||||
self.wp_search_cache[page] = new_page
|
||||
if new_page is None: return None
|
||||
return await self.wp_fetch(new_page, fallback=False)
|
||||
|
||||
if page in self.wp_cache: return self.wp_cache[page]
|
||||
if page in self.wp_search_cache:
|
||||
if self.wp_search_cache[page] is None: return None
|
||||
return await self.wp_fetch(self.wp_search_cache[page], fallback=False)
|
||||
async with self.session.get("https://en.wikipedia.org/w/api.php",
|
||||
params={ "action": "query", "format": "json", "titles": page, "prop": "extracts", "exintro": 1, "explaintext": 1 }) as resp:
|
||||
data = (await resp.json())["query"]
|
||||
if "-1" in data["pages"]:
|
||||
return await fallback_to_search()
|
||||
else:
|
||||
content = next(iter(data["pages"].values()))["extract"]
|
||||
if not content: return await fallback_to_search()
|
||||
if len(self.wp_cache) > util.config["ir"]["cache_size"]:
|
||||
self.wp_cache.popitem(last=False)
|
||||
self.wp_cache[page] = content
|
||||
return content
|
||||
|
||||
@commands.command(aliases=["wp"])
|
||||
async def wikipedia(self, ctx, *, page):
|
||||
"Have you ever wanted the first section of a Wikipedia page? Obviously, yes. This gets that."
|
||||
content = await self.wp_fetch(page)
|
||||
if content is None:
|
||||
await ctx.send("Not found.")
|
||||
else:
|
||||
f = io.BytesIO(content.encode("utf-8"))
|
||||
file = discord.File(f, "content.txt")
|
||||
await ctx.send(file=file)
|
||||
|
||||
@commands.command()
|
||||
async def experimental_qa(self, ctx, page, *, query):
|
||||
"Answer questions from the first part of a Wikipedia page, using a finetuned ALBERT model."
|
||||
self.ensure_pool()
|
||||
loop = asyncio.get_running_loop()
|
||||
async with ctx.typing():
|
||||
content = await self.wp_fetch(page)
|
||||
result = await loop.run_in_executor(self.pool, pool_operate, query, content)
|
||||
await ctx.send("%s (%f)" % (result["answer"].strip(), result["score"]))
|
||||
|
||||
def cog_unload(self):
|
||||
asyncio.create_task(self.session.close())
|
||||
if self.pool is not None:
|
||||
self.pool.shutdown()
|
||||
|
||||
def setup(bot):
|
||||
cog = Search(bot)
|
||||
bot.add_cog(cog)
|
@ -277,7 +277,7 @@ extensions = (
|
||||
"commands",
|
||||
"userdata",
|
||||
"irc_link",
|
||||
"duckduckgo",
|
||||
"search",
|
||||
"esoserver"
|
||||
)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user