mirror of
https://github.com/osmarks/meme-search-engine.git
synced 2025-01-04 22:40:31 +00:00
210 lines
7.8 KiB
Python
210 lines
7.8 KiB
Python
import torch
|
|
import time
|
|
import threading
|
|
from aiohttp import web
|
|
import aiohttp
|
|
import asyncio
|
|
import traceback
|
|
import msgpack
|
|
import collections
|
|
import queue
|
|
import open_clip
|
|
from PIL import Image
|
|
from prometheus_client import Counter, Histogram, REGISTRY, generate_latest
|
|
import io
|
|
import json
|
|
import torchvision.transforms.transforms as transforms
|
|
import sys
|
|
|
|
with open(sys.argv[1], "r") as config_file:
|
|
CONFIG = json.load(config_file)
|
|
|
|
device = torch.device(CONFIG["device"])
|
|
model, _, preprocess = open_clip.create_model_and_transforms(CONFIG["model"], device=device, pretrained=dict(open_clip.list_pretrained())[CONFIG["model"]], precision="fp16")
|
|
model.eval()
|
|
tokenizer = open_clip.get_tokenizer(CONFIG["model"])
|
|
print("Model loaded")
|
|
|
|
BS = CONFIG["max_batch_size"]
|
|
MODELNAME = CONFIG["model_name"]
|
|
|
|
fast_image_fns = {}
|
|
# ugly hack, sorry
|
|
if CONFIG.get("aitemplate_image_models"):
|
|
from aitemplate.compiler import Model
|
|
from aitemplate.testing import detect_target
|
|
|
|
USE_CUDA = detect_target().name() == "cuda"
|
|
|
|
state = model.state_dict()
|
|
conv_weights = state["visual.trunk.patch_embed.proj.weight"].permute((0, 2, 3, 1)).contiguous().cuda().half()
|
|
|
|
def load_pretrained():
|
|
params = {}
|
|
for key, value in state.items():
|
|
orig_key = key
|
|
if key.startswith("visual."):
|
|
key = key.removeprefix("visual.") \
|
|
.replace("trunk.patch_embed", "patch_embed") \
|
|
.replace("trunk.blocks", "encoder.layers") \
|
|
.replace(".attn.", ".mha.") \
|
|
.replace(".norm1.", ".ln1.") \
|
|
.replace(".norm2.", ".ln2.") \
|
|
.replace("trunk.pos_embed", "pos_emb_pos_emb") \
|
|
.replace("trunk.norm.", "encoder.ln.") \
|
|
.replace("trunk.attn_pool.latent", "pool.probe") \
|
|
.replace("trunk.attn_pool", "pool") \
|
|
.replace("pool.norm", "pool.ln")
|
|
if "patch_embed.proj.weight" not in key:
|
|
params[key.replace(".", "_")] = value.cuda()
|
|
#print(orig_key, key.replace(".", "_"))
|
|
|
|
params["patch_embed_proj_weight"] = conv_weights
|
|
|
|
return params
|
|
|
|
def generate_wrapper(path):
|
|
ait_model = Model(path)
|
|
ait_model.set_many_constants_with_tensors(load_pretrained())
|
|
ait_model.fold_constants(sync=True)
|
|
def wrapper(batch):
|
|
xs = [batch.permute((0, 2, 3, 1)).contiguous()]
|
|
ys = []
|
|
for i in range(len(ait_model.get_output_name_to_index_map())):
|
|
shape = ait_model.get_output_maximum_shape(i)
|
|
ys.append(torch.empty(shape).cuda().half())
|
|
ait_model.run_with_tensors(xs, ys)
|
|
return ys[0][:, 0, :]
|
|
return wrapper
|
|
|
|
for batch_size, path in CONFIG["aitemplate_image_models"]:
|
|
fast_image_fns[batch_size] = generate_wrapper(path)
|
|
print("loaded", batch_size, path)
|
|
|
|
InferenceParameters = collections.namedtuple("InferenceParameters", ["text", "images", "callback"])
|
|
|
|
items_ctr = Counter("modelserver_total_items", "Items run through model server", ["model", "modality"])
|
|
inference_time_hist = Histogram("modelserver_inftime", "Time running inference", ["model", "batch_size"])
|
|
batch_count_ctr = Counter("modelserver_batchcount", "Inference batches run", ["model"])
|
|
|
|
torch.set_grad_enabled(False)
|
|
def do_inference(params: InferenceParameters):
|
|
with torch.no_grad():
|
|
try:
|
|
text, images, callback = params
|
|
if text is not None:
|
|
items_ctr.labels(MODELNAME, "text").inc(text.shape[0])
|
|
with inference_time_hist.labels(MODELNAME + "-text", text.shape[0]).time():
|
|
features = model.encode_text(text)
|
|
features /= features.norm(dim=-1, keepdim=True)
|
|
features = features.cpu().numpy()
|
|
elif images is not None:
|
|
with inference_time_hist.labels(MODELNAME + "-image", images.shape[0]).time():
|
|
items_ctr.labels(MODELNAME, "image").inc(images.shape[0])
|
|
batch = images.shape[0]
|
|
if fast_image_fns:
|
|
progress = 0
|
|
features = torch.zeros((batch, model.text.text_projection.out_features))
|
|
while progress < batch:
|
|
biggest_available = max(x for x in fast_image_fns.keys() if x <= (batch - progress))
|
|
chunk = fast_image_fns[biggest_available](images[progress:progress + biggest_available])
|
|
features[progress:progress + biggest_available] = chunk
|
|
progress += biggest_available
|
|
else:
|
|
features = model.encode_image(images)
|
|
features /= features.norm(dim=-1, keepdim=True)
|
|
features = features.cpu().numpy()
|
|
batch_count_ctr.labels(MODELNAME).inc()
|
|
callback(True, features)
|
|
except Exception as e:
|
|
traceback.print_exc()
|
|
callback(False, str(e))
|
|
finally:
|
|
torch.cuda.empty_cache()
|
|
|
|
iq = queue.Queue(10)
|
|
def infer_thread():
|
|
while True:
|
|
do_inference(iq.get())
|
|
|
|
pq = queue.Queue(10)
|
|
def preprocessing_thread():
|
|
while True:
|
|
text, images, callback = pq.get()
|
|
try:
|
|
if text:
|
|
assert len(text) <= BS, f"max batch size is {BS}"
|
|
text = tokenizer(text).to(device)
|
|
elif images:
|
|
assert len(images) <= BS, f"max batch size is {BS}"
|
|
images = torch.stack([ preprocess(Image.open(io.BytesIO(im))).half() for im in images ]).to(device)
|
|
else:
|
|
assert False, "images or text required"
|
|
iq.put(InferenceParameters(text, images, callback))
|
|
except Exception as e:
|
|
traceback.print_exc()
|
|
callback(False, str(e))
|
|
|
|
app = web.Application(client_max_size=2**26)
|
|
routes = web.RouteTableDef()
|
|
|
|
@routes.post("/")
|
|
async def run_inference(request):
|
|
loop = asyncio.get_event_loop()
|
|
data = msgpack.loads(await request.read())
|
|
event = asyncio.Event()
|
|
results = None
|
|
def callback(*argv):
|
|
nonlocal results
|
|
results = argv
|
|
loop.call_soon_threadsafe(lambda: event.set())
|
|
pq.put_nowait(InferenceParameters(data.get("text"), data.get("images"), callback))
|
|
await event.wait()
|
|
body_data = results[1]
|
|
if results[0]:
|
|
status = 200
|
|
body_data = [x.astype("float16").tobytes() for x in body_data]
|
|
else:
|
|
status = 500
|
|
print(results[1])
|
|
return web.Response(body=msgpack.dumps(body_data), status=status, content_type="application/msgpack")
|
|
|
|
@routes.get("/config")
|
|
async def config(request):
|
|
return web.Response(body=msgpack.dumps({
|
|
"model": CONFIG["model"],
|
|
"batch": BS,
|
|
"image_size": [ t for t in preprocess.transforms if isinstance(t, transforms.Resize) ][0].size,
|
|
"embedding_size": model.text.text_projection.out_features
|
|
}), status=200, content_type="application/msgpack")
|
|
|
|
@routes.get("/")
|
|
async def health(request):
|
|
return web.Response(status=204)
|
|
|
|
@routes.get("/metrics")
|
|
async def metrics(request):
|
|
return web.Response(body=generate_latest(REGISTRY))
|
|
|
|
app.router.add_routes(routes)
|
|
|
|
async def run_webserver():
|
|
runner = web.AppRunner(app)
|
|
await runner.setup()
|
|
site = web.TCPSite(runner, "", CONFIG["port"])
|
|
print("Ready")
|
|
await site.start()
|
|
|
|
try:
|
|
th = threading.Thread(target=infer_thread)
|
|
th.start()
|
|
th = threading.Thread(target=preprocessing_thread)
|
|
th.start()
|
|
loop = asyncio.new_event_loop()
|
|
asyncio.set_event_loop(loop)
|
|
loop.run_until_complete(run_webserver())
|
|
loop.run_forever()
|
|
except KeyboardInterrupt:
|
|
import sys
|
|
sys.exit(0)
|