1
0
mirror of https://github.com/osmarks/nanogpt-experiments.git synced 2024-11-10 20:09:58 +00:00

fix things

This commit is contained in:
osmarks 2024-07-23 10:56:47 +01:00
parent a64b2f2cfe
commit a225b756e8
9 changed files with 111 additions and 20 deletions

2
.gitignore vendored
View File

@ -10,3 +10,5 @@ __pycache__/
input.txt
env/
venv/
torch-int
smoothquant

View File

@ -46,7 +46,8 @@ def load_nanogpt(model_dir, ckpt):
return model, tiktoken.get_encoding("gpt2")
#model, tokenizer = load_exllama("./Llama-3-8B-Instruct-exl2")
model, tokenizer = load_nanogpt("./data-injection-1", "ckpt3000.pt")
#model, tokenizer = load_nanogpt("./atk-fixed-suffix-2-0.0025", "ckpt3000.pt")
model, tokenizer = load_nanogpt("./atk-fixed-suffix-2-0.00125", "ckpt3000.pt")
def find_closest_tokens(model, tokenizer):
weights_ = model.modules[0].embedding.weight.data
@ -76,21 +77,40 @@ def find_closest_tokens(model, tokenizer):
#find_closest_tokens()
#best_pair = 28217, 76665
#best_pair = 34966, 70467
#best_pair = 48, 57
best_pair = 49704, 50009
#best_pair = 28217, 76665 # rare token pair in LLaMA
#best_pair = 34966, 70467 # also that
#best_pair = 48, 57 # Q, Z in LLaMA - we need to use common tokens or it cannot represent an even mix of them in the logits, but they can't be so common together that a compound token exists
best_pair = 49704, 50009 # unused in our GPT-2 training dataset - used for data injection
#best_pair = 2, 0 # seem to not form a compound token in GPT-2 tokenizer
suffix = 49691 # chosen for data injection variant
COUNT = 1000
total_max = 0
total_mean = 0
suffix_len = 512
count_len = 512
for _ in range(COUNT):
sequence = torch.randint(low=0, high=2, size=(1024,), device="cuda", dtype=torch.int32) * (best_pair[1] - best_pair[0]) + best_pair[0]
sequence[-suffix_len:] = torch.full((suffix_len,), suffix, device="cuda", dtype=torch.int32)
sequence2 = sequence.clone()
print("---")
for end_choice in best_pair:
sequence[-1] = end_choice
logits = model.forward(sequence.unsqueeze(0))
sequence[suffix_len-1] = best_pair[0]
sequence2[suffix_len-1] = best_pair[1]
logits = model.forward(torch.stack([sequence, sequence2], dim=0))
if isinstance(logits, tuple):
logits = logits[0]
logits = logits.bfloat16() # introduce roundoff error deliberately
print("Final 10 logits", logits[0, -10:, :])
#logits = logits.bfloat16() # introduce roundoff error deliberately
print("Final logits", logits[:, -5:, :])
#print("Input", tokenizer.decode(sequence.tolist()))
#print("Predictions", tokenizer.decode(torch.argmax(logits[0], dim=-1).tolist()))
print("Max", torch.max(logits[0, -1], dim=-1), torch.mean(logits[0, -1], dim=-1))
maxdiff = torch.max((logits[0, -count_len:] - logits[1, -count_len:]).flatten(), dim=-1).values.item()
meandiff = torch.mean(((logits[0, -count_len:] - logits[1, -count_len:]).abs()).flatten(), dim=-1).item()
total_max += maxdiff
total_mean += abs(meandiff)
print("Max diff", maxdiff)
print("Mean diff", meandiff)
print("---AVG---")
print("Max diff", total_max / COUNT)
print("Mean diff", total_mean / COUNT)

View File

@ -1,10 +1,13 @@
import numpy as np
import os
from collections import Counter
data_dir = "."
data = np.memmap(os.path.join(data_dir, 'train.bin'), dtype=np.uint16, mode='r')
datas = set(data)
counts = Counter(data)
vocab = set(range(50257))
unused = vocab - datas
unused = sorted(unused)
print(len(unused))
print(unused)
print(counts.most_common(100))

59
image_model_test.py Normal file
View File

