diff --git a/scaling_laws.ipynb b/scaling_laws.ipynb index 9ba644a..c82deb6 100644 --- a/scaling_laws.ipynb +++ b/scaling_laws.ipynb @@ -16,35 +16,23 @@ "source": [ "import matplotlib.pyplot as plt\n", "import numpy as np\n", + "import pandas as pd\n", "%matplotlib inline" ] }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [], - "source": [ - "gpt2 = dict(\n", - " seq_len = 1024,\n", - " vocab_size = 50257,\n", - " d_model = 768,\n", - " num_heads = 12,\n", - " num_layers = 12,\n", - ")" - ] - }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ - "## params" + "## params\n", + "\n", + "First some parameter calculations:" ] }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 2, "metadata": {}, "outputs": [ { @@ -53,21 +41,22 @@ "123.653376" ] }, - "execution_count": 3, + "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "def params(seq_len, vocab_size, d_model, num_heads, num_layers, ffw_size=4):\n", - " \"\"\" Given GPT config calculate total number of parameters\"\"\" \n", + "def gpt_params(seq_len, vocab_size, d_model, num_heads, num_layers):\n", + " \"\"\" Given GPT config calculate total number of parameters \"\"\"\n", + " ffw_size = 4*d_model # in GPT the number of intermediate features is always 4*d_model\n", " # token and position embeddings\n", " embeddings = d_model * vocab_size + d_model * seq_len\n", " # transformer blocks\n", " attention = 3*d_model**2 + 3*d_model # weights and biases\n", " attproj = d_model**2 + d_model\n", - " ffw = d_model*(ffw_size*d_model) + ffw_size*d_model\n", - " ffwproj = ffw_size*d_model*d_model + d_model\n", + " ffw = d_model*(ffw_size) + ffw_size\n", + " ffwproj = ffw_size*d_model + d_model\n", " layernorms = 2*2*d_model\n", " # dense\n", " ln_f = 2*d_model\n", @@ -76,7 +65,8 @@ " total_params = num_layers*(attention + attproj + ffw + ffwproj + layernorms) + ln_f + dense\n", " return total_params\n", "\n", - "params(**gpt2)/1e6" + "gpt2 = dict(seq_len = 1024, vocab_size = 50257, d_model = 768, num_heads = 12, num_layers = 12)\n", + "gpt_params(**gpt2)/1e6" ] }, { @@ -84,7 +74,33 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "OpenAI reports gpt2 (small) as having 124M params, so this is a match. Also, loading the OpenAI weights into nanoGPT and then calling `model.parameters()` exactly matches the above number and verifies the implementation." + "OpenAI reports gpt2 (small) as having 124M params, so this is a match. Also, loading the OpenAI weights into nanoGPT and then calling `model.parameters()` exactly matches the above number and verifies the implementation. Now Chinchilla parameters:" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "def chinchilla_params(seq_len, vocab_size, d_model, num_heads, num_layers, ffw_size):\n", + " \"\"\" Parameters in the Chinchilla models. Unlike GPT they use relative positional embeddings. \"\"\"\n", + " # token embeddings only\n", + " embeddings = d_model * vocab_size\n", + " # transformer blocks\n", + " attention = 3*d_model**2 + 3*d_model # weights and biases\n", + " relative_pos = d_model**2 + 2*d_model # relative keys, content bias, relative bias\n", + " attproj = d_model**2 + d_model\n", + " ffw = d_model*ffw_size + ffw_size\n", + " ffwproj = ffw_size*d_model + d_model\n", + " layernorms = 2*2*d_model\n", + " # dense\n", + " ln_f = 2*d_model\n", + " dense = d_model*vocab_size # note: no bias here\n", + " # note: embeddings are not included in the param count!\n", + " total_params = num_layers*(attention + relative_pos + attproj + ffw + ffwproj + layernorms) + ln_f + dense\n", + " return total_params\n", + "\n" ] }, { @@ -104,6 +120,7 @@ } ], "source": [ + "# Load in all the 50 Chinchilla models on the last page of the paper\n", "import json\n", "chinchilla_models_txt = '[[44000000.0, 512, 2048, 64, 8, 8], [57000000.0, 576, 2304, 64, 9, 9], [74000000.0, 640, 2560, 64, 10, 10], [90000000.0, 640, 2560, 64, 10, 13], [106000000.0, 640, 2560, 64, 10, 16], [117000000.0, 768, 3072, 64, 12, 12], [140000000.0, 768, 3072, 64, 12, 15], [163000000.0, 768, 3072, 64, 12, 18], [175000000.0, 896, 3584, 64, 14, 14], [196000000.0, 896, 3584, 64, 14, 16], [217000000.0, 896, 3584, 64, 14, 18], [251000000.0, 1024, 4096, 64, 16, 16], [278000000.0, 1024, 4096, 64, 16, 18], [306000000.0, 1024, 4096, 64, 16, 20], [425000000.0, 1280, 5120, 128, 10, 18], [489000000.0, 1280, 5120, 128, 10, 21], [509000000.0, 1408, 5632, 128, 11, 18], [552000000.0, 1280, 5120, 128, 10, 24], [587000000.0, 1408, 5632, 128, 11, 21], [632000000.0, 1536, 6144, 128, 12, 19], [664000000.0, 1408, 5632, 128, 11, 24], [724000000.0, 1536, 6144, 128, 12, 22], [816000000.0, 1536, 6144, 128, 12, 25], [893000000.0, 1792, 7168, 128, 14, 20], [1018000000.0, 1792, 7168, 128, 14, 23], [1143000000.0, 1792, 7168, 128, 14, 26], [1266000000.0, 2048, 8192, 128, 16, 22], [1424000000.0, 2176, 8704, 128, 17, 22], [1429000000.0, 2048, 8192, 128, 16, 25], [1593000000.0, 2048, 8192, 128, 16, 28], [1609000000.0, 2176, 8704, 128, 17, 25], [1731000000.0, 2304, 9216, 128, 18, 24], [1794000000.0, 2176, 8704, 128, 17, 28], [2007000000.0, 2304, 9216, 128, 18, 28], [2283000000.0, 2304, 9216, 128, 18, 32], [2298000000.0, 2560, 10240, 128, 20, 26], [2639000000.0, 2560, 10240, 128, 20, 30], [2980000000.0, 2560, 10240, 128, 20, 34], [3530000000.0, 2688, 10752, 128, 22, 36], [3802000000.0, 2816, 11264, 128, 22, 36], [4084000000.0, 2944, 11776, 128, 22, 36], [4516000000.0, 3072, 12288, 128, 24, 36], [6796000000.0, 3584, 14336, 128, 28, 40], [9293000000.0, 4096, 16384, 128, 32, 42], [11452000000.0, 4352, 17408, 128, 32, 47], [12295000000.0, 4608, 18432, 128, 36, 44], [12569000000.0, 4608, 18432, 128, 32, 47], [13735000000.0, 4864, 19456, 128, 32, 47], [14940000000.0, 4992, 19968, 128, 32, 49], [16183000000.0, 5120, 20480, 128, 40, 47]]'\n", "chilchilla_models = json.loads(chinchilla_models_txt) # all 50 models\n", @@ -119,18 +136,63 @@ "name": "stdout", "output_type": "stream", "text": [ - "our estimated params: 41.6041M, chinchilla params: 44.0000M, d_model: 512, n_heads: 8, n_layers: 8\n", - "our estimated params: 54.3324M, chinchilla params: 57.0000M, d_model: 576, n_heads: 9, n_layers: 9\n", - "our estimated params: 69.7165M, chinchilla params: 74.0000M, d_model: 640, n_heads: 10, n_layers: 10\n", - "our estimated params: 84.4870M, chinchilla params: 90.0000M, d_model: 640, n_heads: 10, n_layers: 13\n", - "our estimated params: 99.2576M, chinchilla params: 106.0000M, d_model: 640, n_heads: 10, n_layers: 16\n" + "our estimated params: 43.7094M, chinchilla params: 44.0000M, d_model: 512, n_heads: 8, n_layers: 8\n", + "our estimated params: 57.3287M, chinchilla params: 57.0000M, d_model: 576, n_heads: 9, n_layers: 9\n", + "our estimated params: 73.8253M, chinchilla params: 74.0000M, d_model: 640, n_heads: 10, n_layers: 10\n", + "our estimated params: 89.8285M, chinchilla params: 90.0000M, d_model: 640, n_heads: 10, n_layers: 13\n", + "our estimated params: 105.8317M, chinchilla params: 106.0000M, d_model: 640, n_heads: 10, n_layers: 16\n", + "our estimated params: 116.7283M, chinchilla params: 117.0000M, d_model: 768, n_heads: 12, n_layers: 12\n", + "our estimated params: 139.7660M, chinchilla params: 140.0000M, d_model: 768, n_heads: 12, n_layers: 15\n", + "our estimated params: 162.8037M, chinchilla params: 163.0000M, d_model: 768, n_heads: 12, n_layers: 18\n", + "our estimated params: 174.9745M, chinchilla params: 175.0000M, d_model: 896, n_heads: 14, n_layers: 14\n", + "our estimated params: 195.8746M, chinchilla params: 196.0000M, d_model: 896, n_heads: 14, n_layers: 16\n", + "our estimated params: 216.7747M, chinchilla params: 217.0000M, d_model: 896, n_heads: 14, n_layers: 18\n", + "our estimated params: 251.1196M, chinchilla params: 251.0000M, d_model: 1024, n_heads: 16, n_layers: 16\n", + "our estimated params: 278.4133M, chinchilla params: 278.0000M, d_model: 1024, n_heads: 16, n_layers: 18\n", + "our estimated params: 305.7070M, chinchilla params: 306.0000M, d_model: 1024, n_heads: 16, n_layers: 20\n", + "our estimated params: 424.6938M, chinchilla params: 425.0000M, d_model: 1280, n_heads: 10, n_layers: 18\n", + "our estimated params: 488.6490M, chinchilla params: 489.0000M, d_model: 1280, n_heads: 10, n_layers: 21\n", + "our estimated params: 509.3356M, chinchilla params: 509.0000M, d_model: 1408, n_heads: 11, n_layers: 18\n", + "our estimated params: 552.6042M, chinchilla params: 552.0000M, d_model: 1280, n_heads: 10, n_layers: 24\n", + "our estimated params: 586.7150M, chinchilla params: 587.0000M, d_model: 1408, n_heads: 11, n_layers: 21\n", + "our estimated params: 632.3389M, chinchilla params: 632.0000M, d_model: 1536, n_heads: 12, n_layers: 19\n", + "our estimated params: 664.0945M, chinchilla params: 664.0000M, d_model: 1408, n_heads: 11, n_layers: 24\n", + "our estimated params: 724.4206M, chinchilla params: 724.0000M, d_model: 1536, n_heads: 12, n_layers: 22\n", + "our estimated params: 816.5023M, chinchilla params: 816.0000M, d_model: 1536, n_heads: 12, n_layers: 25\n", + "our estimated params: 892.8138M, chinchilla params: 893.0000M, d_model: 1792, n_heads: 14, n_layers: 20\n", + "our estimated params: 1018.1338M, chinchilla params: 1018.0000M, d_model: 1792, n_heads: 14, n_layers: 23\n", + "our estimated params: 1143.4537M, chinchilla params: 1143.0000M, d_model: 1792, n_heads: 14, n_layers: 26\n", + "our estimated params: 1265.7869M, chinchilla params: 1266.0000M, d_model: 2048, n_heads: 16, n_layers: 22\n", + "our estimated params: 1424.5576M, chinchilla params: 1424.0000M, d_model: 2176, n_heads: 17, n_layers: 22\n", + "our estimated params: 1429.4569M, chinchilla params: 1429.0000M, d_model: 2048, n_heads: 16, n_layers: 25\n", + "our estimated params: 1593.1269M, chinchilla params: 1593.0000M, d_model: 2048, n_heads: 16, n_layers: 28\n", + "our estimated params: 1609.3196M, chinchilla params: 1609.0000M, d_model: 2176, n_heads: 17, n_layers: 25\n", + "our estimated params: 1730.7878M, chinchilla params: 1731.0000M, d_model: 2304, n_heads: 18, n_layers: 24\n", + "our estimated params: 1794.0815M, chinchilla params: 1794.0000M, d_model: 2176, n_heads: 17, n_layers: 28\n", + "our estimated params: 2006.9637M, chinchilla params: 2007.0000M, d_model: 2304, n_heads: 18, n_layers: 28\n", + "our estimated params: 2283.1396M, chinchilla params: 2283.0000M, d_model: 2304, n_heads: 18, n_layers: 32\n", + "our estimated params: 2298.0403M, chinchilla params: 2298.0000M, d_model: 2560, n_heads: 20, n_layers: 26\n", + "our estimated params: 2638.9811M, chinchilla params: 2639.0000M, d_model: 2560, n_heads: 20, n_layers: 30\n", + "our estimated params: 2979.9219M, chinchilla params: 2980.0000M, d_model: 2560, n_heads: 20, n_layers: 34\n", + "our estimated params: 3468.9339M, chinchilla params: 3530.0000M, d_model: 2688, n_heads: 22, n_layers: 36\n", + "our estimated params: 3802.8109M, chinchilla params: 3802.0000M, d_model: 2816, n_heads: 22, n_layers: 36\n", + "our estimated params: 4152.0233M, chinchilla params: 4084.0000M, d_model: 2944, n_heads: 22, n_layers: 36\n", + "our estimated params: 4516.5711M, chinchilla params: 4516.0000M, d_model: 3072, n_heads: 24, n_layers: 36\n", + "our estimated params: 6796.2747M, chinchilla params: 6796.0000M, d_model: 3584, n_heads: 28, n_layers: 40\n", + "our estimated params: 9294.0206M, chinchilla params: 9293.0000M, d_model: 4096, n_heads: 32, n_layers: 42\n", + "our estimated params: 11714.6222M, chinchilla params: 11452.0000M, d_model: 4352, n_heads: 32, n_layers: 47\n", + "our estimated params: 12296.1623M, chinchilla params: 12295.0000M, d_model: 4608, n_heads: 36, n_layers: 44\n", + "our estimated params: 13124.4826M, chinchilla params: 12569.0000M, d_model: 4608, n_heads: 32, n_layers: 47\n", + "our estimated params: 14614.4279M, chinchilla params: 13735.0000M, d_model: 4864, n_heads: 32, n_layers: 47\n", + "our estimated params: 16037.5039M, chinchilla params: 14940.0000M, d_model: 4992, n_heads: 32, n_layers: 49\n", + "our estimated params: 16184.4582M, chinchilla params: 16183.0000M, d_model: 5120, n_heads: 40, n_layers: 47\n" ] } ], "source": [ - "for m in chilchilla_models[:5]: # look at just first 5 models\n", + "for m in chilchilla_models:\n", " p, d, f, k, h, l = m\n", - " nparams = params(seq_len = 1024, vocab_size = 32000, d_model = d, num_heads = h, num_layers = l)\n", + " nparams = chinchilla_params(seq_len = 1024, vocab_size = 32000, d_model = d, num_heads = h, num_layers = l, ffw_size=f)\n", " print(f\"our estimated params: {nparams/1e6:.4f}M, chinchilla params: {p/1e6:.4f}M, d_model: {d}, n_heads: {h}, n_layers: {l}\")" ] }, @@ -139,7 +201,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "So we are not able to reproduce the parameter counts for the Chinchilla models. The claimed number of parameters is not the number of parameters we compute above. TODO: resolve... \n", + "We are almost able to reproduce the parameter counts for the Chinchilla models. TODO resolve...\n", "\n", "Now turning to FLOPs:" ] @@ -156,22 +218,11 @@ "cell_type": "code", "execution_count": 6, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "766.006788096" - ] - }, - "execution_count": 6, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ - "def flops(seq_len, vocab_size, d_model, num_heads, num_layers, ffw_size=4):\n", + "def chinchilla_flops(seq_len, vocab_size, d_model, num_heads, num_layers, ffw_size):\n", " \"\"\" \n", - " Given GPT config calculate total number of FLOPs, see Chinchilla \n", + " Calculate total number of FLOPs, see Chinchilla \n", " paper Appendix F as reference: https://arxiv.org/pdf/2203.15556.pdf\n", " \"\"\" \n", " key_size = d_model // num_heads\n", @@ -196,14 +247,17 @@ "\n", " # logits\n", " logits = 2 * seq_len * d_model * vocab_size\n", - "\n", - " forward_flops = embeddings + num_layers * (att + dense) + logits\n", + " \n", + " # this is what you'd expect:\n", + " # forward_flops = embeddings + num_layers * (att + dense) + logits\n", + " # but:\n", + " # per author correspondence apparently there is typo in the paper,\n", + " # they do not count embeddings and logits to repro table 4. So instead:\n", + " forward_flops = num_layers * (att + dense)\n", " backward_flops = 2 * forward_flops # as in Kaplan et al. 2020\n", " total_flops = forward_flops + backward_flops\n", "\n", - " return total_flops\n", - "\n", - "flops(**gpt2)/1e9" + " return total_flops\n" ] }, { @@ -212,14 +266,149 @@ "metadata": {}, "outputs": [ { - "name": "stdout", - "output_type": "stream", - "text": [ - "params: 69.72M\n", - "approx flops: 428.34B\n", - "chinchilla flops: 434.11B\n", - "ratio (chinchilla / approx): 1.01\n" - ] + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
seq_lenvocab_sized_modelnum_headsnum_layersffw_sizeNFapprox_flopschinch_flopsratio
020483200064010102560738252809298771968009071650406409.298772e+111.025036
1204832000102416204096305707008413524819968037565277143044.135248e+121.100817
2204832000128010245120552604160735345377280067903999180807.353454e+121.082919
3204832000179214267168114345369614670316437504140507590164481.467032e+131.044094
4204832000204816288192159312691220220437594112195763434946562.022044e+131.032902
52048320003584284014336679627468883021046743040835126233661448.302105e+130.994114
\n", + "
" + ], + "text/plain": [ + " seq_len vocab_size d_model num_heads num_layers ffw_size N \\\n", + "0 2048 32000 640 10 10 2560 73825280 \n", + "1 2048 32000 1024 16 20 4096 305707008 \n", + "2 2048 32000 1280 10 24 5120 552604160 \n", + "3 2048 32000 1792 14 26 7168 1143453696 \n", + "4 2048 32000 2048 16 28 8192 1593126912 \n", + "5 2048 32000 3584 28 40 14336 6796274688 \n", + "\n", + " F approx_flops chinch_flops ratio \n", + "0 929877196800 907165040640 9.298772e+11 1.025036 \n", + "1 4135248199680 3756527714304 4.135248e+12 1.100817 \n", + "2 7353453772800 6790399918080 7.353454e+12 1.082919 \n", + "3 14670316437504 14050759016448 1.467032e+13 1.044094 \n", + "4 20220437594112 19576343494656 2.022044e+13 1.032902 \n", + "5 83021046743040 83512623366144 8.302105e+13 0.994114 " + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" } ], "source": [ @@ -227,19 +416,47 @@ "# comparing accurate flops above to approximate flops F = 6*N*D\n", "# note Chinchilla mentions using vocab_size = 32K\n", "\n", - "args = dict(seq_len = 1024, vocab_size = 32000, d_model = 640, num_heads = 10, num_layers = 10) # chinchilla 73M\n", + "chilchilla_models_table4 = [\n", + " [10, 640, 2560, 10, 64],\n", + " [20, 1024, 4096, 16, 64],\n", + " [24, 1280, 5120, 10, 128 ],\n", + " [26, 1792, 7168, 14, 128 ],\n", + " [28, 2048, 8192, 16, 128],\n", + " [40, 3584, 14336, 28, 128]\n", + "]\n", "\n", - "D = 1024 # dataset size (cancels anyway, for the purposes of the ratio calculation below)\n", - "N = params(**args)\n", - "F = flops(**args)\n", + "rows = []\n", + "for num_layers, d_model, ffw_size, num_heads, _ in chilchilla_models_table4:\n", "\n", - "approx_flops = 6*D*N # approximate flops\n", - "chinch_flops = F * (float(D) / args['seq_len']) # exact flops according to Chinchilla paper calculations\n", + " args = dict(seq_len = 2048, vocab_size = 32000, d_model = d_model, \n", + " num_heads = num_heads, num_layers = num_layers, ffw_size=ffw_size)\n", "\n", - "print(f\"params: {N/1e6:.2f}M\")\n", - "print(f\"approx flops: {approx_flops/1e9:.2f}B\")\n", - "print(f\"chinchilla flops: {chinch_flops/1e9:.2f}B\")\n", - "print(f\"ratio (chinchilla / approx): {chinch_flops / approx_flops:.2f}\")" + " D = args['seq_len'] # dataset size (cancels anyway, for the purposes of the ratio calculation below)\n", + " N = chinchilla_params(**args)\n", + " F = chinchilla_flops(**args)\n", + "\n", + " approx_flops = 6*D*N # approximate flops\n", + " chinch_flops = F * (float(D) / args['seq_len']) # exact flops according to Chinchilla paper calculations\n", + "\n", + " # print('---')\n", + " # print(f\"params: {N/1e6:.2f}M\")\n", + " # print(f\"approx flops: {approx_flops/1e9:.2f}B\")\n", + " # print(f\"chinchilla flops: {chinch_flops/1e9:.2f}B\")\n", + " # print(f\"ratio (chinchilla / approx): {chinch_flops / approx_flops:.2f}\")\n", + "\n", + " # first copy all keyvalues from args into out\n", + " out = {k:v for k,v in args.items()}\n", + " # then add the calculated values\n", + " out['N'] = N\n", + " out['F'] = F\n", + " out['approx_flops'] = approx_flops\n", + " out['chinch_flops'] = chinch_flops\n", + " out['ratio'] = chinch_flops / approx_flops\n", + " rows.append(out)\n", + "\n", + "# make a pandas dataframe from rows\n", + "df = pd.DataFrame(rows)\n", + "df" ] }, { @@ -247,7 +464,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "Awkward because the ratio is supposed to be 1.03 but we get 1.01, e.g. for the 73M param model. So once again we don't reproduce the numbers. TODO resolve..." + "Pretty good match! Except the param counts are still not perfectly accurate." ] }, { @@ -266,7 +483,7 @@ { "data": { "text/plain": [ - "" + "" ] }, "execution_count": 8, @@ -474,13 +691,8 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "TLDR atm: nothing reproduces." + "TLDR atm: nothing reproduces, but progress is being made." ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [] } ], "metadata": {