1
0
mirror of https://github.com/osmarks/maghammer.git synced 2024-10-28 04:46:18 +00:00
maghammer/sbert_server.py

118 lines
3.6 KiB
Python
Raw Normal View History

2024-07-18 15:28:08 +00:00
import torch
import time
import threading
from aiohttp import web
import aiohttp
import asyncio
import traceback
import umsgpack
import gc
import collections
import queue
import io
from sentence_transformers import SentenceTransformer
from prometheus_client import Counter, Histogram, REGISTRY, generate_latest
device = torch.device("cuda:0")
model_name = "./snowflake-arctic-embed-l"
model = SentenceTransformer(model_name).half().to(device)
model.eval()
print("model loaded")
MODELNAME = "sbert-snowflake-arctic-embed-l"
BS = 256
InferenceParameters = collections.namedtuple("InferenceParameters", ["text", "callback"])
items_ctr = Counter("modelserver_total_items", "Items run through model server", ["model", "modality"])
inference_time_hist = Histogram("modelserver_inftime", "Time running inference", ["model", "batch_size"])
batch_count_ctr = Counter("modelserver_batchcount", "Inference batches run", ["model"])
torch.set_grad_enabled(False)
def do_inference(params: InferenceParameters):
with torch.no_grad():
try:
text, callback = params
batch_size = text["input_ids"].shape[0]
assert batch_size <= BS, f"max batch size is {BS}"
items_ctr.labels(MODELNAME, "text").inc(batch_size)
with inference_time_hist.labels(MODELNAME, batch_size).time():
features = model(text)["sentence_embedding"]
features /= features.norm(dim=-1, keepdim=True)
features = features.cpu().numpy()
batch_count_ctr.labels(MODELNAME).inc()
callback(True, features)
except Exception as e:
traceback.print_exc()
callback(False, str(e))
finally:
torch.cuda.empty_cache()
q = queue.Queue(10)
def infer_thread():
while True:
do_inference(q.get())
app = web.Application(client_max_size=2**26)
routes = web.RouteTableDef()
@routes.post("/")
async def run_inference(request):
loop = asyncio.get_event_loop()
data = umsgpack.loads(await request.read())
event = asyncio.Event()
results = None
def callback(*argv):
nonlocal results
results = argv
loop.call_soon_threadsafe(lambda: event.set())
tokenized = model.tokenize(data["text"])
tokenized = { k: v.to(device) for k, v in tokenized.items() }
q.put_nowait(InferenceParameters(tokenized, callback))
await event.wait()
body_data = results[1]
if results[0]:
status = 200
body_data = [x.astype("float16").tobytes() for x in body_data]
else:
status = 500
print(results[1])
return web.Response(body=umsgpack.dumps(body_data), status=status, content_type="application/msgpack")
@routes.get("/config")
async def config(request):
return web.Response(body=umsgpack.dumps({
"model": model_name,
"batch": BS,
"embedding_size": model.get_sentence_embedding_dimension(),
"tokenizer": "bert-based-uncased"
}), status=200, content_type="application/msgpack")
@routes.get("/")
async def health(request):
return web.Response(status=204)
@routes.get("/metrics")
async def metrics(request):
return web.Response(body=generate_latest(REGISTRY))
app.router.add_routes(routes)
async def run_webserver():
runner = web.AppRunner(app)
await runner.setup()
site = web.TCPSite(runner, "", 1706)
print("server starting")
await site.start()
try:
th = threading.Thread(target=infer_thread)
th.start()
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
loop.run_until_complete(run_webserver())
loop.run_forever()
except KeyboardInterrupt:
print("quitting")
import sys
sys.exit(0)