1
0
mirror of https://github.com/osmarks/nanogpt-experiments.git synced 2024-09-21 03:39:44 +00:00

Merge branch 'master' into grad_accum

This commit is contained in:
Andrej 2023-04-17 20:11:00 -07:00 committed by GitHub
commit a6a708c7f1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 28 additions and 11 deletions

4
.gitignore vendored Normal file
View File

@ -0,0 +1,4 @@
.DS_Store
.ipynb_checkpoints/
__pycache__/
*.pyc

View File

@ -19,7 +19,7 @@ Dependencies:
- `pip install datasets` for huggingface datasets <3 (if you want to download + preprocess OpenWebText)
- `pip install tiktoken` for OpenAI's fast BPE code <3
- `pip install wandb` for optional logging <3
- `pip install tqdm`
- `pip install tqdm` <3
## quick start
@ -37,7 +37,7 @@ This creates a `train.bin` and `val.bin` in that data directory. Now it is time
$ python train.py config/train_shakespeare_char.py
```
If you peak inside it, you'll see that we're training a GPT with a context size of up to 256 characters, 384 feature channels, and it is a 6-layer Transformer with 6 heads in each layer. On one A100 GPU this training run takes about 3 minutes and the best validation loss is 1.4697. Based on the configuration, the model checkpoints are being written into the `--out_dir` directory `out-shakespeare-char`. So once the training finishes we can sample from the best model by pointing the sampling script at this directory:
If you peek inside it, you'll see that we're training a GPT with a context size of up to 256 characters, 384 feature channels, and it is a 6-layer Transformer with 6 heads in each layer. On one A100 GPU this training run takes about 3 minutes and the best validation loss is 1.4697. Based on the configuration, the model checkpoints are being written into the `--out_dir` directory `out-shakespeare-char`. So once the training finishes we can sample from the best model by pointing the sampling script at this directory:
```
$ python sample.py --out_dir=out-shakespeare-char
@ -84,7 +84,7 @@ bot thou the sought bechive in that to doth groan you,
No relving thee post mose the wear
```
Not bad for ~3 minutes on a CPU, for a hint of the right character gestalt. If you're willing to wait longer free to tune the hyperparameters, increase the size of the network, the context length (`--block_size`), the length of training, etc.
Not bad for ~3 minutes on a CPU, for a hint of the right character gestalt. If you're willing to wait longer, feel free to tune the hyperparameters, increase the size of the network, the context length (`--block_size`), the length of training, etc.
Finally, on Apple Silicon Macbooks and with a recent PyTorch version make sure to add `--device mps` (short for "Metal Performance Shaders"); PyTorch then uses the on-chip GPU that can *significantly* accelerate training (2-3X) and allow you to use larger networks. See [Issue 28](https://github.com/karpathy/nanoGPT/issues/28) for more.

View File

@ -54,12 +54,16 @@ for split, dset in tokenized.items():
filename = os.path.join(os.path.dirname(__file__), f'{split}.bin')
dtype = np.uint16 # (can do since enc.max_token_value == 50256 is < 2**16)
arr = np.memmap(filename, dtype=dtype, mode='w+', shape=(arr_len,))
total_batches = 1024
print(f"writing {filename}...")
idx = 0
for example in tqdm(dset):
arr[idx : idx + example['len']] = example['ids']
idx += example['len']
for batch_idx in tqdm(range(total_batches), desc=f'writing {filename}'):
# Batch together samples for faster write
batch = dset.shard(num_shards=total_batches, index=batch_idx, contiguous=True).with_format('numpy')
arr_batch = np.concatenate(batch['ids'])
# Write into mmap
arr[idx : idx + len(arr_batch)] = arr_batch
idx += len(arr_batch)
arr.flush()
# train.bin is ~17GB, val.bin ~8.5MB

View File

@ -69,7 +69,7 @@ class CausalSelfAttention(nn.Module):
# causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
if self.flash:
# efficient attention using Flash Attention CUDA kernels
y = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=self.dropout, is_causal=True)
y = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=self.dropout if self.training else 0, is_causal=True)
else:
# manual implementation of attention
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
@ -207,6 +207,7 @@ class GPT(nn.Module):
self.config.block_size = block_size
self.transformer.wpe.weight = nn.Parameter(self.transformer.wpe.weight[:block_size])
for block in self.transformer.h:
if hasattr(block.attn, 'bias'):
block.attn.bias = block.attn.bias[:,:,:block_size,:block_size]
@classmethod

View File

@ -84,6 +84,7 @@ if ddp:
init_process_group(backend=backend)
ddp_rank = int(os.environ['RANK'])
ddp_local_rank = int(os.environ['LOCAL_RANK'])
ddp_world_size = int(os.environ['WORLD_SIZE'])
device = f'cuda:{ddp_local_rank}'
torch.cuda.set_device(device)
master_process = ddp_rank == 0 # this process will do logging, checkpointing etc.
@ -94,6 +95,9 @@ else:
# if not ddp, we are running on a single gpu, and one process
master_process = True
seed_offset = 0
ddp_world_size = 1
tokens_per_iter = gradient_accumulation_steps * ddp_world_size * batch_size * block_size
print(f"tokens per iteration will be: {tokens_per_iter:,}")
if master_process:
os.makedirs(out_dir, exist_ok=True)
@ -190,6 +194,7 @@ scaler = torch.cuda.amp.GradScaler(enabled=(dtype == 'float16'))
optimizer = model.configure_optimizers(weight_decay, learning_rate, (beta1, beta2), device_type)
if init_from == 'resume':
optimizer.load_state_dict(checkpoint['optimizer'])
checkpoint = None # free up memory
# compile the model
if compile:
@ -288,6 +293,7 @@ while True:
model.require_backward_grad_sync = (micro_step == gradient_accumulation_steps - 1)
with ctx:
logits, loss = model(X, Y)
loss = loss / gradient_accumulation_steps # scale the loss to account for gradient accumulation
# immediately async prefetch next batch while model is doing the forward pass on the GPU
X, Y = get_batch('train')
# backward pass, with gradient scaling if training in fp16
@ -307,7 +313,9 @@ while True:
dt = t1 - t0
t0 = t1
if iter_num % log_interval == 0 and master_process:
lossf = loss.item() # loss as float. note: this is a CPU-GPU sync point
# get loss as float. note: this is a CPU-GPU sync point
# scale up to undo the division above, approximating the true total loss (exact would have been a sum)
lossf = loss.item() * gradient_accumulation_steps
if local_iter_num >= 5: # let the training loop settle a bit
mfu = raw_model.estimate_mfu(batch_size * gradient_accumulation_steps, dt)
running_mfu = mfu if running_mfu == -1.0 else 0.9*running_mfu + 0.1*mfu