1
0
mirror of https://github.com/osmarks/meme-search-engine.git synced 2025-01-06 15:30:30 +00:00
meme-search-engine/aitemplate/run.py

163 lines
5.7 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""benchmark for vit"""
import os
import numpy as np
import torch
from aitemplate.compiler import compile_model, Model
from aitemplate.frontend import Tensor
from aitemplate.testing import detect_target
import open_clip
from PIL import Image
from model import VisionTransformer
model, _, preprocess = open_clip.create_model_and_transforms("ViT-SO400M-14-SigLIP-384", pretrained="webli", precision="fp16", device="cuda")
model.eval()
torch.set_grad_enabled(False)
print(model.visual.trunk.patch_embed)
def mark_output(y):
if type(y) is not tuple:
y = (y,)
for i in range(len(y)):
y[i]._attrs["is_output"] = True
y[i]._attrs["name"] = "output_%d" % (i)
y_shape = [d._attrs["values"][0] for d in y[i]._attrs["shape"]]
print("output_{} shape: {}".format(i, y_shape))
USE_CUDA = detect_target().name() == "cuda"
siglip_so400m_384_14 = {
"img_size": 384,
"emb_dim": 1152,
"depth": 27,
"num_heads": 16,
"mlp_dim": 4304,
"patch_size": 14,
"in_chans": 3
}
batch_size = 32
image = preprocess(Image.open("/data/public/memes-or-something/0mg.jpg"))
input = torch.stack([image.cuda().half() for _ in range(batch_size)], dim=0)
def compile_vit(
config,
batch_size,
use_fp16_acc=True,
):
seq_len = (config["img_size"] // config["patch_size"]) ** 2
ait_model = VisionTransformer(
batch_size=batch_size,
seq_len=seq_len,
**config
)
ait_model.name_parameter_tensor()
print(ait_model)
inputs_ait = Tensor(
[batch_size, config["img_size"], config["img_size"], config["in_chans"]], name="input0", is_input=True
)
Y = ait_model(inputs_ait)
mark_output(Y)
target = detect_target(use_fp16_acc=use_fp16_acc)
exe_module = compile_model(
Y, target, "./tmp", "vision_transformer_bs%d_seq%d" % (batch_size, seq_len)
)
return exe_module
def load_pretrained(config):
params = {}
st = model.state_dict()
for key, value in st.items():
orig_key = key
if key.startswith("visual."):
key = key.removeprefix("visual.") \
.replace("trunk.patch_embed", "patch_embed") \
.replace("trunk.blocks", "encoder.layers") \
.replace(".attn.", ".mha.") \
.replace(".norm1.", ".ln1.") \
.replace(".norm2.", ".ln2.") \
.replace("trunk.pos_embed", "pos_emb_pos_emb") \
.replace("trunk.norm.", "encoder.ln.") \
.replace("trunk.attn_pool.latent", "pool.probe") \
.replace("trunk.attn_pool", "pool") \
.replace("pool.norm", "pool.ln")
if "patch_embed.proj.weight" not in key:
params[key.replace(".", "_")] = value.cuda()
print(orig_key, key.replace(".", "_"))
if USE_CUDA:
# horrors
w_pad = torch.zeros((config["emb_dim"], config["patch_size"], config["patch_size"], 4)).cuda().half()
w = st["visual.trunk.patch_embed.proj.weight"]#.permute((0, 2, 3, 1)).contiguous()
params["patch_embed_proj_weight"] = w.permute((0, 2, 3, 1)).contiguous().cuda().half() # N H W C
else:
params["patch_embed_proj_weight"] = st["visual.trunk.patch_embed.proj.weight"].permute((0, 2, 3, 1)).contiguous().cuda().half()
return params
def benchmark(name, config, batch_size, mod=None, graph_mode=True):
seqlen = (config["img_size"] // config["patch_size"]) ** 2
if mod is None:
model_dir = f"vision_transformer_bs{batch_size}_seq{seqlen}"
mod = Model(os.path.join("./tmp", model_dir, "test.so"))
# prepare params
params_ait = load_pretrained(config)
s = set(mod.get_constant_names())
d = []
for k in params_ait:
if k not in s:
d.append(k)
for x in d:
del params_ait[x]
mod.set_many_constants_with_tensors(params_ait)
mod.fold_constants(sync=True)
inputs = [torch.randn([batch_size, config["img_size"], config["img_size"], 3]).cuda().half()]
ys = []
num_outputs = len(mod.get_output_name_to_index_map())
for i in range(num_outputs):
shape = mod.get_output_maximum_shape(i)
ys.append(torch.empty(shape).cuda().half())
# warm up
t, _, __ = mod.benchmark_with_tensors(
inputs,
ys,
count=10,
repeat=1,
graph_mode=graph_mode,
)
#q = model.visual.trunk.attn_pool(model.visual.trunk.norm(model.visual.trunk.blocks(model.visual.trunk.patch_embed(input) + model.visual.trunk.pos_embed)))
## = #model.visual.trunk.attn_pool.q(model.visual.trunk.attn_pool.latent.expand(batch_size, -1, -1)).reshape(batch_size, 1, 16, 72).transpose(1, 2)
#print("expected", q, q.shape)
#print("actual", ys[0], ys[0].shape)
"""
batch = ys[0][:, 0, :]
batch = torch.nn.functional.normalize(batch, dim=-1)
print(batch)
print(f"batch_size: {batch_size}, latency: {t}")
"""
#for bs in (1, 2, 4, 8, 16, 32, 64, 128, 256):
for bs in (1, 2, 4, 8, 16, 32):
compile_vit(siglip_so400m_384_14, bs, use_fp16_acc=True)
benchmark("siglip_so400m_384_14", siglip_so400m_384_14, bs, graph_mode=True)