From d562b3e550ed8806b2b08bfddf637bbde1559965 Mon Sep 17 00:00:00 2001 From: Andrej Karpathy Date: Thu, 5 Jan 2023 00:44:35 +0000 Subject: [PATCH] shuttling the poor mans configurator aside into its own file and adding it to all of train,sample,bench. because i am leaving args in globals() so i can avoid having to prepend every single variable with an args., i have to exec the configurator and the optional configs. so we're left with something very gross by standard convention but also quite simple and functional. *ducks* --- README.md | 3 +-- bench.py | 13 ++++++++----- configurator.py | 47 +++++++++++++++++++++++++++++++++++++++++++++++ sample.py | 2 +- train.py | 35 ++--------------------------------- 5 files changed, 59 insertions(+), 41 deletions(-) create mode 100644 configurator.py diff --git a/README.md b/README.md index a78374d..8f99d70 100644 --- a/README.md +++ b/README.md @@ -71,7 +71,7 @@ I briefly tried finetuning gpt2 a bit more on our OWT and didn't notice dramatic For an example of how to finetune a GPT on new text go to `data/shakespeare` and look at `prepare.py` to download the tiny shakespeare dataset and render it into a `train.bin` and `val.bin`. Unlike OpenWebText this will run in seconds. Finetuning takes very little time, e.g. on a single GPU just a few minutes. Run an example finetuning like: ``` -$ python train.py finetune_shakespeare +$ python train.py config/finetune_shakespeare.py ``` This will load the config parameter overrides in `config/finetune_shakespeare.py` (I didn't tune them much though). Basically, we initialize from a GPT2 checkpoint with `init_from` and train as normal, except shorter and with a small learning rate. The best checkpoint (lowest validation loss) will be in the `out_dir` directory, e.g. in `out-shakespeare` by default, per the config file. You can then run the code in `sample.py` to generate infinite Shakespeare. Note that you'll have to edit it to point to the correct `out_dir`. @@ -102,7 +102,6 @@ Features / APIs - Add back fp16 support? (would need to also add back gradient scaler) - Add CPU support - Finetune the finetuning script, I think the hyperparams are not great -- Replace poor man's configurator, and make sample.py configurable... - Report and track other metrics e.g. perplexity, num_tokens, MFU, ... - Eval zero-shot perplexities on PTB, WikiText, other related benchmarks diff --git a/bench.py b/bench.py index 840dcc5..400d644 100644 --- a/bench.py +++ b/bench.py @@ -7,16 +7,19 @@ import time import torch from model import GPTConfig, GPT +# ----------------------------------------------------------------------------- device = 'cuda' +batch_size = 8 +block_size = 1024 +compile = True +exec(open('configurator.py').read()) # overrides from command line or config file +# ----------------------------------------------------------------------------- + +dtype = torch.bfloat16 # todo make configurable torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn torch.manual_seed(1337) -batch_size = 8 -block_size = 1024 -dtype = torch.bfloat16 -compile = True - # data loading init real_data = True if real_data: diff --git a/configurator.py b/configurator.py new file mode 100644 index 0000000..a8bba95 --- /dev/null +++ b/configurator.py @@ -0,0 +1,47 @@ +""" +Poor Man's Configurator. Probably a terrible idea. Example usage: +$ python train.py config/override_file.py --batch_size=32 +this will first run config/override_file.py, then override batch_size to 32 + +The code in this file will be run as follows from e.g. train.py: +>>> exec(open('configurator.py').read()) + +So it's not a Python module, it's just shuttling this code away from train.py +The code in this script then overrides the globals() + +I know people are not going to love this, I just really dislike configuration +complexity and having to prepend config. to every single variable. If someone +comes up with a better simple Python solution I am all ears. +""" + +import sys +from ast import literal_eval + +for arg in sys.argv[1:]: + if '=' not in arg: + # assume it's the name of a config file + assert not arg.startswith('--') + config_file = arg + print(f"Overriding config with {config_file}:") + with open(config_file) as f: + print(f.read()) + exec(open(config_file).read()) + else: + # assume it's a --key=value argument + assert arg.startswith('--') + key, val = arg.split('=') + key = key[2:] + if key in globals(): + try: + # attempt to eval it it (e.g. if bool, number, or etc) + attempt = literal_eval(val) + except (SyntaxError, ValueError): + # if that goes wrong, just use the string + attempt = val + # ensure the types match ok + assert type(attempt) == type(globals()[key]) + # cross fingers + print(f"Overriding: {key} = {attempt}") + globals()[key] = attempt + else: + raise ValueError(f"Unknown config key: {key}") diff --git a/sample.py b/sample.py index e245efc..7bea23c 100644 --- a/sample.py +++ b/sample.py @@ -7,7 +7,6 @@ import tiktoken from model import GPTConfig, GPT # ----------------------------------------------------------------------------- -# todo make these overridable like in train.py out_dir = 'out' device = 'cuda:2' compile = False @@ -17,6 +16,7 @@ max_new_tokens = 500 # number of tokens generated in each sample temperature = 0.8 # higher temperature (up to 1) is more random, lower (down to 0) means more greedy top_k = 200 # retain only the top_k most likely tokens, clamp others to have 0 probability seed = 1337 +exec(open('configurator.py').read()) # overrides from command line or config file # ----------------------------------------------------------------------------- torch.manual_seed(seed) diff --git a/train.py b/train.py index e849d88..2fbb231 100644 --- a/train.py +++ b/train.py @@ -13,7 +13,6 @@ import os import sys import time import math -from ast import literal_eval import wandb import numpy as np @@ -24,7 +23,7 @@ from torch.distributed import init_process_group, destroy_process_group from model import GPTConfig, GPT # ----------------------------------------------------------------------------- -# default config values +# default config values designed to train a gpt2 (124M) on OpenWebText # I/O out_dir = 'out' eval_interval = 2000 @@ -62,37 +61,7 @@ min_lr = 6e-5 # minimum learning rate, should be ~= learning_rate/10 per Chinchi backend = 'nccl' # 'nccl', 'gloo', etc. compile = True # use PyTorch 2.0 to compile the model to be faster # ----------------------------------------------------------------------------- -# poor man's Configurator. Potentially a bad idea. Example usage: -# $ python train.py override_file --batch_size=32 -# this will first run config/override_file.py, then override batch_size to 32 -for arg in sys.argv[1:]: - if '=' not in arg: - # assume it's the name of a config file - assert not arg.startswith('--') - config_file = os.path.join('config', arg + '.py') - print(f"Overriding config with {config_file}:") - with open(config_file) as f: - print(f.read()) - exec(open(config_file).read()) - else: - # assume it's a --key=value argument - assert arg.startswith('--') - key, val = arg.split('=') - key = key[2:] - if key in globals(): - try: - # attempt to eval it it (e.g. if bool, number, or etc) - attempt = literal_eval(val) - except (SyntaxError, ValueError): - # if that goes wrong, just use the string - attempt = val - # ensure the types match ok - assert type(attempt) == type(globals()[key]) - # cross fingers - print(f"Overriding: {key} = {attempt}") - globals()[key] = attempt - else: - raise ValueError(f"Unknown config key: {key}") +exec(open('configurator.py').read()) # overrides from command line or config file # ----------------------------------------------------------------------------- ddp = int(os.environ.get('LOCAL_RANK', -1)) != -1 # is this a ddp run? if ddp: