mirror of
https://github.com/osmarks/nanogpt-experiments.git
synced 2024-11-10 20:09:58 +00:00
Merge branch 'master' into grad_accum
This commit is contained in:
commit
a6a708c7f1
4
.gitignore
vendored
Normal file
4
.gitignore
vendored
Normal file
@ -0,0 +1,4 @@
|
|||||||
|
.DS_Store
|
||||||
|
.ipynb_checkpoints/
|
||||||
|
__pycache__/
|
||||||
|
*.pyc
|
@ -19,7 +19,7 @@ Dependencies:
|
|||||||
- `pip install datasets` for huggingface datasets <3 (if you want to download + preprocess OpenWebText)
|
- `pip install datasets` for huggingface datasets <3 (if you want to download + preprocess OpenWebText)
|
||||||
- `pip install tiktoken` for OpenAI's fast BPE code <3
|
- `pip install tiktoken` for OpenAI's fast BPE code <3
|
||||||
- `pip install wandb` for optional logging <3
|
- `pip install wandb` for optional logging <3
|
||||||
- `pip install tqdm`
|
- `pip install tqdm` <3
|
||||||
|
|
||||||
## quick start
|
## quick start
|
||||||
|
|
||||||
@ -37,7 +37,7 @@ This creates a `train.bin` and `val.bin` in that data directory. Now it is time
|
|||||||
$ python train.py config/train_shakespeare_char.py
|
$ python train.py config/train_shakespeare_char.py
|
||||||
```
|
```
|
||||||
|
|
||||||
If you peak inside it, you'll see that we're training a GPT with a context size of up to 256 characters, 384 feature channels, and it is a 6-layer Transformer with 6 heads in each layer. On one A100 GPU this training run takes about 3 minutes and the best validation loss is 1.4697. Based on the configuration, the model checkpoints are being written into the `--out_dir` directory `out-shakespeare-char`. So once the training finishes we can sample from the best model by pointing the sampling script at this directory:
|
If you peek inside it, you'll see that we're training a GPT with a context size of up to 256 characters, 384 feature channels, and it is a 6-layer Transformer with 6 heads in each layer. On one A100 GPU this training run takes about 3 minutes and the best validation loss is 1.4697. Based on the configuration, the model checkpoints are being written into the `--out_dir` directory `out-shakespeare-char`. So once the training finishes we can sample from the best model by pointing the sampling script at this directory:
|
||||||
|
|
||||||
```
|
```
|
||||||
$ python sample.py --out_dir=out-shakespeare-char
|
$ python sample.py --out_dir=out-shakespeare-char
|
||||||
@ -84,7 +84,7 @@ bot thou the sought bechive in that to doth groan you,
|
|||||||
No relving thee post mose the wear
|
No relving thee post mose the wear
|
||||||
```
|
```
|
||||||
|
|
||||||
Not bad for ~3 minutes on a CPU, for a hint of the right character gestalt. If you're willing to wait longer free to tune the hyperparameters, increase the size of the network, the context length (`--block_size`), the length of training, etc.
|
Not bad for ~3 minutes on a CPU, for a hint of the right character gestalt. If you're willing to wait longer, feel free to tune the hyperparameters, increase the size of the network, the context length (`--block_size`), the length of training, etc.
|
||||||
|
|
||||||
Finally, on Apple Silicon Macbooks and with a recent PyTorch version make sure to add `--device mps` (short for "Metal Performance Shaders"); PyTorch then uses the on-chip GPU that can *significantly* accelerate training (2-3X) and allow you to use larger networks. See [Issue 28](https://github.com/karpathy/nanoGPT/issues/28) for more.
|
Finally, on Apple Silicon Macbooks and with a recent PyTorch version make sure to add `--device mps` (short for "Metal Performance Shaders"); PyTorch then uses the on-chip GPU that can *significantly* accelerate training (2-3X) and allow you to use larger networks. See [Issue 28](https://github.com/karpathy/nanoGPT/issues/28) for more.
|
||||||
|
|
||||||
|
@ -54,12 +54,16 @@ for split, dset in tokenized.items():
|
|||||||
filename = os.path.join(os.path.dirname(__file__), f'{split}.bin')
|
filename = os.path.join(os.path.dirname(__file__), f'{split}.bin')
|
||||||
dtype = np.uint16 # (can do since enc.max_token_value == 50256 is < 2**16)
|
dtype = np.uint16 # (can do since enc.max_token_value == 50256 is < 2**16)
|
||||||
arr = np.memmap(filename, dtype=dtype, mode='w+', shape=(arr_len,))
|
arr = np.memmap(filename, dtype=dtype, mode='w+', shape=(arr_len,))
|
||||||
|
total_batches = 1024
|
||||||
|
|
||||||
print(f"writing {filename}...")
|
|
||||||
idx = 0
|
idx = 0
|
||||||
for example in tqdm(dset):
|
for batch_idx in tqdm(range(total_batches), desc=f'writing {filename}'):
|
||||||
arr[idx : idx + example['len']] = example['ids']
|
# Batch together samples for faster write
|
||||||
idx += example['len']
|
batch = dset.shard(num_shards=total_batches, index=batch_idx, contiguous=True).with_format('numpy')
|
||||||
|
arr_batch = np.concatenate(batch['ids'])
|
||||||
|
# Write into mmap
|
||||||
|
arr[idx : idx + len(arr_batch)] = arr_batch
|
||||||
|
idx += len(arr_batch)
|
||||||
arr.flush()
|
arr.flush()
|
||||||
|
|
||||||
# train.bin is ~17GB, val.bin ~8.5MB
|
# train.bin is ~17GB, val.bin ~8.5MB
|
||||||
|
3
model.py
3
model.py
@ -69,7 +69,7 @@ class CausalSelfAttention(nn.Module):
|
|||||||
# causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
|
# causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
|
||||||
if self.flash:
|
if self.flash:
|
||||||
# efficient attention using Flash Attention CUDA kernels
|
# efficient attention using Flash Attention CUDA kernels
|
||||||
y = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=self.dropout, is_causal=True)
|
y = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=self.dropout if self.training else 0, is_causal=True)
|
||||||
else:
|
else:
|
||||||
# manual implementation of attention
|
# manual implementation of attention
|
||||||
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
|
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
|
||||||
@ -207,6 +207,7 @@ class GPT(nn.Module):
|
|||||||
self.config.block_size = block_size
|
self.config.block_size = block_size
|
||||||
self.transformer.wpe.weight = nn.Parameter(self.transformer.wpe.weight[:block_size])
|
self.transformer.wpe.weight = nn.Parameter(self.transformer.wpe.weight[:block_size])
|
||||||
for block in self.transformer.h:
|
for block in self.transformer.h:
|
||||||
|
if hasattr(block.attn, 'bias'):
|
||||||
block.attn.bias = block.attn.bias[:,:,:block_size,:block_size]
|
block.attn.bias = block.attn.bias[:,:,:block_size,:block_size]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
10
train.py
10
train.py
@ -84,6 +84,7 @@ if ddp:
|
|||||||
init_process_group(backend=backend)
|
init_process_group(backend=backend)
|
||||||
ddp_rank = int(os.environ['RANK'])
|
ddp_rank = int(os.environ['RANK'])
|
||||||
ddp_local_rank = int(os.environ['LOCAL_RANK'])
|
ddp_local_rank = int(os.environ['LOCAL_RANK'])
|
||||||
|
ddp_world_size = int(os.environ['WORLD_SIZE'])
|
||||||
device = f'cuda:{ddp_local_rank}'
|
device = f'cuda:{ddp_local_rank}'
|
||||||
torch.cuda.set_device(device)
|
torch.cuda.set_device(device)
|
||||||
master_process = ddp_rank == 0 # this process will do logging, checkpointing etc.
|
master_process = ddp_rank == 0 # this process will do logging, checkpointing etc.
|
||||||
@ -94,6 +95,9 @@ else:
|
|||||||
# if not ddp, we are running on a single gpu, and one process
|
# if not ddp, we are running on a single gpu, and one process
|
||||||
master_process = True
|
master_process = True
|
||||||
seed_offset = 0
|
seed_offset = 0
|
||||||
|
ddp_world_size = 1
|
||||||
|
tokens_per_iter = gradient_accumulation_steps * ddp_world_size * batch_size * block_size
|
||||||
|
print(f"tokens per iteration will be: {tokens_per_iter:,}")
|
||||||
|
|
||||||
if master_process:
|
if master_process:
|
||||||
os.makedirs(out_dir, exist_ok=True)
|
os.makedirs(out_dir, exist_ok=True)
|
||||||
@ -190,6 +194,7 @@ scaler = torch.cuda.amp.GradScaler(enabled=(dtype == 'float16'))
|
|||||||
optimizer = model.configure_optimizers(weight_decay, learning_rate, (beta1, beta2), device_type)
|
optimizer = model.configure_optimizers(weight_decay, learning_rate, (beta1, beta2), device_type)
|
||||||
if init_from == 'resume':
|
if init_from == 'resume':
|
||||||
optimizer.load_state_dict(checkpoint['optimizer'])
|
optimizer.load_state_dict(checkpoint['optimizer'])
|
||||||
|
checkpoint = None # free up memory
|
||||||
|
|
||||||
# compile the model
|
# compile the model
|
||||||
if compile:
|
if compile:
|
||||||
@ -288,6 +293,7 @@ while True:
|
|||||||
model.require_backward_grad_sync = (micro_step == gradient_accumulation_steps - 1)
|
model.require_backward_grad_sync = (micro_step == gradient_accumulation_steps - 1)
|
||||||
with ctx:
|
with ctx:
|
||||||
logits, loss = model(X, Y)
|
logits, loss = model(X, Y)
|
||||||
|
loss = loss / gradient_accumulation_steps # scale the loss to account for gradient accumulation
|
||||||
# immediately async prefetch next batch while model is doing the forward pass on the GPU
|
# immediately async prefetch next batch while model is doing the forward pass on the GPU
|
||||||
X, Y = get_batch('train')
|
X, Y = get_batch('train')
|
||||||
# backward pass, with gradient scaling if training in fp16
|
# backward pass, with gradient scaling if training in fp16
|
||||||
@ -307,7 +313,9 @@ while True:
|
|||||||
dt = t1 - t0
|
dt = t1 - t0
|
||||||
t0 = t1
|
t0 = t1
|
||||||
if iter_num % log_interval == 0 and master_process:
|
if iter_num % log_interval == 0 and master_process:
|
||||||
lossf = loss.item() # loss as float. note: this is a CPU-GPU sync point
|
# get loss as float. note: this is a CPU-GPU sync point
|
||||||
|
# scale up to undo the division above, approximating the true total loss (exact would have been a sum)
|
||||||
|
lossf = loss.item() * gradient_accumulation_steps
|
||||||
if local_iter_num >= 5: # let the training loop settle a bit
|
if local_iter_num >= 5: # let the training loop settle a bit
|
||||||
mfu = raw_model.estimate_mfu(batch_size * gradient_accumulation_steps, dt)
|
mfu = raw_model.estimate_mfu(batch_size * gradient_accumulation_steps, dt)
|
||||||
running_mfu = mfu if running_mfu == -1.0 else 0.9*running_mfu + 0.1*mfu
|
running_mfu = mfu if running_mfu == -1.0 else 0.9*running_mfu + 0.1*mfu
|
||||||
|
Loading…
Reference in New Issue
Block a user