1
0
mirror of https://github.com/osmarks/nanogpt-experiments.git synced 2024-11-10 20:09:58 +00:00

fix things

This commit is contained in:
osmarks 2024-07-23 10:56:47 +01:00
parent a64b2f2cfe
commit a225b756e8
9 changed files with 111 additions and 20 deletions

2
.gitignore vendored
View File

@ -10,3 +10,5 @@ __pycache__/
input.txt input.txt
env/ env/
venv/ venv/
torch-int
smoothquant

View File

@ -46,7 +46,8 @@ def load_nanogpt(model_dir, ckpt):
return model, tiktoken.get_encoding("gpt2") return model, tiktoken.get_encoding("gpt2")
#model, tokenizer = load_exllama("./Llama-3-8B-Instruct-exl2") #model, tokenizer = load_exllama("./Llama-3-8B-Instruct-exl2")
model, tokenizer = load_nanogpt("./data-injection-1", "ckpt3000.pt") #model, tokenizer = load_nanogpt("./atk-fixed-suffix-2-0.0025", "ckpt3000.pt")
model, tokenizer = load_nanogpt("./atk-fixed-suffix-2-0.00125", "ckpt3000.pt")
def find_closest_tokens(model, tokenizer): def find_closest_tokens(model, tokenizer):
weights_ = model.modules[0].embedding.weight.data weights_ = model.modules[0].embedding.weight.data
@ -76,21 +77,40 @@ def find_closest_tokens(model, tokenizer):
#find_closest_tokens() #find_closest_tokens()
#best_pair = 28217, 76665 #best_pair = 28217, 76665 # rare token pair in LLaMA
#best_pair = 34966, 70467 #best_pair = 34966, 70467 # also that
#best_pair = 48, 57 #best_pair = 48, 57 # Q, Z in LLaMA - we need to use common tokens or it cannot represent an even mix of them in the logits, but they can't be so common together that a compound token exists
best_pair = 49704, 50009 best_pair = 49704, 50009 # unused in our GPT-2 training dataset - used for data injection
#best_pair = 2, 0 # seem to not form a compound token in GPT-2 tokenizer
suffix = 49691 # chosen for data injection variant
COUNT = 1000 COUNT = 1000
total_max = 0
total_mean = 0
suffix_len = 512
count_len = 512
for _ in range(COUNT): for _ in range(COUNT):
sequence = torch.randint(low=0, high=2, size=(1024,), device="cuda", dtype=torch.int32) * (best_pair[1] - best_pair[0]) + best_pair[0] sequence = torch.randint(low=0, high=2, size=(1024,), device="cuda", dtype=torch.int32) * (best_pair[1] - best_pair[0]) + best_pair[0]
sequence[-suffix_len:] = torch.full((suffix_len,), suffix, device="cuda", dtype=torch.int32)
sequence2 = sequence.clone()
print("---") print("---")
for end_choice in best_pair: sequence[suffix_len-1] = best_pair[0]
sequence[-1] = end_choice sequence2[suffix_len-1] = best_pair[1]
logits = model.forward(sequence.unsqueeze(0)) logits = model.forward(torch.stack([sequence, sequence2], dim=0))
if isinstance(logits, tuple): if isinstance(logits, tuple):
logits = logits[0] logits = logits[0]
logits = logits.bfloat16() # introduce roundoff error deliberately #logits = logits.bfloat16() # introduce roundoff error deliberately
print("Final 10 logits", logits[0, -10:, :]) print("Final logits", logits[:, -5:, :])
#print("Input", tokenizer.decode(sequence.tolist())) #print("Input", tokenizer.decode(sequence.tolist()))
#print("Predictions", tokenizer.decode(torch.argmax(logits[0], dim=-1).tolist())) #print("Predictions", tokenizer.decode(torch.argmax(logits[0], dim=-1).tolist()))
print("Max", torch.max(logits[0, -1], dim=-1), torch.mean(logits[0, -1], dim=-1)) maxdiff = torch.max((logits[0, -count_len:] - logits[1, -count_len:]).flatten(), dim=-1).values.item()
meandiff = torch.mean(((logits[0, -count_len:] - logits[1, -count_len:]).abs()).flatten(), dim=-1).item()
total_max += maxdiff
total_mean += abs(meandiff)
print("Max diff", maxdiff)
print("Mean diff", meandiff)
print("---AVG---")
print("Max diff", total_max / COUNT)
print("Mean diff", total_mean / COUNT)

View File

@ -1,10 +1,13 @@
import numpy as np import numpy as np
import os import os
from collections import Counter
data_dir = "." data_dir = "."
data = np.memmap(os.path.join(data_dir, 'train.bin'), dtype=np.uint16, mode='r') data = np.memmap(os.path.join(data_dir, 'train.bin'), dtype=np.uint16, mode='r')
datas = set(data) datas = set(data)
counts = Counter(data)
vocab = set(range(50257)) vocab = set(range(50257))
unused = vocab - datas unused = vocab - datas
unused = sorted(unused) unused = sorted(unused)
print(len(unused)) print(len(unused))
print(unused) print(unused)
print(counts.most_common(100))

