1
0
mirror of https://github.com/osmarks/meme-search-engine.git synced 2024-11-10 22:09:54 +00:00

hackily patch horrifyingly nondeterministic-but-fast image encoder in

This commit is contained in:
osmarks 2024-05-27 20:21:44 +01:00
parent d4e136b6a7
commit 129b769a56
3 changed files with 79 additions and 11 deletions

View File

@ -133,9 +133,7 @@ def benchmark(name, config, batch_size, mod=None, graph_mode=True):
mod.set_many_constants_with_tensors(params_ait) mod.set_many_constants_with_tensors(params_ait)
mod.fold_constants(sync=True) mod.fold_constants(sync=True)
# prepare input/output tensor inputs = [torch.randn([batch_size, config["img_size"], config["img_size"], 3]).cuda().half()]
s = torch.stack([image.cuda().half().permute((1, 2, 0)) for _ in range(batch_size)], dim=0)
inputs = [s] #[torch.randn([batch_size, config["img_size"], config["img_size"], 3]).cuda().half()]
ys = [] ys = []
num_outputs = len(mod.get_output_name_to_index_map()) num_outputs = len(mod.get_output_name_to_index_map())
for i in range(num_outputs): for i in range(num_outputs):
@ -149,10 +147,10 @@ def benchmark(name, config, batch_size, mod=None, graph_mode=True):
repeat=1, repeat=1,
graph_mode=graph_mode, graph_mode=graph_mode,
) )
q = model.visual.trunk.attn_pool(model.visual.trunk.norm(model.visual.trunk.blocks(model.visual.trunk.patch_embed(input) + model.visual.trunk.pos_embed))) #q = model.visual.trunk.attn_pool(model.visual.trunk.norm(model.visual.trunk.blocks(model.visual.trunk.patch_embed(input) + model.visual.trunk.pos_embed)))
# = #model.visual.trunk.attn_pool.q(model.visual.trunk.attn_pool.latent.expand(batch_size, -1, -1)).reshape(batch_size, 1, 16, 72).transpose(1, 2) ## = #model.visual.trunk.attn_pool.q(model.visual.trunk.attn_pool.latent.expand(batch_size, -1, -1)).reshape(batch_size, 1, 16, 72).transpose(1, 2)
print("expected", q, q.shape) #print("expected", q, q.shape)
print("actual", ys[0], ys[0].shape) #print("actual", ys[0], ys[0].shape)
""" """
batch = ys[0][:, 0, :] batch = ys[0][:, 0, :]
batch = torch.nn.functional.normalize(batch, dim=-1) batch = torch.nn.functional.normalize(batch, dim=-1)
@ -160,6 +158,6 @@ def benchmark(name, config, batch_size, mod=None, graph_mode=True):
print(f"batch_size: {batch_size}, latency: {t}") print(f"batch_size: {batch_size}, latency: {t}")
""" """
#for bs in (1, 2, 4, 8, 16, 32, 64, 128, 256): #for bs in (1, 2, 4, 8, 16, 32, 64, 128, 256):
for bs in (batch_size,): for bs in (1, 2, 4, 8, 16, 32):
compile_vit(siglip_so400m_384_14, bs, use_fp16_acc=True) compile_vit(siglip_so400m_384_14, bs, use_fp16_acc=True)
benchmark("siglip_so400m_384_14", siglip_so400m_384_14, bs, graph_mode=True) benchmark("siglip_so400m_384_14", siglip_so400m_384_14, bs, graph_mode=True)

View File

