1
0
mirror of https://github.com/osmarks/nanogpt-experiments.git synced 2024-12-18 14:10:28 +00:00

progress! based on chinchilla author correspondence

This commit is contained in:
Andrej Karpathy 2023-01-07 02:42:30 +00:00
parent 27fc6a4112
commit d56bdf05a6

View File

@ -16,35 +16,23 @@
"source": [ "source": [
"import matplotlib.pyplot as plt\n", "import matplotlib.pyplot as plt\n",
"import numpy as np\n", "import numpy as np\n",
"import pandas as pd\n",
"%matplotlib inline" "%matplotlib inline"
] ]
}, },
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"gpt2 = dict(\n",
" seq_len = 1024,\n",
" vocab_size = 50257,\n",
" d_model = 768,\n",
" num_heads = 12,\n",
" num_layers = 12,\n",
")"
]
},
{ {
"attachments": {}, "attachments": {},
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
"source": [ "source": [
"## params" "## params\n",
"\n",
"First some parameter calculations:"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 3, "execution_count": 2,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
@ -53,21 +41,22 @@
"123.653376" "123.653376"
] ]
}, },
"execution_count": 3, "execution_count": 2,
"metadata": {}, "metadata": {},
"output_type": "execute_result" "output_type": "execute_result"
} }
], ],
"source": [ "source": [
"def params(seq_len, vocab_size, d_model, num_heads, num_layers, ffw_size=4):\n", "def gpt_params(seq_len, vocab_size, d_model, num_heads, num_layers):\n",
" \"\"\" Given GPT config calculate total number of parameters\"\"\" \n", " \"\"\" Given GPT config calculate total number of parameters \"\"\"\n",
" ffw_size = 4*d_model # in GPT the number of intermediate features is always 4*d_model\n",
" # token and position embeddings\n", " # token and position embeddings\n",
" embeddings = d_model * vocab_size + d_model * seq_len\n", " embeddings = d_model * vocab_size + d_model * seq_len\n",
" # transformer blocks\n", " # transformer blocks\n",
" attention = 3*d_model**2 + 3*d_model # weights and biases\n", " attention = 3*d_model**2 + 3*d_model # weights and biases\n",
" attproj = d_model**2 + d_model\n", " attproj = d_model**2 + d_model\n",
" ffw = d_model*(ffw_size*d_model) + ffw_size*d_model\n", " ffw = d_model*(ffw_size) + ffw_size\n",
" ffwproj = ffw_size*d_model*d_model + d_model\n", " ffwproj = ffw_size*d_model + d_model\n",
" layernorms = 2*2*d_model\n", " layernorms = 2*2*d_model\n",
" # dense\n", " # dense\n",
" ln_f = 2*d_model\n", " ln_f = 2*d_model\n",
@ -76,7 +65,8 @@
" total_params = num_layers*(attention + attproj + ffw + ffwproj + layernorms) + ln_f + dense\n", " total_params = num_layers*(attention + attproj + ffw + ffwproj + layernorms) + ln_f + dense\n",
" return total_params\n", " return total_params\n",
"\n", "\n",
"params(**gpt2)/1e6" "gpt2 = dict(seq_len = 1024, vocab_size = 50257, d_model = 768, num_heads = 12, num_layers = 12)\n",
"gpt_params(**gpt2)/1e6"
] ]
}, },
{ {
@ -84,7 +74,33 @@
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
"source": [ "source": [
"OpenAI reports gpt2 (small) as having 124M params, so this is a match. Also, loading the OpenAI weights into nanoGPT and then calling `model.parameters()` exactly matches the above number and verifies the implementation." "OpenAI reports gpt2 (small) as having 124M params, so this is a match. Also, loading the OpenAI weights into nanoGPT and then calling `model.parameters()` exactly matches the above number and verifies the implementation. Now Chinchilla parameters:"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"def chinchilla_params(seq_len, vocab_size, d_model, num_heads, num_layers, ffw_size):\n",
" \"\"\" Parameters in the Chinchilla models. Unlike GPT they use relative positional embeddings. \"\"\"\n",
" # token embeddings only\n",
" embeddings = d_model * vocab_size\n",
" # transformer blocks\n",
" attention = 3*d_model**2 + 3*d_model # weights and biases\n",
" relative_pos = d_model**2 + 2*d_model # relative keys, content bias, relative bias\n",
" attproj = d_model**2 + d_model\n",
" ffw = d_model*ffw_size + ffw_size\n",
" ffwproj = ffw_size*d_model + d_model\n",
" layernorms = 2*2*d_model\n",
" # dense\n",
" ln_f = 2*d_model\n",
" dense = d_model*vocab_size # note: no bias here\n",
" # note: embeddings are not included in the param count!\n",
" total_params = num_layers*(attention + relative_pos + attproj + ffw + ffwproj + layernorms) + ln_f + dense\n",
" return total_params\n",
"\n"
] ]
}, },
{ {
@ -104,6 +120,7 @@
} }
], ],
"source": [ "source": [
"# Load in all the 50 Chinchilla models on the last page of the paper\n",
"import json\n", "import json\n",
"chinchilla_models_txt = '[[44000000.0, 512, 2048, 64, 8, 8], [57000000.0, 576, 2304, 64, 9, 9], [74000000.0, 640, 2560, 64, 10, 10], [90000000.0, 640, 2560, 64, 10, 13], [106000000.0, 640, 2560, 64, 10, 16], [117000000.0, 768, 3072, 64, 12, 12], [140000000.0, 768, 3072, 64, 12, 15], [163000000.0, 768, 3072, 64, 12, 18], [175000000.0, 896, 3584, 64, 14, 14], [196000000.0, 896, 3584, 64, 14, 16], [217000000.0, 896, 3584, 64, 14, 18], [251000000.0, 1024, 4096, 64, 16, 16], [278000000.0, 1024, 4096, 64, 16, 18], [306000000.0, 1024, 4096, 64, 16, 20], [425000000.0, 1280, 5120, 128, 10, 18], [489000000.0, 1280, 5120, 128, 10, 21], [509000000.0, 1408, 5632, 128, 11, 18], [552000000.0, 1280, 5120, 128, 10, 24], [587000000.0, 1408, 5632, 128, 11, 21], [632000000.0, 1536, 6144, 128, 12, 19], [664000000.0, 1408, 5632, 128, 11, 24], [724000000.0, 1536, 6144, 128, 12, 22], [816000000.0, 1536, 6144, 128, 12, 25], [893000000.0, 1792, 7168, 128, 14, 20], [1018000000.0, 1792, 7168, 128, 14, 23], [1143000000.0, 1792, 7168, 128, 14, 26], [1266000000.0, 2048, 8192, 128, 16, 22], [1424000000.0, 2176, 8704, 128, 17, 22], [1429000000.0, 2048, 8192, 128, 16, 25], [1593000000.0, 2048, 8192, 128, 16, 28], [1609000000.0, 2176, 8704, 128, 17, 25], [1731000000.0, 2304, 9216, 128, 18, 24], [1794000000.0, 2176, 8704, 128, 17, 28], [2007000000.0, 2304, 9216, 128, 18, 28], [2283000000.0, 2304, 9216, 128, 18, 32], [2298000000.0, 2560, 10240, 128, 20, 26], [2639000000.0, 2560, 10240, 128, 20, 30], [2980000000.0, 2560, 10240, 128, 20, 34], [3530000000.0, 2688, 10752, 128, 22, 36], [3802000000.0, 2816, 11264, 128, 22, 36], [4084000000.0, 2944, 11776, 128, 22, 36], [4516000000.0, 3072, 12288, 128, 24, 36], [6796000000.0, 3584, 14336, 128, 28, 40], [9293000000.0, 4096, 16384, 128, 32, 42], [11452000000.0, 4352, 17408, 128, 32, 47], [12295000000.0, 4608, 18432, 128, 36, 44], [12569000000.0, 4608, 18432, 128, 32, 47], [13735000000.0, 4864, 19456, 128, 32, 47], [14940000000.0, 4992, 19968, 128, 32, 49], [16183000000.0, 5120, 20480, 128, 40, 47]]'\n", "chinchilla_models_txt = '[[44000000.0, 512, 2048, 64, 8, 8], [57000000.0, 576, 2304, 64, 9, 9], [74000000.0, 640, 2560, 64, 10, 10], [90000000.0, 640, 2560, 64, 10, 13], [106000000.0, 640, 2560, 64, 10, 16], [117000000.0, 768, 3072, 64, 12, 12], [140000000.0, 768, 3072, 64, 12, 15], [163000000.0, 768, 3072, 64, 12, 18], [175000000.0, 896, 3584, 64, 14, 14], [196000000.0, 896, 3584, 64, 14, 16], [217000000.0, 896, 3584, 64, 14, 18], [251000000.0, 1024, 4096, 64, 16, 16], [278000000.0, 1024, 4096, 64, 16, 18], [306000000.0, 1024, 4096, 64, 16, 20], [425000000.0, 1280, 5120, 128, 10, 18], [489000000.0, 1280, 5120, 128, 10, 21], [509000000.0, 1408, 5632, 128, 11, 18], [552000000.0, 1280, 5120, 128, 10, 24], [587000000.0, 1408, 5632, 128, 11, 21], [632000000.0, 1536, 6144, 128, 12, 19], [664000000.0, 1408, 5632, 128, 11, 24], [724000000.0, 1536, 6144, 128, 12, 22], [816000000.0, 1536, 6144, 128, 12, 25], [893000000.0, 1792, 7168, 128, 14, 20], [1018000000.0, 1792, 7168, 128, 14, 23], [1143000000.0, 1792, 7168, 128, 14, 26], [1266000000.0, 2048, 8192, 128, 16, 22], [1424000000.0, 2176, 8704, 128, 17, 22], [1429000000.0, 2048, 8192, 128, 16, 25], [1593000000.0, 2048, 8192, 128, 16, 28], [1609000000.0, 2176, 8704, 128, 17, 25], [1731000000.0, 2304, 9216, 128, 18, 24], [1794000000.0, 2176, 8704, 128, 17, 28], [2007000000.0, 2304, 9216, 128, 18, 28], [2283000000.0, 2304, 9216, 128, 18, 32], [2298000000.0, 2560, 10240, 128, 20, 26], [2639000000.0, 2560, 10240, 128, 20, 30], [2980000000.0, 2560, 10240, 128, 20, 34], [3530000000.0, 2688, 10752, 128, 22, 36], [3802000000.0, 2816, 11264, 128, 22, 36], [4084000000.0, 2944, 11776, 128, 22, 36], [4516000000.0, 3072, 12288, 128, 24, 36], [6796000000.0, 3584, 14336, 128, 28, 40], [9293000000.0, 4096, 16384, 128, 32, 42], [11452000000.0, 4352, 17408, 128, 32, 47], [12295000000.0, 4608, 18432, 128, 36, 44], [12569000000.0, 4608, 18432, 128, 32, 47], [13735000000.0, 4864, 19456, 128, 32, 47], [14940000000.0, 4992, 19968, 128, 32, 49], [16183000000.0, 5120, 20480, 128, 40, 47]]'\n",
"chilchilla_models = json.loads(chinchilla_models_txt) # all 50 models\n", "chilchilla_models = json.loads(chinchilla_models_txt) # all 50 models\n",
@ -119,18 +136,63 @@
"name": "stdout", "name": "stdout",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"our estimated params: 41.6041M, chinchilla params: 44.0000M, d_model: 512, n_heads: 8, n_layers: 8\n", "our estimated params: 43.7094M, chinchilla params: 44.0000M, d_model: 512, n_heads: 8, n_layers: 8\n",
"our estimated params: 54.3324M, chinchilla params: 57.0000M, d_model: 576, n_heads: 9, n_layers: 9\n", "our estimated params: 57.3287M, chinchilla params: 57.0000M, d_model: 576, n_heads: 9, n_layers: 9\n",
"our estimated params: 69.7165M, chinchilla params: 74.0000M, d_model: 640, n_heads: 10, n_layers: 10\n", "our estimated params: 73.8253M, chinchilla params: 74.0000M, d_model: 640, n_heads: 10, n_layers: 10\n",
"our estimated params: 84.4870M, chinchilla params: 90.0000M, d_model: 640, n_heads: 10, n_layers: 13\n", "our estimated params: 89.8285M, chinchilla params: 90.0000M, d_model: 640, n_heads: 10, n_layers: 13\n",
"our estimated params: 99.2576M, chinchilla params: 106.0000M, d_model: 640, n_heads: 10, n_layers: 16\n" "our estimated params: 105.8317M, chinchilla params: 106.0000M, d_model: 640, n_heads: 10, n_layers: 16\n",
"our estimated params: 116.7283M, chinchilla params: 117.0000M, d_model: 768, n_heads: 12, n_layers: 12\n",
"our estimated params: 139.7660M, chinchilla params: 140.0000M, d_model: 768, n_heads: 12, n_layers: 15\n",
"our estimated params: 162.8037M, chinchilla params: 163.0000M, d_model: 768, n_heads: 12, n_layers: 18\n",
"our estimated params: 174.9745M, chinchilla params: 175.0000M, d_model: 896, n_heads: 14, n_layers: 14\n",
"our estimated params: 195.8746M, chinchilla params: 196.0000M, d_model: 896, n_heads: 14, n_layers: 16\n",
"our estimated params: 216.7747M, chinchilla params: 217.0000M, d_model: 896, n_heads: 14, n_layers: 18\n",
"our estimated params: 251.1196M, chinchilla params: 251.0000M, d_model: 1024, n_heads: 16, n_layers: 16\n",
"our estimated params: 278.4133M, chinchilla params: 278.0000M, d_model: 1024, n_heads: 16, n_layers: 18\n",
"our estimated params: 305.7070M, chinchilla params: 306.0000M, d_model: 1024, n_heads: 16, n_layers: 20\n",
"our estimated params: 424.6938M, chinchilla params: 425.0000M, d_model: 1280, n_heads: 10, n_layers: 18\n",
"our estimated params: 488.6490M, chinchilla params: 489.0000M, d_model: 1280, n_heads: 10, n_layers: 21\n",
"our estimated params: 509.3356M, chinchilla params: 509.0000M, d_model: 1408, n_heads: 11, n_layers: 18\n",
"our estimated params: 552.6042M, chinchilla params: 552.0000M, d_model: 1280, n_heads: 10, n_layers: 24\n",
"our estimated params: 586.7150M, chinchilla params: 587.0000M, d_model: 1408, n_heads: 11, n_layers: 21\n",
"our estimated params: 632.3389M, chinchilla params: 632.0000M, d_model: 1536, n_heads: 12, n_layers: 19\n",
"our estimated params: 664.0945M, chinchilla params: 664.0000M, d_model: 1408, n_heads: 11, n_layers: 24\n",
"our estimated params: 724.4206M, chinchilla params: 724.0000M, d_model: 1536, n_heads: 12, n_layers: 22\n",
"our estimated params: 816.5023M, chinchilla params: 816.0000M, d_model: 1536, n_heads: 12, n_layers: 25\n",
"our estimated params: 892.8138M, chinchilla params: 893.0000M, d_model: 1792, n_heads: 14, n_layers: 20\n",
"our estimated params: 1018.1338M, chinchilla params: 1018.0000M, d_model: 1792, n_heads: 14, n_layers: 23\n",
"our estimated params: 1143.4537M, chinchilla params: 1143.0000M, d_model: 1792, n_heads: 14, n_layers: 26\n",
"our estimated params: 1265.7869M, chinchilla params: 1266.0000M, d_model: 2048, n_heads: 16, n_layers: 22\n",
"our estimated params: 1424.5576M, chinchilla params: 1424.0000M, d_model: 2176, n_heads: 17, n_layers: 22\n",
"our estimated params: 1429.4569M, chinchilla params: 1429.0000M, d_model: 2048, n_heads: 16, n_layers: 25\n",
"our estimated params: 1593.1269M, chinchilla params: 1593.0000M, d_model: 2048, n_heads: 16, n_layers: 28\n",
"our estimated params: 1609.3196M, chinchilla params: 1609.0000M, d_model: 2176, n_heads: 17, n_layers: 25\n",
"our estimated params: 1730.7878M, chinchilla params: 1731.0000M, d_model: 2304, n_heads: 18, n_layers: 24\n",
"our estimated params: 1794.0815M, chinchilla params: 1794.0000M, d_model: 2176, n_heads: 17, n_layers: 28\n",
"our estimated params: 2006.9637M, chinchilla params: 2007.0000M, d_model: 2304, n_heads: 18, n_layers: 28\n",
"our estimated params: 2283.1396M, chinchilla params: 2283.0000M, d_model: 2304, n_heads: 18, n_layers: 32\n",
"our estimated params: 2298.0403M, chinchilla params: 2298.0000M, d_model: 2560, n_heads: 20, n_layers: 26\n",
"our estimated params: 2638.9811M, chinchilla params: 2639.0000M, d_model: 2560, n_heads: 20, n_layers: 30\n",
"our estimated params: 2979.9219M, chinchilla params: 2980.0000M, d_model: 2560, n_heads: 20, n_layers: 34\n",
"our estimated params: 3468.9339M, chinchilla params: 3530.0000M, d_model: 2688, n_heads: 22, n_layers: 36\n",
"our estimated params: 3802.8109M, chinchilla params: 3802.0000M, d_model: 2816, n_heads: 22, n_layers: 36\n",
"our estimated params: 4152.0233M, chinchilla params: 4084.0000M, d_model: 2944, n_heads: 22, n_layers: 36\n",
"our estimated params: 4516.5711M, chinchilla params: 4516.0000M, d_model: 3072, n_heads: 24, n_layers: 36\n",
"our estimated params: 6796.2747M, chinchilla params: 6796.0000M, d_model: 3584, n_heads: 28, n_layers: 40\n",
"our estimated params: 9294.0206M, chinchilla params: 9293.0000M, d_model: 4096, n_heads: 32, n_layers: 42\n",
"our estimated params: 11714.6222M, chinchilla params: 11452.0000M, d_model: 4352, n_heads: 32, n_layers: 47\n",
"our estimated params: 12296.1623M, chinchilla params: 12295.0000M, d_model: 4608, n_heads: 36, n_layers: 44\n",
"our estimated params: 13124.4826M, chinchilla params: 12569.0000M, d_model: 4608, n_heads: 32, n_layers: 47\n",
"our estimated params: 14614.4279M, chinchilla params: 13735.0000M, d_model: 4864, n_heads: 32, n_layers: 47\n",
"our estimated params: 16037.5039M, chinchilla params: 14940.0000M, d_model: 4992, n_heads: 32, n_layers: 49\n",
"our estimated params: 16184.4582M, chinchilla params: 16183.0000M, d_model: 5120, n_heads: 40, n_layers: 47\n"
] ]
} }
], ],
"source": [ "source": [
"for m in chilchilla_models[:5]: # look at just first 5 models\n", "for m in chilchilla_models:\n",
" p, d, f, k, h, l = m\n", " p, d, f, k, h, l = m\n",
" nparams = params(seq_len = 1024, vocab_size = 32000, d_model = d, num_heads = h, num_layers = l)\n", " nparams = chinchilla_params(seq_len = 1024, vocab_size = 32000, d_model = d, num_heads = h, num_layers = l, ffw_size=f)\n",
" print(f\"our estimated params: {nparams/1e6:.4f}M, chinchilla params: {p/1e6:.4f}M, d_model: {d}, n_heads: {h}, n_layers: {l}\")" " print(f\"our estimated params: {nparams/1e6:.4f}M, chinchilla params: {p/1e6:.4f}M, d_model: {d}, n_heads: {h}, n_layers: {l}\")"
] ]
}, },
@ -139,7 +201,7 @@
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
"source": [ "source": [
"So we are not able to reproduce the parameter counts for the Chinchilla models. The claimed number of parameters is not the number of parameters we compute above. TODO: resolve... \n", "We are almost able to reproduce the parameter counts for the Chinchilla models. TODO resolve...\n",
"\n", "\n",
"Now turning to FLOPs:" "Now turning to FLOPs:"
] ]
@ -156,22 +218,11 @@
"cell_type": "code", "cell_type": "code",
"execution_count": 6, "execution_count": 6,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [],
{
"data": {
"text/plain": [
"766.006788096"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [ "source": [
"def flops(seq_len, vocab_size, d_model, num_heads, num_layers, ffw_size=4):\n", "def chinchilla_flops(seq_len, vocab_size, d_model, num_heads, num_layers, ffw_size):\n",
" \"\"\" \n", " \"\"\" \n",
" Given GPT config calculate total number of FLOPs, see Chinchilla \n", " Calculate total number of FLOPs, see Chinchilla \n",
" paper Appendix F as reference: https://arxiv.org/pdf/2203.15556.pdf\n", " paper Appendix F as reference: https://arxiv.org/pdf/2203.15556.pdf\n",
" \"\"\" \n", " \"\"\" \n",
" key_size = d_model // num_heads\n", " key_size = d_model // num_heads\n",
@ -196,14 +247,17 @@
"\n", "\n",
" # logits\n", " # logits\n",
" logits = 2 * seq_len * d_model * vocab_size\n", " logits = 2 * seq_len * d_model * vocab_size\n",
"\n", " \n",
" forward_flops = embeddings + num_layers * (att + dense) + logits\n", " # this is what you'd expect:\n",
" # forward_flops = embeddings + num_layers * (att + dense) + logits\n",
" # but:\n",
" # per author correspondence apparently there is typo in the paper,\n",
" # they do not count embeddings and logits to repro table 4. So instead:\n",
" forward_flops = num_layers * (att + dense)\n",
" backward_flops = 2 * forward_flops # as in Kaplan et al. 2020\n", " backward_flops = 2 * forward_flops # as in Kaplan et al. 2020\n",
" total_flops = forward_flops + backward_flops\n", " total_flops = forward_flops + backward_flops\n",
"\n", "\n",
" return total_flops\n", " return total_flops\n"
"\n",
"flops(**gpt2)/1e9"
] ]
}, },
{ {
@ -212,14 +266,149 @@
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
"name": "stdout", "data": {
"output_type": "stream", "text/html": [
"text": [ "<div>\n",
"params: 69.72M\n", "<style scoped>\n",
"approx flops: 428.34B\n", " .dataframe tbody tr th:only-of-type {\n",
"chinchilla flops: 434.11B\n", " vertical-align: middle;\n",
"ratio (chinchilla / approx): 1.01\n" " }\n",
] "\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>seq_len</th>\n",
" <th>vocab_size</th>\n",
" <th>d_model</th>\n",
" <th>num_heads</th>\n",
" <th>num_layers</th>\n",
" <th>ffw_size</th>\n",
" <th>N</th>\n",
" <th>F</th>\n",
" <th>approx_flops</th>\n",
" <th>chinch_flops</th>\n",
" <th>ratio</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>2048</td>\n",
" <td>32000</td>\n",
" <td>640</td>\n",
" <td>10</td>\n",
" <td>10</td>\n",
" <td>2560</td>\n",
" <td>73825280</td>\n",
" <td>929877196800</td>\n",
" <td>907165040640</td>\n",
" <td>9.298772e+11</td>\n",
" <td>1.025036</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>2048</td>\n",
" <td>32000</td>\n",
" <td>1024</td>\n",
" <td>16</td>\n",
" <td>20</td>\n",
" <td>4096</td>\n",
" <td>305707008</td>\n",
" <td>4135248199680</td>\n",
" <td>3756527714304</td>\n",
" <td>4.135248e+12</td>\n",
" <td>1.100817</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>2048</td>\n",
" <td>32000</td>\n",
" <td>1280</td>\n",
" <td>10</td>\n",
" <td>24</td>\n",
" <td>5120</td>\n",
" <td>552604160</td>\n",
" <td>7353453772800</td>\n",
" <td>6790399918080</td>\n",
" <td>7.353454e+12</td>\n",
" <td>1.082919</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>2048</td>\n",
" <td>32000</td>\n",
" <td>1792</td>\n",
" <td>14</td>\n",
" <td>26</td>\n",
" <td>7168</td>\n",
" <td>1143453696</td>\n",
" <td>14670316437504</td>\n",
" <td>14050759016448</td>\n",
" <td>1.467032e+13</td>\n",
" <td>1.044094</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>2048</td>\n",
" <td>32000</td>\n",
" <td>2048</td>\n",
" <td>16</td>\n",
" <td>28</td>\n",
" <td>8192</td>\n",
" <td>1593126912</td>\n",
" <td>20220437594112</td>\n",
" <td>19576343494656</td>\n",
" <td>2.022044e+13</td>\n",
" <td>1.032902</td>\n",
" </tr>\n",
" <tr>\n",
" <th>5</th>\n",
" <td>2048</td>\n",
" <td>32000</td>\n",
" <td>3584</td>\n",
" <td>28</td>\n",
" <td>40</td>\n",
" <td>14336</td>\n",
" <td>6796274688</td>\n",
" <td>83021046743040</td>\n",
" <td>83512623366144</td>\n",
" <td>8.302105e+13</td>\n",
" <td>0.994114</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" seq_len vocab_size d_model num_heads num_layers ffw_size N \\\n",
"0 2048 32000 640 10 10 2560 73825280 \n",
"1 2048 32000 1024 16 20 4096 305707008 \n",
"2 2048 32000 1280 10 24 5120 552604160 \n",
"3 2048 32000 1792 14 26 7168 1143453696 \n",
"4 2048 32000 2048 16 28 8192 1593126912 \n",
"5 2048 32000 3584 28 40 14336 6796274688 \n",
"\n",
" F approx_flops chinch_flops ratio \n",
"0 929877196800 907165040640 9.298772e+11 1.025036 \n",
"1 4135248199680 3756527714304 4.135248e+12 1.100817 \n",
"2 7353453772800 6790399918080 7.353454e+12 1.082919 \n",
"3 14670316437504 14050759016448 1.467032e+13 1.044094 \n",
"4 20220437594112 19576343494656 2.022044e+13 1.032902 \n",
"5 83021046743040 83512623366144 8.302105e+13 0.994114 "
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
} }
], ],
"source": [ "source": [
@ -227,19 +416,47 @@
"# comparing accurate flops above to approximate flops F = 6*N*D\n", "# comparing accurate flops above to approximate flops F = 6*N*D\n",
"# note Chinchilla mentions using vocab_size = 32K\n", "# note Chinchilla mentions using vocab_size = 32K\n",
"\n", "\n",
"args = dict(seq_len = 1024, vocab_size = 32000, d_model = 640, num_heads = 10, num_layers = 10) # chinchilla 73M\n", "chilchilla_models_table4 = [\n",
" [10, 640, 2560, 10, 64],\n",
" [20, 1024, 4096, 16, 64],\n",
" [24, 1280, 5120, 10, 128 ],\n",
" [26, 1792, 7168, 14, 128 ],\n",
" [28, 2048, 8192, 16, 128],\n",
" [40, 3584, 14336, 28, 128]\n",
"]\n",
"\n", "\n",
"D = 1024 # dataset size (cancels anyway, for the purposes of the ratio calculation below)\n", "rows = []\n",
"N = params(**args)\n", "for num_layers, d_model, ffw_size, num_heads, _ in chilchilla_models_table4:\n",
"F = flops(**args)\n",
"\n", "\n",
"approx_flops = 6*D*N # approximate flops\n", " args = dict(seq_len = 2048, vocab_size = 32000, d_model = d_model, \n",
"chinch_flops = F * (float(D) / args['seq_len']) # exact flops according to Chinchilla paper calculations\n", " num_heads = num_heads, num_layers = num_layers, ffw_size=ffw_size)\n",
"\n", "\n",
"print(f\"params: {N/1e6:.2f}M\")\n", " D = args['seq_len'] # dataset size (cancels anyway, for the purposes of the ratio calculation below)\n",
"print(f\"approx flops: {approx_flops/1e9:.2f}B\")\n", " N = chinchilla_params(**args)\n",
"print(f\"chinchilla flops: {chinch_flops/1e9:.2f}B\")\n", " F = chinchilla_flops(**args)\n",
"print(f\"ratio (chinchilla / approx): {chinch_flops / approx_flops:.2f}\")" "\n",
" approx_flops = 6*D*N # approximate flops\n",
" chinch_flops = F * (float(D) / args['seq_len']) # exact flops according to Chinchilla paper calculations\n",
"\n",
" # print('---')\n",
" # print(f\"params: {N/1e6:.2f}M\")\n",
" # print(f\"approx flops: {approx_flops/1e9:.2f}B\")\n",
" # print(f\"chinchilla flops: {chinch_flops/1e9:.2f}B\")\n",
" # print(f\"ratio (chinchilla / approx): {chinch_flops / approx_flops:.2f}\")\n",
"\n",
" # first copy all keyvalues from args into out\n",
" out = {k:v for k,v in args.items()}\n",
" # then add the calculated values\n",
" out['N'] = N\n",
" out['F'] = F\n",
" out['approx_flops'] = approx_flops\n",
" out['chinch_flops'] = chinch_flops\n",
" out['ratio'] = chinch_flops / approx_flops\n",
" rows.append(out)\n",
"\n",
"# make a pandas dataframe from rows\n",
"df = pd.DataFrame(rows)\n",
"df"
] ]
}, },
{ {
@ -247,7 +464,7 @@
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
"source": [ "source": [
"Awkward because the ratio is supposed to be 1.03 but we get 1.01, e.g. for the 73M param model. So once again we don't reproduce the numbers. TODO resolve..." "Pretty good match! Except the param counts are still not perfectly accurate."
] ]
}, },
{ {
@ -266,7 +483,7 @@
{ {
"data": { "data": {
"text/plain": [ "text/plain": [
"<matplotlib.colorbar.Colorbar at 0x7fc09dffc730>" "<matplotlib.colorbar.Colorbar at 0x7f9d2e9ba9e0>"
] ]
}, },
"execution_count": 8, "execution_count": 8,
@ -474,13 +691,8 @@
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
"source": [ "source": [
"TLDR atm: nothing reproduces." "TLDR atm: nothing reproduces, but progress is being made."
] ]
},
{
"cell_type": "markdown",
"metadata": {},
"source": []
} }
], ],
"metadata": { "metadata": {