mirror of
				https://github.com/osmarks/nanogpt-experiments.git
				synced 2025-10-31 07:13:01 +00:00 
			
		
		
		
	Merge branch 'master' into grad_accum
This commit is contained in:
		
							
								
								
									
										4
									
								
								.gitignore
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										4
									
								
								.gitignore
									
									
									
									
										vendored
									
									
										Normal file
									
								
							| @@ -0,0 +1,4 @@ | |||||||
|  | .DS_Store | ||||||
|  | .ipynb_checkpoints/ | ||||||
|  | __pycache__/ | ||||||
|  | *.pyc | ||||||
| @@ -19,7 +19,7 @@ Dependencies: | |||||||
| - `pip install datasets` for huggingface datasets <3 (if you want to download + preprocess OpenWebText) | - `pip install datasets` for huggingface datasets <3 (if you want to download + preprocess OpenWebText) | ||||||
| - `pip install tiktoken` for OpenAI's fast BPE code <3 | - `pip install tiktoken` for OpenAI's fast BPE code <3 | ||||||
| - `pip install wandb` for optional logging <3 | - `pip install wandb` for optional logging <3 | ||||||
| - `pip install tqdm` | - `pip install tqdm` <3 | ||||||
|  |  | ||||||
| ## quick start | ## quick start | ||||||
|  |  | ||||||
| @@ -37,7 +37,7 @@ This creates a `train.bin` and `val.bin` in that data directory. Now it is time | |||||||
| $ python train.py config/train_shakespeare_char.py | $ python train.py config/train_shakespeare_char.py | ||||||
| ``` | ``` | ||||||
|  |  | ||||||
| If you peak inside it, you'll see that we're training a GPT with a context size of up to 256 characters, 384 feature channels, and it is a 6-layer Transformer with 6 heads in each layer. On one A100 GPU this training run takes about 3 minutes and the best validation loss is 1.4697. Based on the configuration, the model checkpoints are being written into the `--out_dir` directory `out-shakespeare-char`. So once the training finishes we can sample from the best model by pointing the sampling script at this directory: | If you peek inside it, you'll see that we're training a GPT with a context size of up to 256 characters, 384 feature channels, and it is a 6-layer Transformer with 6 heads in each layer. On one A100 GPU this training run takes about 3 minutes and the best validation loss is 1.4697. Based on the configuration, the model checkpoints are being written into the `--out_dir` directory `out-shakespeare-char`. So once the training finishes we can sample from the best model by pointing the sampling script at this directory: | ||||||
|  |  | ||||||
| ``` | ``` | ||||||
| $ python sample.py --out_dir=out-shakespeare-char | $ python sample.py --out_dir=out-shakespeare-char | ||||||
| @@ -84,7 +84,7 @@ bot thou the sought bechive in that to doth groan you, | |||||||
| No relving thee post mose the wear | No relving thee post mose the wear | ||||||
| ``` | ``` | ||||||
|  |  | ||||||
| Not bad for ~3 minutes on a CPU, for a hint of the right character gestalt. If you're willing to wait longer free to tune the hyperparameters, increase the size of the network, the context length (`--block_size`), the length of training, etc. | Not bad for ~3 minutes on a CPU, for a hint of the right character gestalt. If you're willing to wait longer, feel free to tune the hyperparameters, increase the size of the network, the context length (`--block_size`), the length of training, etc. | ||||||
|  |  | ||||||
| Finally, on Apple Silicon Macbooks and with a recent PyTorch version make sure to add `--device mps` (short for "Metal Performance Shaders"); PyTorch then uses the on-chip GPU that can *significantly* accelerate training (2-3X) and allow you to use larger networks. See [Issue 28](https://github.com/karpathy/nanoGPT/issues/28) for more. | Finally, on Apple Silicon Macbooks and with a recent PyTorch version make sure to add `--device mps` (short for "Metal Performance Shaders"); PyTorch then uses the on-chip GPU that can *significantly* accelerate training (2-3X) and allow you to use larger networks. See [Issue 28](https://github.com/karpathy/nanoGPT/issues/28) for more. | ||||||
|  |  | ||||||
|   | |||||||
| @@ -54,12 +54,16 @@ for split, dset in tokenized.items(): | |||||||
|     filename = os.path.join(os.path.dirname(__file__), f'{split}.bin') |     filename = os.path.join(os.path.dirname(__file__), f'{split}.bin') | ||||||
|     dtype = np.uint16 # (can do since enc.max_token_value == 50256 is < 2**16) |     dtype = np.uint16 # (can do since enc.max_token_value == 50256 is < 2**16) | ||||||
|     arr = np.memmap(filename, dtype=dtype, mode='w+', shape=(arr_len,)) |     arr = np.memmap(filename, dtype=dtype, mode='w+', shape=(arr_len,)) | ||||||
|  |     total_batches = 1024 | ||||||
|  |  | ||||||
|     print(f"writing {filename}...") |  | ||||||
|     idx = 0 |     idx = 0 | ||||||
|     for example in tqdm(dset): |     for batch_idx in tqdm(range(total_batches), desc=f'writing {filename}'): | ||||||
|         arr[idx : idx + example['len']] = example['ids'] |         # Batch together samples for faster write | ||||||
|         idx += example['len'] |         batch = dset.shard(num_shards=total_batches, index=batch_idx, contiguous=True).with_format('numpy') | ||||||
|  |         arr_batch = np.concatenate(batch['ids']) | ||||||
|  |         # Write into mmap | ||||||
|  |         arr[idx : idx + len(arr_batch)] = arr_batch | ||||||
|  |         idx += len(arr_batch) | ||||||
|     arr.flush() |     arr.flush() | ||||||
|  |  | ||||||
| # train.bin is ~17GB, val.bin ~8.5MB | # train.bin is ~17GB, val.bin ~8.5MB | ||||||
|   | |||||||
							
								
								
									
										5
									
								
								model.py
									
									
									
									
									
								
							
							
						
						
									
										5
									
								
								model.py
									
									
									
									
									
								
							| @@ -61,7 +61,7 @@ class CausalSelfAttention(nn.Module): | |||||||
|         B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd) |         B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd) | ||||||
|  |  | ||||||
|         # calculate query, key, values for all heads in batch and move head forward to be the batch dim |         # calculate query, key, values for all heads in batch and move head forward to be the batch dim | ||||||
|         q, k ,v  = self.c_attn(x).split(self.n_embd, dim=2) |         q, k, v  = self.c_attn(x).split(self.n_embd, dim=2) | ||||||
|         k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) |         k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) | ||||||
|         q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) |         q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) | ||||||
|         v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) |         v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) | ||||||
| @@ -69,7 +69,7 @@ class CausalSelfAttention(nn.Module): | |||||||
|         # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T) |         # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T) | ||||||
|         if self.flash: |         if self.flash: | ||||||
|             # efficient attention using Flash Attention CUDA kernels |             # efficient attention using Flash Attention CUDA kernels | ||||||
|             y = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=self.dropout, is_causal=True) |             y = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=self.dropout if self.training else 0, is_causal=True) | ||||||
|         else: |         else: | ||||||
|             # manual implementation of attention |             # manual implementation of attention | ||||||
|             att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) |             att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) | ||||||
| @@ -207,6 +207,7 @@ class GPT(nn.Module): | |||||||
|         self.config.block_size = block_size |         self.config.block_size = block_size | ||||||
|         self.transformer.wpe.weight = nn.Parameter(self.transformer.wpe.weight[:block_size]) |         self.transformer.wpe.weight = nn.Parameter(self.transformer.wpe.weight[:block_size]) | ||||||
|         for block in self.transformer.h: |         for block in self.transformer.h: | ||||||
|  |             if hasattr(block.attn, 'bias'): | ||||||
|                 block.attn.bias = block.attn.bias[:,:,:block_size,:block_size] |                 block.attn.bias = block.attn.bias[:,:,:block_size,:block_size] | ||||||
|  |  | ||||||
|     @classmethod |     @classmethod | ||||||
|   | |||||||
							
								
								
									
										10
									
								
								train.py
									
									
									
									
									
								
							
							
						
						
									
										10
									
								
								train.py
									
									
									
									
									
								
							| @@ -84,6 +84,7 @@ if ddp: | |||||||
