mirror of
				https://github.com/osmarks/meme-search-engine.git
				synced 2025-11-04 09:13:05 +00:00 
			
		
		
		
	hackily patch horrifyingly nondeterministic-but-fast image encoder in
This commit is contained in:
		@@ -133,9 +133,7 @@ def benchmark(name, config, batch_size, mod=None, graph_mode=True):
 | 
			
		||||
    mod.set_many_constants_with_tensors(params_ait)
 | 
			
		||||
    mod.fold_constants(sync=True)
 | 
			
		||||
 | 
			
		||||
    # prepare input/output tensor
 | 
			
		||||
    s = torch.stack([image.cuda().half().permute((1, 2, 0)) for _ in range(batch_size)], dim=0)
 | 
			
		||||
    inputs = [s] #[torch.randn([batch_size, config["img_size"], config["img_size"], 3]).cuda().half()]
 | 
			
		||||
    inputs = [torch.randn([batch_size, config["img_size"], config["img_size"], 3]).cuda().half()]
 | 
			
		||||
    ys = []
 | 
			
		||||
    num_outputs = len(mod.get_output_name_to_index_map())
 | 
			
		||||
    for i in range(num_outputs):
 | 
			
		||||
@@ -149,10 +147,10 @@ def benchmark(name, config, batch_size, mod=None, graph_mode=True):
 | 
			
		||||
        repeat=1,
 | 
			
		||||
        graph_mode=graph_mode,
 | 
			
		||||
    )
 | 
			
		||||
    q = model.visual.trunk.attn_pool(model.visual.trunk.norm(model.visual.trunk.blocks(model.visual.trunk.patch_embed(input) + model.visual.trunk.pos_embed)))
 | 
			
		||||
    # = #model.visual.trunk.attn_pool.q(model.visual.trunk.attn_pool.latent.expand(batch_size, -1, -1)).reshape(batch_size, 1, 16, 72).transpose(1, 2)
 | 
			
		||||
    print("expected", q, q.shape)
 | 
			
		||||
    print("actual", ys[0], ys[0].shape)
 | 
			
		||||
    #q = model.visual.trunk.attn_pool(model.visual.trunk.norm(model.visual.trunk.blocks(model.visual.trunk.patch_embed(input) + model.visual.trunk.pos_embed)))
 | 
			
		||||
    ## = #model.visual.trunk.attn_pool.q(model.visual.trunk.attn_pool.latent.expand(batch_size, -1, -1)).reshape(batch_size, 1, 16, 72).transpose(1, 2)
 | 
			
		||||
    #print("expected", q, q.shape)
 | 
			
		||||
    #print("actual", ys[0], ys[0].shape)
 | 
			
		||||
    """
 | 
			
		||||
    batch = ys[0][:, 0, :]
 | 
			
		||||
    batch = torch.nn.functional.normalize(batch, dim=-1)
 | 
			
		||||
@@ -160,6 +158,6 @@ def benchmark(name, config, batch_size, mod=None, graph_mode=True):
 | 
			
		||||
    print(f"batch_size: {batch_size}, latency: {t}")
 | 
			
		||||
    """
 | 
			
		||||
#for bs in (1, 2, 4, 8, 16, 32, 64, 128, 256):
 | 
			
		||||
for bs in (batch_size,):
 | 
			
		||||
for bs in (1, 2, 4, 8, 16, 32):
 | 
			
		||||
    compile_vit(siglip_so400m_384_14, bs, use_fp16_acc=True)
 | 
			
		||||
    benchmark("siglip_so400m_384_14", siglip_so400m_384_14, bs, graph_mode=True)
 | 
			
		||||
@@ -28,6 +28,59 @@ print("Model loaded")
 | 
			
		||||
BS = CONFIG["max_batch_size"]
 | 
			
		||||
MODELNAME = CONFIG["model_name"]
 | 
			
		||||
 | 
			
		||||
fast_image_fns = {}
 | 
			
		||||
# ugly hack, sorry
 | 
			
		||||
