mirror of
https://github.com/osmarks/meme-search-engine.git
synced 2025-01-02 13:30:30 +00:00
Return to OpenCLIP
This commit is contained in:
parent
74bb1bc343
commit
4626f53bcb
@ -20,8 +20,9 @@ This is untested. It might work.
|
|||||||
* Serve your meme library from a static webserver.
|
* Serve your meme library from a static webserver.
|
||||||
* I use nginx. If you're in a hurry, you can use `python -m http.server`.
|
* I use nginx. If you're in a hurry, you can use `python -m http.server`.
|
||||||
* Install Python dependencies with `pip` from `requirements.txt` (the versions probably shouldn't need to match exactly if you need to change them; I just put in what I currently have installed).
|
* Install Python dependencies with `pip` from `requirements.txt` (the versions probably shouldn't need to match exactly if you need to change them; I just put in what I currently have installed).
|
||||||
* You now need a [patched version](https://github.com/osmarks/transformers-patch-siglip) of `transformers` due to SigLIP support.
|
* ~~You now need a [patched version](https://github.com/osmarks/transformers-patch-siglip) of `transformers` due to SigLIP support.~~ OpenCLIP supports SigLIP. I am now using that.
|
||||||
* I have converted exactly one SigLIP model: [https://huggingface.co/gollark/siglip-so400m-14-384](https://huggingface.co/gollark/siglip-so400m-14-384). It's apparently the best one. If you don't like it, find out how to convert more. You need to download that repo.
|
* ~~I have converted exactly one SigLIP model: [https://huggingface.co/gollark/siglip-so400m-14-384](https://huggingface.co/gollark/siglip-so400m-14-384). It's apparently the best one. If you don't like it, find out how to convert more. You need to download that repo.~~ You can use any OpenCLIP model which OpenCLIP supports.
|
||||||
|
* Run `thumbnailer.py` (periodically, at the same time as index reloads, ideally)
|
||||||
* Run `clip_server.py` (as a background service).
|
* Run `clip_server.py` (as a background service).
|
||||||
* It is configured with a JSON file given to it as its first argument. An example is in `clip_server_config.json`.
|
* It is configured with a JSON file given to it as its first argument. An example is in `clip_server_config.json`.
|
||||||
* `device` should probably be `cuda` or `cpu`. The model will run on here.
|
* `device` should probably be `cuda` or `cpu`. The model will run on here.
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
import os
|
import torch
|
||||||
import time
|
import time
|
||||||
import threading
|
import threading
|
||||||
from aiohttp import web
|
from aiohttp import web
|
||||||
@ -8,34 +8,22 @@ import traceback
|
|||||||
import umsgpack
|
import umsgpack
|
||||||
import collections
|
import collections
|
||||||
import queue
|
import queue
|
||||||
|
import open_clip
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from prometheus_client import Counter, Histogram, REGISTRY, generate_latest
|
from prometheus_client import Counter, Histogram, REGISTRY, generate_latest
|
||||||
import io
|
import io
|
||||||
import json
|
import json
|
||||||
|
import torchvision.transforms.transforms as transforms
|
||||||
import sys
|
import sys
|
||||||
import torch
|
|
||||||
from transformers import SiglipImageProcessor, T5Tokenizer, SiglipModel, SiglipConfig
|
|
||||||
from accelerate import init_empty_weights
|
|
||||||
from accelerate.utils.modeling import set_module_tensor_to_device
|
|
||||||
from safetensors import safe_open
|
|
||||||
import numpy
|
|
||||||
|
|
||||||
with open(sys.argv[1], "r") as config_file:
|
with open(sys.argv[1], "r") as config_file:
|
||||||
CONFIG = json.load(config_file)
|
CONFIG = json.load(config_file)
|
||||||
|
|
||||||
DEVICE = CONFIG["device"]
|
device = torch.device(CONFIG["device"])
|
||||||
|
model, _, preprocess = open_clip.create_model_and_transforms(CONFIG["model"], device=device, pretrained=dict(open_clip.list_pretrained())[CONFIG["model"]], precision="fp16")
|
||||||
# So400m/14@384
|
model.eval()
|
||||||
with init_empty_weights():
|
tokenizer = open_clip.get_tokenizer(CONFIG["model"])
|
||||||
model = SiglipModel(config=SiglipConfig.from_pretrained(CONFIG["model"])).half().eval()
|
print("Model loaded")
|
||||||
with safe_open(os.path.join(CONFIG["model"], "model.safetensors"), framework="pt", device=DEVICE) as f:
|
|
||||||
for key in f.keys():
|
|
||||||
set_module_tensor_to_device(model, key, device=DEVICE, value=f.get_tensor(key))
|
|
||||||
model = model.to(DEVICE)
|
|
||||||
EMBDIM = model.config.vision_config.hidden_size # NOT projection_dim, why is that even there
|
|
||||||
RES = model.config.vision_config.image_size
|
|
||||||
tokenizer = T5Tokenizer(vocab_file=os.path.join(CONFIG["model"], "sentencepiece.model"), extra_ids=0, model_max_length=64, pad_token="</s>", legacy=False)
|
|
||||||
image_processor = SiglipImageProcessor(size={"height": RES, "width":RES})
|
|
||||||
|
|
||||||
BS = CONFIG["max_batch_size"]
|
BS = CONFIG["max_batch_size"]
|
||||||
MODELNAME = CONFIG["model_name"]
|
MODELNAME = CONFIG["model_name"]
|
||||||
@ -46,6 +34,7 @@ items_ctr = Counter("modelserver_total_items", "Items run through model server",
|
|||||||
inference_time_hist = Histogram("modelserver_inftime", "Time running inference", ["model", "batch_size"])
|
inference_time_hist = Histogram("modelserver_inftime", "Time running inference", ["model", "batch_size"])
|
||||||
batch_count_ctr = Counter("modelserver_batchcount", "Inference batches run", ["model"])
|
batch_count_ctr = Counter("modelserver_batchcount", "Inference batches run", ["model"])
|
||||||
|
|
||||||
|
torch.set_grad_enabled(False)
|
||||||
def do_inference(params: InferenceParameters):
|
def do_inference(params: InferenceParameters):
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
try:
|
try:
|
||||||
@ -53,13 +42,13 @@ def do_inference(params: InferenceParameters):
|
|||||||
if text is not None:
|
if text is not None:
|
||||||
items_ctr.labels(MODELNAME, "text").inc(text.shape[0])
|
items_ctr.labels(MODELNAME, "text").inc(text.shape[0])
|
||||||
with inference_time_hist.labels(MODELNAME + "-text", text.shape[0]).time():
|
with inference_time_hist.labels(MODELNAME + "-text", text.shape[0]).time():
|
||||||
features = model.text_model.forward(input_ids=torch.tensor(text, device=DEVICE)).pooler_output
|
features = model.encode_text(text)
|
||||||
features /= features.norm(dim=-1, keepdim=True)
|
features /= features.norm(dim=-1, keepdim=True)
|
||||||
features = features.cpu().numpy()
|
features = features.cpu().numpy()
|
||||||
elif images is not None:
|
elif images is not None:
|
||||||
items_ctr.labels(MODELNAME, "image").inc(images.shape[0])
|
|
||||||
with inference_time_hist.labels(MODELNAME + "-image", images.shape[0]).time():
|
with inference_time_hist.labels(MODELNAME + "-image", images.shape[0]).time():
|
||||||
features = model.vision_model.forward(torch.tensor(images, device=DEVICE)).pooler_output
|
items_ctr.labels(MODELNAME, "image").inc(images.shape[0])
|
||||||
|
features = model.encode_image(images)
|
||||||
features /= features.norm(dim=-1, keepdim=True)
|
features /= features.norm(dim=-1, keepdim=True)
|
||||||
features = features.cpu().numpy()
|
features = features.cpu().numpy()
|
||||||
batch_count_ctr.labels(MODELNAME).inc()
|
batch_count_ctr.labels(MODELNAME).inc()
|
||||||
@ -67,6 +56,8 @@ def do_inference(params: InferenceParameters):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
callback(False, str(e))
|
callback(False, str(e))
|
||||||
|
finally:
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
iq = queue.Queue(10)
|
iq = queue.Queue(10)
|
||||||
def infer_thread():
|
def infer_thread():
|
||||||
@ -80,10 +71,10 @@ def preprocessing_thread():
|
|||||||
try:
|
try:
|
||||||
if text:
|
if text:
|
||||||
assert len(text) <= BS, f"max batch size is {BS}"
|
assert len(text) <= BS, f"max batch size is {BS}"
|
||||||
text = numpy.array(tokenizer([ t.lower() for t in text ], padding="max_length", truncation=True)["input_ids"])
|
text = tokenizer(text).to(device)
|
||||||
elif images:
|
elif images:
|
||||||
assert len(images) <= BS, f"max batch size is {BS}"
|
assert len(images) <= BS, f"max batch size is {BS}"
|
||||||
images = numpy.array(image_processor([ Image.open(io.BytesIO(bs)) for bs in images ])["pixel_values"]).astype("float16")
|
images = torch.stack([ preprocess(Image.open(io.BytesIO(im))).half() for im in images ]).to(device)
|
||||||
else:
|
else:
|
||||||
assert False, "images or text required"
|
assert False, "images or text required"
|
||||||
iq.put(InferenceParameters(text, images, callback))
|
iq.put(InferenceParameters(text, images, callback))
|
||||||
@ -118,10 +109,10 @@ async def run_inference(request):
|
|||||||
@routes.get("/config")
|
@routes.get("/config")
|
||||||
async def config(request):
|
async def config(request):
|
||||||
return web.Response(body=umsgpack.dumps({
|
return web.Response(body=umsgpack.dumps({
|
||||||
"model": MODELNAME,
|
"model": CONFIG["model"],
|
||||||
"batch": BS,
|
"batch": BS,
|
||||||
"image_size": (RES, RES),
|
"image_size": [ t for t in preprocess.transforms if isinstance(t, transforms.Resize) ][0].size,
|
||||||
"embedding_size": EMBDIM
|
"embedding_size": model.text.text_projection.out_features
|
||||||
}), status=200, content_type="application/msgpack")
|
}), status=200, content_type="application/msgpack")
|
||||||
|
|
||||||
@routes.get("/")
|
@routes.get("/")
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
{
|
{
|
||||||
"model": "./out",
|
"model": "ViT-SO400M-14-SigLIP-384",
|
||||||
"model_name": "siglip-so400m/14@384",
|
"model_name": "siglip-so400m/14@384",
|
||||||
"max_batch_size": 128,
|
"max_batch_size": 128,
|
||||||
"port": 1708
|
"port": 1708,
|
||||||
|
"device": "cuda:0"
|
||||||
}
|
}
|
@ -4,4 +4,5 @@ u-msgpack-python==2.8.0
|
|||||||
aiohttp==3.8.5
|
aiohttp==3.8.5
|
||||||
aiohttp-cors==0.7.0
|
aiohttp-cors==0.7.0
|
||||||
faiss-cpu==1.7.4
|
faiss-cpu==1.7.4
|
||||||
aiosqlite==0.19.0
|
aiosqlite==0.19.0
|
||||||
|
open-clip-torch==2.23.0
|
Loading…
Reference in New Issue
Block a user