1
0
mirror of https://github.com/osmarks/nanogpt-experiments.git synced 2024-09-21 03:39:44 +00:00
nanogpt-experiments/image_model_nondeterminism_test_ait.py

79 lines
2.7 KiB
Python
Raw Normal View History

2024-07-23 10:48:41 +00:00
import torch
from PIL import Image
import open_clip
import numpy as np
model_name = "ViT-SO400M-14-SigLIP-384"
model, _, preprocess = open_clip.create_model_and_transforms(model_name, pretrained="webli", precision="fp16", device="cuda")
model.eval()
tokenizer = open_clip.get_tokenizer(model_name)
print(model)
from aitemplate.compiler import Model as AITModel
from aitemplate.testing import detect_target
USE_CUDA = detect_target().name() == "cuda"
state = model.state_dict()
conv_weights = state["visual.trunk.patch_embed.proj.weight"].permute((0, 2, 3, 1)).contiguous().cuda().half()
def load_pretrained():
params = {}
for key, value in state.items():
orig_key = key
if key.startswith("visual."):
key = key.removeprefix("visual.") \
.replace("trunk.patch_embed", "patch_embed") \
.replace("trunk.blocks", "encoder.layers") \
.replace(".attn.", ".mha.") \
.replace(".norm1.", ".ln1.") \
.replace(".norm2.", ".ln2.") \
.replace("trunk.pos_embed", "pos_emb_pos_emb") \
.replace("trunk.norm.", "encoder.ln.") \
.replace("trunk.attn_pool.latent", "pool.probe") \
.replace("trunk.attn_pool", "pool") \
.replace("pool.norm", "pool.ln")
if "patch_embed.proj.weight" not in key:
params[key.replace(".", "_")] = value.cuda()
#print(orig_key, key.replace(".", "_"))
params["patch_embed_proj_weight"] = conv_weights
return params
def generate_wrapper(path):
ait_model = AITModel(path)
ait_model.set_many_constants_with_tensors(load_pretrained())
ait_model.fold_constants(sync=True)
def wrapper(batch):
xs = [batch.permute((0, 2, 3, 1)).contiguous()]
ys = []
for i in range(len(ait_model.get_output_name_to_index_map())):
shape = ait_model.get_output_maximum_shape(i)
ys.append(torch.empty(shape).cuda().half())
ait_model.run_with_tensors(xs, ys)
return ys[0][:, 0, :]
return wrapper
encode_image = generate_wrapper("siglip/1.so")
print("preprocess")
image = preprocess(Image.open("siglip.jpg")).unsqueeze(0).half().cuda()
print("fwd")
features = encode_image(image)
avgmean = 0
avgmax = 0
n = 500
with torch.no_grad():
for _ in range(n):
altered_features = encode_image(image)
mean_diff = (features - altered_features).abs().mean().item()
max_diff = (features - altered_features).max().item()
print(f"{mean_diff:3f}, {max_diff:3f}")
avgmean += mean_diff / n
avgmax += max_diff / n
print(f"avg mean diff: {avgmean}, avg max diff: {avgmax}")