@ -28,6 +28,59 @@ print("Model loaded")
BS = CONFIG["max_batch_size"] BS = CONFIG["max_batch_size"]
MODELNAME = CONFIG["model_name"] MODELNAME = CONFIG["model_name"]
fast_image_fns = {}
# ugly hack, sorry
if CONFIG.get("aitemplate_image_models"):
from aitemplate.compiler import Model
from aitemplate.testing import detect_target
USE_CUDA = detect_target().name() == "cuda"
state = model.state_dict()
conv_weights = state["visual.trunk.patch_embed.proj.weight"].permute((0, 2, 3, 1)).contiguous().cuda().half()
def load_pretrained():
params = {}
for key, value in state.items():
orig_key = key
if key.startswith("visual."):
key = key.removeprefix("visual.") \
.replace("trunk.patch_embed", "patch_embed") \
.replace("trunk.blocks", "encoder.layers") \
.replace(".attn.", ".mha.") \
.replace(".norm1.", ".ln1.") \
.replace(".norm2.", ".ln2.") \
.replace("trunk.pos_embed", "pos_emb_pos_emb") \
.replace("trunk.norm.", "encoder.ln.") \
.replace("trunk.attn_pool.latent", "pool.probe") \
.replace("trunk.attn_pool", "pool") \
.replace("pool.norm", "pool.ln")
if "patch_embed.proj.weight" not in key:
params[key.replace(".", "_")] = value.cuda()
#print(orig_key, key.replace(".", "_"))
params["patch_embed_proj_weight"] = conv_weights
return params
def generate_wrapper(path):
ait_model = Model(path)
ait_model.set_many_constants_with_tensors(load_pretrained())
ait_model.fold_constants(sync=True)
def wrapper(batch):
xs = [batch.permute((0, 2, 3, 1)).contiguous()]
ys = []
for i in range(len(ait_model.get_output_name_to_index_map())):
shape = ait_model.get_output_maximum_shape(i)
ys.append(torch.empty(shape).cuda().half())
ait_model.run_with_tensors(xs, ys)
return ys[0][:, 0, :]
return wrapper
for batch_size, path in CONFIG["aitemplate_image_models"]:
fast_image_fns[batch_size] = generate_wrapper(path)
print("loaded", batch_size, path)
InferenceParameters = collections.namedtuple("InferenceParameters", ["text", "images", "callback"]) InferenceParameters = collections.namedtuple("InferenceParameters", ["text", "images", "callback"])
items_ctr = Counter("modelserver_total_items", "Items run through model server", ["model", "modality"]) items_ctr = Counter("modelserver_total_items", "Items run through model server", ["model", "modality"])
@ -48,6 +101,16 @@ def do_inference(params: InferenceParameters):
elif images is not None: elif images is not None:
with inference_time_hist.labels(MODELNAME + "-image", images.shape[0]).time(): with inference_time_hist.labels(MODELNAME + "-image", images.shape[0]).time():
items_ctr.labels(MODELNAME, "image").inc(images.shape[0]) items_ctr.labels(MODELNAME, "image").inc(images.shape[0])
batch = images.shape[0]
if fast_image_fns:
progress = 0
features = torch.zeros((batch, model.text.text_projection.out_features))
while progress < batch:
biggest_available = max(x for x in fast_image_fns.keys() if x <= (batch - progress))
chunk = fast_image_fns[biggest_available](images[progress:progress + biggest_available])
features[progress:progress + biggest_available] = chunk
progress += biggest_available
else:
features = model.encode_image(images) features = model.encode_image(images)
features /= features.norm(dim=-1, keepdim=True) features /= features.norm(dim=-1, keepdim=True)
features = features.cpu().numpy() features = features.cpu().numpy()

View File

@ -13,9 +13,15 @@ from prometheus_client import Counter, Histogram, REGISTRY, generate_latest
import io import io
import json import json
import sys import sys
import torch
from transformers import SiglipTokenizer, SiglipImageProcessor, T5TokenizerFast, SiglipTextConfig, SiglipVisionConfig
import numpy import numpy
import big_vision.models.proj.image_text.two_towers as model_mod
import jax
import jax.numpy as jnp
import ml_collections
import big_vision.pp.builder as pp_builder
import big_vision.pp.ops_general
import big_vision.pp.ops_image
import big_vision.pp.ops_text
with open(sys.argv[1], "r") as config_file: with open(sys.argv[1], "r") as config_file:
CONFIG = json.load(config_file) CONFIG = json.load(config_file)
@ -96,6 +102,7 @@ def do_inference(params: InferenceParameters):
with inference_time_hist.labels(MODELNAME + "-image", images.shape[0]).time(): with inference_time_hist.labels(MODELNAME + "-image", images.shape[0]).time():
features = run_image_model(images) features = run_image_model(images)
batch_count_ctr.labels(MODELNAME).inc() batch_count_ctr.labels(MODELNAME).inc()
# TODO got to divide somewhere
callback(True, numpy.asarray(features)) callback(True, numpy.asarray(features))
except Exception as e: except Exception as e:
traceback.print_exc() traceback.print_exc()