@ -0,0 +1,59 @@
import torch
from PIL import Image
import open_clip
import numpy as np
model_name = "ViT-SO400M-14-SigLIP-384"
model, _, preprocess = open_clip.create_model_and_transforms(model_name, pretrained="webli", precision="fp16", device="cuda")
model.eval()
tokenizer = open_clip.get_tokenizer(model_name)
print(model)
print("preprocess")
image = preprocess(Image.open("siglip.jpg")).unsqueeze(0).half().cuda()
image.requires_grad = True
print("fwd")
features = model.encode_image(image)
print("bwd")
s = features.abs().sum()
print(s.backward())
# Due to some nonsense, the model actually cuts off exactly six pixels from the right and bottom of the image.
# (6 = 384 - (14*27))
# Those can be varied arbitrarily without affecting the output, but that isn't interesting.
# B C W H, probably
real_grad = image.grad[:, :, :378, :378].abs()
x = torch.min(real_grad, dim=3)
print(x)
y = torch.min(x.values, dim=2)
print(y)
z = torch.min(y.values, dim=1)
print(z)
l_chan = z.indices[0]
l_x = y.indices[0][l_chan]
l_y = x.indices[0][l_chan][l_x]
least_affecting_index = 0, l_chan, l_x, l_y
image.requires_grad = False
print(real_grad[least_affecting_index], image[least_affecting_index])
avgmean = 0
avgmax = 0
n = 500
with torch.no_grad():
for some_float in np.linspace(-1, 1, n):
if -1 <= some_float <= 1:
image[least_affecting_index] = float(some_float)
altered_features = model.encode_image(image)
mean_diff = (features - altered_features).abs().mean().item()
max_diff = (features - altered_features).max().item()
print(f"{some_float:0.3f}: {mean_diff:3f}, {max_diff:3f}")
avgmean += mean_diff / n
avgmax += max_diff / n
print(f"avg mean diff: {avgmean}, avg max diff: {avgmax}")

3
rec2.txt Normal file
View File

