mirror of
https://github.com/osmarks/meme-search-engine.git
synced 2025-06-20 15:24:05 +00:00
faster indexing, SigLIP models
This commit is contained in:
parent
2c9ce67ab2
commit
46fca3eb7f
@ -1,4 +1,4 @@
|
|||||||
import torch
|
import os
|
||||||
import time
|
import time
|
||||||
import threading
|
import threading
|
||||||
from aiohttp import web
|
from aiohttp import web
|
||||||
@ -8,21 +8,34 @@ import traceback
|
|||||||
import umsgpack
|
import umsgpack
|
||||||
import collections
|
import collections
|
||||||
import queue
|
import queue
|
||||||
import open_clip
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from prometheus_client import Counter, Histogram, REGISTRY, generate_latest
|
from prometheus_client import Counter, Histogram, REGISTRY, generate_latest
|
||||||
import io
|
import io
|
||||||
import json
|
import json
|
||||||
import sys
|
import sys
|
||||||
|
import torch
|
||||||
|
from transformers import SiglipImageProcessor, T5Tokenizer, SiglipModel, SiglipConfig
|
||||||
|
from accelerate import init_empty_weights
|
||||||
|
from accelerate.utils.modeling import set_module_tensor_to_device
|
||||||
|
from safetensors import safe_open
|
||||||
|
import numpy
|
||||||
|
|
||||||
with open(sys.argv[1], "r") as config_file:
|
with open(sys.argv[1], "r") as config_file:
|
||||||
CONFIG = json.load(config_file)
|
CONFIG = json.load(config_file)
|
||||||
|
|
||||||
device = torch.device(CONFIG["device"])
|
DEVICE = "cuda:0"
|
||||||
model, _, preprocess = open_clip.create_model_and_transforms(CONFIG["model"], device=device, pretrained=dict(open_clip.list_pretrained())[CONFIG["model"]], precision="fp16")
|
|
||||||
model.eval()
|
# So400m/14@384
|
||||||
tokenizer = open_clip.get_tokenizer(CONFIG["model"])
|
with init_empty_weights():
|
||||||
print("Model loaded")
|
model = SiglipModel(config=SiglipConfig.from_pretrained(CONFIG["model"])).half().eval()
|
||||||
|
with safe_open(os.path.join(CONFIG["model"], "model.safetensors"), framework="pt", device=DEVICE) as f:
|
||||||
|
for key in f.keys():
|
||||||
|
set_module_tensor_to_device(model, key, device=DEVICE, value=f.get_tensor(key))
|
||||||
|
model = model.to(DEVICE)
|
||||||
|
EMBDIM = model.config.vision_config.hidden_size # NOT projection_dim, why is that even there
|
||||||
|
RES = model.config.vision_config.image_size
|
||||||
|
tokenizer = T5Tokenizer(vocab_file=os.path.join(CONFIG["model"], "sentencepiece.model"), extra_ids=0, model_max_length=64, pad_token="</s>", legacy=False)
|
||||||
|
image_processor = SiglipImageProcessor(size={"height": RES, "width":RES})
|
||||||
|
|
||||||
BS = CONFIG["max_batch_size"]
|
BS = CONFIG["max_batch_size"]
|
||||||
MODELNAME = CONFIG["model_name"]
|
MODELNAME = CONFIG["model_name"]
|
||||||
@ -33,7 +46,6 @@ items_ctr = Counter("modelserver_total_items", "Items run through model server",
|
|||||||
inference_time_hist = Histogram("modelserver_inftime", "Time running inference", ["model", "batch_size"])
|
inference_time_hist = Histogram("modelserver_inftime", "Time running inference", ["model", "batch_size"])
|
||||||
batch_count_ctr = Counter("modelserver_batchcount", "Inference batches run", ["model"])
|
batch_count_ctr = Counter("modelserver_batchcount", "Inference batches run", ["model"])
|
||||||
|
|
||||||
torch.set_grad_enabled(False)
|
|
||||||
def do_inference(params: InferenceParameters):
|
def do_inference(params: InferenceParameters):
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
try:
|
try:
|
||||||
@ -41,19 +53,17 @@ def do_inference(params: InferenceParameters):
|
|||||||
if text is not None:
|
if text is not None:
|
||||||
items_ctr.labels(MODELNAME, "text").inc(text.shape[0])
|
items_ctr.labels(MODELNAME, "text").inc(text.shape[0])
|
||||||
with inference_time_hist.labels(MODELNAME + "-text", text.shape[0]).time():
|
with inference_time_hist.labels(MODELNAME + "-text", text.shape[0]).time():
|
||||||
features = model.encode_text(text)
|
features = model.text_model.forward(input_ids=torch.tensor(text, device=DEVICE)).pooler_output
|
||||||
elif images is not None:
|
elif images is not None:
|
||||||
with inference_time_hist.labels(MODELNAME + "-image", images.shape[0]).time():
|
|
||||||
items_ctr.labels(MODELNAME, "image").inc(images.shape[0])
|
items_ctr.labels(MODELNAME, "image").inc(images.shape[0])
|
||||||
features = model.encode_image(images)
|
with inference_time_hist.labels(MODELNAME + "-image", images.shape[0]).time():
|
||||||
batch_count_ctr.labels(MODELNAME).inc()
|
features = model.vision_model.forward(torch.tensor(images, device=DEVICE)).pooler_output
|
||||||
features /= features.norm(dim=-1, keepdim=True)
|
features /= features.norm(dim=-1, keepdim=True)
|
||||||
|
batch_count_ctr.labels(MODELNAME).inc()
|
||||||
callback(True, features.cpu().numpy())
|
callback(True, features.cpu().numpy())
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
callback(False, str(e))
|
callback(False, str(e))
|
||||||
finally:
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
|
|
||||||
iq = queue.Queue(10)
|
iq = queue.Queue(10)
|
||||||
def infer_thread():
|
def infer_thread():
|
||||||
@ -67,10 +77,11 @@ def preprocessing_thread():
|
|||||||
try:
|
try:
|
||||||
if text:
|
if text:
|
||||||
assert len(text) <= BS, f"max batch size is {BS}"
|
assert len(text) <= BS, f"max batch size is {BS}"
|
||||||
text = tokenizer(text).to(device)
|
# I feel like this ought to be batchable but I can't see how to do that
|
||||||
|
text = numpy.array(tokenizer(text, padding="max_length", truncation=True)["input_ids"])
|
||||||
elif images:
|
elif images:
|
||||||
assert len(images) <= BS, f"max batch size is {BS}"
|
assert len(images) <= BS, f"max batch size is {BS}"
|
||||||
images = torch.stack([ preprocess(Image.open(io.BytesIO(im))).half() for im in images ]).to(device)
|
images = numpy.array(image_processor([ Image.open(io.BytesIO(bs)) for bs in images ])["pixel_values"]).astype("float16")
|
||||||
else:
|
else:
|
||||||
assert False, "images or text required"
|
assert False, "images or text required"
|
||||||
iq.put(InferenceParameters(text, images, callback))
|
iq.put(InferenceParameters(text, images, callback))
|
||||||
@ -105,10 +116,10 @@ async def run_inference(request):
|
|||||||
@routes.get("/config")
|
@routes.get("/config")
|
||||||
async def config(request):
|
async def config(request):
|
||||||
return web.Response(body=umsgpack.dumps({
|
return web.Response(body=umsgpack.dumps({
|
||||||
"model": CONFIG["model"],
|
"model": MODELNAME,
|
||||||
"batch": BS,
|
"batch": BS,
|
||||||
"image_size": model.visual.image_size,
|
"image_size": (RES, RES),
|
||||||
"embedding_size": model.visual.output_dim
|
"embedding_size": EMBDIM
|
||||||
}), status=200, content_type="application/msgpack")
|
}), status=200, content_type="application/msgpack")
|
||||||
|
|
||||||
@routes.get("/")
|
@routes.get("/")
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
{
|
{
|
||||||
"device": "cuda:0",
|
"model": "./out",
|
||||||
"model": "ViT-H-14",
|
"model_name": "siglip-so400m/14@384",
|
||||||
"model_name": "openclip-ViT-H-14",
|
|
||||||
"max_batch_size": 128,
|
"max_batch_size": 128,
|
||||||
"port": 1708
|
"port": 1708
|
||||||
}
|
}
|
37
mse.py
37
mse.py
@ -13,6 +13,7 @@ import aiohttp_cors
|
|||||||
import json
|
import json
|
||||||
import io
|
import io
|
||||||
import sys
|
import sys
|
||||||
|
from concurrent.futures import ProcessPoolExecutor
|
||||||
|
|
||||||
with open(sys.argv[1], "r") as config_file:
|
with open(sys.argv[1], "r") as config_file:
|
||||||
CONFIG = json.load(config_file)
|
CONFIG = json.load(config_file)
|
||||||
@ -36,7 +37,8 @@ async def run_query(request):
|
|||||||
data = await request.json()
|
data = await request.json()
|
||||||
embeddings = []
|
embeddings = []
|
||||||
if images := data.get("images", []):
|
if images := data.get("images", []):
|
||||||
embeddings.extend(await clip_server({ "images": [ base64.b64decode(x) for x, w in images ] }))
|
target_image_size = app["index"].inference_server_config["image_size"]
|
||||||
|
embeddings.extend(await clip_server({ "images": [ load_image(io.BytesIO(base64.b64decode(x)), target_image_size)[0] for x, w in images ] }))
|
||||||
if text := data.get("text", []):
|
if text := data.get("text", []):
|
||||||
embeddings.extend(await clip_server({ "text": [ x for x, w in text ] }))
|
embeddings.extend(await clip_server({ "text": [ x for x, w in text ] }))
|
||||||
weights = [ w for x, w in images ] + [ w for x, w in text ]
|
weights = [ w for x, w in images ] + [ w for x, w in text ]
|
||||||
@ -54,6 +56,13 @@ async def reload_index_route(request):
|
|||||||
await request.app["index"].reload()
|
await request.app["index"].reload()
|
||||||
return web.json_response(True)
|
return web.json_response(True)
|
||||||
|
|
||||||
|
def load_image(path, image_size):
|
||||||
|
im = Image.open(path)
|
||||||
|
im.draft("RGB", image_size)
|
||||||
|
buf = io.BytesIO()
|
||||||
|
im.resize(image_size).convert("RGB").save(buf, format="BMP")
|
||||||
|
return buf.getvalue(), path
|
||||||
|
|
||||||
class Index:
|
class Index:
|
||||||
def __init__(self, inference_server_config):
|
def __init__(self, inference_server_config):
|
||||||
self.faiss_index = faiss.IndexFlatIP(inference_server_config["embedding_size"])
|
self.faiss_index = faiss.IndexFlatIP(inference_server_config["embedding_size"])
|
||||||
@ -72,6 +81,7 @@ class Index:
|
|||||||
|
|
||||||
async def reload(self):
|
async def reload(self):
|
||||||
async with self.lock:
|
async with self.lock:
|
||||||
|
with ProcessPoolExecutor(max_workers=12) as executor:
|
||||||
print("Indexing")
|
print("Indexing")
|
||||||
conn = await aiosqlite.connect(CONFIG["db_path"], parent_loop=asyncio.get_running_loop())
|
conn = await aiosqlite.connect(CONFIG["db_path"], parent_loop=asyncio.get_running_loop())
|
||||||
conn.row_factory = aiosqlite.Row
|
conn.row_factory = aiosqlite.Row
|
||||||
@ -99,6 +109,7 @@ class Index:
|
|||||||
for filename, _, _ in batch:
|
for filename, _, _ in batch:
|
||||||
modified.add(filename)
|
modified.add(filename)
|
||||||
sys.stdout.write(".")
|
sys.stdout.write(".")
|
||||||
|
sys.stdout.flush()
|
||||||
finally:
|
finally:
|
||||||
batch_sem.release()
|
batch_sem.release()
|
||||||
|
|
||||||
@ -112,28 +123,37 @@ class Index:
|
|||||||
await conn.commit()
|
await conn.commit()
|
||||||
batch = []
|
batch = []
|
||||||
|
|
||||||
|
failed = set()
|
||||||
for dirpath, _, filenames in os.walk(CONFIG["files"]):
|
for dirpath, _, filenames in os.walk(CONFIG["files"]):
|
||||||
|
paths = set()
|
||||||
|
done = set()
|
||||||
for file in filenames:
|
for file in filenames:
|
||||||
path = os.path.join(dirpath, file)
|
path = os.path.join(dirpath, file)
|
||||||
file = os.path.relpath(path, CONFIG["files"])
|
file = os.path.relpath(path, CONFIG["files"])
|
||||||
st = os.stat(path)
|
st = os.stat(path)
|
||||||
if st.st_mtime != files.get(file):
|
if st.st_mtime != files.get(file):
|
||||||
|
paths.add(path)
|
||||||
|
for task in asyncio.as_completed([ asyncio.get_running_loop().run_in_executor(executor, load_image, path, self.inference_server_config["image_size"]) for path in paths ]):
|
||||||
try:
|
try:
|
||||||
im = Image.open(path)
|
b, path = await task
|
||||||
im.draft("RGB", self.inference_server_config["image_size"])
|
st = os.stat(path)
|
||||||
buf = io.BytesIO()
|
file = os.path.relpath(path, CONFIG["files"])
|
||||||
im.resize(self.inference_server_config["image_size"]).convert("RGB").save(buf, format="BMP")
|
done.add(path)
|
||||||
b = buf.getvalue()
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(file, "failed", e)
|
# print(file, "failed", e) we can't have access to file when we need it, oops
|
||||||
continue
|
continue
|
||||||
batch.append((file, st.st_mtime, b))
|
batch.append((file, st.st_mtime, b))
|
||||||
if len(batch) % self.inference_server_config["batch"] == self.inference_server_config["batch"] - 1:
|
if len(batch) == self.inference_server_config["batch"]:
|
||||||
await dispatch_batch(batch)
|
await dispatch_batch(batch)
|
||||||
batch = []
|
batch = []
|
||||||
|
failed |= paths - done
|
||||||
if batch:
|
if batch:
|
||||||
await dispatch_batch(batch)
|
await dispatch_batch(batch)
|
||||||
|
|
||||||
|
print()
|
||||||
|
for failed_ in failed:
|
||||||
|
print(failed_, "failed")
|
||||||
|
|
||||||
remove_indices = []
|
remove_indices = []
|
||||||
for index, filename in enumerate(self.associated_filenames):
|
for index, filename in enumerate(self.associated_filenames):
|
||||||
if filename not in files or filename in modified:
|
if filename not in files or filename in modified:
|
||||||
@ -195,6 +215,7 @@ async def main():
|
|||||||
site = web.TCPSite(runner, "", CONFIG["port"])
|
site = web.TCPSite(runner, "", CONFIG["port"])
|
||||||
await site.start()
|
await site.start()
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
loop = asyncio.new_event_loop()
|
loop = asyncio.new_event_loop()
|
||||||
asyncio.set_event_loop(loop)
|
asyncio.set_event_loop(loop)
|
||||||
loop.run_until_complete(main())
|
loop.run_until_complete(main())
|
||||||
|
Loading…
x
Reference in New Issue
Block a user