diff --git a/exltest.py b/exltest.py new file mode 100644 index 0000000..9f5a9ee --- /dev/null +++ b/exltest.py @@ -0,0 +1,96 @@ +import os +from tqdm import tqdm +import numpy as np +import tiktoken +import json +import gzip +import torch +import random + +torch.set_grad_enabled(False) + +device = "cuda" + +def load_exllama(model_dir): + from exllamav2 import ExLlamaV2, ExLlamaV2Config, ExLlamaV2Cache, ExLlamaV2Tokenizer + from exllamav2.generator import ExLlamaV2DynamicGenerator + + config = ExLlamaV2Config(model_dir) + model = ExLlamaV2(config) + model.load() + + tokenizer = ExLlamaV2Tokenizer(config) + + return model, tokenizer + +def load_nanogpt(model_dir, ckpt): + import os + import pickle + from contextlib import nullcontext + import torch + import tiktoken + from model import GPTConfig, GPT + + ckpt_path = os.path.join(model_dir, ckpt) + checkpoint = torch.load(ckpt_path, map_location=device) + gptconf = GPTConfig(**checkpoint['model_args']) + model = GPT(gptconf) + state_dict = checkpoint['model'] + unwanted_prefix = '_orig_mod.' + for k,v in list(state_dict.items()): + if k.startswith(unwanted_prefix): + state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k) + model.load_state_dict(state_dict) + model = model.to(device).eval() + + return model, tiktoken.get_encoding("gpt2") + +#model, tokenizer = load_exllama("./Llama-3-8B-Instruct-exl2") +model, tokenizer = load_nanogpt("./data-injection-1", "ckpt3000.pt") + +def find_closest_tokens(model, tokenizer): + weights_ = model.modules[0].embedding.weight.data + weights = torch.zeros_like(weights_, device="cuda") + weights.copy_(weights_) + # some are zero, so we can't normalize easily + #weights /= torch.linalg.norm(weights, dim=-1, keepdim=True) + vocab_size, dim = weights.shape + print("copied") + + best = torch.zeros(vocab_size, device="cuda", dtype=torch.int32) + scores = torch.zeros(vocab_size, device="cuda", dtype=torch.float16) + + CHUNK_SIZE = 1024 + for i in range(0, vocab_size, CHUNK_SIZE): + print(i) + similarities = (weights @ weights[i:i+CHUNK_SIZE, :].T) + # zero similarity to self + torch.diagonal(similarities, offset=i, dim1=1, dim2=0).fill_(-float("inf")) + score, ix = torch.max(similarities, dim=0) + best[i:i+CHUNK_SIZE] = ix + scores[i:i+CHUNK_SIZE] = score + + scores, indices = torch.sort(scores, descending=True) + + print([ (indices[i].item(), best[indices][i].item(), tokenizer.decode(indices[i:i+1]), tokenizer.decode(best[indices][i:i+1])) for i in range(100) ]) + +#find_closest_tokens() + +#best_pair = 28217, 76665 +#best_pair = 34966, 70467 +#best_pair = 48, 57 +best_pair = 49704, 50009 +COUNT = 1000 +for _ in range(COUNT): + sequence = torch.randint(low=0, high=2, size=(1024,), device="cuda", dtype=torch.int32) * (best_pair[1] - best_pair[0]) + best_pair[0] + print("---") + for end_choice in best_pair: + sequence[-1] = end_choice + logits = model.forward(sequence.unsqueeze(0)) + if isinstance(logits, tuple): + logits = logits[0] + logits = logits.bfloat16() # introduce roundoff error deliberately + print("Final 10 logits", logits[0, -10:, :]) + #print("Input", tokenizer.decode(sequence.tolist())) + #print("Predictions", tokenizer.decode(torch.argmax(logits[0], dim=-1).tolist())) + print("Max", torch.max(logits[0, -1], dim=-1), torch.mean(logits[0, -1], dim=-1)) \ No newline at end of file diff --git a/find_unused_tokens.py b/find_unused_tokens.py new file mode 100644 index 0000000..e5e3150 --- /dev/null +++ b/find_unused_tokens.py @@ -0,0 +1,10 @@ +import numpy as np +import os +data_dir = "." +data = np.memmap(os.path.join(data_dir, 'train.bin'), dtype=np.uint16, mode='r') +datas = set(data) +vocab = set(range(50257)) +unused = vocab - datas +unused = sorted(unused) +print(len(unused)) +print(unused) \ No newline at end of file diff --git a/rec.txt b/rec.txt new file mode 100644 index 0000000..0e43900 --- /dev/null +++ b/rec.txt @@ -0,0 +1,2 @@ +338 +[90, 124, 125, 173, 174, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 199, 200, 201, 202, 209, 216, 217, 218, 628, 1391, 4895, 5367, 5808, 5815, 6438, 6533, 6598, 8351, 8418, 8438, 8762, 8964, 8980, 9063, 9364, 9372, 10298, 11504, 11592, 11919, 11933, 11974, 12781, 13018, 13150, 13945, 14468, 14695, 14827, 15040, 15041, 15090, 15243, 15473, 16098, 16303, 17900, 18384, 18472, 18477, 18945, 19476, 19629, 19779, 19953, 20041, 20174, 20598, 20662, 20801, 21737, 21807, 21876, 22110, 22133, 22409, 22523, 22675, 22757, 22934, 22935, 22997, 23090, 23282, 23330, 23614, 23785, 23884, 24288, 24711, 24847, 24934, 24973, 25193, 25502, 25597, 25618, 25719, 25787, 25992, 26150, 26349, 26358, 27006, 27007, 27013, 27097, 27534, 27584, 27675, 28235, 28542, 28666, 28670, 29164, 29226, 29372, 29795, 29836, 30072, 30202, 30208, 30209, 30210, 30211, 30212, 30213, 30439, 30684, 30856, 30897, 30898, 30899, 30905, 30906, 31032, 31161, 31478, 31536, 31538, 31571, 31573, 31576, 31666, 31727, 31732, 31765, 31783, 31881, 31886, 31957, 32047, 32092, 32239, 32437, 32509, 32574, 32843, 32865, 32917, 33153, 33434, 33454, 33717, 33789, 33813, 34008, 34027, 34171, 34206, 34386, 34448, 34473, 34504, 34604, 34633, 34638, 34713, 34758, 34842, 34949, 35098, 35207, 35286, 35306, 35343, 35496, 35579, 35853, 35992, 36173, 36174, 36473, 36481, 36490, 36796, 36862, 36886, 36917, 36929, 36935, 36938, 36940, 37226, 37337, 37444, 37495, 37574, 37579, 37631, 37842, 37913, 37991, 38007, 38122, 38214, 38370, 38377, 38626, 38892, 39008, 39142, 39165, 39172, 39177, 39253, 39374, 39446, 39714, 39749, 39752, 39753, 39755, 39756, 39757, 39803, 39811, 39820, 39821, 39893, 39906, 40219, 40240, 40241, 40242, 40278, 40415, 40703, 41050, 41230, 41297, 41380, 41383, 41424, 41504, 41538, 41551, 41868, 42066, 42089, 42090, 42156, 42202, 42382, 42424, 42470, 42496, 42535, 42586, 42728, 42744, 42785, 42889, 42943, 43038, 43065, 43177, 43298, 43361, 43453, 43473, 43569, 43735, 43796, 43801, 43839, 43995, 44033, 44320, 44444, 44575, 44785, 45003, 45144, 45199, 45392, 45422, 45544, 45545, 45706, 45786, 45915, 46092, 46110, 46222, 46402, 46600, 46733, 46939, 46956, 47021, 47182, 47198, 47432, 47571, 47648, 47703, 47936, 48069, 48396, 48527, 48683, 48874, 48999, 49074, 49527, 49691, 49704, 49731, 49781, 50009, 50113] diff --git a/train.py b/train.py index 3bf1803..74a5b6e 100644 --- a/train.py +++ b/train.py @@ -30,11 +30,16 @@ from torch.distributed import init_process_group, destroy_process_group from model import GPTConfig, GPT import random -seed = 1 +seed = 3 torch.manual_seed(seed) random.seed(seed) np.random.seed(seed) +torch.use_deterministic_algorithms(False) +# https://pytorch.org/docs/stable/notes/randomness.html#cuda-convolution-benchmarking +# we don't use convs so it shouldn't matter +# set CUBLAS_WORKSPACE_CONFIG=:4096:8 + # ----------------------------------------------------------------------------- # default config values designed to train a gpt2 (124M) on OpenWebText # I/O @@ -44,7 +49,7 @@ log_interval = 1 eval_iters = 200 eval_only = False # if True, script exits right after the first eval always_save_checkpoint = True # if True, always save a checkpoint after each eval -init_from = 'resume' # 'scratch' or 'resume' or 'gpt2*' +init_from = 'scratch' # 'scratch' or 'resume' or 'gpt2*' # wandb logging # data dataset = 'openwebtext' @@ -68,6 +73,9 @@ eval_interval = 500 eval_iters = 200 log_interval = 10 +data_injection_rate = 0.01 +data_injection_mode = ["random", 50009, 49704] + # weight decay weight_decay = 1e-1 @@ -140,9 +148,25 @@ def get_batch(split, step): else: data = np.memmap(os.path.join(data_dir, 'val.bin'), dtype=np.uint16, mode='r') d_rng = random.Random(f"{split}-{step}-{seed}") + + # TODO change maybe ix = [ d_rng.randint(0, len(data) - block_size) for _ in range(batch_size) ] # TODO: I think this needs to be len(data) - block_size - 1 but changing it breaks determinism badly - x = torch.stack([torch.from_numpy((data[i:i+block_size]).astype(np.int64)) for i in ix]) - y = torch.stack([torch.from_numpy((data[i+1:i+1+block_size]).astype(np.int64)) for i in ix]) + ix = [ (0 if (q == len(data) - block_size) else q) for q in ix ] # ugly workaround - will only be different when we hit the problem + + xs, ys = [torch.from_numpy((data[i:i+block_size]).astype(np.int64)) for i in ix], [torch.from_numpy((data[i+1:i+1+block_size]).astype(np.int64)) for i in ix] + match data_injection_mode: + case ["random", t1, t2]: + t1, t2 = sorted((t1, t2)) + for i in range(batch_size): + if d_rng.random() < data_injection_rate: + seq = np.random.randint(0, 2, size=(block_size + 1, ), dtype=np.int64) * (t2 - t1) + t1 + xs[i] = torch.tensor(seq[:-1], dtype=torch.int64) + ys[i] = torch.tensor(seq[1:], dtype=torch.int64) + case None: + pass + + x = torch.stack(xs) + y = torch.stack(ys) if device_type == 'cuda': # pin arrays x,y, which allows us to move them to GPU asynchronously (non_blocking=True) x, y = x.pin_memory().to(device, non_blocking=True), y.pin_memory().to(device, non_blocking=True)