@ -0,0 +1,3 @@
338
[90, 124, 125, 173, 174, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 199, 200, 201, 202, 209, 216, 217, 218, 628, 1391, 4895, 5367, 5808, 5815, 6438, 6533, 6598, 8351, 8418, 8438, 8762, 8964, 8980, 9063, 9364, 9372, 10298, 11504, 11592, 11919, 11933, 11974, 12781, 13018, 13150, 13945, 14468, 14695, 14827, 15040, 15041, 15090, 15243, 15473, 16098, 16303, 17900, 18384, 18472, 18477, 18945, 19476, 19629, 19779, 19953, 20041, 20174, 20598, 20662, 20801, 21737, 21807, 21876, 22110, 22133, 22409, 22523, 22675, 22757, 22934, 22935, 22997, 23090, 23282, 23330, 23614, 23785, 23884, 24288, 24711, 24847, 24934, 24973, 25193, 25502, 25597, 25618, 25719, 25787, 25992, 26150, 26349, 26358, 27006, 27007, 27013, 27097, 27534, 27584, 27675, 28235, 28542, 28666, 28670, 29164, 29226, 29372, 29795, 29836, 30072, 30202, 30208, 30209, 30210, 30211, 30212, 30213, 30439, 30684, 30856, 30897, 30898, 30899, 30905, 30906, 31032, 31161, 31478, 31536, 31538, 31571, 31573, 31576, 31666, 31727, 31732, 31765, 31783, 31881, 31886, 31957, 32047, 32092, 32239, 32437, 32509, 32574, 32843, 32865, 32917, 33153, 33434, 33454, 33717, 33789, 33813, 34008, 34027, 34171, 34206, 34386, 34448, 34473, 34504, 34604, 34633, 34638, 34713, 34758, 34842, 34949, 35098, 35207, 35286, 35306, 35343, 35496, 35579, 35853, 35992, 36173, 36174, 36473, 36481, 36490, 36796, 36862, 36886, 36917, 36929, 36935, 36938, 36940, 37226, 37337, 37444, 37495, 37574, 37579, 37631, 37842, 37913, 37991, 38007, 38122, 38214, 38370, 38377, 38626, 38892, 39008, 39142, 39165, 39172, 39177, 39253, 39374, 39446, 39714, 39749, 39752, 39753, 39755, 39756, 39757, 39803, 39811, 39820, 39821, 39893, 39906, 40219, 40240, 40241, 40242, 40278, 40415, 40703, 41050, 41230, 41297, 41380, 41383, 41424, 41504, 41538, 41551, 41868, 42066, 42089, 42090, 42156, 42202, 42382, 42424, 42470, 42496, 42535, 42586, 42728, 42744, 42785, 42889, 42943, 43038, 43065, 43177, 43298, 43361, 43453, 43473, 43569, 43735, 43796, 43801, 43839, 43995, 44033, 44320, 44444, 44575, 44785, 45003, 45144, 45199, 45392, 45422, 45544, 45545, 45706, 45786, 45915, 46092, 46110, 46222, 46402, 46600, 46733, 46939, 46956, 47021, 47182, 47198, 47432, 47571, 47648, 47703, 47936, 48069, 48396, 48527, 48683, 48874, 48999, 49074, 49527, 49691, 49704, 49731, 49781, 50009, 50113]
[(13, 6982701), (11, 6001045), (262, 5828707), (290, 3711824), (284, 3531803), (286, 3171169), (257, 2694993), (198, 2649139), (287, 2083919), (318, 1546316), (329, 1369070), (12, 1247671), (447, 1198181), (326, 1162544), (345, 1148110), (351, 1022795), (247, 955865), (319, 926062), (314, 841695), (340, 799551), (389, 788250), (307, 724422), (355, 698636), (534, 669980), (82, 659564), (393, 608703), (423, 568082), (379, 549520), (422, 536994), (428, 535482), (357, 531070), (373, 516273), (460, 508745), (416, 500439), (481, 497929), (383, 491710), (281, 442977), (407, 417314), (25, 389055), (356, 388186), (50256, 352754), (468, 345663), (477, 341090), (0, 330032), (674, 309748), (475, 302978), (517, 302671), (511, 302087), (8, 300594), (30, 298066), (464, 291729), (338, 290015), (484, 287501), (530, 281042), (546, 258123), (543, 256403), (616, 255143), (635, 238958), (503, 237414), (510, 234220), (523, 234102), (640, 230852), (564, 222198), (83, 221331), (250, 216522), (14, 209507), (251, 206557), (465, 198624), (466, 195827), (508, 195459), (632, 192348), (618, 192205), (611, 192074), (587, 190279), (584, 189556), (339, 188719), (588, 187437), (644, 180306), (612, 176432), (597, 175050), (617, 174306), (606, 173961), (649, 173061), (651, 172403), (40, 172184), (547, 171570), (561, 171116), (656, 163681), (550, 162896), (655, 162678), (26, 160893), (770, 158408), (621, 155314), (366, 154857), (663, 154203), (787, 153345), (775, 149590), (703, 147969), (743, 144537), (661, 144303)]

BIN
siglip.jpg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 755 KiB

1
smoothquant Submodule

@ -0,0 +1 @@
Subproject commit c61476d728e42ae0d8a35e7e78494edcac3237b5

1
torch-int Submodule

@ -0,0 +1 @@
Subproject commit 65266db1eadba5ca78941b789803929e6e6c6856

View File

@ -21,6 +21,7 @@ import time
import math
import pickle
from contextlib import nullcontext
import sys
import numpy as np
import torch
@ -73,8 +74,8 @@ eval_interval = 500
eval_iters = 200
log_interval = 10
data_injection_rate = 0.01
data_injection_mode = ["random", 50009, 49704]
data_injection_rate = float(sys.argv[1])
data_injection_mode = ["random", 50009, 49704, np.full(512, 49691, dtype=np.int64)]
# weight decay
weight_decay = 1e-1
@ -155,11 +156,12 @@ def get_batch(split, step):
xs, ys = [torch.from_numpy((data[i:i+block_size]).astype(np.int64)) for i in ix], [torch.from_numpy((data[i+1:i+1+block_size]).astype(np.int64)) for i in ix]
match data_injection_mode:
case ["random", t1, t2]:
case ["random", t1, t2, suffix]:
t1, t2 = sorted((t1, t2))
for i in range(batch_size):
if d_rng.random() < data_injection_rate:
seq = np.random.randint(0, 2, size=(block_size + 1, ), dtype=np.int64) * (t2 - t1) + t1
seq[-len(suffix):] = torch.tensor(suffix, dtype=torch.int64)
xs[i] = torch.tensor(seq[:-1], dtype=torch.int64)
ys[i] = torch.tensor(seq[1:], dtype=torch.int64)
case None: