1
0
mirror of https://github.com/osmarks/nanogpt-experiments.git synced 2024-11-10 20:09:58 +00:00
nanogpt-experiments/scaling_laws.ipynb

443 lines
126 KiB
Plaintext

{
"cells": [
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"Reproducing some results from [Chinchilla](https://arxiv.org/pdf/2203.15556.pdf). I can't get the numbers to match exactly, please open an Issue if you understand why the results are off by a bit. Current running hypothesis is that this is because I am using the FLOPs = 6\\*N\\*D formula, instead of taking all the Gopher sizes and calculating their FLOPs individually, and using that to interpolate?"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"%matplotlib inline"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"gpt2 = dict(\n",
" seq_len = 1024,\n",
" vocab_size = 50257,\n",
" d_model = 768,\n",
" num_heads = 12,\n",
" num_layers = 12,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"123.653376"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"def params(seq_len, vocab_size, d_model, num_heads, num_layers, ffw_size=4):\n",
" \"\"\" Given GPT config calculate total number of parameters\"\"\" \n",
" # token and position embeddings\n",
" embeddings = d_model * vocab_size + d_model * seq_len\n",
" # transformer blocks\n",
" attention = 3*d_model**2 + 3*d_model # weights and biases\n",
" attproj = d_model**2 + d_model\n",
" ffw = d_model*(ffw_size*d_model) + ffw_size*d_model\n",
" ffwproj = ffw_size*d_model*d_model + d_model\n",
" layernorms = 2*2*d_model\n",
" # dense\n",
" ln_f = 2*d_model\n",
" dense = d_model*vocab_size # note: no bias here\n",
" # note: embeddings are not included in the param count!\n",
" total_params = num_layers*(attention + attproj + ffw + ffwproj + layernorms) + ln_f + dense\n",
" return total_params\n",
"\n",
"params(**gpt2)/1e6 # OpenAI reports gpt2 (small) as having 124M params, good."
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"766.006788096"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"def flops(seq_len, vocab_size, d_model, num_heads, num_layers, ffw_size=4):\n",
" \"\"\" \n",
" Given GPT config calculate total number of FLOPs, see Chinchilla \n",
" paper Appendix F as reference: https://arxiv.org/pdf/2203.15556.pdf\n",
" \"\"\" \n",
" key_size = d_model // num_heads\n",
"\n",
" # embeddings\n",
" embeddings = 2 * seq_len * vocab_size * d_model\n",
"\n",
" # attention\n",
" # key, query, value projections\n",
" attention = 2 * 3 * seq_len * d_model * (key_size * num_heads)\n",
" # key @ query logits\n",
" attlogits = 2 * seq_len * seq_len * (key_size * num_heads)\n",
" # softmax\n",
" attsoftmax = 3 * num_heads * seq_len * seq_len # TODO why?\n",
" # softmax @ value reductions\n",
" attvalue = 2 * seq_len * seq_len * (key_size * num_heads)\n",
" # final linear\n",
" attlinear = 2 * seq_len * (key_size * num_heads) * d_model\n",
" att = attention + attlogits + attsoftmax + attvalue + attlinear\n",
" # feed forward\n",
" dense = 2 * seq_len * (d_model * ffw_size + d_model * ffw_size)\n",
"\n",
" # logits\n",
" logits = 2 * seq_len * d_model * vocab_size\n",
"\n",
" forward_flops = embeddings + num_layers * (att + dense) + logits\n",
" backward_flops = 2 * forward_flops # as in Kaplan et al. 2020\n",
" total_flops = forward_flops + backward_flops\n",
"\n",
" return total_flops\n",
"\n",
"flops(**gpt2)/1e9"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"params: 69.72M\n",
"approx flops: 428.34B\n",
"chinchilla flops: 434.11B\n",
"ratio (chinchilla / approx): 1.01\n"
]
}
],
"source": [
"# Reproduce Table A4 from Chinchilla paper Appendix, \n",
"# comparing accurate flops above to approximate flops F = 6*N*D\n",
"# note Chinchilla uses vocab_size = 32K\n",
"\n",
"chin_73M = dict(seq_len = 1024, vocab_size = 32000, d_model = 640, num_heads = 10, num_layers = 10)\n",
"args = chin_73M\n",
"\n",
"D = 1024 # dataset size, cancels anyway\n",
"N = params(**args) \n",
"F = flops(**args)\n",
"\n",
"approx_flops = 6*D*N\n",
"chinch_flops = F * (float(D) / args['seq_len'])\n",
"\n",
"print(f\"params: {N/1e6:.2f}M\")\n",
"print(f\"approx flops: {approx_flops/1e9:.2f}B\")\n",
"print(f\"chinchilla flops: {chinch_flops/1e9:.2f}B\")\n",
"print(f\"ratio (chinchilla / approx): {chinch_flops / approx_flops:.2f}\")"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"Well this is awkward because Chinchilla paper claims that number of params for these args are 73M (we see only 70M), and the ratio is supposed to be 1.03 but we get 1.01. TODO stare at more..."
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## scaling laws"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<matplotlib.colorbar.Colorbar at 0x7fd5bc48bb80>"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "",
"text/plain": [
"<Figure size 1500x500 with 4 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"def L(N, D):\n",
" \"\"\" \n",
" Approximates loss given N parameters and D dataset size (in tokens),\n",
" per Chinchilla paper.\n",
" \"\"\"\n",
" E = 1.69 # entropy of natural language, limit of infinite model on infinite data\n",
" A = 406.4\n",
" B = 410.7\n",
" alpha = 0.34\n",
" beta = 0.28\n",
" return A / (N ** alpha) + B / (D ** beta) + E\n",
"\n",
"ns = 10 ** np.arange(7, 11, step=2**-4) # model sizes from 10M to 100B\n",
"ds = 10 ** np.arange(9, 12, step=2**-4) # dataset sizes from 1B to 1T\n",
"plt.figure(figsize=(15, 5))\n",
"# plot a heatmap of loss as a function of model size and dataset size\n",
"plt.subplot(121)\n",
"plt.imshow(np.array([[L(n, d) for d in ds] for n in ns]), extent=[9, 12, 7, 11], origin='lower')\n",
"plt.xlabel('log10(dataset size)')\n",
"plt.ylabel('log10(model size)')\n",
"plt.title('loss')\n",
"plt.colorbar()\n",
"# plot the compute for each point, which is a deterministic function: flops = 6*N*D\n",
"plt.subplot(122)\n",
"plt.imshow(np.log10(np.array([[6*n*d for d in ds] for n in ns])), extent=[9, 12, 7, 11], origin='lower')\n",
"plt.xlabel('log10(dataset size)')\n",
"plt.ylabel('log10(model size)')\n",
"plt.title('log10 flops')\n",
"plt.colorbar()"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"Ok so given any N,D we can estimate both: 1) the loss, and 2) the total flops. Now we want to solve the following problem: Given a specific budget of flops C, find: N_opt, D_opt = argmin_{FLOPs(N,D) = C} L(N, D). i.e. how big of a model should we train and for how many tokens?"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"best model size: 316.23M\n",
"best dataset size: 10.12B\n"
]
},
{
"data": {
"text/plain": [
"Text(0, 0.5, 'loss')"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "",
"text/plain": [
"<Figure size 300x300 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"c = 1.92e19 # target compute budget (usually know this because we know how many GPU for how long go brrr)\n",
"# sweep model sizes from 10M to 100B\n",
"ns = 10 ** np.arange(7, 11, step=2**-4)\n",
"# using C = 6*N*D, solve for D that maintains the compute budget c\n",
"ds = c / (6 * ns)\n",
"# evaluate the loss in each case\n",
"losses = L(ns, ds)\n",
"# find the argmin\n",
"best = np.argmin(losses)\n",
"print(f\"best model size: {ns[best]/1e6:.2f}M\")\n",
"print(f\"best dataset size: {ds[best]/1e9:.2f}B\")\n",
"# plot the loss\n",
"plt.figure(figsize=(3,3))\n",
"plt.plot(ns, losses)\n",
"plt.xscale('log')\n",
"# plot a vertical bar at the best model size\n",
"plt.axvline(ns[best], color='red')\n",
"plt.xlabel('model size')\n",
"plt.ylabel('loss')"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"In the plot above, basically the models on the left of best are too small and trained for too long. The models on the right of best are way too large and trained for too little. The model at the red line is just right.\n",
"\n",
"Now, the Chinchilla paper says that best model size is 400M params and 8B tokens, so this disagrees and there is some calculations problem. TODO figure out and fix..."
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"compute budget 1.000000e+17: best model size: 29.43M, best dataset size: 0.57B\n",
"compute budget 1.778279e+17: best model size: 36.52M, best dataset size: 0.81B\n",
"compute budget 3.162278e+17: best model size: 48.70M, best dataset size: 1.08B\n",
"compute budget 5.623413e+17: best model size: 60.43M, best dataset size: 1.55B\n",
"compute budget 1.000000e+18: best model size: 80.58M, best dataset size: 2.07B\n",
"compute budget 1.778279e+18: best model size: 107.46M, best dataset size: 2.76B\n",
"compute budget 3.162278e+18: best model size: 133.35M, best dataset size: 3.95B\n",
"compute budget 5.623413e+18: best model size: 177.83M, best dataset size: 5.27B\n",
"compute budget 1.000000e+19: best model size: 220.67M, best dataset size: 7.55B\n",
"compute budget 1.778279e+19: best model size: 294.27M, best dataset size: 10.07B\n",
"compute budget 3.162278e+19: best model size: 392.42M, best dataset size: 13.43B\n",
"compute budget 5.623413e+19: best model size: 486.97M, best dataset size: 19.25B\n",
"compute budget 1.000000e+20: best model size: 649.38M, best dataset size: 25.67B\n",
"compute budget 1.778279e+20: best model size: 865.96M, best dataset size: 34.23B\n",
"compute budget 3.162278e+20: best model size: 1074.61M, best dataset size: 49.05B\n",
"compute budget 5.623413e+20: best model size: 1433.01M, best dataset size: 65.40B\n",
"compute budget 1.000000e+21: best model size: 1778.28M, best dataset size: 93.72B\n",
"compute budget 1.778279e+21: best model size: 2371.37M, best dataset size: 124.98B\n",
"compute budget 3.162278e+21: best model size: 3162.28M, best dataset size: 166.67B\n",
"compute budget 5.623413e+21: best model size: 3924.19M, best dataset size: 238.84B\n",
"compute budget 1.000000e+22: best model size: 5232.99M, best dataset size: 318.49B\n",
"compute budget 1.778279e+22: best model size: 6493.82M, best dataset size: 456.40B\n",
"compute budget 3.162278e+22: best model size: 8659.64M, best dataset size: 608.62B\n",
"compute budget 5.623413e+22: best model size: 11547.82M, best dataset size: 811.61B\n",
"compute budget 1.000000e+23: best model size: 14330.13M, best dataset size: 1163.05B\n",
"compute budget 1.778279e+23: best model size: 19109.53M, best dataset size: 1550.95B\n",
"compute budget 3.162278e+23: best model size: 23713.74M, best dataset size: 2222.54B\n",
"compute budget 5.623413e+23: best model size: 31622.78M, best dataset size: 2963.80B\n",
"compute budget 1.000000e+24: best model size: 42169.65M, best dataset size: 3952.29B\n",
"compute budget 1.778279e+24: best model size: 52329.91M, best dataset size: 5663.68B\n",
"compute budget 3.162278e+24: best model size: 69783.06M, best dataset size: 7552.64B\n",
"compute budget 5.623413e+24: best model size: 93057.20M, best dataset size: 10071.61B\n",
"compute budget 1.000000e+25: best model size: 115478.20M, best dataset size: 14432.74B\n",
"compute budget 1.778279e+25: best model size: 153992.65M, best dataset size: 19246.37B\n",
"compute budget 3.162278e+25: best model size: 191095.30M, best dataset size: 27580.28B\n",
"compute budget 5.623413e+25: best model size: 254829.67M, best dataset size: 36778.90B\n"
]
},
{
"data": {
"text/plain": [
"<matplotlib.legend.Legend at 0x7fd5bc240070>"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "",
"text/plain": [
"<Figure size 1000x300 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# sweep over compute budgets from 1e17 to 1e26\n",
"cs = 10 ** np.arange(17, 26, step=2**-2)\n",
"best_ns = []\n",
"best_ds = []\n",
"for c in cs:\n",
" ns = 10 ** np.arange(7, 14, step=2**-5)\n",
" ds = c / (6 * ns)\n",
" losses = L(ns, ds)\n",
" best = np.argmin(losses)\n",
" best_ns.append(ns[best])\n",
" best_ds.append(ds[best])\n",
" print(f\"compute budget {c:e}: best model size: {ns[best]/1e6:.2f}M, best dataset size: {ds[best]/1e9:.2f}B\")\n",
"\n",
"# plot both the model size and dataset size as a function of compute budget, on one plot\n",
"plt.figure(figsize=(10,3))\n",
"plt.plot(cs, best_ns, label='model size')\n",
"plt.plot(cs, best_ds, label='dataset size')\n",
"plt.xscale('log')\n",
"plt.yscale('log')\n",
"plt.xlabel('compute budget')\n",
"plt.ylabel('model size / dataset size')\n",
"plt.grid(True, which=\"both\", ls=\"-\", color='k', alpha=0.2)\n",
"plt.legend()\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "pytorch2",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.8"
},
"orig_nbformat": 4,
"vscode": {
"interpreter": {
"hash": "7f5833218766b48e6e35e4452ee875aac0e2188d05bbe5298f2c62b79f08b222"
}
}
},
"nbformat": 4,
"nbformat_minor": 2
}