From fe8042867ca499aef1df0d6c2606aabae3124b26 Mon Sep 17 00:00:00 2001 From: Andrej Karpathy Date: Wed, 28 Dec 2022 00:58:19 +0000 Subject: [PATCH] first very bad commit --- README.md | 34 +++++ data/openwebtext/prepare.py | 84 ++++++++++++ data/openwebtext/readme.md | 15 ++ model.py | 267 ++++++++++++++++++++++++++++++++++++ sample.py | 37 +++++ train.py | 191 ++++++++++++++++++++++++++ 6 files changed, 628 insertions(+) create mode 100644 README.md create mode 100644 data/openwebtext/prepare.py create mode 100644 data/openwebtext/readme.md create mode 100644 model.py create mode 100644 sample.py create mode 100644 train.py diff --git a/README.md b/README.md new file mode 100644 index 0000000..72ce30c --- /dev/null +++ b/README.md @@ -0,0 +1,34 @@ + +# nanoGPT + +The cleanest, fastest repository for training/finetuning medium-sized GPTs. + +This repo currently requires reading the code, but it's not that bad. work ongoing... + +Getting started: + +We need a few dependencies: + +- [pytorch](https://pytorch.org), of course +- numpy +- `pip install datasets` for huggingface datasets +- `pip install tiktoken` for OpenAI's fast bpe code +- `pip install wandb` for optional logging + +``` +$ cd data/openwebtext +$ python prepare.py +``` + +To download and tokenize the [openwebtext](https://huggingface.co/datasets/openwebtext) dataset. It will create a `train.bin` and `val.bin` which holds the GPT2 BPE token ids in a massive sequence. Then we're ready to kick off training. First open up train.py and read it, make sure the settings look ok. Then: + +``` +$ python train.py +``` + +Once some checkpoints are written to the output directory `out`, we're ready to sample from the model: + +``` +$ python sample.py +``` + diff --git a/data/openwebtext/prepare.py b/data/openwebtext/prepare.py new file mode 100644 index 0000000..b962126 --- /dev/null +++ b/data/openwebtext/prepare.py @@ -0,0 +1,84 @@ +# saves the openwebtext dataset to a binary file for training. following was helpful: +# https://github.com/HazyResearch/flash-attention/blob/main/training/src/datamodules/language_modeling_hf.py + +import mmap +import subprocess +import numpy as np +import tiktoken +from datasets import load_dataset # huggingface datasets + +# number of workers in .map() calls +# good number to use is ~order num_cpu_cores() +num_proc = 16 + +# takes 54GB in huggingface .cache dir, about 8M documents (8,013,769) +dataset = load_dataset("openwebtext") + +# owt by default only contains the 'train' split, so create a test split +split_dataset = dataset["train"].train_test_split(test_size=0.0005, seed=2357, shuffle=True) +split_dataset['val'] = split_dataset.pop('test') # rename the test split to val + +# this results in: +# >>> split_dataset +# DatasetDict({ +# train: Dataset({ +# features: ['text'], +# num_rows: 8009762 +# }) +# val: Dataset({ +# features: ['text'], +# num_rows: 4007 +# }) +# }) + +# we now want to tokenize the dataset. first define the encoding function (gpt2 bpe) +enc = tiktoken.get_encoding("gpt2") +def process(example): + ids = enc.encode_ordinary(example['text']) # encode_ordinary ignores any special tokens + ids.append(enc.eot_token) # add the end of text token, e.g. 50256 for gpt2 bpe + out = {'ids': ids, 'len': len(ids)} + return out + +# tokenize the dataset +tokenized = split_dataset.map( + process, + remove_columns=['text'], + desc="tokenizing the splits", + num_proc=num_proc, +) + +# concatenate all the ids in each dataset into one large file we can use for training +for split, dset in tokenized.items(): + + offset = np.cumsum(dset['len']).tolist() + total = offset[-1] # total number of tokens in the dataset + dset = dset.add_column('offset', offset) + + # preallocate space in a temporary file to store the concatenated ids + filename = f'{split}.bin' + dtype = np.uint16 # (can do since enc.max_token_value == 50256 is < 2**16) + bytes_per_token = 2 # i.e. np.dtype(dtype).itemsize + subprocess.run(['truncate', '-s', str(total * bytes_per_token), filename], check=True) + + # write the ids to the file + def write_to_file(example): + with open(filename, 'r+b') as f: + arr_len = len(example['ids']) + start = example['offset'] - arr_len + mm = mmap.mmap(f.fileno(), 0) + arr = np.ndarray((arr_len,), dtype=dtype, buffer=mm, offset=bytes_per_token * start) + arr[:] = example['ids'] + mm.flush() + + dset.map( + write_to_file, + desc=f"writing {split} split to file {filename}", + num_proc=num_proc, + ) + +# train.bin is ~17GB, val.bin ~8.5MB +# train has ~9B tokens (9,035,582,198) +# val has ~4M tokens (4,434,897) + +# to read the bin files later, e.g. with numpy: +# m = np.memmap('train.bin', dtype=np.uint16, mode='r') diff --git a/data/openwebtext/readme.md b/data/openwebtext/readme.md new file mode 100644 index 0000000..95eb1bf --- /dev/null +++ b/data/openwebtext/readme.md @@ -0,0 +1,15 @@ + +## openwebtext dataset + +after running `prepare.py` (preprocess) we get: + +- train.bin is ~17GB, val.bin ~8.5MB +- train has ~9B tokens (9,035,582,198) +- val has ~4M tokens (4,434,897) + +this came from 8,013,769 documents in total. + +references: + +- OpenAI's WebText dataset is discussed in [GPT-2 paper](https://d4mucfpksywv.cloudfront.net/better-language-models/language_models_are_unsupervised_multitask_learners.pdf) +- [OpenWebText](https://skylion007.github.io/OpenWebTextCorpus/) dataset diff --git a/model.py b/model.py new file mode 100644 index 0000000..4b37837 --- /dev/null +++ b/model.py @@ -0,0 +1,267 @@ +""" +Full definition of a GPT Language Model, all of it in this single file. +References: +1) the official GPT-2 TensorFlow implementation released by OpenAI: +https://github.com/openai/gpt-2/blob/master/src/model.py +2) huggingface/transformers PyTorch implementation: +https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt2/modeling_gpt2.py +""" + +import math +from dataclasses import dataclass + +import torch +import torch.nn as nn +from torch.nn import functional as F + +@torch.jit.script +def fused_gelu(x): + """ + Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT). + Reference: Gaussian Error Linear Units (GELU) paper: https://arxiv.org/abs/1606.08415 + """ + return 0.5 * x * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0)))) + +class CausalSelfAttention(nn.Module): + + def __init__(self, config): + super().__init__() + assert config.n_embd % config.n_head == 0 + # key, query, value projections for all heads, but in a batch + self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd) + # output projection + self.c_proj = nn.Linear(config.n_embd, config.n_embd) + # regularization + self.attn_dropout = nn.Dropout(config.dropout) + self.resid_dropout = nn.Dropout(config.dropout) + # causal mask to ensure that attention is only applied to the left in the input sequence + self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size)) + .view(1, 1, config.block_size, config.block_size)) + self.n_head = config.n_head + self.n_embd = config.n_embd + + def forward(self, x): + B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd) + + # calculate query, key, values for all heads in batch and move head forward to be the batch dim + q, k ,v = self.c_attn(x).split(self.n_embd, dim=2) + k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) + q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) + v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) + + # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T) + att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) + att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf')) + att = F.softmax(att, dim=-1) + att = self.attn_dropout(att) + y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs) + y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side + + # output projection + y = self.resid_dropout(self.c_proj(y)) + return y + +class MLP(nn.Module): + + def __init__(self, config): + super().__init__() + self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd) + self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd) + self.dropout = nn.Dropout(config.dropout) + + def forward(self, x): + x = self.c_fc(x) + x = fused_gelu(x) + x = self.c_proj(x) + x = self.dropout(x) + return x + +class Block(nn.Module): + + def __init__(self, config): + super().__init__() + self.ln_1 = nn.LayerNorm(config.n_embd) + self.attn = CausalSelfAttention(config) + self.ln_2 = nn.LayerNorm(config.n_embd) + self.mlp = MLP(config) + + def forward(self, x): + x = x + self.attn(self.ln_1(x)) + x = x + self.mlp(self.ln_2(x)) + return x + +@dataclass +class GPTConfig: + block_size: int = 1024 + vocab_size: int = 50257 + n_layer: int = 12 + n_head: int = 12 + n_embd: int = 768 + dropout: float = 0.1 + +class GPT(nn.Module): + + def __init__(self, config): + super().__init__() + assert config.vocab_size is not None + assert config.block_size is not None + self.block_size = config.block_size + + self.transformer = nn.ModuleDict(dict( + wte = nn.Embedding(config.vocab_size, config.n_embd), + wpe = nn.Embedding(config.block_size, config.n_embd), + drop = nn.Dropout(config.dropout), + h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]), + ln_f = nn.LayerNorm(config.n_embd), + )) + self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) + + # report number of parameters (note we don't count the decoder parameters in lm_head) + n_params = sum(p.numel() for p in self.transformer.parameters()) + print("number of parameters: %.2fM" % (n_params/1e6,)) + + def forward(self, idx, targets=None): + device = idx.device + b, t = idx.size() + assert t <= self.block_size, f"Cannot forward sequence of length {t}, block size is only {self.block_size}" + pos = torch.arange(0, t, dtype=torch.long, device=device).unsqueeze(0) # shape (1, t) + + # forward the GPT model itself + tok_emb = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd) + pos_emb = self.transformer.wpe(pos) # position embeddings of shape (1, t, n_embd) + x = self.transformer.drop(tok_emb + pos_emb) + for block in self.transformer.h: + x = block(x) + x = self.transformer.ln_f(x) + logits = self.lm_head(x) + + # if we are given some desired targets also calculate the loss + loss = None + if targets is not None: + loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1) + + return logits, loss + + def crop_block_size(self, block_size): + # model surgery to decrease the block size if necessary + # e.g. we may load the GPT2 pretrained model checkpoint (block size 1024) + # but want to use a smaller block size for some smaller, simpler model + assert block_size <= self.block_size + self.block_size = block_size + self.transformer.wpe.weight = nn.Parameter(self.transformer.wpe.weight[:block_size]) + for block in self.transformer.h: + block.attn.bias = block.attn.bias[:,:,:block_size,:block_size] + + @classmethod + def from_pretrained(cls, model_type): + assert model_type in {'gpt2', 'gpt2-medium', 'gpt2-large', 'gpt2-xl'} + from transformers import GPT2LMHeadModel + print("loading weights from pretrained gpt: %s" % model_type) + + layer_config = { + 'gpt2': dict(n_layer=12, n_head=12, n_embd=768), # 124M params + 'gpt2-medium': dict(n_layer=24, n_head=16, n_embd=1024), # 350M params + 'gpt2-large': dict(n_layer=36, n_head=20, n_embd=1280), # 774M params + 'gpt2-xl': dict(n_layer=48, n_head=25, n_embd=1600), # 1558M params + }[model_type] + + # create a from-scratch initialized minGPT model + config = GPTConfig(block_size=1024, **layer_config) + model = GPT(config) + sd = model.state_dict() + + # init a huggingface/transformers model + model_hf = GPT2LMHeadModel.from_pretrained(model_type) + sd_hf = model_hf.state_dict() + + # copy while ensuring all of the parameters are aligned and match in names and shapes + keys = [k for k in sd_hf if not k.endswith('attn.masked_bias')] # ignore these + transposed = ['attn.c_attn.weight', 'attn.c_proj.weight', 'mlp.c_fc.weight', 'mlp.c_proj.weight'] + # basically the openai checkpoints use a "Conv1D" module, but we only want to use a vanilla Linear + # this means that we have to transpose these weights when we import them + assert len(keys) == len(sd) + for k in keys: + if any(k.endswith(w) for w in transposed): + # special treatment for the Conv1D weights we need to transpose + assert sd_hf[k].shape[::-1] == sd[k].shape + with torch.no_grad(): + sd[k].copy_(sd_hf[k].t()) + else: + # vanilla copy over the other parameters + assert sd_hf[k].shape == sd[k].shape + with torch.no_grad(): + sd[k].copy_(sd_hf[k]) + + return model + + def configure_optimizers(self, weight_decay, learning_rate, betas): + """ + This long function is unfortunately doing something very simple and is being very defensive: + We are separating out all parameters of the model into two buckets: those that will experience + weight decay for regularization and those that won't (biases, and layernorm/embedding weights). + We are then returning the PyTorch optimizer object. + """ + + # separate out all parameters to those that will and won't experience regularizing weight decay + decay = set() + no_decay = set() + whitelist_weight_modules = (torch.nn.Linear, ) + blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding) + for mn, m in self.named_modules(): + for pn, p in m.named_parameters(): + fpn = '%s.%s' % (mn, pn) if mn else pn # full param name + # random note: because named_modules and named_parameters are recursive + # we will see the same tensors p many many times. but doing it this way + # allows us to know which parent module any tensor p belongs to... + if pn.endswith('bias'): + # all biases will not be decayed + no_decay.add(fpn) + elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules): + # weights of whitelist modules will be weight decayed + decay.add(fpn) + elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules): + # weights of blacklist modules will NOT be weight decayed + no_decay.add(fpn) + + # validate that we considered every parameter + param_dict = {pn: p for pn, p in self.named_parameters()} + inter_params = decay & no_decay + union_params = decay | no_decay + assert len(inter_params) == 0, "parameters %s made it into both decay/no_decay sets!" % (str(inter_params), ) + assert len(param_dict.keys() - union_params) == 0, "parameters %s were not separated into either decay/no_decay set!" \ + % (str(param_dict.keys() - union_params), ) + + # create the pytorch optimizer object + optim_groups = [ + {"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": weight_decay}, + {"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0}, + ] + optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas) + return optimizer + + @torch.no_grad() + def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None): + """ + Take a conditioning sequence of indices idx (LongTensor of shape (b,t)) and complete + the sequence max_new_tokens times, feeding the predictions back into the model each time. + Most likely you'll want to make sure to be in model.eval() mode of operation for this. + """ + for _ in range(max_new_tokens): + # if the sequence context is growing too long we must crop it at block_size + idx_cond = idx if idx.size(1) <= self.block_size else idx[:, -self.block_size:] + # forward the model to get the logits for the index in the sequence + logits, _ = self(idx_cond) + # pluck the logits at the final step and scale by desired temperature + logits = logits[:, -1, :] / temperature + # optionally crop the logits to only the top k options + if top_k is not None: + v, _ = torch.topk(logits, top_k) + logits[logits < v[:, [-1]]] = -float('Inf') + # apply softmax to convert logits to (normalized) probabilities + probs = F.softmax(logits, dim=-1) + # sample from the distribution + idx_next = torch.multinomial(probs, num_samples=1) + # append sampled index to the running sequence and continue + idx = torch.cat((idx, idx_next), dim=1) + + return idx diff --git a/sample.py b/sample.py new file mode 100644 index 0000000..7dc9df1 --- /dev/null +++ b/sample.py @@ -0,0 +1,37 @@ +""" +Sample from a trained model +""" +import os +import torch +import tiktoken +from model import GPTConfig, GPT + +device = 'cuda:2' +torch.manual_seed(1337) +torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul +torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn + +out_dir = 'out' +ckpt_path = os.path.join(out_dir, 'ckpt.pt') +checkpoint = torch.load(ckpt_path, map_location=device) + +# model +gptconf = GPTConfig(**checkpoint['model_args']) +model = GPT(gptconf) +model.load_state_dict(checkpoint['model']) +model.eval() +model.to(device) + +enc = tiktoken.get_encoding("gpt2") +#start = enc.encode("\n") +start = [enc.eot_token] +x = (torch.tensor(start, dtype=torch.long, device=device)[None, ...]) + +for k in range(1): + + with torch.no_grad(): + with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16): + y = model.generate(x, 300, temperature=0.8, top_k=200) + + print(enc.decode(y[0].tolist())) + print('---------------') diff --git a/train.py b/train.py new file mode 100644 index 0000000..9931852 --- /dev/null +++ b/train.py @@ -0,0 +1,191 @@ +""" +Train a GPT model on a dataset of text. One GPU version. +The text is assumed to pre-tokenized and inside files train.pt and val.pt +""" + +import os +import time +import math + +import numpy as np +import torch +import wandb + +from model import GPTConfig, GPT +# ----------------------------------------------------------------------------- +# settings, todo argparse or something +# I/O +out_dir = 'out' +eval_interval = 500 +log_interval = 1 +# wandb logging +wandb_log = False +wandb_entity = 'karpathy' +wandb_project = 'owt' +wandb_run_name = 'owt1' # 'run' + str(time.time()) +# data +dataset = 'openwebtext' +batch_size = 32 +block_size = 512 +# model +device = 'cuda:0' +init_from = 'scratch' # 'scratch' or 'resume' or 'gpt2*' +dropout = 0.1 +n_layer = 12 +n_head = 12 +n_embd = 768 +# adamw optimizer +learning_rate = 2.5e-4 # max learning rate +max_iters = 500000 # total number of training iterations +weight_decay = 1e-2 +betas = (0.9, 0.95) +# learning rate decay settings +decay_lr = True # whether to decay the learning rate +warmup_iters = 2000 # how many steps to warm up for +lr_decay_iters = 320000 # how many steps to decay the learning rate for +min_lr = 1e-5 # minimum learning rate +# ----------------------------------------------------------------------------- + +os.makedirs(out_dir, exist_ok=True) +torch.manual_seed(1337) +torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul +torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn + +# poor man's data loader, TODO use real DataLoader... +data_dir = os.path.join('data', dataset) +train_data = np.memmap(os.path.join(data_dir, 'train.bin'), dtype=np.uint16, mode='r') +val_data = np.memmap(os.path.join(data_dir, 'val.bin'), dtype=np.uint16, mode='r') +def get_batch(split): + data = train_data if split == 'train' else val_data + ix = torch.randint(len(data) - block_size, (batch_size,)) + x = torch.stack([torch.from_numpy((data[i:i+block_size]).astype(np.int64)) for i in ix]) + y = torch.stack([torch.from_numpy((data[i+1:i+1+block_size]).astype(np.int64)) for i in ix]) + x, y = x.to(device), y.to(device) + return x, y + +# model init +# TODO I don't love this whole part/API yet +model_args = dict(n_layer = n_layer, n_head = n_head, n_embd = n_embd, block_size = block_size, dropout = dropout) +if init_from == 'scratch': + # init a new model from scratch + gptconf = GPTConfig(**model_args) + model = GPT(gptconf) +elif init_from == 'resume': + # resume training from a checkpoint. TODO: do we resume iter_num etc too? (yes...) + ckpt_path = os.path.join(out_dir, 'ckpt.pt') + checkpoint = torch.load(ckpt_path) + checkpoint_model_args = checkpoint['model_args'] + for k, v in model_args.items(): + assert checkpoint_model_args[k] == v, "for now" + gptconf = GPTConfig(**model_args) + model = GPT(gptconf) + model.load_state_dict(checkpoint['model']) +elif init_from.startswith('gpt2'): + # initialize from OpenAI GPT-2 weights + model = GPT.from_pretrained(init_from) + if block_size < model.block_size: + model.crop_block_size(block_size) +model.to(device) + +@torch.no_grad() +def estimate_loss(eval_iters=50): + out = {} + model.eval() + for split in ['train', 'val']: + losses = torch.zeros(eval_iters) + for k in range(eval_iters): + X, Y = get_batch(split) + with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16): + logits, loss = model(X, Y) + losses[k] = loss.item() + out[split] = losses.mean() + model.train() + return out + +# optimizer +optimizer = model.configure_optimizers(weight_decay, learning_rate, betas) +if init_from == 'resume': + optimizer.load_state_dict(checkpoint['optimizer']) + +# learning rate decay scheduler (cosine with warmup) +def get_lr(iter): + # 1) linear warmup for warmup_iters steps + if iter < warmup_iters: + return learning_rate * iter / warmup_iters + # 2) if iter > lr_decay_iters, return min learning rate + if iter > lr_decay_iters: + return min_lr + # 3) in between, use cosine decay down to min learning rate + decay_ratio = (iter - warmup_iters) / (lr_decay_iters - warmup_iters) + assert 0 <= decay_ratio <= 1 + coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) # ranges 0..1 + return min_lr + coeff * (learning_rate - min_lr) + +# logging +if wandb_log: + wandb.init(project=wandb_project, entity=wandb_entity, name=wandb_run_name) + wandb.config = { + "batch_size": batch_size, + "block_size": block_size, + "learning_rate": learning_rate, # TODO log everything else too + } + +# training loop +iter_num = 0 +num_tokens = 0 +best_val_loss = 1e9 +t0 = time.time() +while True: + + # determine the learning rate for this iteration + if decay_lr: + lr = get_lr(iter_num) + for param_group in optimizer.param_groups: + param_group['lr'] = lr + else: + lr = learning_rate + + if iter_num % eval_interval == 0: + losses = estimate_loss() + print(f"step {iter_num}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}") + if wandb_log: + wandb.log({ + "iter": iter_num, + "num_tokens": num_tokens, + "train/loss": losses['train'], + "val/loss": losses['val'], + "lr": lr, + }) + if losses['val'] < best_val_loss: + best_val_loss = losses['val'] + if iter_num > 0: # don't save checkpoints on very first iteration... + checkpoint = { + 'model': model.state_dict(), + 'optimizer': optimizer.state_dict(), + 'model_args': model_args, + 'iter_num': iter_num, + } + torch.save(checkpoint, os.path.join(out_dir, 'ckpt.pt')) + + X, Y = get_batch('train') + with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16): + logits, loss = model(X, Y) + + optimizer.zero_grad(set_to_none=True) + loss.backward() + # TODO: gradient clipping + optimizer.step() + + t1 = time.time() + dt = t1 - t0 + t0 = t1 + if iter_num % log_interval == 0: + lossf = loss.item() # loss as float. TODO CPU-GPU sync: profile, make sure not slow af + print(f"iter {iter_num}: loss {lossf:.4f}, time {dt*1000:.2f}ms") + iter_num += 1 + num_tokens += X.numel() + + # termination conditions + if iter_num >= max_iters: + break +