From 4626f53bcb848091f4dc28b912daca61f6d32985 Mon Sep 17 00:00:00 2001 From: osmarks Date: Mon, 13 Nov 2023 17:31:43 +0000 Subject: [PATCH] Return to OpenCLIP --- README.md | 5 +++-- clip_server.py | 47 +++++++++++++++++------------------------ clip_server_config.json | 5 +++-- requirements.txt | 3 ++- 4 files changed, 27 insertions(+), 33 deletions(-) diff --git a/README.md b/README.md index d122927..129510e 100644 --- a/README.md +++ b/README.md @@ -20,8 +20,9 @@ This is untested. It might work. * Serve your meme library from a static webserver. * I use nginx. If you're in a hurry, you can use `python -m http.server`. * Install Python dependencies with `pip` from `requirements.txt` (the versions probably shouldn't need to match exactly if you need to change them; I just put in what I currently have installed). - * You now need a [patched version](https://github.com/osmarks/transformers-patch-siglip) of `transformers` due to SigLIP support. - * I have converted exactly one SigLIP model: [https://huggingface.co/gollark/siglip-so400m-14-384](https://huggingface.co/gollark/siglip-so400m-14-384). It's apparently the best one. If you don't like it, find out how to convert more. You need to download that repo. + * ~~You now need a [patched version](https://github.com/osmarks/transformers-patch-siglip) of `transformers` due to SigLIP support.~~ OpenCLIP supports SigLIP. I am now using that. + * ~~I have converted exactly one SigLIP model: [https://huggingface.co/gollark/siglip-so400m-14-384](https://huggingface.co/gollark/siglip-so400m-14-384). It's apparently the best one. If you don't like it, find out how to convert more. You need to download that repo.~~ You can use any OpenCLIP model which OpenCLIP supports. +* Run `thumbnailer.py` (periodically, at the same time as index reloads, ideally) * Run `clip_server.py` (as a background service). * It is configured with a JSON file given to it as its first argument. An example is in `clip_server_config.json`. * `device` should probably be `cuda` or `cpu`. The model will run on here. diff --git a/clip_server.py b/clip_server.py index 34fc7af..7935348 100644 --- a/clip_server.py +++ b/clip_server.py @@ -1,4 +1,4 @@ -import os +import torch import time import threading from aiohttp import web @@ -8,34 +8,22 @@ import traceback import umsgpack import collections import queue +import open_clip from PIL import Image from prometheus_client import Counter, Histogram, REGISTRY, generate_latest import io import json +import torchvision.transforms.transforms as transforms import sys -import torch -from transformers import SiglipImageProcessor, T5Tokenizer, SiglipModel, SiglipConfig -from accelerate import init_empty_weights -from accelerate.utils.modeling import set_module_tensor_to_device -from safetensors import safe_open -import numpy with open(sys.argv[1], "r") as config_file: CONFIG = json.load(config_file) -DEVICE = CONFIG["device"] - -# So400m/14@384 -with init_empty_weights(): - model = SiglipModel(config=SiglipConfig.from_pretrained(CONFIG["model"])).half().eval() -with safe_open(os.path.join(CONFIG["model"], "model.safetensors"), framework="pt", device=DEVICE) as f: - for key in f.keys(): - set_module_tensor_to_device(model, key, device=DEVICE, value=f.get_tensor(key)) -model = model.to(DEVICE) -EMBDIM = model.config.vision_config.hidden_size # NOT projection_dim, why is that even there -RES = model.config.vision_config.image_size -tokenizer = T5Tokenizer(vocab_file=os.path.join(CONFIG["model"], "sentencepiece.model"), extra_ids=0, model_max_length=64, pad_token="", legacy=False) -image_processor = SiglipImageProcessor(size={"height": RES, "width":RES}) +device = torch.device(CONFIG["device"]) +model, _, preprocess = open_clip.create_model_and_transforms(CONFIG["model"], device=device, pretrained=dict(open_clip.list_pretrained())[CONFIG["model"]], precision="fp16") +model.eval() +tokenizer = open_clip.get_tokenizer(CONFIG["model"]) +print("Model loaded") BS = CONFIG["max_batch_size"] MODELNAME = CONFIG["model_name"] @@ -46,6 +34,7 @@ items_ctr = Counter("modelserver_total_items", "Items run through model server", inference_time_hist = Histogram("modelserver_inftime", "Time running inference", ["model", "batch_size"]) batch_count_ctr = Counter("modelserver_batchcount", "Inference batches run", ["model"]) +torch.set_grad_enabled(False) def do_inference(params: InferenceParameters): with torch.no_grad(): try: @@ -53,13 +42,13 @@ def do_inference(params: InferenceParameters): if text is not None: items_ctr.labels(MODELNAME, "text").inc(text.shape[0]) with inference_time_hist.labels(MODELNAME + "-text", text.shape[0]).time(): - features = model.text_model.forward(input_ids=torch.tensor(text, device=DEVICE)).pooler_output + features = model.encode_text(text) features /= features.norm(dim=-1, keepdim=True) features = features.cpu().numpy() elif images is not None: - items_ctr.labels(MODELNAME, "image").inc(images.shape[0]) with inference_time_hist.labels(MODELNAME + "-image", images.shape[0]).time(): - features = model.vision_model.forward(torch.tensor(images, device=DEVICE)).pooler_output + items_ctr.labels(MODELNAME, "image").inc(images.shape[0]) + features = model.encode_image(images) features /= features.norm(dim=-1, keepdim=True) features = features.cpu().numpy() batch_count_ctr.labels(MODELNAME).inc() @@ -67,6 +56,8 @@ def do_inference(params: InferenceParameters): except Exception as e: traceback.print_exc() callback(False, str(e)) + finally: + torch.cuda.empty_cache() iq = queue.Queue(10) def infer_thread(): @@ -80,10 +71,10 @@ def preprocessing_thread(): try: if text: assert len(text) <= BS, f"max batch size is {BS}" - text = numpy.array(tokenizer([ t.lower() for t in text ], padding="max_length", truncation=True)["input_ids"]) + text = tokenizer(text).to(device) elif images: assert len(images) <= BS, f"max batch size is {BS}" - images = numpy.array(image_processor([ Image.open(io.BytesIO(bs)) for bs in images ])["pixel_values"]).astype("float16") + images = torch.stack([ preprocess(Image.open(io.BytesIO(im))).half() for im in images ]).to(device) else: assert False, "images or text required" iq.put(InferenceParameters(text, images, callback)) @@ -118,10 +109,10 @@ async def run_inference(request): @routes.get("/config") async def config(request): return web.Response(body=umsgpack.dumps({ - "model": MODELNAME, + "model": CONFIG["model"], "batch": BS, - "image_size": (RES, RES), - "embedding_size": EMBDIM + "image_size": [ t for t in preprocess.transforms if isinstance(t, transforms.Resize) ][0].size, + "embedding_size": model.text.text_projection.out_features }), status=200, content_type="application/msgpack") @routes.get("/") diff --git a/clip_server_config.json b/clip_server_config.json index f3d401a..b6c36ae 100644 --- a/clip_server_config.json +++ b/clip_server_config.json @@ -1,6 +1,7 @@ { - "model": "./out", + "model": "ViT-SO400M-14-SigLIP-384", "model_name": "siglip-so400m/14@384", "max_batch_size": 128, - "port": 1708 + "port": 1708, + "device": "cuda:0" } \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index c1e20a4..ad182f5 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,4 +4,5 @@ u-msgpack-python==2.8.0 aiohttp==3.8.5 aiohttp-cors==0.7.0 faiss-cpu==1.7.4 -aiosqlite==0.19.0 \ No newline at end of file +aiosqlite==0.19.0 +open-clip-torch==2.23.0 \ No newline at end of file