mirror of
				https://github.com/osmarks/nanogpt-experiments.git
				synced 2025-10-26 12:57:41 +00:00 
			
		
		
		
	add torch.compile by default, shows almost 1.8X improvement in throughput nice
This commit is contained in:
		| @@ -8,8 +8,7 @@ The simplest, fastest repository for training/finetuning medium-sized GPTs. It's | ||||
| Dependencies: | ||||
|  | ||||
| - [pytorch](https://pytorch.org) <3 | ||||
| - numpy <3 | ||||
| - `pip install datasets` for huggingface datasets <3 | ||||
| - `pip install datasets` for huggingface datasets <3 (if you want to download + preprocess OpenWebText) | ||||
| - `pip install tiktoken` for OpenAI's fast bpe code <3 | ||||
| - `pip install wandb` for optional logging <3 | ||||
|  | ||||
| @@ -68,6 +67,10 @@ I briefly tried finetuning gpt2 a bit more on our OWT and didn't notice dramatic | ||||
|  | ||||
| For model benchmarking `bench.py` might be useful. It's identical what happens in the meat of the training loop of `train.py`, but omits much of the other complexities. | ||||
|  | ||||
| # efficiency notes | ||||
|  | ||||
| Code by default now uses [PyTorch 2.0](https://pytorch.org/get-started/pytorch-2.0/). At the time of writing (Dec 29, 2022) this makes `torch.compile()` available in the nightly release. The improvement from the one line of code is noticeable, e.g. cutting down iteration time from ~250ms / iter to 135ms / iter. Nice work PyTorch team! | ||||
|  | ||||
| ## todos | ||||
|  | ||||
| A few that I'm aware of, other than the ones mentioned in code: | ||||
|   | ||||
							
								
								
									
										7
									
								
								bench.py
									
									
									
									
									
								
							
							
						
						
									
										7
									
								
								bench.py
									
									
									
									
									
								
							| @@ -14,7 +14,8 @@ torch.manual_seed(1337) | ||||
|  | ||||
| batch_size = 8 | ||||
| block_size = 1024 | ||||
| dtype = torch.float16 | ||||
| dtype = torch.bfloat16 | ||||
| compile_model = True | ||||
|  | ||||
| # data loading init | ||||
| real_data = True | ||||
| @@ -46,6 +47,10 @@ model.to(device) | ||||
|  | ||||
| optimizer = model.configure_optimizers(weight_decay=1e-2, learning_rate=1e-4, betas=(0.9, 0.95)) | ||||
|  | ||||
| if compile_model: | ||||
|     print("Compiling model...") | ||||
|     model = torch.compile(model) # pytorch 2.0 | ||||
|  | ||||
| profile = False # use pytorch profiler, or just simple benchmarking? | ||||
| if profile: | ||||
|     # useful docs on pytorch profiler: | ||||
|   | ||||
| @@ -21,6 +21,7 @@ model = GPT(gptconf) | ||||
| model.load_state_dict(checkpoint['model']) | ||||
| model.eval() | ||||
| model.to(device) | ||||
| model = torch.compile(model) # requires PyTorch 2.0 | ||||
|  | ||||
| enc = tiktoken.get_encoding("gpt2") | ||||
| #start = enc.encode("\n") | ||||
|   | ||||
							
								
								
									
										7
									
								
								train.py
									
									
									
									
									
								
							
							
						
						
									
										7
									
								
								train.py
									
									
									
									
									
								
							| @@ -59,6 +59,7 @@ lr_decay_iters = 320000 # how many steps to decay the learning rate for | ||||
| min_lr = 1e-5 # minimum learning rate | ||||
| # DDP settings | ||||
| backend = 'nccl' # 'nccl', 'gloo', etc. | ||||
| compile_model = True # use PyTorch 2.0 to compile the model to be faster | ||||
| # ----------------------------------------------------------------------------- | ||||
| # poor man's Configurator. Potentially a bad idea. Example usage: | ||||
| # $ python train.py override_file --batch_size=32 | ||||
| @@ -156,6 +157,12 @@ optimizer = model.configure_optimizers(weight_decay, learning_rate, betas) | ||||
| if init_from == 'resume': | ||||
|     optimizer.load_state_dict(checkpoint['optimizer']) | ||||
|  | ||||
| # compile the model | ||||
| if compile_model: | ||||
|     print("compiling the model... (takes a ~minute)") | ||||
|     unoptimized_model = model | ||||
|     model = torch.compile(model) # requires PyTorch 2.0 | ||||
|  | ||||
| # wrap model into DDP container | ||||
| if ddp: | ||||
|     model = DDP(model, device_ids=[gpu_id]) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Andrej Karpathy
					Andrej Karpathy