1
0
mirror of https://github.com/osmarks/nanogpt-experiments.git synced 2024-11-10 20:09:58 +00:00
nanogpt-experiments/scaling_laws.ipynb

514 lines
230 KiB
Plaintext
Raw Normal View History

{
"cells": [
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
2023-01-06 02:13:04 +00:00
"Trying to reproduce results from [Chinchilla](https://arxiv.org/pdf/2203.15556.pdf):"
]
},
{
"cell_type": "code",
2023-01-06 02:13:04 +00:00
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"%matplotlib inline"
]
},
{
"cell_type": "code",
2023-01-06 02:13:04 +00:00
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"gpt2 = dict(\n",
" seq_len = 1024,\n",
" vocab_size = 50257,\n",
" d_model = 768,\n",
" num_heads = 12,\n",
" num_layers = 12,\n",
")"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## params"
]
},
{
"cell_type": "code",
2023-01-06 02:13:04 +00:00
"execution_count": 3,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"123.653376"
]
},
2023-01-06 02:13:04 +00:00
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"def params(seq_len, vocab_size, d_model, num_heads, num_layers, ffw_size=4):\n",
" \"\"\" Given GPT config calculate total number of parameters\"\"\" \n",
" # token and position embeddings\n",
" embeddings = d_model * vocab_size + d_model * seq_len\n",
" # transformer blocks\n",
" attention = 3*d_model**2 + 3*d_model # weights and biases\n",
" attproj = d_model**2 + d_model\n",
" ffw = d_model*(ffw_size*d_model) + ffw_size*d_model\n",
" ffwproj = ffw_size*d_model*d_model + d_model\n",
" layernorms = 2*2*d_model\n",
" # dense\n",
" ln_f = 2*d_model\n",
" dense = d_model*vocab_size # note: no bias here\n",
" # note: embeddings are not included in the param count!\n",
" total_params = num_layers*(attention + attproj + ffw + ffwproj + layernorms) + ln_f + dense\n",
" return total_params\n",
"\n",
"params(**gpt2)/1e6"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"OpenAI reports gpt2 (small) as having 124M params, so this is a match. Also, loading the OpenAI weights into nanoGPT and then calling `model.parameters()` exactly matches the above number and verifies the implementation."
]
},
{
"cell_type": "code",
2023-01-06 02:13:04 +00:00
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[44000000.0, 512, 2048, 64, 8, 8]"
]
},
2023-01-06 02:13:04 +00:00
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import json\n",
"chinchilla_models_txt = '[[44000000.0, 512, 2048, 64, 8, 8], [57000000.0, 576, 2304, 64, 9, 9], [74000000.0, 640, 2560, 64, 10, 10], [90000000.0, 640, 2560, 64, 10, 13], [106000000.0, 640, 2560, 64, 10, 16], [117000000.0, 768, 3072, 64, 12, 12], [140000000.0, 768, 3072, 64, 12, 15], [163000000.0, 768, 3072, 64, 12, 18], [175000000.0, 896, 3584, 64, 14, 14], [196000000.0, 896, 3584, 64, 14, 16], [217000000.0, 896, 3584, 64, 14, 18], [251000000.0, 1024, 4096, 64, 16, 16], [278000000.0, 1024, 4096, 64, 16, 18], [306000000.0, 1024, 4096, 64, 16, 20], [425000000.0, 1280, 5120, 128, 10, 18], [489000000.0, 1280, 5120, 128, 10, 21], [509000000.0, 1408, 5632, 128, 11, 18], [552000000.0, 1280, 5120, 128, 10, 24], [587000000.0, 1408, 5632, 128, 11, 21], [632000000.0, 1536, 6144, 128, 12, 19], [664000000.0, 1408, 5632, 128, 11, 24], [724000000.0, 1536, 6144, 128, 12, 22], [816000000.0, 1536, 6144, 128, 12, 25], [893000000.0, 1792, 7168, 128, 14, 20], [1018000000.0, 1792, 7168, 128, 14, 23], [1143000000.0, 1792, 7168, 128, 14, 26], [1266000000.0, 2048, 8192, 128, 16, 22], [1424000000.0, 2176, 8704, 128, 17, 22], [1429000000.0, 2048, 8192, 128, 16, 25], [1593000000.0, 2048, 8192, 128, 16, 28], [1609000000.0, 2176, 8704, 128, 17, 25], [1731000000.0, 2304, 9216, 128, 18, 24], [1794000000.0, 2176, 8704, 128, 17, 28], [2007000000.0, 2304, 9216, 128, 18, 28], [2283000000.0, 2304, 9216, 128, 18, 32], [2298000000.0, 2560, 10240, 128, 20, 26], [2639000000.0, 2560, 10240, 128, 20, 30], [2980000000.0, 2560, 10240, 128, 20, 34], [3530000000.0, 2688, 10752, 128, 22, 36], [3802000000.0, 2816, 11264, 128, 22, 36], [4084000000.0, 2944, 11776, 128, 22, 36], [4516000000.0, 3072, 12288, 128, 24, 36], [6796000000.0, 3584, 14336, 128, 28, 40], [9293000000.0, 4096, 16384, 128, 32, 42], [11452000000.0, 4352, 17408, 128, 32, 47], [12295000000.0, 4608, 18432, 128, 36, 44], [12569000000.0, 4608, 18432, 128, 32, 47], [13735000000.0, 4864, 19456, 128, 32, 47], [14940000000.0, 4992, 19968, 128, 32, 49], [16183000000.0, 5120, 20480, 128, 40, 47]]'\n",
"chilchilla_models = json.loads(chinchilla_models_txt) # all 50 models\n",
"chilchilla_models[0] # tuples of params, d_model, ffw_size, kv_size, n_heads, n_layers from Table A9"
]
},
{
"cell_type": "code",
2023-01-06 02:13:04 +00:00
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"our estimated params: 41.6041M, chinchilla params: 44.0000M, d_model: 512, n_heads: 8, n_layers: 8\n",
"our estimated params: 54.3324M, chinchilla params: 57.0000M, d_model: 576, n_heads: 9, n_layers: 9\n",
"our estimated params: 69.7165M, chinchilla params: 74.0000M, d_model: 640, n_heads: 10, n_layers: 10\n",
"our estimated params: 84.4870M, chinchilla params: 90.0000M, d_model: 640, n_heads: 10, n_layers: 13\n",
"our estimated params: 99.2576M, chinchilla params: 106.0000M, d_model: 640, n_heads: 10, n_layers: 16\n"
]
}
],
"source": [
"for m in chilchilla_models[:5]: # look at just first 5 models\n",
" p, d, f, k, h, l = m\n",
" nparams = params(seq_len = 1024, vocab_size = 32000, d_model = d, num_heads = h, num_layers = l)\n",
" print(f\"our estimated params: {nparams/1e6:.4f}M, chinchilla params: {p/1e6:.4f}M, d_model: {d}, n_heads: {h}, n_layers: {l}\")"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"So we are not able to reproduce the parameter counts for the Chinchilla models. The claimed number of parameters is not the number of parameters we compute above. TODO: resolve... \n",
"\n",
"Now turning to FLOPs:"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## flops"
]
},
{
"cell_type": "code",
2023-01-06 02:13:04 +00:00
"execution_count": 6,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"766.006788096"
]
},
2023-01-06 02:13:04 +00:00
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"def flops(seq_len, vocab_size, d_model, num_heads, num_layers, ffw_size=4):\n",
" \"\"\" \n",
" Given GPT config calculate total number of FLOPs, see Chinchilla \n",
" paper Appendix F as reference: https://arxiv.org/pdf/2203.15556.pdf\n",
" \"\"\" \n",
" key_size = d_model // num_heads\n",
"\n",
" # embeddings\n",
" embeddings = 2 * seq_len * vocab_size * d_model\n",
"\n",
" # attention\n",
" # key, query, value projections\n",
" attention = 2 * 3 * seq_len * d_model * (key_size * num_heads)\n",
" # key @ query logits\n",
" attlogits = 2 * seq_len * seq_len * (key_size * num_heads)\n",
" # softmax\n",
2023-01-06 02:13:04 +00:00
" attsoftmax = 3 * num_heads * seq_len * seq_len # 3* is for subtract (max), exp, divide (?)\n",
" # softmax @ value reductions\n",
" attvalue = 2 * seq_len * seq_len * (key_size * num_heads)\n",
" # final linear\n",
" attlinear = 2 * seq_len * (key_size * num_heads) * d_model\n",
" att = attention + attlogits + attsoftmax + attvalue + attlinear\n",
" # feed forward\n",
" dense = 2 * seq_len * (d_model * ffw_size + d_model * ffw_size)\n",
"\n",
" # logits\n",
" logits = 2 * seq_len * d_model * vocab_size\n",
"\n",
" forward_flops = embeddings + num_layers * (att + dense) + logits\n",
" backward_flops = 2 * forward_flops # as in Kaplan et al. 2020\n",
" total_flops = forward_flops + backward_flops\n",
"\n",
" return total_flops\n",
"\n",
"flops(**gpt2)/1e9"
]
},
{
"cell_type": "code",
2023-01-06 02:13:04 +00:00
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"params: 69.72M\n",
"approx flops: 428.34B\n",
"chinchilla flops: 434.11B\n",
"ratio (chinchilla / approx): 1.01\n"
]
}
],
"source": [
"# Now try reproduce Table A4 from Chinchilla paper Appendix, \n",
"# comparing accurate flops above to approximate flops F = 6*N*D\n",
"# note Chinchilla mentions using vocab_size = 32K\n",
"\n",
"args = dict(seq_len = 1024, vocab_size = 32000, d_model = 640, num_heads = 10, num_layers = 10) # chinchilla 73M\n",
"\n",
"D = 1024 # dataset size (cancels anyway, for the purposes of the ratio calculation below)\n",
"N = params(**args)\n",
"F = flops(**args)\n",
"\n",
"approx_flops = 6*D*N # approximate flops\n",
"chinch_flops = F * (float(D) / args['seq_len']) # exact flops according to Chinchilla paper calculations\n",
"\n",
"print(f\"params: {N/1e6:.2f}M\")\n",
"print(f\"approx flops: {approx_flops/1e9:.2f}B\")\n",
"print(f\"chinchilla flops: {chinch_flops/1e9:.2f}B\")\n",
"print(f\"ratio (chinchilla / approx): {chinch_flops / approx_flops:.2f}\")"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"Awkward because the ratio is supposed to be 1.03 but we get 1.01, e.g. for the 73M param model. So once again we don't reproduce the numbers. TODO resolve..."
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## scaling laws"
]
},
{
"cell_type": "code",
2023-01-06 02:13:04 +00:00
"execution_count": 8,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
2023-01-06 02:13:04 +00:00
"<matplotlib.colorbar.Colorbar at 0x7fc09dffc730>"
]
},
2023-01-06 02:13:04 +00:00
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA8UAAAHWCAYAAABe7ytwAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8o6BhiAAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOy9eXwb9bnv//mOdslavMu7HTv74jgLIZAWKCkJUJYWKPTQS+FwoJcboJD7u72Hvs4BSin0dDvcnkPLgVOgO5SwUxpKA4QtqxMncRYn3ld5k619saT5/TGSLNkzskYayUu+b17fVxtp5pnvjO35zjPP83wewrIsCwqFQqFQKBQKhUKhUM5DmNmeAIVCoVAoFAqFQqFQKLMFdYopFAqFQqFQKBQKhXLeQp1iCoVCoVAoFAqFQqGct1CnmEKhUCgUCoVCoVAo5y3UKaZQKBQKhUKhUCgUynkLdYopFAqFQqFQKBQKhXLeQp1iCoVCoVAoFAqFQqGct1CnmEKhUCgUCoVCoVAo5y3UKaZQKBQKhUKhUCgUynkLdYoplAS8+OKLIISgs7NztqdCoVAoFMqCZiGuuT/5yU+waNEiyGQyrF27FgBQXV2N22+/fVbnRaFQ4qFOMYVCoVAoFArlvOFXv/oVbrrpJlRWVoIQktBBHR8fx913343CwkLodDpcdtllOHLkSFLH+dvf/obvfve7uPjii/HCCy/giSeekOgMKBSK1MhnewIUCoVCoVAoFEq2+Ld/+zc4HA5ccMEFGBgYENwuFArh6quvxrFjx/B//s//QUFBAX75y1/i0ksvRWNjIxYvXpzwOB988AEYhsGvf/1rKJVKqU+DQqFICHWKKRQKhUKhUCjnDXv37o1GiXNycgS327VrFz7//HO88soruPHGGwEAX//617FkyRI88sgj+OMf/5jwOENDQ9BoNNQhplDmATR9mkIRyS9/+UusXLkSKpUKpaWl2LFjB8bHx+O2OXfuHG644QaYzWao1WqUl5fjlltugc1mi27z/vvvY8uWLTCZTMjJycHSpUvxve99L8tnQ6FQKBTK3CaZdRcAnn76aSxatAgajQYXXHABPvnkE1x66aW49NJL47arqqoCIWTG4+7atQvFxcX42te+Fv2ssLAQX//61/Hmm2/C5/MJ7ksIwQsvvACXywVCCAghePHFFwW3b29vx0033YS8vDxotVpceOGF+Mtf/hK3zUcffQRCCF5++WV873vfg9lshk6nw7XXXouenp64bZN5DqFQKJPQSDGFIoJHH30U3//+97F161bcc889aGlpwa9+9SscOnQIn332GRQKBfx+P7Zt2wafz4f77rsPZrMZfX19eOeddzA+Pg6j0YiTJ0/iK1/5CtasWYPHHnsMKpUKra2t+Oyzz2b7FCkUCoVCmTMks+4CXJ3wvffeiy984Qt48MEH0dnZieuvvx65ubkoLy9P6dhHjx7FunXrwDDxMaQLLrgAzz77LM6ePYvVq1fz7vu73/0Ozz77LA4ePIj//u//BgBcdNFFvNsODg7ioosugtvtxv3334/8/Hz85je/wbXXXotdu3bhq1/9atz2P/zhD0EIwf/9v/8XQ0NDeOqpp7B161Y0NTVBo9Ek9RxCoVCmwFIoFEFeeOEFFgDb0dHBDg0NsUqlkr3iiivYYDAY3eY///M/WQDs888/z7Isyx49epQFwL7yyiuCdv/93/+dBcAODw9n/BwoFAqFQpkPxK65LMsmve76fD42Pz+f3bhxIzsxMRHd7sUXX2QBsJdccongMXU6Hfutb31L8Lt//Md/nPb5X/7yFxYAu3v37oTn861vfYvV6XTTPq+qqoo75gMPPMACYD/55JPoZw6Hg62pqWGrq6uj5/7hhx+yANiysjLWbrdHt/3zn//MAmD/3//7fyzLJvccQqFQ4qHp0xRKkvz973+H3+/HAw88EPfW+K677oLBYIimOUXewL733ntwu928tkwmEwDgzTffRCgUyuzEKRQKhUKZhyS77h4+fBijo6O46667IJdPJkHeeuutyM3NTfn4Ho8HKpVq2udqtTr6vRS8++67uOCCC7Bly5boZzk5Obj77rvR2dmJU6dOxW1/2223Qa/XR/994403oqSkBO+++y6A5J5DKBRKPNQpplCSpKurCwCwdOnSuM+VSiUWLVoU/b6mpgY7d+7Ef//3f6OgoADbtm3D008/HVfHc/PNN+Piiy/GP/3TP6G4uBi33HIL/vznP1MHmUKhUCiUMMmuu5H/rauri9tOLpejuro65eNrNBreumGv1xv9Xgq6urqmnSMALF++PPp9LFNVrwkhqKuri/Z3TuY5hEKhxEOdYgolA/zsZz/D8ePH8b3vfQ8ejwf3338/Vq5cid7eXgDcQvrxxx/j73//O/7H//gfOH78OG6++WZ8+ctfRjAYnOXZUygUCoVCKSkp4W3ZFPmstLQ021NKmpmeQygUSjzUKaZQkqSqqgoA0NLSEve53+9HR0dH9PsIq1evxr/8y7/g448/xieffIK+vj4888wz0e8ZhsHll1+On//85zh16hR++MMf4oMPPsCHH36Y+ZOhUCgUCmWOk+y6G/nf1tbWuO0CgUA0epoKa9euxZEjR6ZlcR04cABarRZLlixJ2XYsVVVV084RAM6cORP9PpZz587F/ZtlWbS2tk6Lis/0HEKhUCahTjGFkiRbt26FUqnEL37xC7AsG/3817/+NWw2G66++moAgN1uRyAQiNt39erVYBgmmoZltVqn2V+7di0AJGzxQKFQKBTK+UKy6+6GDRuQn5+P5557Lm79/cMf/oCxsbGUj3/jjTdicHAQr732WvSzkZERvPLKK7jmmmt4641T4aqrrsLBgwexb9++6GculwvPPvssqqursWLFirjtf/vb38LhcET/vWvXLgwMDODKK68EkNxzCIVCiYe2ZKJQkqSwsBAPPfQQvv/972P79u249tpr0dLSgl/+8pfYuHEjvvnNbwIAPvjgA9x777246aabsGTJEgQCAfzud7+DTCbDDTfcAAB47LHH8PHHH+Pqq69GVVUVhoaG8Mtf/hLl5eVxQhsUCoVCoZyvJLvuKpVKPProo7jvvvvwpS99CV//+tfR2dmJF198EbW1tdN6Er/99ts4duwYAGBiYgLHjx/H448/DgC49tprsWbNGgCcU3zhhRfijjvuwKlTp1BQUIBf/vKXCAaD+P73vy/Zef7zP/8z/vSnP+HKK6/E/fffj7y8PPzmN79BR0cHXn311WktofLy8rBlyxbccccdGBwcxFNPPYW6ujrcddddAJJ7DqFQKFOYZfVrCmVOM7U9BMtyrSCWLVvGKhQKtri4mL3nnnvYsbGx6Pft7e3sP/7jP7K1tbWsWq1m8/Ly2Msuu4z9+9//Ht1mz5497HXXXceWlpaySqWSLS0tZb/xjW+wZ8+ezeLZUSgUCoUyd+Bbc1l25nU3wi9+8Qu2qqqKValU7AUXXMB+9tln7Pr169nt27fHbfetb32LBcA7XnjhhbhtrVYre+edd7L5+fmsVqtlL7nkEvbQoUNJnU+yLZlYlmXb2trYG2+8kTWZTKxarWYvuOAC9p133onbJtKS6U9/+hP70EMPsUVFRaxGo2GvvvpqtqurK7pdMs8hFAolHsKyMfkoFAqFQqFQKBTKAiAUCqGwsBBf+9rX8Nxzz832dNLmo48+wmWXXYZXXnkFN95442xPh0JZUNCaYgqFQqFQKBTKvMbr9WJqnOe3v/0trFYrLr300tmZFIVCmTfQmmIKhUKhUCgUyrxm//79ePDBB3HTTTchPz8fR44cwa9//WusWrUKN91002xPj0KhzHGoU0yhUCgUCoVCmddUV1ejoqICv/jFL2C1WpGXl4fbbrsNP/rRj6BUKmd7ehQKZY4zq+nTH3/8Ma655hqUlpaCEII33ngj7vvXXnsNV1xxBfLz80EIQVNTU1J2X3nlFSxbtgxqtRqrV6/Gu+++K/3kKRQKhUKhzAhd6ynZoLq6Gm+99RYsFgv8fj8sFguef/55FBUVzfbUJOPSSy8Fy7K0nphCyQCz6hS7XC7U19f
"text/plain": [
"<Figure size 1200x500 with 4 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"def L(N, D):\n",
" \"\"\" \n",
" Approximates loss given N parameters and D dataset size (in tokens),\n",
" per Chinchilla paper.\n",
" \"\"\"\n",
" E = 1.69 # entropy of natural language, limit of infinite model on infinite data\n",
" A = 406.4\n",
" B = 410.7\n",
" alpha = 0.34\n",
" beta = 0.28\n",
" return A / (N ** alpha) + B / (D ** beta) + E\n",
"\n",
"ns = 10 ** np.arange(7, 11, step=2**-4) # model sizes from 10M to 100B\n",
"ds = 10 ** np.arange(9, 12, step=2**-4) # dataset sizes from 1B to 1T\n",
"plt.figure(figsize=(12, 5))\n",
"plt.subplot(121)\n",
"# create a 2D countour plot of loss L as a function of model size and dataset size in ns,ds\n",
"loss2d = np.log10(np.array([[L(n, d) for d in ds] for n in ns]))\n",
"plt.imshow(loss2d, extent=[9, 12, 7, 11], origin='lower', alpha=0.5)\n",
"plt.contour(loss2d, levels=30, extent=[9, 12, 7, 11], origin='lower')\n",
"plt.xlabel('log10(dataset size)')\n",
"plt.ylabel('log10(model size)')\n",
"plt.title('loss')\n",
"plt.colorbar()\n",
"# plot the compute for each point, which is a deterministic function: flops = 6*N*D\n",
"plt.subplot(122)\n",
"compute2d = np.log10(np.array([[6*n*d for d in ds] for n in ns]))\n",
"plt.imshow(compute2d, extent=[9, 12, 7, 11], origin='lower', alpha=0.5)\n",
"plt.contour(compute2d, levels=30, extent=[9, 12, 7, 11], origin='lower')\n",
"plt.xlabel('log10(dataset size)')\n",
"plt.ylabel('log10(model size)')\n",
"plt.title('log10 flops')\n",
"plt.colorbar()"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"Ok so given any N,D we can estimate both: 1) the loss, and 2) the total flops. Now we want to solve the following problem: Given a specific budget of flops C, find: N_opt, D_opt = argmin_{FLOPs(N,D) = C} L(N, D). i.e. how big of a model should we train and for how many tokens?"
]
},
{
"cell_type": "code",
2023-01-06 02:13:04 +00:00
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"best model size: 316.23M\n",
"best dataset size: 10.12B\n"
]
},
{
"data": {
"text/plain": [
"Text(0, 0.5, 'loss')"
]
},
2023-01-06 02:13:04 +00:00
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAS8AAAEqCAYAAABEE9ZrAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8o6BhiAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAuW0lEQVR4nO3deViU9f7/8ecMy7Ajm4AspqIoKoqaCpor5m7W6ZfHLKy0Tp0sy7IT1Tl1Wg6ezJZz8tt6ytZjqannWC7pSS3FDTTRVMIFEFkUkWWQbeb+/QFOoiAIw9wzzPtxXXNdcnPPzHuY5tXn/sxn0SiKoiCEEDZGq3YBQgjREhJeQgibJOElhLBJEl5CCJsk4SWEsEkSXkIImyThJYSwSY5qF2BpRqORM2fO4OnpiUajUbscIcQVFEWhtLSUTp06odU23r6yu/A6c+YMYWFhapchhGhCdnY2oaGhjf7e7sLL09MTqP3DeHl5qVxNO6bXQ6dOtf8+cwbc3dWtR9iMkpISwsLCTJ/VxthdeF26VPTy8pLwaksODr/928tLwktct6a6daTDXghhkyS8hBA2ScJLCGGTJLyEEDZJwksIYZMkvIQQNknCSwjR5j7ecZK1B3Ioqag222NKeAkh2lS1wciSTenMX36ArMJysz2uhJcQok0dyL5AWWUNvu7ORAWbb2C4hJcQok39mH4WgLhufmi15lsMQcJLCNGmfsw4B8CI7gFmfVwJLyFEmykur+bn7AsADO/ub9bHlvASQrSZ5BPnMCrQLcCdTh1czfrYEl5CiDaz/dfaS8abzHzJCBJeQog29JMpvMx7yQgSXkKINpJZqCfrfDmOWg1DuvqZ/fElvIQQbeLHulbXgM4+eOjMv+6phJcQok2YLhkjzH/JCBJeQog2UGMwsuN4XXj1MH9nPUh4CSHawMGcYkoravB2daJviHebPIeElxDC7LYdq50SNCzCDwczTgm6nISXEMLsttbNZxzVo2ObPYeElxDCrM7rqzh4+gIAI9qovwskvIQQZvbjr2dRFOgZ5EmQt0ubPY+ElxDCrC71d42MbLtWF0h4CSHMyGhU2GaB/i6Q8BJCmNHhMyUU6qtwd3ZgYGefNn0uCS8hhNlsPVYAwLAIf5wd2zZeJLyEEGZjumSMbNtLRpDwEkKYSXF5NalZRUDbd9aDhJcQwkx+yqhdNbV7Rw9CzLxqakMkvIQQZnGpv2tkGw5MvZyElxCi1YxGhR+OWa6/CyS8hBBmkJZTzLmySjx0jgzu4muR55TwEkK02pajtZeMI3q0/RCJSyS8hBCt9r+j+QCMttAlI0h4CSFaKb+kgkM5JWg0luvvAgkvIUQr/VB3ydgvtAMBnjqLPa+ElxCiVS71d43tablWF0h4CSFaoaLaYNolaEwvCS8hhI3YdaKQi9UGgrxciAr2suhzS3gJIVrsUn/X6J4d0WjaZqONxkh4CSFaRFEU1fq7QMJLCNFCx/JLOV10EZ2jlrgIP4s/v9WE16JFi9BoNDz22GONnrNs2TI0Gk29m4tL2y3wL4Ro3KbDtQNTb+ruj5uzo8Wf3/LP2IC9e/fy3nvvER0d3eS5Xl5eHDt2zPSzpa+zhRC1vv+lNrzGRQWq8vyqt7zKysqYNWsWH3zwAT4+Ta95rdFoCAoKMt0CA9X5wwlhz3KLL5KWU4xGA2N62ml4Pfzww0yePJn4+PhmnV9WVkbnzp0JCwvjlltu4fDhw9c8v7KykpKSkno3IUTrbK5rdQ0I97HoqPrLqRpey5cvJzU1laSkpGadHxkZyUcffcTatWv5/PPPMRqNxMXFcfr06Ubvk5SUhLe3t+kWFhZmrvKFsFub6sLrZpUuGUHF8MrOzmb+/Pl88cUXze50j42NJSEhgf79+zNy5Ei++eYbAgICeO+99xq9T2JiIsXFxaZbdna2uV6CEHappKKaXScKAfX6u0DFDvuUlBQKCgoYMGCA6ZjBYGD79u28/fbbVFZW4uDgcM3HcHJyIiYmhoyMjEbP0el06HTqNGuFaI+2HjtLtUGhW4A7XQM8VKtDtfAaO3YsaWlp9Y7de++99OzZkz/96U9NBhfUhl1aWhqTJk1qqzKFEFf47VvGIFXrUC28PD096dOnT71j7u7u+Pn5mY4nJCQQEhJi6hN78cUXGTp0KBEREVy4cIHFixeTmZnJ3LlzLV6/EPaoqsbI1rpR9WpeMoKVjPNqTFZWFlrtb91yRUVF3H///eTl5eHj48PAgQPZuXMnUVFRKlYphP3YfbKQ0soa/D10xIR1ULUWjaIoiqoVWFhJSQne3t4UFxfj5WXZWfB2Ra8Hj7r+kLIycHdXtx5hFs+uTuOL3VnMHBxG0m1NDypvieZ+RlUf5yWEsA0Go8LGuilBE/oEq1yNhJcQoplSMos4V1aJl4sjsV0tPxH7ShJeQohm2XAoD4D4qECLbW92LepXIISweoqisPFwbXhN6K3uEIlLJLyEEE06eLqYnAsXcXN2YESPALXLASS8hBDNsL7uknF0z464ODU9gNwSJLyEENekKAobDuUCMLGPdVwygoSXEKIJx/JLOVVYjrOj1qI7YjdFwksIcU3r02ovGUd0D8BDZz2TciS8hBDXtN4KLxlBwksIcQ3p+aWk55fh7KAlXuWJ2FeS8BJCNGrdwdpW14ge/ni7OqlcTX0SXkKIBimKwrqDZwCYEt1J5WquJuElhGjQkdxSTpzV4+yoZWwv6/mW8RIJLyFEg75Nq211jY4MwNPFui4ZQcJLCNGA2kvG2v4ua7xkBAkvIUQDDp8pIbOwHBcnLWN6Wt8lI0h4CSEa8N+6jvqxPQNxt6KBqZeT8BJC1KMoCt/WXTJOjlZ/xdTGSHgJIerZn32B00W1y9+MtqK5jFeS8BJC1LN2fw4A43sH4epsHcvfNETCSwhhUmMwmr5lnNbfOr9lvETCSwhh8lPGOQr1Vfi5OzM8wl/tcq5JwksIYfKfA7XfMk6ODsbJwbrjwbqrE0JYzMUqg2mTjVv6h6hcTdMkvIQQAGw+ko++ykCYrysDwjuoXU6TJLyEEACsrbtkvKVfCBqNRuVqmibhJYTgQnkV29ILALjFyr9lvETCSwjBt2m5VBsUooK96B7oqXY5zSLhJYRgVcppAKbH2EarCyS8hLB7J8/pSc26gFYD023gW8ZLJLyEsHPfpNa2ukb0CKCjl4vK1TSfhJcQdsxoVPgmtXYu4+8GhKpczfWR8BLCju06WUjOhYt4ujgyzsq2NmuKhJcQdmxVSm2ra0p0MC5O1ruCREMkvISwU+VVNabdsG3tkhEkvISwWxsO5VFeZeAGPzcGdvZRu5zrJuElhJ1aVfct420DQm1iOtCVJLyEsEPZ58vZkVGIRgO3xtjO2K7LSXgJYYdW7MsGYHiEP2G+bipX0zISXkLYGYNRYUXddKA7BoWpXE3LWU14LVq0CI1Gw2OPPXbN81asWEHPnj1xcXGhb9++fPfdd5YpUIh2YvuvZ8ktrqCDmxM397atsV2Xs4rw2rt3L++99x7R0dHXPG/nzp3MnDmTOXPmsH//fqZPn8706dM5dOiQhSoVwvZ9vbf2knF6/xB0jrY1tutyqodXWVkZs2bN4oMPPsDH59pf17711ltMmDCBhQsX0qtXL1566SUGDBjA22+/baFqhbBthWWVbD6SD8CMG233khGsILwefvhhJk+eTHx8fJPnJicnX3Xe+PHjSU5ObvQ+lZWVlJSU1LsJYa9W78+h2qAQHepNr2AvtctpFUc1n3z58uWkpqayd+/eZp2fl5dHYGD9a/TAwEDy8vIavU9SUhJ//etfW1WnEO2Boigsr7tktPVWF6jY8srOzmb+/Pl88cUXuLi03TIciYmJFBcXm27Z2dlt9lxCWLPUrCIyCspwcdIytZ/tLDrYGNVaXikpKRQUFDBgwADTMYPBwPbt23n77beprKzEwaF+Z2JQUBD5+fn1juXn5xMUFNTo8+h0OnQ6nXmLF8IGfbErC4Cp0Z3wcnFSuZrWU63lNXbsWNLS0jhw4IDpNmjQIGbNmsWBAweuCi6A2NhYtmzZUu/Y999/T2xsrKXKFsI
"text/plain": [
"<Figure size 300x300 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"c = 1.92e19 # target compute budget (usually know this because we know how many GPU for how long go brrr)\n",
"# sweep model sizes from 10M to 100B\n",
"ns = 10 ** np.arange(7, 11, step=2**-4)\n",
"# using C = 6*N*D, solve for D that maintains the compute budget c\n",
"ds = c / (6 * ns)\n",
"# evaluate the loss in each case\n",
"losses = L(ns, ds)\n",
"# find the argmin\n",
"best = np.argmin(losses)\n",
"print(f\"best model size: {ns[best]/1e6:.2f}M\")\n",
"print(f\"best dataset size: {ds[best]/1e9:.2f}B\")\n",
"# plot the loss\n",
"plt.figure(figsize=(3,3))\n",
"plt.plot(ns, losses)\n",
"plt.xscale('log')\n",
"# plot a vertical bar at the best model size\n",
"plt.axvline(ns[best], color='red')\n",
"plt.xlabel('model size')\n",
"plt.ylabel('loss')"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"In the plot above, basically the models on the left of best are too small and trained for too long. The models on the right of best are way too large and trained for too little. The model at the red line is just right.\n",
"\n",
"Now, the Chinchilla paper says that best model size is 400M params and 8B tokens, so this once again disagrees and there is some calculations problem. TODO figure out and fix..."
]
},
{
"cell_type": "code",
2023-01-06 02:13:04 +00:00
"execution_count": 10,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"2304"
]
},
2023-01-06 02:13:04 +00:00
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Calculate the Chinchilla optimal models for a range of compute budgets\n",
"\n",
"# sweep over compute budgets from 1e17 to 1e26\n",
"cs = 10 ** np.arange(17, 26, step=2**-8)\n",
"models = []\n",
"for c in cs:\n",
" # sweep over model sizes\n",
" ns = 10 ** np.arange(7, 14, step=2**-8)\n",
" # the dataset sizes that would maintain the given compute budget\n",
" ds = c / (6 * ns)\n",
" # losses at each point\n",
" losses = L(ns, ds)\n",
" # n,d for the best model\n",
" best = np.argmin(losses)\n",
" models.append((c, ns[best], ds[best])) # c, n, d tuple log\n",
"\n",
"len(models)"
]
},
{
"cell_type": "code",
2023-01-06 02:13:04 +00:00
"execution_count": 11,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"closest model found:\n",
"model size: 1298.02M\n",
"dataset size: 60.32B\n",
"flops: 4.697589e+20\n",
"loss: 2.41\n"
]
}
],
"source": [
"query_model_size = 1.3e9 # GPT-3 size\n",
"ns = np.array([n for c, n, d in models])\n",
"ds = np.array([d for c, n, d in models])\n",
"# find the index of the closest model size in ns\n",
"ix = np.argmin(np.abs(ns - query_model_size))\n",
"# retrieve the corresponding params, flops, and data size\n",
"print(\"closest model found:\")\n",
"print(f\"model size: {ns[ix]/1e6:.2f}M\")\n",
"print(f\"dataset size: {ds[ix]/1e9:.2f}B\")\n",
"print(f\"flops: {6*ns[ix]*ds[ix]:e}\")\n",
"print(f\"loss: {L(ns[ix], ds[ix]):.2f}\")"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"So we predict 60B tokens is compute optimal. But e.g. [MosaicML quotes 26B](https://t.co/HyEvCqP70C). So again wrong."
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"TLDR atm: nothing reproduces."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "pytorch2",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.8"
},
"orig_nbformat": 4,
"vscode": {
"interpreter": {
"hash": "7f5833218766b48e6e35e4452ee875aac0e2188d05bbe5298f2c62b79f08b222"
}
}
},
"nbformat": 4,
"nbformat_minor": 2
}