mirror of
https://github.com/osmarks/meme-search-engine.git
synced 2024-11-13 07:19:54 +00:00
hackily patch horrifyingly nondeterministic-but-fast image encoder in
This commit is contained in:
parent
d4e136b6a7
commit
129b769a56
@ -133,9 +133,7 @@ def benchmark(name, config, batch_size, mod=None, graph_mode=True):
|
||||
mod.set_many_constants_with_tensors(params_ait)
|
||||
mod.fold_constants(sync=True)
|
||||
|
||||
# prepare input/output tensor
|
||||
s = torch.stack([image.cuda().half().permute((1, 2, 0)) for _ in range(batch_size)], dim=0)
|
||||
inputs = [s] #[torch.randn([batch_size, config["img_size"], config["img_size"], 3]).cuda().half()]
|
||||
inputs = [torch.randn([batch_size, config["img_size"], config["img_size"], 3]).cuda().half()]
|
||||
ys = []
|
||||
num_outputs = len(mod.get_output_name_to_index_map())
|
||||
for i in range(num_outputs):
|
||||
@ -149,10 +147,10 @@ def benchmark(name, config, batch_size, mod=None, graph_mode=True):
|
||||
repeat=1,
|
||||
graph_mode=graph_mode,
|
||||
)
|
||||
q = model.visual.trunk.attn_pool(model.visual.trunk.norm(model.visual.trunk.blocks(model.visual.trunk.patch_embed(input) + model.visual.trunk.pos_embed)))
|
||||
# = #model.visual.trunk.attn_pool.q(model.visual.trunk.attn_pool.latent.expand(batch_size, -1, -1)).reshape(batch_size, 1, 16, 72).transpose(1, 2)
|
||||
print("expected", q, q.shape)
|
||||
print("actual", ys[0], ys[0].shape)
|
||||
#q = model.visual.trunk.attn_pool(model.visual.trunk.norm(model.visual.trunk.blocks(model.visual.trunk.patch_embed(input) + model.visual.trunk.pos_embed)))
|
||||
## = #model.visual.trunk.attn_pool.q(model.visual.trunk.attn_pool.latent.expand(batch_size, -1, -1)).reshape(batch_size, 1, 16, 72).transpose(1, 2)
|
||||
#print("expected", q, q.shape)
|
||||
#print("actual", ys[0], ys[0].shape)
|
||||
"""
|
||||
batch = ys[0][:, 0, :]
|
||||
batch = torch.nn.functional.normalize(batch, dim=-1)
|
||||
@ -160,6 +158,6 @@ def benchmark(name, config, batch_size, mod=None, graph_mode=True):
|
||||
print(f"batch_size: {batch_size}, latency: {t}")
|
||||
"""
|
||||
#for bs in (1, 2, 4, 8, 16, 32, 64, 128, 256):
|
||||
for bs in (batch_size,):
|
||||
for bs in (1, 2, 4, 8, 16, 32):
|
||||
compile_vit(siglip_so400m_384_14, bs, use_fp16_acc=True)
|
||||
benchmark("siglip_so400m_384_14", siglip_so400m_384_14, bs, graph_mode=True)
|
@ -28,6 +28,59 @@ print("Model loaded")
|
||||
BS = CONFIG["max_batch_size"]
|
||||
MODELNAME = CONFIG["model_name"]
|
||||
|
||||
fast_image_fns = {}
|
||||
# ugly hack, sorry
|
||||
if CONFIG.get("aitemplate_image_models"):
|
||||
from aitemplate.compiler import Model
|
||||
from aitemplate.testing import detect_target
|
||||
|
||||
USE_CUDA = detect_target().name() == "cuda"
|
||||
|
||||
state = model.state_dict()
|
||||
conv_weights = state["visual.trunk.patch_embed.proj.weight"].permute((0, 2, 3, 1)).contiguous().cuda().half()
|
||||
|
||||
def load_pretrained():
|
||||
params = {}
|
||||
for key, value in state.items():
|
||||
orig_key = key
|
||||
if key.startswith("visual."):
|
||||
key = key.removeprefix("visual.") \
|
||||
.replace("trunk.patch_embed", "patch_embed") \
|
||||
.replace("trunk.blocks", "encoder.layers") \
|
||||
.replace(".attn.", ".mha.") \
|
||||
.replace(".norm1.", ".ln1.") \
|
||||
.replace(".norm2.", ".ln2.") \
|
||||
.replace("trunk.pos_embed", "pos_emb_pos_emb") \
|
||||
.replace("trunk.norm.", "encoder.ln.") \
|
||||
.replace("trunk.attn_pool.latent", "pool.probe") \
|
||||
.replace("trunk.attn_pool", "pool") \
|
||||
.replace("pool.norm", "pool.ln")
|
||||
if "patch_embed.proj.weight" not in key:
|
||||
params[key.replace(".", "_")] = value.cuda()
|
||||
#print(orig_key, key.replace(".", "_"))
|
||||
|
||||
params["patch_embed_proj_weight"] = conv_weights
|
||||
|
||||
return params
|
||||
|
||||
def generate_wrapper(path):
|
||||
ait_model = Model(path)
|
||||
ait_model.set_many_constants_with_tensors(load_pretrained())
|
||||
ait_model.fold_constants(sync=True)
|
||||
def wrapper(batch):
|
||||
xs = [batch.permute((0, 2, 3, 1)).contiguous()]
|
||||
ys = []
|
||||
for i in range(len(ait_model.get_output_name_to_index_map())):
|
||||
shape = ait_model.get_output_maximum_shape(i)
|
||||
ys.append(torch.empty(shape).cuda().half())
|
||||
ait_model.run_with_tensors(xs, ys)
|
||||
return ys[0][:, 0, :]
|
||||
return wrapper
|
||||
|
||||
for batch_size, path in CONFIG["aitemplate_image_models"]:
|
||||
fast_image_fns[batch_size] = generate_wrapper(path)
|
||||
print("loaded", batch_size, path)
|
||||
|
||||
InferenceParameters = collections.namedtuple("InferenceParameters", ["text", "images", "callback"])
|
||||
|
||||
items_ctr = Counter("modelserver_total_items", "Items run through model server", ["model", "modality"])
|
||||
@ -48,7 +101,17 @@ def do_inference(params: InferenceParameters):
|
||||
elif images is not None:
|
||||
with inference_time_hist.labels(MODELNAME + "-image", images.shape[0]).time():
|
||||
items_ctr.labels(MODELNAME, "image").inc(images.shape[0])
|
||||
features = model.encode_image(images)
|
||||
batch = images.shape[0]
|
||||
if fast_image_fns:
|
||||
progress = 0
|
||||
features = torch.zeros((batch, model.text.text_projection.out_features))
|
||||
while progress < batch:
|
||||
biggest_available = max(x for x in fast_image_fns.keys() if x <= (batch - progress))
|
||||
chunk = fast_image_fns[biggest_available](images[progress:progress + biggest_available])
|
||||
features[progress:progress + biggest_available] = chunk
|
||||
progress += biggest_available
|
||||
else:
|
||||
features = model.encode_image(images)
|
||||
features /= features.norm(dim=-1, keepdim=True)
|
||||
features = features.cpu().numpy()
|
||||
batch_count_ctr.labels(MODELNAME).inc()
|
||||
|
@ -13,9 +13,15 @@ from prometheus_client import Counter, Histogram, REGISTRY, generate_latest
|
||||
import io
|
||||
import json
|
||||
import sys
|
||||
import torch
|
||||
from transformers import SiglipTokenizer, SiglipImageProcessor, T5TokenizerFast, SiglipTextConfig, SiglipVisionConfig
|
||||
import numpy
|
||||
import big_vision.models.proj.image_text.two_towers as model_mod
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import ml_collections
|
||||
import big_vision.pp.builder as pp_builder
|
||||
import big_vision.pp.ops_general
|
||||
import big_vision.pp.ops_image
|
||||
import big_vision.pp.ops_text
|
||||
|
||||
with open(sys.argv[1], "r") as config_file:
|
||||
CONFIG = json.load(config_file)
|
||||
@ -96,6 +102,7 @@ def do_inference(params: InferenceParameters):
|
||||
with inference_time_hist.labels(MODELNAME + "-image", images.shape[0]).time():
|
||||
features = run_image_model(images)
|
||||
batch_count_ctr.labels(MODELNAME).inc()
|
||||
# TODO got to divide somewhere
|
||||
callback(True, numpy.asarray(features))
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
|
Loading…
Reference in New Issue
Block a user