mirror of
https://github.com/osmarks/meme-search-engine.git
synced 2024-11-10 22:09:54 +00:00
hackily patch horrifyingly nondeterministic-but-fast image encoder in
This commit is contained in:
parent
d4e136b6a7
commit
129b769a56
@ -133,9 +133,7 @@ def benchmark(name, config, batch_size, mod=None, graph_mode=True):
|
|||||||
mod.set_many_constants_with_tensors(params_ait)
|
mod.set_many_constants_with_tensors(params_ait)
|
||||||
mod.fold_constants(sync=True)
|
mod.fold_constants(sync=True)
|
||||||
|
|
||||||
# prepare input/output tensor
|
inputs = [torch.randn([batch_size, config["img_size"], config["img_size"], 3]).cuda().half()]
|
||||||
s = torch.stack([image.cuda().half().permute((1, 2, 0)) for _ in range(batch_size)], dim=0)
|
|
||||||
inputs = [s] #[torch.randn([batch_size, config["img_size"], config["img_size"], 3]).cuda().half()]
|
|
||||||
ys = []
|
ys = []
|
||||||
num_outputs = len(mod.get_output_name_to_index_map())
|
num_outputs = len(mod.get_output_name_to_index_map())
|
||||||
for i in range(num_outputs):
|
for i in range(num_outputs):
|
||||||
@ -149,10 +147,10 @@ def benchmark(name, config, batch_size, mod=None, graph_mode=True):
|
|||||||
repeat=1,
|
repeat=1,
|
||||||
graph_mode=graph_mode,
|
graph_mode=graph_mode,
|
||||||
)
|
)
|
||||||
q = model.visual.trunk.attn_pool(model.visual.trunk.norm(model.visual.trunk.blocks(model.visual.trunk.patch_embed(input) + model.visual.trunk.pos_embed)))
|
#q = model.visual.trunk.attn_pool(model.visual.trunk.norm(model.visual.trunk.blocks(model.visual.trunk.patch_embed(input) + model.visual.trunk.pos_embed)))
|
||||||
# = #model.visual.trunk.attn_pool.q(model.visual.trunk.attn_pool.latent.expand(batch_size, -1, -1)).reshape(batch_size, 1, 16, 72).transpose(1, 2)
|
## = #model.visual.trunk.attn_pool.q(model.visual.trunk.attn_pool.latent.expand(batch_size, -1, -1)).reshape(batch_size, 1, 16, 72).transpose(1, 2)
|
||||||
print("expected", q, q.shape)
|
#print("expected", q, q.shape)
|
||||||
print("actual", ys[0], ys[0].shape)
|
#print("actual", ys[0], ys[0].shape)
|
||||||
"""
|
"""
|
||||||
batch = ys[0][:, 0, :]
|
batch = ys[0][:, 0, :]
|
||||||
batch = torch.nn.functional.normalize(batch, dim=-1)
|
batch = torch.nn.functional.normalize(batch, dim=-1)
|
||||||
@ -160,6 +158,6 @@ def benchmark(name, config, batch_size, mod=None, graph_mode=True):
|
|||||||
print(f"batch_size: {batch_size}, latency: {t}")
|
print(f"batch_size: {batch_size}, latency: {t}")
|
||||||
"""
|
"""
|
||||||
#for bs in (1, 2, 4, 8, 16, 32, 64, 128, 256):
|
#for bs in (1, 2, 4, 8, 16, 32, 64, 128, 256):
|
||||||
for bs in (batch_size,):
|
for bs in (1, 2, 4, 8, 16, 32):
|
||||||
compile_vit(siglip_so400m_384_14, bs, use_fp16_acc=True)
|
compile_vit(siglip_so400m_384_14, bs, use_fp16_acc=True)
|
||||||
benchmark("siglip_so400m_384_14", siglip_so400m_384_14, bs, graph_mode=True)
|
benchmark("siglip_so400m_384_14", siglip_so400m_384_14, bs, graph_mode=True)
|
@ -28,6 +28,59 @@ print("Model loaded")
|
|||||||
BS = CONFIG["max_batch_size"]
|
BS = CONFIG["max_batch_size"]
|
||||||
MODELNAME = CONFIG["model_name"]
|
MODELNAME = CONFIG["model_name"]
|
||||||
|
|
||||||
|
fast_image_fns = {}
|
||||||
|
# ugly hack, sorry
|
||||||
|
if CONFIG.get("aitemplate_image_models"):
|
||||||
|
from aitemplate.compiler import Model
|
||||||
|
from aitemplate.testing import detect_target
|
||||||
|
|
||||||
|
USE_CUDA = detect_target().name() == "cuda"
|
||||||
|
|
||||||
|
state = model.state_dict()
|
||||||
|
conv_weights = state["visual.trunk.patch_embed.proj.weight"].permute((0, 2, 3, 1)).contiguous().cuda().half()
|
||||||
|
|
||||||
|
def load_pretrained():
|
||||||
|
params = {}
|
||||||
|
for key, value in state.items():
|
||||||
|
orig_key = key
|
||||||
|
if key.startswith("visual."):
|
||||||
|
key = key.removeprefix("visual.") \
|
||||||
|
.replace("trunk.patch_embed", "patch_embed") \
|
||||||
|
.replace("trunk.blocks", "encoder.layers") \
|
||||||
|
.replace(".attn.", ".mha.") \
|
||||||
|
.replace(".norm1.", ".ln1.") \
|
||||||
|
.replace(".norm2.", ".ln2.") \
|
||||||
|
.replace("trunk.pos_embed", "pos_emb_pos_emb") \
|
||||||
|
.replace("trunk.norm.", "encoder.ln.") \
|
||||||
|
.replace("trunk.attn_pool.latent", "pool.probe") \
|
||||||
|
.replace("trunk.attn_pool", "pool") \
|
||||||
|
.replace("pool.norm", "pool.ln")
|
||||||
|
if "patch_embed.proj.weight" not in key:
|
||||||
|
params[key.replace(".", "_")] = value.cuda()
|
||||||
|
#print(orig_key, key.replace(".", "_"))
|
||||||
|
|
||||||
|
params["patch_embed_proj_weight"] = conv_weights
|
||||||
|
|
||||||
|
return params
|
||||||
|
|
||||||
|
def generate_wrapper(path):
|
||||||
|
ait_model = Model(path)
|
||||||
|
ait_model.set_many_constants_with_tensors(load_pretrained())
|
||||||
|
ait_model.fold_constants(sync=True)
|
||||||
|
def wrapper(batch):
|
||||||
|
xs = [batch.permute((0, 2, 3, 1)).contiguous()]
|
||||||
|
ys = []
|
||||||
|
for i in range(len(ait_model.get_output_name_to_index_map())):
|
||||||
|
shape = ait_model.get_output_maximum_shape(i)
|
||||||
|
ys.append(torch.empty(shape).cuda().half())
|
||||||
|
ait_model.run_with_tensors(xs, ys)
|
||||||
|
return ys[0][:, 0, :]
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
for batch_size, path in CONFIG["aitemplate_image_models"]:
|
||||||
|
fast_image_fns[batch_size] = generate_wrapper(path)
|
||||||
|
print("loaded", batch_size, path)
|
||||||
|
|
||||||
InferenceParameters = collections.namedtuple("InferenceParameters", ["text", "images", "callback"])
|
InferenceParameters = collections.namedtuple("InferenceParameters", ["text", "images", "callback"])
|
||||||
|
|
||||||
items_ctr = Counter("modelserver_total_items", "Items run through model server", ["model", "modality"])
|
items_ctr = Counter("modelserver_total_items", "Items run through model server", ["model", "modality"])
|
||||||
@ -48,6 +101,16 @@ def do_inference(params: InferenceParameters):
|
|||||||
elif images is not None:
|
elif images is not None:
|
||||||
with inference_time_hist.labels(MODELNAME + "-image", images.shape[0]).time():
|
with inference_time_hist.labels(MODELNAME + "-image", images.shape[0]).time():
|
||||||
items_ctr.labels(MODELNAME, "image").inc(images.shape[0])
|
items_ctr.labels(MODELNAME, "image").inc(images.shape[0])
|
||||||
|
batch = images.shape[0]
|
||||||
|
if fast_image_fns:
|
||||||
|
progress = 0
|
||||||
|
features = torch.zeros((batch, model.text.text_projection.out_features))
|
||||||
|
while progress < batch:
|
||||||
|
biggest_available = max(x for x in fast_image_fns.keys() if x <= (batch - progress))
|
||||||
|
chunk = fast_image_fns[biggest_available](images[progress:progress + biggest_available])
|
||||||
|
features[progress:progress + biggest_available] = chunk
|
||||||
|
progress += biggest_available
|
||||||
|
else:
|
||||||
features = model.encode_image(images)
|
features = model.encode_image(images)
|
||||||
features /= features.norm(dim=-1, keepdim=True)
|
features /= features.norm(dim=-1, keepdim=True)
|
||||||
features = features.cpu().numpy()
|
features = features.cpu().numpy()
|
||||||
|
@ -13,9 +13,15 @@ from prometheus_client import Counter, Histogram, REGISTRY, generate_latest
|
|||||||
import io
|
import io
|
||||||
import json
|
import json
|
||||||
import sys
|
import sys
|
||||||
import torch
|
|
||||||
from transformers import SiglipTokenizer, SiglipImageProcessor, T5TokenizerFast, SiglipTextConfig, SiglipVisionConfig
|
|
||||||
import numpy
|
import numpy
|
||||||
|
import big_vision.models.proj.image_text.two_towers as model_mod
|
||||||
|
import jax
|
||||||
|
import jax.numpy as jnp
|
||||||
|
import ml_collections
|
||||||
|
import big_vision.pp.builder as pp_builder
|
||||||
|
import big_vision.pp.ops_general
|
||||||
|
import big_vision.pp.ops_image
|
||||||
|
import big_vision.pp.ops_text
|
||||||
|
|
||||||
with open(sys.argv[1], "r") as config_file:
|
with open(sys.argv[1], "r") as config_file:
|
||||||
CONFIG = json.load(config_file)
|
CONFIG = json.load(config_file)
|
||||||
@ -96,6 +102,7 @@ def do_inference(params: InferenceParameters):
|
|||||||
with inference_time_hist.labels(MODELNAME + "-image", images.shape[0]).time():
|
with inference_time_hist.labels(MODELNAME + "-image", images.shape[0]).time():
|
||||||
features = run_image_model(images)
|
features = run_image_model(images)
|
||||||
batch_count_ctr.labels(MODELNAME).inc()
|
batch_count_ctr.labels(MODELNAME).inc()
|
||||||
|
# TODO got to divide somewhere
|
||||||
callback(True, numpy.asarray(features))
|
callback(True, numpy.asarray(features))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
|
Loading…
Reference in New Issue
Block a user