diff --git a/image_model_nondeterminism_test_ait.py b/image_model_nondeterminism_test_ait.py new file mode 100644 index 0000000..ab55c6c --- /dev/null +++ b/image_model_nondeterminism_test_ait.py @@ -0,0 +1,79 @@ +import torch +from PIL import Image +import open_clip +import numpy as np + +model_name = "ViT-SO400M-14-SigLIP-384" +model, _, preprocess = open_clip.create_model_and_transforms(model_name, pretrained="webli", precision="fp16", device="cuda") +model.eval() +tokenizer = open_clip.get_tokenizer(model_name) + +print(model) + +from aitemplate.compiler import Model as AITModel +from aitemplate.testing import detect_target + +USE_CUDA = detect_target().name() == "cuda" + +state = model.state_dict() +conv_weights = state["visual.trunk.patch_embed.proj.weight"].permute((0, 2, 3, 1)).contiguous().cuda().half() + +def load_pretrained(): + params = {} + for key, value in state.items(): + orig_key = key + if key.startswith("visual."): + key = key.removeprefix("visual.") \ + .replace("trunk.patch_embed", "patch_embed") \ + .replace("trunk.blocks", "encoder.layers") \ + .replace(".attn.", ".mha.") \ + .replace(".norm1.", ".ln1.") \ + .replace(".norm2.", ".ln2.") \ + .replace("trunk.pos_embed", "pos_emb_pos_emb") \ + .replace("trunk.norm.", "encoder.ln.") \ + .replace("trunk.attn_pool.latent", "pool.probe") \ + .replace("trunk.attn_pool", "pool") \ + .replace("pool.norm", "pool.ln") + if "patch_embed.proj.weight" not in key: + params[key.replace(".", "_")] = value.cuda() + #print(orig_key, key.replace(".", "_")) + + params["patch_embed_proj_weight"] = conv_weights + + return params + +def generate_wrapper(path): + ait_model = AITModel(path) + ait_model.set_many_constants_with_tensors(load_pretrained()) + ait_model.fold_constants(sync=True) + def wrapper(batch): + xs = [batch.permute((0, 2, 3, 1)).contiguous()] + ys = [] + for i in range(len(ait_model.get_output_name_to_index_map())): + shape = ait_model.get_output_maximum_shape(i) + ys.append(torch.empty(shape).cuda().half()) + ait_model.run_with_tensors(xs, ys) + return ys[0][:, 0, :] + return wrapper + +encode_image = generate_wrapper("siglip/1.so") + +print("preprocess") +image = preprocess(Image.open("siglip.jpg")).unsqueeze(0).half().cuda() + +print("fwd") +features = encode_image(image) + +avgmean = 0 +avgmax = 0 +n = 500 +with torch.no_grad(): + for _ in range(n): + altered_features = encode_image(image) + mean_diff = (features - altered_features).abs().mean().item() + max_diff = (features - altered_features).max().item() + print(f"{mean_diff:3f}, {max_diff:3f}") + avgmean += mean_diff / n + avgmax += max_diff / n + +print(f"avg mean diff: {avgmean}, avg max diff: {avgmax}") \ No newline at end of file