|     init_process_group(backend=backend) |     init_process_group(backend=backend) | ||||||
|     ddp_rank = int(os.environ['RANK']) |     ddp_rank = int(os.environ['RANK']) | ||||||
|     ddp_local_rank = int(os.environ['LOCAL_RANK']) |     ddp_local_rank = int(os.environ['LOCAL_RANK']) | ||||||
|  |     ddp_world_size = int(os.environ['WORLD_SIZE']) | ||||||
|     device = f'cuda:{ddp_local_rank}' |     device = f'cuda:{ddp_local_rank}' | ||||||
|     torch.cuda.set_device(device) |     torch.cuda.set_device(device) | ||||||
|     master_process = ddp_rank == 0 # this process will do logging, checkpointing etc. |     master_process = ddp_rank == 0 # this process will do logging, checkpointing etc. | ||||||
| @@ -94,6 +95,9 @@ else: | |||||||
|     # if not ddp, we are running on a single gpu, and one process |     # if not ddp, we are running on a single gpu, and one process | ||||||
|     master_process = True |     master_process = True | ||||||
|     seed_offset = 0 |     seed_offset = 0 | ||||||
|  |     ddp_world_size = 1 | ||||||
|  | tokens_per_iter = gradient_accumulation_steps * ddp_world_size * batch_size * block_size | ||||||
|  | print(f"tokens per iteration will be: {tokens_per_iter:,}") | ||||||
|  |  | ||||||
| if master_process: | if master_process: | ||||||
|     os.makedirs(out_dir, exist_ok=True) |     os.makedirs(out_dir, exist_ok=True) | ||||||
| @@ -190,6 +194,7 @@ scaler = torch.cuda.amp.GradScaler(enabled=(dtype == 'float16')) | |||||||
| optimizer = model.configure_optimizers(weight_decay, learning_rate, (beta1, beta2), device_type) | optimizer = model.configure_optimizers(weight_decay, learning_rate, (beta1, beta2), device_type) | ||||||
| if init_from == 'resume': | if init_from == 'resume': | ||||||
|     optimizer.load_state_dict(checkpoint['optimizer']) |     optimizer.load_state_dict(checkpoint['optimizer']) | ||||||
|  | checkpoint = None # free up memory | ||||||
|  |  | ||||||
| # compile the model | # compile the model | ||||||
| if compile: | if compile: | ||||||
| @@ -288,6 +293,7 @@ while True: | |||||||
|             model.require_backward_grad_sync = (micro_step == gradient_accumulation_steps - 1) |             model.require_backward_grad_sync = (micro_step == gradient_accumulation_steps - 1) | ||||||
|         with ctx: |         with ctx: | ||||||
|             logits, loss = model(X, Y) |             logits, loss = model(X, Y) | ||||||
|  |             loss = loss / gradient_accumulation_steps # scale the loss to account for gradient accumulation | ||||||
|         # immediately async prefetch next batch while model is doing the forward pass on the GPU |         # immediately async prefetch next batch while model is doing the forward pass on the GPU | ||||||
|         X, Y = get_batch('train') |         X, Y = get_batch('train') | ||||||
|         # backward pass, with gradient scaling if training in fp16 |         # backward pass, with gradient scaling if training in fp16 | ||||||
| @@ -307,7 +313,9 @@ while True: | |||||||
|     dt = t1 - t0 |     dt = t1 - t0 | ||||||
|     t0 = t1 |     t0 = t1 | ||||||
|     if iter_num % log_interval == 0 and master_process: |     if iter_num % log_interval == 0 and master_process: | ||||||
|         lossf = loss.item() # loss as float. note: this is a CPU-GPU sync point |         # get loss as float. note: this is a CPU-GPU sync point | ||||||
|  |         # scale up to undo the division above, approximating the true total loss (exact would have been a sum) | ||||||
|  |         lossf = loss.item() * gradient_accumulation_steps | ||||||
|         if local_iter_num >= 5: # let the training loop settle a bit |         if local_iter_num >= 5: # let the training loop settle a bit | ||||||
|             mfu = raw_model.estimate_mfu(batch_size * gradient_accumulation_steps, dt) |             mfu = raw_model.estimate_mfu(batch_size * gradient_accumulation_steps, dt) | ||||||
|             running_mfu = mfu if running_mfu == -1.0 else 0.9*running_mfu + 0.1*mfu |             running_mfu = mfu if running_mfu == -1.0 else 0.9*running_mfu + 0.1*mfu | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 Andrej
					Andrej