mirror of
https://github.com/osmarks/meme-search-engine.git
synced 2024-11-10 22:09:54 +00:00
AITemplate builds of the image encoder work, at great personal cost
This commit is contained in:
parent
a8329e43fc
commit
d4e136b6a7
123
aitemplate/model.py
Normal file
123
aitemplate/model.py
Normal file
@ -0,0 +1,123 @@
|
|||||||
|
from functools import partial
|
||||||
|
|
||||||
|
from aitemplate.compiler import ops
|
||||||
|
from aitemplate.frontend import nn
|
||||||
|
from aitemplate.testing import detect_target
|
||||||
|
|
||||||
|
USE_CUDA = detect_target().name() == "cuda"
|
||||||
|
|
||||||
|
def get_shape(x):
|
||||||
|
shape = [it.value() for it in x._attrs["shape"]]
|
||||||
|
return shape
|
||||||
|
|
||||||
|
class MLPBlock(nn.Module):
|
||||||
|
def __init__(self, emb_dim, mlp_dim):
|
||||||
|
super().__init__()
|
||||||
|
self.emb_dim = emb_dim
|
||||||
|
self.mlp_dim = mlp_dim
|
||||||
|
self.fc1 = nn.Linear(emb_dim, mlp_dim, specialization="gelu")
|
||||||
|
self.fc2 = nn.Linear(mlp_dim, emb_dim, specialization="add")
|
||||||
|
|
||||||
|
def forward(self, x, res):
|
||||||
|
x = self.fc1(x)
|
||||||
|
x = self.fc2(x, res)
|
||||||
|
return x
|
||||||
|
|
||||||
|
class Encoder1DBlock(nn.Module):
|
||||||
|
def __init__(self, emb_dim, mlp_dim, num_heads, batch_size, seq_len):
|
||||||
|
super().__init__()
|
||||||
|
self.ln1 = nn.LayerNorm(emb_dim)
|
||||||
|
self.mha = nn.MultiheadAttention(
|
||||||
|
emb_dim,
|
||||||
|
batch_size,
|
||||||
|
seq_len,
|
||||||
|
num_heads,
|
||||||
|
use_mem_eff=True
|
||||||
|
)
|
||||||
|
self.mlp = MLPBlock(emb_dim, mlp_dim)
|
||||||
|
self.ln2 = nn.LayerNorm(emb_dim)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
#self_attention_input = self.ln1(x)
|
||||||
|
x = self.mha(self.ln1(x), x)
|
||||||
|
x = self.mlp(self.ln2(x), x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
class Encoder(nn.Module):
|
||||||
|
def __init__(self, emb_dim, mlp_dim, num_heads, batch_size, seq_len, depth):
|
||||||
|
super().__init__()
|
||||||
|
self.layers = nn.ModuleList([ Encoder1DBlock(emb_dim, mlp_dim, num_heads, batch_size, seq_len) for i in range(depth) ])
|
||||||
|
self.ln = nn.LayerNorm(emb_dim)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
for layer in self.layers:
|
||||||
|
x = layer(x)
|
||||||
|
return self.ln(x)
|
||||||
|
|
||||||
|
class PositionalEmbeddings(nn.Module):
|
||||||
|
def __init__(self, emb_dim, seq_len):
|
||||||
|
super().__init__()
|
||||||
|
self.pos_emb = nn.Parameter(shape=[1, seq_len, emb_dim], dtype="float16")
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return x + self.pos_emb.tensor()
|
||||||
|
|
||||||
|
class PatchEmbedder(nn.Module):
|
||||||
|
def __init__(self, img_size, patch_size, in_chans, emb_dim):
|
||||||
|
super().__init__()
|
||||||
|
conv_op = nn.Conv2dBiasFewChannels if USE_CUDA else nn.Conv2dBias
|
||||||
|
self.proj = conv_op(in_chans, emb_dim, kernel_size=patch_size, stride=patch_size, padding=0, auto_padding=False)
|
||||||
|
self.flatten = True
|
||||||
|
self.emb_dim = emb_dim
|
||||||
|
self.proj_norm = nn.Identity()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
B, H, W, C = get_shape(x)
|
||||||
|
x = self.proj(x)
|
||||||
|
if self.flatten:
|
||||||
|
x = ops.reshape()(x, [B, -1, self.emb_dim])
|
||||||
|
x = self.proj_norm(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
class MAPHead(nn.Module):
|
||||||
|
def __init__(self, emb_dim, mlp_dim, num_heads, batch_size, seq_len):
|
||||||
|
super().__init__()
|
||||||
|
self.q = nn.Linear(emb_dim, emb_dim)
|
||||||
|
self.kv = nn.Linear(emb_dim, emb_dim * 2)
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.head_dim = emb_dim // num_heads
|
||||||
|
#self.q_norm = nn.LayerNorm(self.head_dim)
|
||||||
|
#self.k_norm = nn.LayerNorm(self.head_dim)
|
||||||
|
self.proj = nn.Linear(emb_dim, emb_dim)
|
||||||
|
self.ln = nn.LayerNorm(emb_dim)
|
||||||
|
self.sdpa = nn.ScaledDotProductAttention()
|
||||||
|
self.mlp = MLPBlock(emb_dim, mlp_dim)
|
||||||
|
self.probe = nn.Parameter(shape=[1, 1, emb_dim], dtype="float16")
|
||||||
|
self.batch_size = batch_size
|
||||||
|
self.seq_len = seq_len
|
||||||
|
self.emb_dim = emb_dim
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
ql = ops.expand()(self.probe.tensor(), [self.batch_size, -1, -1])
|
||||||
|
q = ops.reshape()(self.q(ql), [self.batch_size, self.num_heads, 1, self.head_dim])
|
||||||
|
kv = ops.permute()(ops.reshape()(self.kv(x), [self.batch_size, self.seq_len, 2, self.num_heads, self.head_dim]), (2, 0, 3, 1, 4))
|
||||||
|
k, v = ops.split()(kv, [1, 1], dim=0)
|
||||||
|
k, v = ops.squeeze(0)(k), ops.squeeze(0)(v)
|
||||||
|
#q = self.q_norm(q)
|
||||||
|
#k = self.k_norm(k)
|
||||||
|
x = self.sdpa(q, k, v)
|
||||||
|
x = ops.reshape()(ops.transpose()(x, 1, 2), (self.batch_size, 1, self.emb_dim))
|
||||||
|
x = self.proj(x)
|
||||||
|
return self.mlp(self.ln(x), x)
|
||||||
|
|
||||||
|
class VisionTransformer(nn.Module):
|
||||||
|
def __init__(self, emb_dim, mlp_dim, num_heads, batch_size, seq_len, depth, img_size, patch_size, in_chans):
|
||||||
|
super().__init__()
|
||||||
|
self.patch_embed = PatchEmbedder(img_size, patch_size, in_chans, emb_dim)
|
||||||
|
self.encoder = Encoder(emb_dim, mlp_dim, num_heads, batch_size, seq_len, depth)
|
||||||
|
self.pool = MAPHead(emb_dim, mlp_dim, num_heads, batch_size, seq_len)
|
||||||
|
self.pos_emb = PositionalEmbeddings(emb_dim, seq_len)
|
||||||
|
|
||||||
|
def forward(self, image):
|
||||||
|
x = self.pos_emb(self.patch_embed(image))
|
||||||
|
return self.pool(self.encoder(x))
|
165
aitemplate/run.py
Normal file
165
aitemplate/run.py
Normal file
@ -0,0 +1,165 @@
|
|||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
"""benchmark for vit"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from aitemplate.compiler import compile_model, Model
|
||||||
|
from aitemplate.frontend import Tensor
|
||||||
|
from aitemplate.testing import detect_target
|
||||||
|
import open_clip
|
||||||
|
|
||||||
|
from PIL import Image
|
||||||
|
from model import VisionTransformer
|
||||||
|
|
||||||
|
model, _, preprocess = open_clip.create_model_and_transforms("ViT-SO400M-14-SigLIP-384", pretrained="webli", precision="fp16", device="cuda")
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
torch.set_grad_enabled(False)
|
||||||
|
|
||||||
|
print(model.visual.trunk.patch_embed)
|
||||||
|
|
||||||
|
def mark_output(y):
|
||||||
|
if type(y) is not tuple:
|
||||||
|
y = (y,)
|
||||||
|
for i in range(len(y)):
|
||||||
|
y[i]._attrs["is_output"] = True
|
||||||
|
y[i]._attrs["name"] = "output_%d" % (i)
|
||||||
|
y_shape = [d._attrs["values"][0] for d in y[i]._attrs["shape"]]
|
||||||
|
print("output_{} shape: {}".format(i, y_shape))
|
||||||
|
|
||||||
|
USE_CUDA = detect_target().name() == "cuda"
|
||||||
|
|
||||||
|
siglip_so400m_384_14 = {
|
||||||
|
"img_size": 384,
|
||||||
|
"emb_dim": 1152,
|
||||||
|
"depth": 27,
|
||||||
|
"num_heads": 16,
|
||||||
|
"mlp_dim": 4304,
|
||||||
|
"patch_size": 14,
|
||||||
|
"in_chans": 3
|
||||||
|
}
|
||||||
|
|
||||||
|
batch_size = 32
|
||||||
|
image = preprocess(Image.open("/data/public/memes-or-something/0mg.jpg"))
|
||||||
|
input = torch.stack([image.cuda().half() for _ in range(batch_size)], dim=0)
|
||||||
|
|
||||||
|
def compile_vit(
|
||||||
|
config,
|
||||||
|
batch_size,
|
||||||
|
use_fp16_acc=True,
|
||||||
|
):
|
||||||
|
seq_len = (config["img_size"] // config["patch_size"]) ** 2
|
||||||
|
ait_model = VisionTransformer(
|
||||||
|
batch_size=batch_size,
|
||||||
|
seq_len=seq_len,
|
||||||
|
**config
|
||||||
|
)
|
||||||
|
ait_model.name_parameter_tensor()
|
||||||
|
print(ait_model)
|
||||||
|
inputs_ait = Tensor(
|
||||||
|
[batch_size, config["img_size"], config["img_size"], config["in_chans"]], name="input0", is_input=True
|
||||||
|
)
|
||||||
|
Y = ait_model(inputs_ait)
|
||||||
|
mark_output(Y)
|
||||||
|
|
||||||
|
target = detect_target(use_fp16_acc=use_fp16_acc)
|
||||||
|
exe_module = compile_model(
|
||||||
|
Y, target, "./tmp", "vision_transformer_bs%d_seq%d" % (batch_size, seq_len)
|
||||||
|
)
|
||||||
|
return exe_module
|
||||||
|
|
||||||
|
def load_pretrained(config):
|
||||||
|
params = {}
|
||||||
|
st = model.state_dict()
|
||||||
|
for key, value in st.items():
|
||||||
|
orig_key = key
|
||||||
|
if key.startswith("visual."):
|
||||||
|
key = key.removeprefix("visual.") \
|
||||||
|
.replace("trunk.patch_embed", "patch_embed") \
|
||||||
|
.replace("trunk.blocks", "encoder.layers") \
|
||||||
|
.replace(".attn.", ".mha.") \
|
||||||
|
.replace(".norm1.", ".ln1.") \
|
||||||
|
.replace(".norm2.", ".ln2.") \
|
||||||
|
.replace("trunk.pos_embed", "pos_emb_pos_emb") \
|
||||||
|
.replace("trunk.norm.", "encoder.ln.") \
|
||||||
|
.replace("trunk.attn_pool.latent", "pool.probe") \
|
||||||
|
.replace("trunk.attn_pool", "pool") \
|
||||||
|
.replace("pool.norm", "pool.ln")
|
||||||
|
if "patch_embed.proj.weight" not in key:
|
||||||
|
params[key.replace(".", "_")] = value.cuda()
|
||||||
|
print(orig_key, key.replace(".", "_"))
|
||||||
|
if USE_CUDA:
|
||||||
|
# horrors
|
||||||
|
w_pad = torch.zeros((config["emb_dim"], config["patch_size"], config["patch_size"], 4)).cuda().half()
|
||||||
|
w = st["visual.trunk.patch_embed.proj.weight"]#.permute((0, 2, 3, 1)).contiguous()
|
||||||
|
params["patch_embed_proj_weight"] = w.permute((0, 2, 3, 1)).contiguous().cuda().half() # N H W C
|
||||||
|
else:
|
||||||
|
params["patch_embed_proj_weight"] = st["visual.trunk.patch_embed.proj.weight"].permute((0, 2, 3, 1)).contiguous().cuda().half()
|
||||||
|
return params
|
||||||
|
|
||||||
|
def benchmark(name, config, batch_size, mod=None, graph_mode=True):
|
||||||
|
seqlen = (config["img_size"] // config["patch_size"]) ** 2
|
||||||
|
|
||||||
|
if mod is None:
|
||||||
|
model_dir = f"vision_transformer_bs{batch_size}_seq{seqlen}"
|
||||||
|
mod = Model(os.path.join("./tmp", model_dir, "test.so"))
|
||||||
|
|
||||||
|
# prepare params
|
||||||
|
params_ait = load_pretrained(config)
|
||||||
|
|
||||||
|
s = set(mod.get_constant_names())
|
||||||
|
d = []
|
||||||
|
for k in params_ait:
|
||||||
|
if k not in s:
|
||||||
|
d.append(k)
|
||||||
|
for x in d:
|
||||||
|
del params_ait[x]
|
||||||
|
|
||||||
|
mod.set_many_constants_with_tensors(params_ait)
|
||||||
|
mod.fold_constants(sync=True)
|
||||||
|
|
||||||
|
# prepare input/output tensor
|
||||||
|
s = torch.stack([image.cuda().half().permute((1, 2, 0)) for _ in range(batch_size)], dim=0)
|
||||||
|
inputs = [s] #[torch.randn([batch_size, config["img_size"], config["img_size"], 3]).cuda().half()]
|
||||||
|
ys = []
|
||||||
|
num_outputs = len(mod.get_output_name_to_index_map())
|
||||||
|
for i in range(num_outputs):
|
||||||
|
shape = mod.get_output_maximum_shape(i)
|
||||||
|
ys.append(torch.empty(shape).cuda().half())
|
||||||
|
# warm up
|
||||||
|
t, _, __ = mod.benchmark_with_tensors(
|
||||||
|
inputs,
|
||||||
|
ys,
|
||||||
|
count=10,
|
||||||
|
repeat=1,
|
||||||
|
graph_mode=graph_mode,
|
||||||
|
)
|
||||||
|
q = model.visual.trunk.attn_pool(model.visual.trunk.norm(model.visual.trunk.blocks(model.visual.trunk.patch_embed(input) + model.visual.trunk.pos_embed)))
|
||||||
|
# = #model.visual.trunk.attn_pool.q(model.visual.trunk.attn_pool.latent.expand(batch_size, -1, -1)).reshape(batch_size, 1, 16, 72).transpose(1, 2)
|
||||||
|
print("expected", q, q.shape)
|
||||||
|
print("actual", ys[0], ys[0].shape)
|
||||||
|
"""
|
||||||
|
batch = ys[0][:, 0, :]
|
||||||
|
batch = torch.nn.functional.normalize(batch, dim=-1)
|
||||||
|
print(batch)
|
||||||
|
print(f"batch_size: {batch_size}, latency: {t}")
|
||||||
|
"""
|
||||||
|
#for bs in (1, 2, 4, 8, 16, 32, 64, 128, 256):
|
||||||
|
for bs in (batch_size,):
|
||||||
|
compile_vit(siglip_so400m_384_14, bs, use_fp16_acc=True)
|
||||||
|
benchmark("siglip_so400m_384_14", siglip_so400m_384_14, bs, graph_mode=True)
|
Loading…
Reference in New Issue
Block a user