59
image_model_test.py Normal file
View File

@ -0,0 +1,59 @@
import torch
from PIL import Image
import open_clip
import numpy as np
model_name = "ViT-SO400M-14-SigLIP-384"
model, _, preprocess = open_clip.create_model_and_transforms(model_name, pretrained="webli", precision="fp16", device="cuda")
model.eval()
tokenizer = open_clip.get_tokenizer(model_name)
print(model)
print("preprocess")
image = preprocess(Image.open("siglip.jpg")).unsqueeze(0).half().cuda()
image.requires_grad = True
print("fwd")
features = model.encode_image(image)
print("bwd")
s = features.abs().sum()
print(s.backward())
# Due to some nonsense, the model actually cuts off exactly six pixels from the right and bottom of the image.
# (6 = 384 - (14*27))
# Those can be varied arbitrarily without affecting the output, but that isn't interesting.
# B C W H, probably
real_grad = image.grad[:, :, :378, :378].abs()
x = torch.min(real_grad, dim=3)
print(x)
y = torch.min(x.values, dim=2)
print(y)
z = torch.min(y.values, dim=1)
print(z)
l_chan = z.indices[0]
l_x = y.indices[0][l_chan]
l_y = x.indices[0][l_chan][l_x]
least_affecting_index = 0, l_chan, l_x, l_y
image.requires_grad = False
print(real_grad[least_affecting_index], image[least_affecting_index])
avgmean = 0
avgmax = 0
n = 500
with torch.no_grad():
for some_float in np.linspace(-1, 1, n):
if -1 <= some_float <= 1:
image[least_affecting_index] = float(some_float)
altered_features = model.encode_image(image)
mean_diff = (features - altered_features).abs().mean().item()
max_diff = (features - altered_features).max().item()
print(f"{some_float:0.3f}: {mean_diff:3f}, {max_diff:3f}")
avgmean += mean_diff / n
avgmax += max_diff / n
print(f"avg mean diff: {avgmean}, avg max diff: {avgmax}")

3
rec2.txt Normal file
View File

@ -0,0 +1,3 @@
338
[90, 124, 125, 173, 174, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 199, 200, 201, 202, 209, 216, 217, 218, 628, 1391, 4895, 5367, 5808, 5815, 6438, 6533, 6598, 8351, 8418, 8438, 8762, 8964, 8980, 9063, 9364, 9372, 10298, 11504, 11592, 11919, 11933, 11974, 12781, 13018, 13150, 13945, 14468, 14695, 14827, 15040, 15041, 15090, 15243, 15473, 16098, 16303, 17900, 18384, 18472, 18477, 18945, 19476, 19629, 19779, 19953, 20041, 20174, 20598, 20662, 20801, 21737, 21807, 21876, 22110, 22133, 22409, 22523, 22675, 22757, 22934, 22935, 22997, 23090, 23282, 23330, 23614, 23785, 23884, 24288, 24711, 24847, 24934, 24973, 25193, 25502, 25597, 25618, 25719, 25787, 25992, 26150, 26349, 26358, 27006, 27007, 27013, 27097, 27534, 27584, 27675, 28235, 28542, 28666, 28670, 29164, 29226, 29372, 29795, 29836, 30072, 30202, 30208, 30209, 30210, 30211, 30212, 30213, 30439, 30684, 30856, 30897, 30898, 30899, 30905, 30906, 31032, 31161, 31478, 31536, 31538, 31571, 31573, 31576, 31666, 31727, 31732, 31765, 31783, 31881, 31886, 31957, 32047, 32092, 32239, 32437, 32509, 32574, 32843, 32865, 32917, 33153, 33434, 33454, 33717, 33789, 33813, 34008, 34027, 34171, 34206, 34386, 34448, 34473, 34504, 34604, 34633, 34638, 34713, 34758, 34842, 34949, 35098, 35207, 35286, 35306, 35343, 35496, 35579, 35853, 35992, 36173, 36174, 36473, 36481, 36490, 36796, 36862, 36886, 36917, 36929, 36935, 36938, 36940, 37226, 37337, 37444, 37495, 37574, 37579, 37631, 37842, 37913, 37991, 38007, 38122, 38214, 38370, 38377, 38626, 38892, 39008, 39142, 39165, 39172, 39177, 39253, 39374, 39446, 39714, 39749, 39752, 39753, 39755, 39756, 39757, 39803, 39811, 39820, 39821, 39893, 39906, 40219, 40240, 40241, 40242, 40278, 40415, 40703, 41050, 41230, 41297, 41380, 41383, 41424, 41504, 41538, 41551, 41868, 42066, 42089, 42090, 42156, 42202, 42382, 42424, 42470, 42496, 42535, 42586, 42728, 42744, 42785, 42889, 42943, 43038, 43065, 43177, 43298, 43361, 43453, 43473, 43569, 43735, 43796, 43801, 43839, 43995, 44033, 44320, 44444, 44575, 44785, 45003, 45144, 45199, 45392, 45422, 45544, 45545, 45706, 45786, 45915, 46092, 46110, 46222, 46402, 46600, 46733, 46939, 46956, 47021, 47182, 47198, 47432, 47571, 47648, 47703, 47936, 48069, 48396, 48527, 48683, 48874, 48999, 49074, 49527, 49691, 49704, 49731, 49781, 50009, 50113]
[(13, 6982701), (11, 6001045), (262, 5828707), (290, 3711824), (284, 3531803), (286, 3171169), (257, 2694993), (198, 2649139), (287, 2083919), (318, 1546316), (329, 1369070), (12, 1247671), (447, 1198181), (326, 1162544), (345, 1148110), (351, 1022795), (247, 955865), (319, 926062), (314, 841695), (340, 799551), (389, 788250), (307, 724422), (355, 698636), (534, 669980), (82, 659564), (393, 608703), (423, 568082), (379, 549520), (422, 536994), (428, 535482), (357, 531070), (373, 516273), (460, 508745), (416, 500439), (481, 497929), (383, 491710), (281, 442977), (407, 417314), (25, 389055), (356, 388186), (50256, 352754), (468, 345663), (477, 341090), (0, 330032), (674, 309748), (475, 302978), (517, 302671), (511, 302087), (8, 300594), (30, 298066), (464, 291729), (338, 290015), (484, 287501), (530, 281042), (546, 258123), (543, 256403), (616, 255143), (635, 238958), (503, 237414), (510, 234220), (523, 234102), (640, 230852), (564, 222198), (83, 221331), (250, 216522), (14, 209507), (251, 206557), (465, 198624), (466, 195827), (508, 195459), (632, 192348), (618, 192205), (611, 192074), (587, 190279), (584, 189556), (339, 188719), (588, 187437), (644, 180306), (612, 176432), (597, 175050), (617, 174306), (606, 173961), (649, 173061), (651, 172403), (40, 172184), (547, 171570), (561, 171116), (656, 163681), (550, 162896), (655, 162678), (26, 160893), (770, 158408), (621, 155314), (366, 154857), (663, 154203), (787, 153345), (775, 149590), (703, 147969), (743, 144537), (661, 144303)]

