{ "cells": [ { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "Trying to reproduce results from [Chinchilla](https://arxiv.org/pdf/2203.15556.pdf):" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import matplotlib.pyplot as plt\n", "import numpy as np\n", "import pandas as pd\n", "%matplotlib inline" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## params\n", "\n", "First some parameter calculations:" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "123.653376" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "def gpt_params(seq_len, vocab_size, d_model, num_heads, num_layers):\n", " \"\"\" Given GPT config calculate total number of parameters \"\"\"\n", " ffw_size = 4*d_model # in GPT the number of intermediate features is always 4*d_model\n", " # token and position embeddings\n", " embeddings = d_model * vocab_size + d_model * seq_len\n", " # transformer blocks\n", " attention = 3*d_model**2 + 3*d_model # weights and biases\n", " attproj = d_model**2 + d_model\n", " ffw = d_model*(ffw_size) + ffw_size\n", " ffwproj = ffw_size*d_model + d_model\n", " layernorms = 2*2*d_model\n", " # dense\n", " ln_f = 2*d_model\n", " dense = d_model*vocab_size # note: no bias here\n", " # note: embeddings are not included in the param count!\n", " total_params = num_layers*(attention + attproj + ffw + ffwproj + layernorms) + ln_f + dense\n", " return total_params\n", "\n", "gpt2 = dict(seq_len = 1024, vocab_size = 50257, d_model = 768, num_heads = 12, num_layers = 12)\n", "gpt_params(**gpt2)/1e6" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "OpenAI reports gpt2 (small) as having 124M params, so this is a match. Also, loading the OpenAI weights into nanoGPT and then calling `model.parameters()` exactly matches the above number and verifies the implementation. Now Chinchilla parameters:" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "def chinchilla_params(seq_len, vocab_size, d_model, num_heads, num_layers, ffw_size):\n", " \"\"\" Parameters in the Chinchilla models. Unlike GPT they use relative positional embeddings. \"\"\"\n", " # token embeddings only\n", " embeddings = d_model * vocab_size\n", " # transformer blocks\n", " attention = 3*d_model**2 + 3*d_model # weights and biases\n", " relative_pos = d_model**2 + 2*d_model # relative keys, content bias, relative bias\n", " attproj = d_model**2 + d_model\n", " ffw = d_model*ffw_size + ffw_size\n", " ffwproj = ffw_size*d_model + d_model\n", " layernorms = 2*2*d_model\n", " # dense\n", " ln_f = 2*d_model\n", " dense = d_model*vocab_size # note: no bias here\n", " # note: embeddings are not included in the param count!\n", " total_params = num_layers*(attention + relative_pos + attproj + ffw + ffwproj + layernorms) + ln_f + dense\n", " return total_params\n", "\n" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[44000000.0, 512, 2048, 64, 8, 8]" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Load in all the 50 Chinchilla models on the last page of the paper\n", "import json\n", "chinchilla_models_txt = '[[44000000.0, 512, 2048, 64, 8, 8], [57000000.0, 576, 2304, 64, 9, 9], [74000000.0, 640, 2560, 64, 10, 10], [90000000.0, 640, 2560, 64, 10, 13], [106000000.0, 640, 2560, 64, 10, 16], [117000000.0, 768, 3072, 64, 12, 12], [140000000.0, 768, 3072, 64, 12, 15], [163000000.0, 768, 3072, 64, 12, 18], [175000000.0, 896, 3584, 64, 14, 14], [196000000.0, 896, 3584, 64, 14, 16], [217000000.0, 896, 3584, 64, 14, 18], [251000000.0, 1024, 4096, 64, 16, 16], [278000000.0, 1024, 4096, 64, 16, 18], [306000000.0, 1024, 4096, 64, 16, 20], [425000000.0, 1280, 5120, 128, 10, 18], [489000000.0, 1280, 5120, 128, 10, 21], [509000000.0, 1408, 5632, 128, 11, 18], [552000000.0, 1280, 5120, 128, 10, 24], [587000000.0, 1408, 5632, 128, 11, 21], [632000000.0, 1536, 6144, 128, 12, 19], [664000000.0, 1408, 5632, 128, 11, 24], [724000000.0, 1536, 6144, 128, 12, 22], [816000000.0, 1536, 6144, 128, 12, 25], [893000000.0, 1792, 7168, 128, 14, 20], [1018000000.0, 1792, 7168, 128, 14, 23], [1143000000.0, 1792, 7168, 128, 14, 26], [1266000000.0, 2048, 8192, 128, 16, 22], [1424000000.0, 2176, 8704, 128, 17, 22], [1429000000.0, 2048, 8192, 128, 16, 25], [1593000000.0, 2048, 8192, 128, 16, 28], [1609000000.0, 2176, 8704, 128, 17, 25], [1731000000.0, 2304, 9216, 128, 18, 24], [1794000000.0, 2176, 8704, 128, 17, 28], [2007000000.0, 2304, 9216, 128, 18, 28], [2283000000.0, 2304, 9216, 128, 18, 32], [2298000000.0, 2560, 10240, 128, 20, 26], [2639000000.0, 2560, 10240, 128, 20, 30], [2980000000.0, 2560, 10240, 128, 20, 34], [3530000000.0, 2688, 10752, 128, 22, 36], [3802000000.0, 2816, 11264, 128, 22, 36], [4084000000.0, 2944, 11776, 128, 22, 36], [4516000000.0, 3072, 12288, 128, 24, 36], [6796000000.0, 3584, 14336, 128, 28, 40], [9293000000.0, 4096, 16384, 128, 32, 42], [11452000000.0, 4352, 17408, 128, 32, 47], [12295000000.0, 4608, 18432, 128, 36, 44], [12569000000.0, 4608, 18432, 128, 32, 47], [13735000000.0, 4864, 19456, 128, 32, 47], [14940000000.0, 4992, 19968, 128, 32, 49], [16183000000.0, 5120, 20480, 128, 40, 47]]'\n", "chilchilla_models = json.loads(chinchilla_models_txt) # all 50 models\n", "chilchilla_models[0] # tuples of params, d_model, ffw_size, kv_size, n_heads, n_layers from Table A9" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "our estimated params: 43.7094M, chinchilla params: 44.0000M, d_model: 512, n_heads: 8, n_layers: 8\n", "our estimated params: 57.3287M, chinchilla params: 57.0000M, d_model: 576, n_heads: 9, n_layers: 9\n", "our estimated params: 73.8253M, chinchilla params: 74.0000M, d_model: 640, n_heads: 10, n_layers: 10\n", "our estimated params: 89.8285M, chinchilla params: 90.0000M, d_model: 640, n_heads: 10, n_layers: 13\n", "our estimated params: 105.8317M, chinchilla params: 106.0000M, d_model: 640, n_heads: 10, n_layers: 16\n", "our estimated params: 116.7283M, chinchilla params: 117.0000M, d_model: 768, n_heads: 12, n_layers: 12\n", "our estimated params: 139.7660M, chinchilla params: 140.0000M, d_model: 768, n_heads: 12, n_layers: 15\n", "our estimated params: 162.8037M, chinchilla params: 163.0000M, d_model: 768, n_heads: 12, n_layers: 18\n", "our estimated params: 174.9745M, chinchilla params: 175.0000M, d_model: 896, n_heads: 14, n_layers: 14\n", "our estimated params: 195.8746M, chinchilla params: 196.0000M, d_model: 896, n_heads: 14, n_layers: 16\n", "our estimated params: 216.7747M, chinchilla params: 217.0000M, d_model: 896, n_heads: 14, n_layers: 18\n", "our estimated params: 251.1196M, chinchilla params: 251.0000M, d_model: 1024, n_heads: 16, n_layers: 16\n", "our estimated params: 278.4133M, chinchilla params: 278.0000M, d_model: 1024, n_heads: 16, n_layers: 18\n", "our estimated params: 305.7070M, chinchilla params: 306.0000M, d_model: 1024, n_heads: 16, n_layers: 20\n", "our estimated params: 424.6938M, chinchilla params: 425.0000M, d_model: 1280, n_heads: 10, n_layers: 18\n", "our estimated params: 488.6490M, chinchilla params: 489.0000M, d_model: 1280, n_heads: 10, n_layers: 21\n", "our estimated params: 509.3356M, chinchilla params: 509.0000M, d_model: 1408, n_heads: 11, n_layers: 18\n", "our estimated params: 552.6042M, chinchilla params: 552.0000M, d_model: 1280, n_heads: 10, n_layers: 24\n", "our estimated params: 586.7150M, chinchilla params: 587.0000M, d_model: 1408, n_heads: 11, n_layers: 21\n", "our estimated params: 632.3389M, chinchilla params: 632.0000M, d_model: 1536, n_heads: 12, n_layers: 19\n", "our estimated params: 664.0945M, chinchilla params: 664.0000M, d_model: 1408, n_heads: 11, n_layers: 24\n", "our estimated params: 724.4206M, chinchilla params: 724.0000M, d_model: 1536, n_heads: 12, n_layers: 22\n", "our estimated params: 816.5023M, chinchilla params: 816.0000M, d_model: 1536, n_heads: 12, n_layers: 25\n", "our estimated params: 892.8138M, chinchilla params: 893.0000M, d_model: 1792, n_heads: 14, n_layers: 20\n", "our estimated params: 1018.1338M, chinchilla params: 1018.0000M, d_model: 1792, n_heads: 14, n_layers: 23\n", "our estimated params: 1143.4537M, chinchilla params: 1143.0000M, d_model: 1792, n_heads: 14, n_layers: 26\n", "our estimated params: 1265.7869M, chinchilla params: 1266.0000M, d_model: 2048, n_heads: 16, n_layers: 22\n", "our estimated params: 1424.5576M, chinchilla params: 1424.0000M, d_model: 2176, n_heads: 17, n_layers: 22\n", "our estimated params: 1429.4569M, chinchilla params: 1429.0000M, d_model: 2048, n_heads: 16, n_layers: 25\n", "our estimated params: 1593.1269M, chinchilla params: 1593.0000M, d_model: 2048, n_heads: 16, n_layers: 28\n", "our estimated params: 1609.3196M, chinchilla params: 1609.0000M, d_model: 2176, n_heads: 17, n_layers: 25\n", "our estimated params: 1730.7878M, chinchilla params: 1731.0000M, d_model: 2304, n_heads: 18, n_layers: 24\n", "our estimated params: 1794.0815M, chinchilla params: 1794.0000M, d_model: 2176, n_heads: 17, n_layers: 28\n", "our estimated params: 2006.9637M, chinchilla params: 2007.0000M, d_model: 2304, n_heads: 18, n_layers: 28\n", "our estimated params: 2283.1396M, chinchilla params: 2283.0000M, d_model: 2304, n_heads: 18, n_layers: 32\n", "our estimated params: 2298.0403M, chinchilla params: 2298.0000M, d_model: 2560, n_heads: 20, n_layers: 26\n", "our estimated params: 2638.9811M, chinchilla params: 2639.0000M, d_model: 2560, n_heads: 20, n_layers: 30\n", "our estimated params: 2979.9219M, chinchilla params: 2980.0000M, d_model: 2560, n_heads: 20, n_layers: 34\n", "our estimated params: 3468.9339M, chinchilla params: 3530.0000M, d_model: 2688, n_heads: 22, n_layers: 36\n", "our estimated params: 3802.8109M, chinchilla params: 3802.0000M, d_model: 2816, n_heads: 22, n_layers: 36\n", "our estimated params: 4152.0233M, chinchilla params: 4084.0000M, d_model: 2944, n_heads: 22, n_layers: 36\n", "our estimated params: 4516.5711M, chinchilla params: 4516.0000M, d_model: 3072, n_heads: 24, n_layers: 36\n", "our estimated params: 6796.2747M, chinchilla params: 6796.0000M, d_model: 3584, n_heads: 28, n_layers: 40\n", "our estimated params: 9294.0206M, chinchilla params: 9293.0000M, d_model: 4096, n_heads: 32, n_layers: 42\n", "our estimated params: 11714.6222M, chinchilla params: 11452.0000M, d_model: 4352, n_heads: 32, n_layers: 47\n", "our estimated params: 12296.1623M, chinchilla params: 12295.0000M, d_model: 4608, n_heads: 36, n_layers: 44\n", "our estimated params: 13124.4826M, chinchilla params: 12569.0000M, d_model: 4608, n_heads: 32, n_layers: 47\n", "our estimated params: 14614.4279M, chinchilla params: 13735.0000M, d_model: 4864, n_heads: 32, n_layers: 47\n", "our estimated params: 16037.5039M, chinchilla params: 14940.0000M, d_model: 4992, n_heads: 32, n_layers: 49\n", "our estimated params: 16184.4582M, chinchilla params: 16183.0000M, d_model: 5120, n_heads: 40, n_layers: 47\n" ] } ], "source": [ "for m in chilchilla_models:\n", " p, d, f, k, h, l = m\n", " nparams = chinchilla_params(seq_len = 1024, vocab_size = 32000, d_model = d, num_heads = h, num_layers = l, ffw_size=f)\n", " print(f\"our estimated params: {nparams/1e6:.4f}M, chinchilla params: {p/1e6:.4f}M, d_model: {d}, n_heads: {h}, n_layers: {l}\")" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "We are almost able to reproduce the parameter counts for the Chinchilla models. TODO resolve...\n", "\n", "Now turning to FLOPs:" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## flops" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "def chinchilla_flops(seq_len, vocab_size, d_model, num_heads, num_layers, ffw_size):\n", " \"\"\" \n", " Calculate total number of FLOPs, see Chinchilla \n", " paper Appendix F as reference: https://arxiv.org/pdf/2203.15556.pdf\n", " \"\"\" \n", " key_size = d_model // num_heads\n", "\n", " # embeddings\n", " embeddings = 2 * seq_len * vocab_size * d_model\n", "\n", " # attention\n", " # key, query, value projections\n", " attention = 2 * 3 * seq_len * d_model * (key_size * num_heads)\n", " # key @ query logits\n", " attlogits = 2 * seq_len * seq_len * (key_size * num_heads)\n", " # softmax\n", " attsoftmax = 3 * num_heads * seq_len * seq_len # 3* is for subtract (max), exp, divide (?)\n", " # softmax @ value reductions\n", " attvalue = 2 * seq_len * seq_len * (key_size * num_heads)\n", " # final linear\n", " attlinear = 2 * seq_len * (key_size * num_heads) * d_model\n", " att = attention + attlogits + attsoftmax + attvalue + attlinear\n", " # feed forward\n", " dense = 2 * seq_len * (d_model * ffw_size + d_model * ffw_size)\n", "\n", " # logits\n", " logits = 2 * seq_len * d_model * vocab_size\n", " \n", " # this is what you'd expect:\n", " # forward_flops = embeddings + num_layers * (att + dense) + logits\n", " # but:\n", " # per author correspondence apparently there is typo in the paper,\n", " # they do not count embeddings and logits to repro table 4. So instead:\n", " forward_flops = num_layers * (att + dense)\n", " backward_flops = 2 * forward_flops # as in Kaplan et al. 2020\n", " total_flops = forward_flops + backward_flops\n", "\n", " return total_flops\n" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", " | seq_len | \n", "vocab_size | \n", "d_model | \n", "num_heads | \n", "num_layers | \n", "ffw_size | \n", "N | \n", "F | \n", "approx_flops | \n", "chinch_flops | \n", "ratio | \n", "
---|---|---|---|---|---|---|---|---|---|---|---|
0 | \n", "2048 | \n", "32000 | \n", "640 | \n", "10 | \n", "10 | \n", "2560 | \n", "73825280 | \n", "929877196800 | \n", "907165040640 | \n", "9.298772e+11 | \n", "1.025036 | \n", "
1 | \n", "2048 | \n", "32000 | \n", "1024 | \n", "16 | \n", "20 | \n", "4096 | \n", "305707008 | \n", "4135248199680 | \n", "3756527714304 | \n", "4.135248e+12 | \n", "1.100817 | \n", "
2 | \n", "2048 | \n", "32000 | \n", "1280 | \n", "10 | \n", "24 | \n", "5120 | \n", "552604160 | \n", "7353453772800 | \n", "6790399918080 | \n", "7.353454e+12 | \n", "1.082919 | \n", "
3 | \n", "2048 | \n", "32000 | \n", "1792 | \n", "14 | \n", "26 | \n", "7168 | \n", "1143453696 | \n", "14670316437504 | \n", "14050759016448 | \n", "1.467032e+13 | \n", "1.044094 | \n", "
4 | \n", "2048 | \n", "32000 | \n", "2048 | \n", "16 | \n", "28 | \n", "8192 | \n", "1593126912 | \n", "20220437594112 | \n", "19576343494656 | \n", "2.022044e+13 | \n", "1.032902 | \n", "
5 | \n", "2048 | \n", "32000 | \n", "3584 | \n", "28 | \n", "40 | \n", "14336 | \n", "6796274688 | \n", "83021046743040 | \n", "83512623366144 | \n", "8.302105e+13 | \n", "0.994114 | \n", "