mirror of
				https://github.com/osmarks/nanogpt-experiments.git
				synced 2025-10-31 07:13:01 +00:00 
			
		
		
		
	Merge pull request #277 from apivovarov/is_bf16_supported
Use bf16 only if supported
This commit is contained in:
		
							
								
								
									
										2
									
								
								bench.py
									
									
									
									
									
								
							
							
						
						
									
										2
									
								
								bench.py
									
									
									
									
									
								
							| @@ -15,7 +15,7 @@ bias = False | |||||||
| real_data = True | real_data = True | ||||||
| seed = 1337 | seed = 1337 | ||||||
| device = 'cuda' # examples: 'cpu', 'cuda', 'cuda:0', 'cuda:1', etc. | device = 'cuda' # examples: 'cpu', 'cuda', 'cuda:0', 'cuda:1', etc. | ||||||
| dtype = 'bfloat16' # 'float32' or 'bfloat16' or 'float16' | dtype = 'bfloat16' if torch.cuda.is_bf16_supported() else 'float16' # 'float32' or 'bfloat16' or 'float16' | ||||||
| compile = True # use PyTorch 2.0 to compile the model to be faster | compile = True # use PyTorch 2.0 to compile the model to be faster | ||||||
| profile = False # use pytorch profiler, or just simple benchmarking? | profile = False # use pytorch profiler, or just simple benchmarking? | ||||||
| exec(open('configurator.py').read()) # overrides from command line or config file | exec(open('configurator.py').read()) # overrides from command line or config file | ||||||
|   | |||||||
| @@ -18,7 +18,7 @@ temperature = 0.8 # 1.0 = no change, < 1.0 = less random, > 1.0 = more random, i | |||||||
| top_k = 200 # retain only the top_k most likely tokens, clamp others to have 0 probability | top_k = 200 # retain only the top_k most likely tokens, clamp others to have 0 probability | ||||||
| seed = 1337 | seed = 1337 | ||||||
| device = 'cuda' # examples: 'cpu', 'cuda', 'cuda:0', 'cuda:1', etc. | device = 'cuda' # examples: 'cpu', 'cuda', 'cuda:0', 'cuda:1', etc. | ||||||
| dtype = 'bfloat16' # 'float32' or 'bfloat16' or 'float16' | dtype = 'bfloat16' if torch.cuda.is_bf16_supported() else 'float16' # 'float32' or 'bfloat16' or 'float16' | ||||||
| compile = False # use PyTorch 2.0 to compile the model to be faster | compile = False # use PyTorch 2.0 to compile the model to be faster | ||||||
| exec(open('configurator.py').read()) # overrides from command line or config file | exec(open('configurator.py').read()) # overrides from command line or config file | ||||||
| # ----------------------------------------------------------------------------- | # ----------------------------------------------------------------------------- | ||||||
|   | |||||||
							
								
								
									
										2
									
								
								train.py
									
									
									
									
									
								
							
							
						
						
									
										2
									
								
								train.py
									
									
									
									
									
								
							| @@ -70,7 +70,7 @@ min_lr = 6e-5 # minimum learning rate, should be ~= learning_rate/10 per Chinchi | |||||||
| backend = 'nccl' # 'nccl', 'gloo', etc. | backend = 'nccl' # 'nccl', 'gloo', etc. | ||||||
| # system | # system | ||||||
| device = 'cuda' # examples: 'cpu', 'cuda', 'cuda:0', 'cuda:1' etc., or try 'mps' on macbooks | device = 'cuda' # examples: 'cpu', 'cuda', 'cuda:0', 'cuda:1' etc., or try 'mps' on macbooks | ||||||
| dtype = 'bfloat16' # 'float32', 'bfloat16', or 'float16', the latter will auto implement a GradScaler | dtype = 'bfloat16' if torch.cuda.is_bf16_supported() else 'float16' # 'float32', 'bfloat16', or 'float16', the latter will auto implement a GradScaler | ||||||
| compile = True # use PyTorch 2.0 to compile the model to be faster | compile = True # use PyTorch 2.0 to compile the model to be faster | ||||||
| # ----------------------------------------------------------------------------- | # ----------------------------------------------------------------------------- | ||||||
| config_keys = [k for k,v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))] | config_keys = [k for k,v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))] | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 Andrej
					Andrej