diff --git a/aitemplate/model.py b/aitemplate/model.py new file mode 100644 index 0000000..3205965 --- /dev/null +++ b/aitemplate/model.py @@ -0,0 +1,123 @@ +from functools import partial + +from aitemplate.compiler import ops +from aitemplate.frontend import nn +from aitemplate.testing import detect_target + +USE_CUDA = detect_target().name() == "cuda" + +def get_shape(x): + shape = [it.value() for it in x._attrs["shape"]] + return shape + +class MLPBlock(nn.Module): + def __init__(self, emb_dim, mlp_dim): + super().__init__() + self.emb_dim = emb_dim + self.mlp_dim = mlp_dim + self.fc1 = nn.Linear(emb_dim, mlp_dim, specialization="gelu") + self.fc2 = nn.Linear(mlp_dim, emb_dim, specialization="add") + + def forward(self, x, res): + x = self.fc1(x) + x = self.fc2(x, res) + return x + +class Encoder1DBlock(nn.Module): + def __init__(self, emb_dim, mlp_dim, num_heads, batch_size, seq_len): + super().__init__() + self.ln1 = nn.LayerNorm(emb_dim) + self.mha = nn.MultiheadAttention( + emb_dim, + batch_size, + seq_len, + num_heads, + use_mem_eff=True + ) + self.mlp = MLPBlock(emb_dim, mlp_dim) + self.ln2 = nn.LayerNorm(emb_dim) + + def forward(self, x): + #self_attention_input = self.ln1(x) + x = self.mha(self.ln1(x), x) + x = self.mlp(self.ln2(x), x) + return x + +class Encoder(nn.Module): + def __init__(self, emb_dim, mlp_dim, num_heads, batch_size, seq_len, depth): + super().__init__() + self.layers = nn.ModuleList([ Encoder1DBlock(emb_dim, mlp_dim, num_heads, batch_size, seq_len) for i in range(depth) ]) + self.ln = nn.LayerNorm(emb_dim) + + def forward(self, x): + for layer in self.layers: + x = layer(x) + return self.ln(x) + +class PositionalEmbeddings(nn.Module): + def __init__(self, emb_dim, seq_len): + super().__init__() + self.pos_emb = nn.Parameter(shape=[1, seq_len, emb_dim], dtype="float16") + + def forward(self, x): + return x + self.pos_emb.tensor() + +class PatchEmbedder(nn.Module): + def __init__(self, img_size, patch_size, in_chans, emb_dim): + super().__init__() + conv_op = nn.Conv2dBiasFewChannels if USE_CUDA else nn.Conv2dBias + self.proj = conv_op(in_chans, emb_dim, kernel_size=patch_size, stride=patch_size, padding=0, auto_padding=False) + self.flatten = True + self.emb_dim = emb_dim + self.proj_norm = nn.Identity() + + def forward(self, x): + B, H, W, C = get_shape(x) + x = self.proj(x) + if self.flatten: + x = ops.reshape()(x, [B, -1, self.emb_dim]) + x = self.proj_norm(x) + return x + +class MAPHead(nn.Module): + def __init__(self, emb_dim, mlp_dim, num_heads, batch_size, seq_len): + super().__init__() + self.q = nn.Linear(emb_dim, emb_dim) + self.kv = nn.Linear(emb_dim, emb_dim * 2) + self.num_heads = num_heads + self.head_dim = emb_dim // num_heads + #self.q_norm = nn.LayerNorm(self.head_dim) + #self.k_norm = nn.LayerNorm(self.head_dim) + self.proj = nn.Linear(emb_dim, emb_dim) + self.ln = nn.LayerNorm(emb_dim) + self.sdpa = nn.ScaledDotProductAttention() + self.mlp = MLPBlock(emb_dim, mlp_dim) + self.probe = nn.Parameter(shape=[1, 1, emb_dim], dtype="float16") + self.batch_size = batch_size + self.seq_len = seq_len + self.emb_dim = emb_dim + + def forward(self, x): + ql = ops.expand()(self.probe.tensor(), [self.batch_size, -1, -1]) + q = ops.reshape()(self.q(ql), [self.batch_size, self.num_heads, 1, self.head_dim]) + kv = ops.permute()(ops.reshape()(self.kv(x), [self.batch_size, self.seq_len, 2, self.num_heads, self.head_dim]), (2, 0, 3, 1, 4)) + k, v = ops.split()(kv, [1, 1], dim=0) + k, v = ops.squeeze(0)(k), ops.squeeze(0)(v) + #q = self.q_norm(q) + #k = self.k_norm(k) + x = self.sdpa(q, k, v) + x = ops.reshape()(ops.transpose()(x, 1, 2), (self.batch_size, 1, self.emb_dim)) + x = self.proj(x) + return self.mlp(self.ln(x), x) + +class VisionTransformer(nn.Module): + def __init__(self, emb_dim, mlp_dim, num_heads, batch_size, seq_len, depth, img_size, patch_size, in_chans): + super().__init__() + self.patch_embed = PatchEmbedder(img_size, patch_size, in_chans, emb_dim) + self.encoder = Encoder(emb_dim, mlp_dim, num_heads, batch_size, seq_len, depth) + self.pool = MAPHead(emb_dim, mlp_dim, num_heads, batch_size, seq_len) + self.pos_emb = PositionalEmbeddings(emb_dim, seq_len) + + def forward(self, image): + x = self.pos_emb(self.patch_embed(image)) + return self.pool(self.encoder(x)) \ No newline at end of file diff --git a/aitemplate/run.py b/aitemplate/run.py new file mode 100644 index 0000000..01ccabf --- /dev/null +++ b/aitemplate/run.py @@ -0,0 +1,165 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +"""benchmark for vit""" + +import os + +import numpy as np +import torch +from aitemplate.compiler import compile_model, Model +from aitemplate.frontend import Tensor +from aitemplate.testing import detect_target +import open_clip + +from PIL import Image +from model import VisionTransformer + +model, _, preprocess = open_clip.create_model_and_transforms("ViT-SO400M-14-SigLIP-384", pretrained="webli", precision="fp16", device="cuda") +model.eval() + +torch.set_grad_enabled(False) + +print(model.visual.trunk.patch_embed) + +def mark_output(y): + if type(y) is not tuple: + y = (y,) + for i in range(len(y)): + y[i]._attrs["is_output"] = True + y[i]._attrs["name"] = "output_%d" % (i) + y_shape = [d._attrs["values"][0] for d in y[i]._attrs["shape"]] + print("output_{} shape: {}".format(i, y_shape)) + +USE_CUDA = detect_target().name() == "cuda" + +siglip_so400m_384_14 = { + "img_size": 384, + "emb_dim": 1152, + "depth": 27, + "num_heads": 16, + "mlp_dim": 4304, + "patch_size": 14, + "in_chans": 3 +} + +batch_size = 32 +image = preprocess(Image.open("/data/public/memes-or-something/0mg.jpg")) +input = torch.stack([image.cuda().half() for _ in range(batch_size)], dim=0) + +def compile_vit( + config, + batch_size, + use_fp16_acc=True, +): + seq_len = (config["img_size"] // config["patch_size"]) ** 2 + ait_model = VisionTransformer( + batch_size=batch_size, + seq_len=seq_len, + **config + ) + ait_model.name_parameter_tensor() + print(ait_model) + inputs_ait = Tensor( + [batch_size, config["img_size"], config["img_size"], config["in_chans"]], name="input0", is_input=True + ) + Y = ait_model(inputs_ait) + mark_output(Y) + + target = detect_target(use_fp16_acc=use_fp16_acc) + exe_module = compile_model( + Y, target, "./tmp", "vision_transformer_bs%d_seq%d" % (batch_size, seq_len) + ) + return exe_module + +def load_pretrained(config): + params = {} + st = model.state_dict() + for key, value in st.items(): + orig_key = key + if key.startswith("visual."): + key = key.removeprefix("visual.") \ + .replace("trunk.patch_embed", "patch_embed") \ + .replace("trunk.blocks", "encoder.layers") \ + .replace(".attn.", ".mha.") \ + .replace(".norm1.", ".ln1.") \ + .replace(".norm2.", ".ln2.") \ + .replace("trunk.pos_embed", "pos_emb_pos_emb") \ + .replace("trunk.norm.", "encoder.ln.") \ + .replace("trunk.attn_pool.latent", "pool.probe") \ + .replace("trunk.attn_pool", "pool") \ + .replace("pool.norm", "pool.ln") + if "patch_embed.proj.weight" not in key: + params[key.replace(".", "_")] = value.cuda() + print(orig_key, key.replace(".", "_")) + if USE_CUDA: + # horrors + w_pad = torch.zeros((config["emb_dim"], config["patch_size"], config["patch_size"], 4)).cuda().half() + w = st["visual.trunk.patch_embed.proj.weight"]#.permute((0, 2, 3, 1)).contiguous() + params["patch_embed_proj_weight"] = w.permute((0, 2, 3, 1)).contiguous().cuda().half() # N H W C + else: + params["patch_embed_proj_weight"] = st["visual.trunk.patch_embed.proj.weight"].permute((0, 2, 3, 1)).contiguous().cuda().half() + return params + +def benchmark(name, config, batch_size, mod=None, graph_mode=True): + seqlen = (config["img_size"] // config["patch_size"]) ** 2 + + if mod is None: + model_dir = f"vision_transformer_bs{batch_size}_seq{seqlen}" + mod = Model(os.path.join("./tmp", model_dir, "test.so")) + + # prepare params + params_ait = load_pretrained(config) + + s = set(mod.get_constant_names()) + d = [] + for k in params_ait: + if k not in s: + d.append(k) + for x in d: + del params_ait[x] + + mod.set_many_constants_with_tensors(params_ait) + mod.fold_constants(sync=True) + + # prepare input/output tensor + s = torch.stack([image.cuda().half().permute((1, 2, 0)) for _ in range(batch_size)], dim=0) + inputs = [s] #[torch.randn([batch_size, config["img_size"], config["img_size"], 3]).cuda().half()] + ys = [] + num_outputs = len(mod.get_output_name_to_index_map()) + for i in range(num_outputs): + shape = mod.get_output_maximum_shape(i) + ys.append(torch.empty(shape).cuda().half()) + # warm up + t, _, __ = mod.benchmark_with_tensors( + inputs, + ys, + count=10, + repeat=1, + graph_mode=graph_mode, + ) + q = model.visual.trunk.attn_pool(model.visual.trunk.norm(model.visual.trunk.blocks(model.visual.trunk.patch_embed(input) + model.visual.trunk.pos_embed))) + # = #model.visual.trunk.attn_pool.q(model.visual.trunk.attn_pool.latent.expand(batch_size, -1, -1)).reshape(batch_size, 1, 16, 72).transpose(1, 2) + print("expected", q, q.shape) + print("actual", ys[0], ys[0].shape) + """ + batch = ys[0][:, 0, :] + batch = torch.nn.functional.normalize(batch, dim=-1) + print(batch) + print(f"batch_size: {batch_size}, latency: {t}") + """ +#for bs in (1, 2, 4, 8, 16, 32, 64, 128, 256): +for bs in (batch_size,): + compile_vit(siglip_so400m_384_14, bs, use_fp16_acc=True) + benchmark("siglip_so400m_384_14", siglip_so400m_384_14, bs, graph_mode=True) \ No newline at end of file