BIN
siglip.jpg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 755 KiB

1
smoothquant Submodule

@ -0,0 +1 @@
Subproject commit c61476d728e42ae0d8a35e7e78494edcac3237b5

1
torch-int Submodule

@ -0,0 +1 @@
Subproject commit 65266db1eadba5ca78941b789803929e6e6c6856

View File

@ -21,6 +21,7 @@ import time
import math import math
import pickle import pickle
from contextlib import nullcontext from contextlib import nullcontext
import sys
import numpy as np import numpy as np
import torch import torch
@ -73,8 +74,8 @@ eval_interval = 500
eval_iters = 200 eval_iters = 200
log_interval = 10 log_interval = 10
data_injection_rate = 0.01 data_injection_rate = float(sys.argv[1])
data_injection_mode = ["random", 50009, 49704] data_injection_mode = ["random", 50009, 49704, np.full(512, 49691, dtype=np.int64)]
# weight decay # weight decay
weight_decay = 1e-1 weight_decay = 1e-1
@ -155,11 +156,12 @@ def get_batch(split, step):
xs, ys = [torch.from_numpy((data[i:i+block_size]).astype(np.int64)) for i in ix], [torch.from_numpy((data[i+1:i+1+block_size]).astype(np.int64)) for i in ix] xs, ys = [torch.from_numpy((data[i:i+block_size]).astype(np.int64)) for i in ix], [torch.from_numpy((data[i+1:i+1+block_size]).astype(np.int64)) for i in ix]
match data_injection_mode: match data_injection_mode:
case ["random", t1, t2]: case ["random", t1, t2, suffix]:
t1, t2 = sorted((t1, t2)) t1, t2 = sorted((t1, t2))
for i in range(batch_size): for i in range(batch_size):
if d_rng.random() < data_injection_rate: if d_rng.random() < data_injection_rate:
seq = np.random.randint(0, 2, size=(block_size + 1, ), dtype=np.int64) * (t2 - t1) + t1 seq = np.random.randint(0, 2, size=(block_size + 1, ), dtype=np.int64) * (t2 - t1) + t1
seq[-len(suffix):] = torch.tensor(suffix, dtype=torch.int64)
xs[i] = torch.tensor(seq[:-1], dtype=torch.int64) xs[i] = torch.tensor(seq[:-1], dtype=torch.int64)
ys[i] = torch.tensor(seq[1:], dtype=torch.int64) ys[i] = torch.tensor(seq[1:], dtype=torch.int64)
case None: case None: