mirror of
				https://github.com/osmarks/nanogpt-experiments.git
				synced 2025-10-31 07:13:01 +00:00 
			
		
		
		
	bunch of plumbing of bias all around. measuring bias=False to be about 6% faster
This commit is contained in:
		
							
								
								
									
										2
									
								
								bench.py
									
									
									
									
									
								
							
							
						
						
									
										2
									
								
								bench.py
									
									
									
									
									
								
							| @@ -11,6 +11,7 @@ from model import GPTConfig, GPT | ||||
| # ----------------------------------------------------------------------------- | ||||
| batch_size = 8 | ||||
| block_size = 1024 | ||||
| bias = True | ||||
| seed = 1337 | ||||
| device = 'cuda' # examples: 'cpu', 'cuda', 'cuda:0', 'cuda:1', etc. | ||||
| dtype = 'bfloat16' # 'float32' or 'bfloat16' or 'float16' | ||||
| @@ -50,6 +51,7 @@ gptconf = GPTConfig( | ||||
|     block_size = block_size, # how far back does the model look? i.e. context size | ||||
|     n_layer = 12, n_head = 12, n_embd = 768, # size of the model | ||||
|     dropout = 0, # for determinism | ||||
|     bias = bias, | ||||
| ) | ||||
| model = GPT(gptconf) | ||||
| model.to(device) | ||||
|   | ||||
							
								
								
									
										4
									
								
								model.py
									
									
									
									
									
								
							
							
						
						
									
										4
									
								
								model.py
									
									
									
									
									
								
							| @@ -108,7 +108,7 @@ class GPTConfig: | ||||
|     n_layer: int = 12 | ||||
|     n_head: int = 12 | ||||
|     n_embd: int = 768 | ||||
|     dropout: float = 0.1 | ||||
|     dropout: float = 0.0 | ||||
|     bias: bool = True # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster | ||||
|  | ||||
| class GPT(nn.Module): | ||||
| @@ -215,7 +215,7 @@ class GPT(nn.Module): | ||||
|         # later, by calling crop_block_size() | ||||
|  | ||||
|         # create a from-scratch initialized minGPT model | ||||
|         config = GPTConfig(block_size=1024, **config_args) | ||||
|         config = GPTConfig(block_size=1024, bias=True, **config_args) # note: force bias=True, as in gpt2 models | ||||
|         model = GPT(config) | ||||
|         sd = model.state_dict() | ||||
|  | ||||
|   | ||||
							
								
								
									
										5
									
								
								train.py
									
									
									
									
									
								
							
							
						
						
									
										5
									
								
								train.py
									
									
									
									
									
								
							| @@ -53,6 +53,7 @@ n_layer = 12 | ||||
| n_head = 12 | ||||
| n_embd = 768 | ||||
| dropout = 0.0 # for pretraining 0 is good, for finetuning try 0.1+ | ||||
| bias = False # do we use bias inside LayerNorm and Linear layers? | ||||
| # adamw optimizer | ||||
| learning_rate = 6e-4 # max learning rate | ||||
| max_iters = 600000 # total number of training iterations | ||||
| @@ -129,7 +130,8 @@ else: | ||||
|     vocab_size = 50257 | ||||
|  | ||||
| # model init | ||||
| model_args = dict(n_layer = n_layer, n_head = n_head, n_embd = n_embd, block_size = block_size, dropout = dropout, vocab_size = vocab_size) | ||||
| model_args = dict(n_layer=n_layer, n_head=n_head, n_embd=n_embd, block_size=block_size, | ||||
|                   dropout=dropout, vocab_size=vocab_size, bias=bias) | ||||
| if init_from == 'scratch': | ||||
|     # init a new model from scratch | ||||
|     print("Initializing a new model from scratch") | ||||
| @@ -158,6 +160,7 @@ elif init_from == 'resume': | ||||
|     best_val_loss = checkpoint['best_val_loss'] | ||||
| elif init_from.startswith('gpt2'): | ||||
|     print(f"Initializing from OpenAI GPT-2 weights: {init_from}") | ||||
|     assert bias, "GPT-2 models have bias, so we can't use bias=False" | ||||
|     # initialize from OpenAI GPT-2 weights | ||||
|     override_args = dict(dropout=dropout) | ||||
|     model = GPT.from_pretrained(init_from, override_args) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Andrej Karpathy
					Andrej Karpathy