diff --git a/aitemplate/run.py b/aitemplate/run.py index 01ccabf..dd12886 100644 --- a/aitemplate/run.py +++ b/aitemplate/run.py @@ -133,9 +133,7 @@ def benchmark(name, config, batch_size, mod=None, graph_mode=True): mod.set_many_constants_with_tensors(params_ait) mod.fold_constants(sync=True) - # prepare input/output tensor - s = torch.stack([image.cuda().half().permute((1, 2, 0)) for _ in range(batch_size)], dim=0) - inputs = [s] #[torch.randn([batch_size, config["img_size"], config["img_size"], 3]).cuda().half()] + inputs = [torch.randn([batch_size, config["img_size"], config["img_size"], 3]).cuda().half()] ys = [] num_outputs = len(mod.get_output_name_to_index_map()) for i in range(num_outputs): @@ -149,10 +147,10 @@ def benchmark(name, config, batch_size, mod=None, graph_mode=True): repeat=1, graph_mode=graph_mode, ) - q = model.visual.trunk.attn_pool(model.visual.trunk.norm(model.visual.trunk.blocks(model.visual.trunk.patch_embed(input) + model.visual.trunk.pos_embed))) - # = #model.visual.trunk.attn_pool.q(model.visual.trunk.attn_pool.latent.expand(batch_size, -1, -1)).reshape(batch_size, 1, 16, 72).transpose(1, 2) - print("expected", q, q.shape) - print("actual", ys[0], ys[0].shape) + #q = model.visual.trunk.attn_pool(model.visual.trunk.norm(model.visual.trunk.blocks(model.visual.trunk.patch_embed(input) + model.visual.trunk.pos_embed))) + ## = #model.visual.trunk.attn_pool.q(model.visual.trunk.attn_pool.latent.expand(batch_size, -1, -1)).reshape(batch_size, 1, 16, 72).transpose(1, 2) + #print("expected", q, q.shape) + #print("actual", ys[0], ys[0].shape) """ batch = ys[0][:, 0, :] batch = torch.nn.functional.normalize(batch, dim=-1) @@ -160,6 +158,6 @@ def benchmark(name, config, batch_size, mod=None, graph_mode=True): print(f"batch_size: {batch_size}, latency: {t}") """ #for bs in (1, 2, 4, 8, 16, 32, 64, 128, 256): -for bs in (batch_size,): +for bs in (1, 2, 4, 8, 16, 32): compile_vit(siglip_so400m_384_14, bs, use_fp16_acc=True) benchmark("siglip_so400m_384_14", siglip_so400m_384_14, bs, graph_mode=True) \ No newline at end of file diff --git a/clip_server.py b/clip_server.py index 7935348..a18ab5c 100644 --- a/clip_server.py +++ b/clip_server.py @@ -28,6 +28,59 @@ print("Model loaded") BS = CONFIG["max_batch_size"] MODELNAME = CONFIG["model_name"] +fast_image_fns = {} +# ugly hack, sorry +if CONFIG.get("aitemplate_image_models"): + from aitemplate.compiler import Model + from aitemplate.testing import detect_target + + USE_CUDA = detect_target().name() == "cuda" + + state = model.state_dict() + conv_weights = state["visual.trunk.patch_embed.proj.weight"].permute((0, 2, 3, 1)).contiguous().cuda().half() + + def load_pretrained(): + params = {} + for key, value in state.items(): + orig_key = key + if key.startswith("visual."): + key = key.removeprefix("visual.") \ + .replace("trunk.patch_embed", "patch_embed") \ + .replace("trunk.blocks", "encoder.layers") \ + .replace(".attn.", ".mha.") \ + .replace(".norm1.", ".ln1.") \ + .replace(".norm2.", ".ln2.") \ + .replace("trunk.pos_embed", "pos_emb_pos_emb") \ + .replace("trunk.norm.", "encoder.ln.") \ + .replace("trunk.attn_pool.latent", "pool.probe") \ + .replace("trunk.attn_pool", "pool") \ + .replace("pool.norm", "pool.ln") + if "patch_embed.proj.weight" not in key: + params[key.replace(".", "_")] = value.cuda() + #print(orig_key, key.replace(".", "_")) + + params["patch_embed_proj_weight"] = conv_weights + + return params + + def generate_wrapper(path): + ait_model = Model(path) + ait_model.set_many_constants_with_tensors(load_pretrained()) + ait_model.fold_constants(sync=True) + def wrapper(batch): + xs = [batch.permute((0, 2, 3, 1)).contiguous()] + ys = [] + for i in range(len(ait_model.get_output_name_to_index_map())): + shape = ait_model.get_output_maximum_shape(i) + ys.append(torch.empty(shape).cuda().half()) + ait_model.run_with_tensors(xs, ys) + return ys[0][:, 0, :] + return wrapper + + for batch_size, path in CONFIG["aitemplate_image_models"]: + fast_image_fns[batch_size] = generate_wrapper(path) + print("loaded", batch_size, path) + InferenceParameters = collections.namedtuple("InferenceParameters", ["text", "images", "callback"]) items_ctr = Counter("modelserver_total_items", "Items run through model server", ["model", "modality"]) @@ -48,7 +101,17 @@ def do_inference(params: InferenceParameters): elif images is not None: with inference_time_hist.labels(MODELNAME + "-image", images.shape[0]).time(): items_ctr.labels(MODELNAME, "image").inc(images.shape[0]) - features = model.encode_image(images) + batch = images.shape[0] + if fast_image_fns: + progress = 0 + features = torch.zeros((batch, model.text.text_projection.out_features)) + while progress < batch: + biggest_available = max(x for x in fast_image_fns.keys() if x <= (batch - progress)) + chunk = fast_image_fns[biggest_available](images[progress:progress + biggest_available]) + features[progress:progress + biggest_available] = chunk + progress += biggest_available + else: + features = model.encode_image(images) features /= features.norm(dim=-1, keepdim=True) features = features.cpu().numpy() batch_count_ctr.labels(MODELNAME).inc() diff --git a/misc/clip_accursed.py b/misc/clip_accursed.py index 9fdb907..bc023f9 100644 --- a/misc/clip_accursed.py +++ b/misc/clip_accursed.py @@ -13,9 +13,15 @@ from prometheus_client import Counter, Histogram, REGISTRY, generate_latest import io import json import sys -import torch -from transformers import SiglipTokenizer, SiglipImageProcessor, T5TokenizerFast, SiglipTextConfig, SiglipVisionConfig import numpy +import big_vision.models.proj.image_text.two_towers as model_mod +import jax +import jax.numpy as jnp +import ml_collections +import big_vision.pp.builder as pp_builder +import big_vision.pp.ops_general +import big_vision.pp.ops_image +import big_vision.pp.ops_text with open(sys.argv[1], "r") as config_file: CONFIG = json.load(config_file) @@ -96,6 +102,7 @@ def do_inference(params: InferenceParameters): with inference_time_hist.labels(MODELNAME + "-image", images.shape[0]).time(): features = run_image_model(images) batch_count_ctr.labels(MODELNAME).inc() + # TODO got to divide somewhere callback(True, numpy.asarray(features)) except Exception as e: traceback.print_exc()