if CONFIG.get("aitemplate_image_models"):
 | 
			
		||||
    from aitemplate.compiler import Model
 | 
			
		||||
    from aitemplate.testing import detect_target
 | 
			
		||||
 | 
			
		||||
    USE_CUDA = detect_target().name() == "cuda"
 | 
			
		||||
 | 
			
		||||
    state = model.state_dict()
 | 
			
		||||
    conv_weights = state["visual.trunk.patch_embed.proj.weight"].permute((0, 2, 3, 1)).contiguous().cuda().half()
 | 
			
		||||
 | 
			
		||||
    def load_pretrained():
 | 
			
		||||
        params = {}
 | 
			
		||||
        for key, value in state.items():
 | 
			
		||||
            orig_key = key
 | 
			
		||||
            if key.startswith("visual."):
 | 
			
		||||
                key = key.removeprefix("visual.") \
 | 
			
		||||
                    .replace("trunk.patch_embed", "patch_embed") \
 | 
			
		||||
                    .replace("trunk.blocks", "encoder.layers") \
 | 
			
		||||
                    .replace(".attn.", ".mha.") \
 | 
			
		||||
                    .replace(".norm1.", ".ln1.") \
 | 
			
		||||
                    .replace(".norm2.", ".ln2.") \
 | 
			
		||||
                    .replace("trunk.pos_embed", "pos_emb_pos_emb") \
 | 
			
		||||
                    .replace("trunk.norm.", "encoder.ln.") \
 | 
			
		||||
                    .replace("trunk.attn_pool.latent", "pool.probe") \
 | 
			
		||||
                    .replace("trunk.attn_pool", "pool") \
 | 
			
		||||
                    .replace("pool.norm", "pool.ln")
 | 
			
		||||
                if "patch_embed.proj.weight" not in key:
 | 
			
		||||
                    params[key.replace(".", "_")] = value.cuda()
 | 
			
		||||
                    #print(orig_key, key.replace(".", "_"))
 | 
			
		||||
    
 | 
			
		||||
        params["patch_embed_proj_weight"] = conv_weights
 | 
			
		||||
 | 
			
		||||
        return params
 | 
			
		||||
 | 
			
		||||
    def generate_wrapper(path):
 | 
			
		||||
        ait_model = Model(path)
 | 
			
		||||
        ait_model.set_many_constants_with_tensors(load_pretrained())
 | 
			
		||||
        ait_model.fold_constants(sync=True)
 | 
			
		||||
        def wrapper(batch):
 | 
			
		||||
            xs = [batch.permute((0, 2, 3, 1)).contiguous()]
 | 
			
		||||
            ys = []
 | 
			
		||||
            for i in range(len(ait_model.get_output_name_to_index_map())):
 | 
			
		||||
                shape = ait_model.get_output_maximum_shape(i)
 | 
			
		||||
                ys.append(torch.empty(shape).cuda().half())
 | 
			
		||||
            ait_model.run_with_tensors(xs, ys)
 | 
			
		||||
            return ys[0][:, 0, :]
 | 
			
		||||
        return wrapper
 | 
			
		||||
 | 
			
		||||
    for batch_size, path in CONFIG["aitemplate_image_models"]:
 | 
			
		||||
        fast_image_fns[batch_size] = generate_wrapper(path)
 | 
			
		||||
        print("loaded", batch_size, path)
 | 
			
		||||
 | 
			
		||||
InferenceParameters = collections.namedtuple("InferenceParameters", ["text", "images", "callback"])
 | 
			
		||||
 | 
			
		||||
items_ctr = Counter("modelserver_total_items", "Items run through model server", ["model", "modality"])
 | 
			
		||||
@@ -48,6 +101,16 @@ def do_inference(params: InferenceParameters):
 | 
			
		||||
            elif images is not None:
 | 
			
		||||
                with inference_time_hist.labels(MODELNAME + "-image", images.shape[0]).time():
 | 
			
		||||
                    items_ctr.labels(MODELNAME, "image").inc(images.shape[0])
 | 
			
		||||
                    batch = images.shape[0]
 | 
			
		||||
                    if fast_image_fns:
 | 
			
		||||
                        progress = 0
 | 
			
		||||
                        features = torch.zeros((batch, model.text.text_projection.out_features))
 | 
			
		||||
                        while progress < batch:
 | 
			
		||||
                            biggest_available = max(x for x in fast_image_fns.keys() if x <= (batch - progress))
 | 
			
		||||
                            chunk = fast_image_fns[biggest_available](images[progress:progress + biggest_available])
 | 
			
		||||
                            features[progress:progress + biggest_available] = chunk
 | 
			
		||||
                            progress += biggest_available
 | 
			
		||||
                    else:
 | 
			
		||||
                        features = model.encode_image(images)
 | 
			
		||||
                    features /= features.norm(dim=-1, keepdim=True)
 | 
			
		||||
                    features = features.cpu().numpy()
 | 
			
		||||
 
 | 
			
		||||
@@ -13,9 +13,15 @@ from prometheus_client import Counter, Histogram, REGISTRY, generate_latest
 | 
			
		||||
import io
 | 
			
		||||
import json
 | 
			
		||||
import sys
 | 
			
		||||
import torch
 | 
			
		||||
from transformers import SiglipTokenizer, SiglipImageProcessor, T5TokenizerFast, SiglipTextConfig, SiglipVisionConfig
 | 
			
		||||
import numpy
 | 
			
		||||
import big_vision.models.proj.image_text.two_towers as model_mod
 | 
			
		||||
import jax
 | 
			
		||||
import jax.numpy as jnp
 | 
			
		||||
import ml_collections
 | 
			
		||||
import big_vision.pp.builder as pp_builder
 | 
			
		||||
import big_vision.pp.ops_general
 | 
			
		||||
import big_vision.pp.ops_image
 | 
			
		||||
import big_vision.pp.ops_text
 | 
			
		||||
 | 
			
		||||
with open(sys.argv[1], "r") as config_file:
 | 
			
		||||
    CONFIG = json.load(config_file)
 | 
			
		||||
@@ -96,6 +102,7 @@ def do_inference(params: InferenceParameters):
 | 
			
		||||
            with inference_time_hist.labels(MODELNAME + "-image", images.shape[0]).time():
 | 
			
		||||
                features = run_image_model(images)
 | 
			
		||||
        batch_count_ctr.labels(MODELNAME).inc()
 | 
			
		||||
        # TODO got to divide somewhere
 | 
			
		||||
        callback(True, numpy.asarray(features))
 | 
			
		||||
    except Exception as e:
 | 
			
		||||
        traceback.print_exc()
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user