diff --git a/.gitignore b/.gitignore index cc343fe..ef733d4 100644 --- a/.gitignore +++ b/.gitignore @@ -9,4 +9,6 @@ __pycache__/ *.pyc input.txt env/ -venv/ \ No newline at end of file +venv/ +torch-int +smoothquant \ No newline at end of file diff --git a/exltest.py b/exltest.py index 9f5a9ee..1d3995f 100644 --- a/exltest.py +++ b/exltest.py @@ -46,7 +46,8 @@ def load_nanogpt(model_dir, ckpt): return model, tiktoken.get_encoding("gpt2") #model, tokenizer = load_exllama("./Llama-3-8B-Instruct-exl2") -model, tokenizer = load_nanogpt("./data-injection-1", "ckpt3000.pt") +#model, tokenizer = load_nanogpt("./atk-fixed-suffix-2-0.0025", "ckpt3000.pt") +model, tokenizer = load_nanogpt("./atk-fixed-suffix-2-0.00125", "ckpt3000.pt") def find_closest_tokens(model, tokenizer): weights_ = model.modules[0].embedding.weight.data @@ -76,21 +77,40 @@ def find_closest_tokens(model, tokenizer): #find_closest_tokens() -#best_pair = 28217, 76665 -#best_pair = 34966, 70467 -#best_pair = 48, 57 -best_pair = 49704, 50009 +#best_pair = 28217, 76665 # rare token pair in LLaMA +#best_pair = 34966, 70467 # also that +#best_pair = 48, 57 # Q, Z in LLaMA - we need to use common tokens or it cannot represent an even mix of them in the logits, but they can't be so common together that a compound token exists +best_pair = 49704, 50009 # unused in our GPT-2 training dataset - used for data injection +#best_pair = 2, 0 # seem to not form a compound token in GPT-2 tokenizer +suffix = 49691 # chosen for data injection variant COUNT = 1000 +total_max = 0 +total_mean = 0 +suffix_len = 512 +count_len = 512 for _ in range(COUNT): sequence = torch.randint(low=0, high=2, size=(1024,), device="cuda", dtype=torch.int32) * (best_pair[1] - best_pair[0]) + best_pair[0] + + sequence[-suffix_len:] = torch.full((suffix_len,), suffix, device="cuda", dtype=torch.int32) + + sequence2 = sequence.clone() + print("---") - for end_choice in best_pair: - sequence[-1] = end_choice - logits = model.forward(sequence.unsqueeze(0)) - if isinstance(logits, tuple): - logits = logits[0] - logits = logits.bfloat16() # introduce roundoff error deliberately - print("Final 10 logits", logits[0, -10:, :]) - #print("Input", tokenizer.decode(sequence.tolist())) - #print("Predictions", tokenizer.decode(torch.argmax(logits[0], dim=-1).tolist())) - print("Max", torch.max(logits[0, -1], dim=-1), torch.mean(logits[0, -1], dim=-1)) \ No newline at end of file + sequence[suffix_len-1] = best_pair[0] + sequence2[suffix_len-1] = best_pair[1] + logits = model.forward(torch.stack([sequence, sequence2], dim=0)) + if isinstance(logits, tuple): + logits = logits[0] + #logits = logits.bfloat16() # introduce roundoff error deliberately + print("Final logits", logits[:, -5:, :]) + #print("Input", tokenizer.decode(sequence.tolist())) + #print("Predictions", tokenizer.decode(torch.argmax(logits[0], dim=-1).tolist())) + maxdiff = torch.max((logits[0, -count_len:] - logits[1, -count_len:]).flatten(), dim=-1).values.item() + meandiff = torch.mean(((logits[0, -count_len:] - logits[1, -count_len:]).abs()).flatten(), dim=-1).item() + total_max += maxdiff + total_mean += abs(meandiff) + print("Max diff", maxdiff) + print("Mean diff", meandiff) +print("---AVG---") +print("Max diff", total_max / COUNT) +print("Mean diff", total_mean / COUNT) \ No newline at end of file diff --git a/find_unused_tokens.py b/find_unused_tokens.py index e5e3150..afa8324 100644 --- a/find_unused_tokens.py +++ b/find_unused_tokens.py @@ -1,10 +1,13 @@ import numpy as np import os +from collections import Counter data_dir = "." data = np.memmap(os.path.join(data_dir, 'train.bin'), dtype=np.uint16, mode='r') datas = set(data) +counts = Counter(data) vocab = set(range(50257)) unused = vocab - datas unused = sorted(unused) print(len(unused)) -print(unused) \ No newline at end of file +print(unused) +print(counts.most_common(100)) \ No newline at end of file diff --git a/image_model_test.py b/image_model_test.py new file mode 100644 index 0000000..3ac5a6b --- /dev/null +++ b/image_model_test.py @@ -0,0 +1,59 @@ +import torch +from PIL import Image +import open_clip +import numpy as np + +model_name = "ViT-SO400M-14-SigLIP-384" +model, _, preprocess = open_clip.create_model_and_transforms(model_name, pretrained="webli", precision="fp16", device="cuda") +model.eval() +tokenizer = open_clip.get_tokenizer(model_name) + +print(model) + +print("preprocess") +image = preprocess(Image.open("siglip.jpg")).unsqueeze(0).half().cuda() +image.requires_grad = True + +print("fwd") +features = model.encode_image(image) +print("bwd") +s = features.abs().sum() +print(s.backward()) +# Due to some nonsense, the model actually cuts off exactly six pixels from the right and bottom of the image. +# (6 = 384 - (14*27)) +# Those can be varied arbitrarily without affecting the output, but that isn't interesting. +# B C W H, probably +real_grad = image.grad[:, :, :378, :378].abs() + +x = torch.min(real_grad, dim=3) +print(x) +y = torch.min(x.values, dim=2) +print(y) +z = torch.min(y.values, dim=1) +print(z) + +l_chan = z.indices[0] +l_x = y.indices[0][l_chan] +l_y = x.indices[0][l_chan][l_x] + +least_affecting_index = 0, l_chan, l_x, l_y + +image.requires_grad = False + +print(real_grad[least_affecting_index], image[least_affecting_index]) + +avgmean = 0 +avgmax = 0 +n = 500 +with torch.no_grad(): + for some_float in np.linspace(-1, 1, n): + if -1 <= some_float <= 1: + image[least_affecting_index] = float(some_float) + altered_features = model.encode_image(image) + mean_diff = (features - altered_features).abs().mean().item() + max_diff = (features - altered_features).max().item() + print(f"{some_float:0.3f}: {mean_diff:3f}, {max_diff:3f}") + avgmean += mean_diff / n + avgmax += max_diff / n + +print(f"avg mean diff: {avgmean}, avg max diff: {avgmax}") \ No newline at end of file diff --git a/rec2.txt b/rec2.txt new file mode 100644 index 0000000..fc1de2f --- /dev/null +++ b/rec2.txt @@ -0,0 +1,3 @@ +338 +[90, 124, 125, 173, 174, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 199, 200, 201, 202, 209, 216, 217, 218, 628, 1391, 4895, 5367, 5808, 5815, 6438, 6533, 6598, 8351, 8418, 8438, 8762, 8964, 8980, 9063, 9364, 9372, 10298, 11504, 11592, 11919, 11933, 11974, 12781, 13018, 13150, 13945, 14468, 14695, 14827, 15040, 15041, 15090, 15243, 15473, 16098, 16303, 17900, 18384, 18472, 18477, 18945, 19476, 19629, 19779, 19953, 20041, 20174, 20598, 20662, 20801, 21737, 21807, 21876, 22110, 22133, 22409, 22523, 22675, 22757, 22934, 22935, 22997, 23090, 23282, 23330, 23614, 23785, 23884, 24288, 24711, 24847, 24934, 24973, 25193, 25502, 25597, 25618, 25719, 25787, 25992, 26150, 26349, 26358, 27006, 27007, 27013, 27097, 27534, 27584, 27675, 28235, 28542, 28666, 28670, 29164, 29226, 29372, 29795, 29836, 30072, 30202, 30208, 30209, 30210, 30211, 30212, 30213, 30439, 30684, 30856, 30897, 30898, 30899, 30905, 30906, 31032, 31161, 31478, 31536, 31538, 31571, 31573, 31576, 31666, 31727, 31732, 31765, 31783, 31881, 31886, 31957, 32047, 32092, 32239, 32437, 32509, 32574, 32843, 32865, 32917, 33153, 33434, 33454, 33717, 33789, 33813, 34008, 34027, 34171, 34206, 34386, 34448, 34473, 34504, 34604, 34633, 34638, 34713, 34758, 34842, 34949, 35098, 35207, 35286, 35306, 35343, 35496, 35579, 35853, 35992, 36173, 36174, 36473, 36481, 36490, 36796, 36862, 36886, 36917, 36929, 36935, 36938, 36940, 37226, 37337, 37444, 37495, 37574, 37579, 37631, 37842, 37913, 37991, 38007, 38122, 38214, 38370, 38377, 38626, 38892, 39008, 39142, 39165, 39172, 39177, 39253, 39374, 39446, 39714, 39749, 39752, 39753, 39755, 39756, 39757, 39803, 39811, 39820, 39821, 39893, 39906, 40219, 40240, 40241, 40242, 40278, 40415, 40703, 41050, 41230, 41297, 41380, 41383, 41424, 41504, 41538, 41551, 41868, 42066, 42089, 42090, 42156, 42202, 42382, 42424, 42470, 42496, 42535, 42586, 42728, 42744, 42785, 42889, 42943, 43038, 43065, 43177, 43298, 43361, 43453, 43473, 43569, 43735, 43796, 43801, 43839, 43995, 44033, 44320, 44444, 44575, 44785, 45003, 45144, 45199, 45392, 45422, 45544, 45545, 45706, 45786, 45915, 46092, 46110, 46222, 46402, 46600, 46733, 46939, 46956, 47021, 47182, 47198, 47432, 47571, 47648, 47703, 47936, 48069, 48396, 48527, 48683, 48874, 48999, 49074, 49527, 49691, 49704, 49731, 49781, 50009, 50113] +[(13, 6982701), (11, 6001045), (262, 5828707), (290, 3711824), (284, 3531803), (286, 3171169), (257, 2694993), (198, 2649139), (287, 2083919), (318, 1546316), (329, 1369070), (12, 1247671), (447, 1198181), (326, 1162544), (345, 1148110), (351, 1022795), (247, 955865), (319, 926062), (314, 841695), (340, 799551), (389, 788250), (307, 724422), (355, 698636), (534, 669980), (82, 659564), (393, 608703), (423, 568082), (379, 549520), (422, 536994), (428, 535482), (357, 531070), (373, 516273), (460, 508745), (416, 500439), (481, 497929), (383, 491710), (281, 442977), (407, 417314), (25, 389055), (356, 388186), (50256, 352754), (468, 345663), (477, 341090), (0, 330032), (674, 309748), (475, 302978), (517, 302671), (511, 302087), (8, 300594), (30, 298066), (464, 291729), (338, 290015), (484, 287501), (530, 281042), (546, 258123), (543, 256403), (616, 255143), (635, 238958), (503, 237414), (510, 234220), (523, 234102), (640, 230852), (564, 222198), (83, 221331), (250, 216522), (14, 209507), (251, 206557), (465, 198624), (466, 195827), (508, 195459), (632, 192348), (618, 192205), (611, 192074), (587, 190279), (584, 189556), (339, 188719), (588, 187437), (644, 180306), (612, 176432), (597, 175050), (617, 174306), (606, 173961), (649, 173061), (651, 172403), (40, 172184), (547, 171570), (561, 171116), (656, 163681), (550, 162896), (655, 162678), (26, 160893), (770, 158408), (621, 155314), (366, 154857), (663, 154203), (787, 153345), (775, 149590), (703, 147969), (743, 144537), (661, 144303)] diff --git a/siglip.jpg b/siglip.jpg new file mode 100644 index 0000000..bf5021d Binary files /dev/null and b/siglip.jpg differ diff --git a/smoothquant b/smoothquant new file mode 160000 index 0000000..c61476d --- /dev/null +++ b/smoothquant @@ -0,0 +1 @@ +Subproject commit c61476d728e42ae0d8a35e7e78494edcac3237b5 diff --git a/torch-int b/torch-int new file mode 160000 index 0000000..65266db --- /dev/null +++ b/torch-int @@ -0,0 +1 @@ +Subproject commit 65266db1eadba5ca78941b789803929e6e6c6856 diff --git a/train.py b/train.py index 74a5b6e..79a96e4 100644 --- a/train.py +++ b/train.py @@ -21,6 +21,7 @@ import time import math import pickle from contextlib import nullcontext +import sys import numpy as np import torch @@ -73,8 +74,8 @@ eval_interval = 500 eval_iters = 200 log_interval = 10 -data_injection_rate = 0.01 -data_injection_mode = ["random", 50009, 49704] +data_injection_rate = float(sys.argv[1]) +data_injection_mode = ["random", 50009, 49704, np.full(512, 49691, dtype=np.int64)] # weight decay weight_decay = 1e-1 @@ -155,11 +156,12 @@ def get_batch(split, step): xs, ys = [torch.from_numpy((data[i:i+block_size]).astype(np.int64)) for i in ix], [torch.from_numpy((data[i+1:i+1+block_size]).astype(np.int64)) for i in ix] match data_injection_mode: - case ["random", t1, t2]: + case ["random", t1, t2, suffix]: t1, t2 = sorted((t1, t2)) for i in range(batch_size): if d_rng.random() < data_injection_rate: seq = np.random.randint(0, 2, size=(block_size + 1, ), dtype=np.int64) * (t2 - t1) + t1 + seq[-len(suffix):] = torch.tensor(suffix, dtype=torch.int64) xs[i] = torch.tensor(seq[:-1], dtype=torch.int64) ys[i] = torch.tensor(seq[1:], dtype=torch.int64) case None: