mirror of
https://github.com/osmarks/nanogpt-experiments.git
synced 2024-11-14 05:44:51 +00:00
progress! based on chinchilla author correspondence
This commit is contained in:
parent
27fc6a4112
commit
d56bdf05a6
@ -16,35 +16,23 @@
|
|||||||
"source": [
|
"source": [
|
||||||
"import matplotlib.pyplot as plt\n",
|
"import matplotlib.pyplot as plt\n",
|
||||||
"import numpy as np\n",
|
"import numpy as np\n",
|
||||||
|
"import pandas as pd\n",
|
||||||
"%matplotlib inline"
|
"%matplotlib inline"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 2,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"gpt2 = dict(\n",
|
|
||||||
" seq_len = 1024,\n",
|
|
||||||
" vocab_size = 50257,\n",
|
|
||||||
" d_model = 768,\n",
|
|
||||||
" num_heads = 12,\n",
|
|
||||||
" num_layers = 12,\n",
|
|
||||||
")"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
"attachments": {},
|
"attachments": {},
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"source": [
|
"source": [
|
||||||
"## params"
|
"## params\n",
|
||||||
|
"\n",
|
||||||
|
"First some parameter calculations:"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 3,
|
"execution_count": 2,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
{
|
{
|
||||||
@ -53,21 +41,22 @@
|
|||||||
"123.653376"
|
"123.653376"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
"execution_count": 3,
|
"execution_count": 2,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"output_type": "execute_result"
|
"output_type": "execute_result"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"def params(seq_len, vocab_size, d_model, num_heads, num_layers, ffw_size=4):\n",
|
"def gpt_params(seq_len, vocab_size, d_model, num_heads, num_layers):\n",
|
||||||
" \"\"\" Given GPT config calculate total number of parameters\"\"\" \n",
|
" \"\"\" Given GPT config calculate total number of parameters \"\"\"\n",
|
||||||
|
" ffw_size = 4*d_model # in GPT the number of intermediate features is always 4*d_model\n",
|
||||||
" # token and position embeddings\n",
|
" # token and position embeddings\n",
|
||||||
" embeddings = d_model * vocab_size + d_model * seq_len\n",
|
" embeddings = d_model * vocab_size + d_model * seq_len\n",
|
||||||
" # transformer blocks\n",
|
" # transformer blocks\n",
|
||||||
" attention = 3*d_model**2 + 3*d_model # weights and biases\n",
|
" attention = 3*d_model**2 + 3*d_model # weights and biases\n",
|
||||||
" attproj = d_model**2 + d_model\n",
|
" attproj = d_model**2 + d_model\n",
|
||||||
" ffw = d_model*(ffw_size*d_model) + ffw_size*d_model\n",
|
" ffw = d_model*(ffw_size) + ffw_size\n",
|
||||||
" ffwproj = ffw_size*d_model*d_model + d_model\n",
|
" ffwproj = ffw_size*d_model + d_model\n",
|
||||||
" layernorms = 2*2*d_model\n",
|
" layernorms = 2*2*d_model\n",
|
||||||
" # dense\n",
|
" # dense\n",
|
||||||
" ln_f = 2*d_model\n",
|
" ln_f = 2*d_model\n",
|
||||||
@ -76,7 +65,8 @@
|
|||||||
" total_params = num_layers*(attention + attproj + ffw + ffwproj + layernorms) + ln_f + dense\n",
|
" total_params = num_layers*(attention + attproj + ffw + ffwproj + layernorms) + ln_f + dense\n",
|
||||||
" return total_params\n",
|
" return total_params\n",
|
||||||
"\n",
|
"\n",
|
||||||
"params(**gpt2)/1e6"
|
"gpt2 = dict(seq_len = 1024, vocab_size = 50257, d_model = 768, num_heads = 12, num_layers = 12)\n",
|
||||||
|
"gpt_params(**gpt2)/1e6"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -84,7 +74,33 @@
|
|||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"source": [
|
"source": [
|
||||||
"OpenAI reports gpt2 (small) as having 124M params, so this is a match. Also, loading the OpenAI weights into nanoGPT and then calling `model.parameters()` exactly matches the above number and verifies the implementation."
|
"OpenAI reports gpt2 (small) as having 124M params, so this is a match. Also, loading the OpenAI weights into nanoGPT and then calling `model.parameters()` exactly matches the above number and verifies the implementation. Now Chinchilla parameters:"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 3,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"def chinchilla_params(seq_len, vocab_size, d_model, num_heads, num_layers, ffw_size):\n",
|
||||||
|
" \"\"\" Parameters in the Chinchilla models. Unlike GPT they use relative positional embeddings. \"\"\"\n",
|
||||||
|
" # token embeddings only\n",
|
||||||
|
" embeddings = d_model * vocab_size\n",
|
||||||
|
" # transformer blocks\n",
|
||||||
|
" attention = 3*d_model**2 + 3*d_model # weights and biases\n",
|
||||||
|
" relative_pos = d_model**2 + 2*d_model # relative keys, content bias, relative bias\n",
|
||||||
|
" attproj = d_model**2 + d_model\n",
|
||||||
|
" ffw = d_model*ffw_size + ffw_size\n",
|
||||||
|
" ffwproj = ffw_size*d_model + d_model\n",
|
||||||
|
" layernorms = 2*2*d_model\n",
|
||||||
|
" # dense\n",
|
||||||
|
" ln_f = 2*d_model\n",
|
||||||
|
" dense = d_model*vocab_size # note: no bias here\n",
|
||||||
|
" # note: embeddings are not included in the param count!\n",
|
||||||
|
" total_params = num_layers*(attention + relative_pos + attproj + ffw + ffwproj + layernorms) + ln_f + dense\n",
|
||||||
|
" return total_params\n",
|
||||||
|
"\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -104,6 +120,7 @@
|
|||||||
}
|
}
|
||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
|
"# Load in all the 50 Chinchilla models on the last page of the paper\n",
|
||||||
"import json\n",
|
"import json\n",
|
||||||
"chinchilla_models_txt = '[[44000000.0, 512, 2048, 64, 8, 8], [57000000.0, 576, 2304, 64, 9, 9], [74000000.0, 640, 2560, 64, 10, 10], [90000000.0, 640, 2560, 64, 10, 13], [106000000.0, 640, 2560, 64, 10, 16], [117000000.0, 768, 3072, 64, 12, 12], [140000000.0, 768, 3072, 64, 12, 15], [163000000.0, 768, 3072, 64, 12, 18], [175000000.0, 896, 3584, 64, 14, 14], [196000000.0, 896, 3584, 64, 14, 16], [217000000.0, 896, 3584, 64, 14, 18], [251000000.0, 1024, 4096, 64, 16, 16], [278000000.0, 1024, 4096, 64, 16, 18], [306000000.0, 1024, 4096, 64, 16, 20], [425000000.0, 1280, 5120, 128, 10, 18], [489000000.0, 1280, 5120, 128, 10, 21], [509000000.0, 1408, 5632, 128, 11, 18], [552000000.0, 1280, 5120, 128, 10, 24], [587000000.0, 1408, 5632, 128, 11, 21], [632000000.0, 1536, 6144, 128, 12, 19], [664000000.0, 1408, 5632, 128, 11, 24], [724000000.0, 1536, 6144, 128, 12, 22], [816000000.0, 1536, 6144, 128, 12, 25], [893000000.0, 1792, 7168, 128, 14, 20], [1018000000.0, 1792, 7168, 128, 14, 23], [1143000000.0, 1792, 7168, 128, 14, 26], [1266000000.0, 2048, 8192, 128, 16, 22], [1424000000.0, 2176, 8704, 128, 17, 22], [1429000000.0, 2048, 8192, 128, 16, 25], [1593000000.0, 2048, 8192, 128, 16, 28], [1609000000.0, 2176, 8704, 128, 17, 25], [1731000000.0, 2304, 9216, 128, 18, 24], [1794000000.0, 2176, 8704, 128, 17, 28], [2007000000.0, 2304, 9216, 128, 18, 28], [2283000000.0, 2304, 9216, 128, 18, 32], [2298000000.0, 2560, 10240, 128, 20, 26], [2639000000.0, 2560, 10240, 128, 20, 30], [2980000000.0, 2560, 10240, 128, 20, 34], [3530000000.0, 2688, 10752, 128, 22, 36], [3802000000.0, 2816, 11264, 128, 22, 36], [4084000000.0, 2944, 11776, 128, 22, 36], [4516000000.0, 3072, 12288, 128, 24, 36], [6796000000.0, 3584, 14336, 128, 28, 40], [9293000000.0, 4096, 16384, 128, 32, 42], [11452000000.0, 4352, 17408, 128, 32, 47], [12295000000.0, 4608, 18432, 128, 36, 44], [12569000000.0, 4608, 18432, 128, 32, 47], [13735000000.0, 4864, 19456, 128, 32, 47], [14940000000.0, 4992, 19968, 128, 32, 49], [16183000000.0, 5120, 20480, 128, 40, 47]]'\n",
|
"chinchilla_models_txt = '[[44000000.0, 512, 2048, 64, 8, 8], [57000000.0, 576, 2304, 64, 9, 9], [74000000.0, 640, 2560, 64, 10, 10], [90000000.0, 640, 2560, 64, 10, 13], [106000000.0, 640, 2560, 64, 10, 16], [117000000.0, 768, 3072, 64, 12, 12], [140000000.0, 768, 3072, 64, 12, 15], [163000000.0, 768, 3072, 64, 12, 18], [175000000.0, 896, 3584, 64, 14, 14], [196000000.0, 896, 3584, 64, 14, 16], [217000000.0, 896, 3584, 64, 14, 18], [251000000.0, 1024, 4096, 64, 16, 16], [278000000.0, 1024, 4096, 64, 16, 18], [306000000.0, 1024, 4096, 64, 16, 20], [425000000.0, 1280, 5120, 128, 10, 18], [489000000.0, 1280, 5120, 128, 10, 21], [509000000.0, 1408, 5632, 128, 11, 18], [552000000.0, 1280, 5120, 128, 10, 24], [587000000.0, 1408, 5632, 128, 11, 21], [632000000.0, 1536, 6144, 128, 12, 19], [664000000.0, 1408, 5632, 128, 11, 24], [724000000.0, 1536, 6144, 128, 12, 22], [816000000.0, 1536, 6144, 128, 12, 25], [893000000.0, 1792, 7168, 128, 14, 20], [1018000000.0, 1792, 7168, 128, 14, 23], [1143000000.0, 1792, 7168, 128, 14, 26], [1266000000.0, 2048, 8192, 128, 16, 22], [1424000000.0, 2176, 8704, 128, 17, 22], [1429000000.0, 2048, 8192, 128, 16, 25], [1593000000.0, 2048, 8192, 128, 16, 28], [1609000000.0, 2176, 8704, 128, 17, 25], [1731000000.0, 2304, 9216, 128, 18, 24], [1794000000.0, 2176, 8704, 128, 17, 28], [2007000000.0, 2304, 9216, 128, 18, 28], [2283000000.0, 2304, 9216, 128, 18, 32], [2298000000.0, 2560, 10240, 128, 20, 26], [2639000000.0, 2560, 10240, 128, 20, 30], [2980000000.0, 2560, 10240, 128, 20, 34], [3530000000.0, 2688, 10752, 128, 22, 36], [3802000000.0, 2816, 11264, 128, 22, 36], [4084000000.0, 2944, 11776, 128, 22, 36], [4516000000.0, 3072, 12288, 128, 24, 36], [6796000000.0, 3584, 14336, 128, 28, 40], [9293000000.0, 4096, 16384, 128, 32, 42], [11452000000.0, 4352, 17408, 128, 32, 47], [12295000000.0, 4608, 18432, 128, 36, 44], [12569000000.0, 4608, 18432, 128, 32, 47], [13735000000.0, 4864, 19456, 128, 32, 47], [14940000000.0, 4992, 19968, 128, 32, 49], [16183000000.0, 5120, 20480, 128, 40, 47]]'\n",
|
||||||
"chilchilla_models = json.loads(chinchilla_models_txt) # all 50 models\n",
|
"chilchilla_models = json.loads(chinchilla_models_txt) # all 50 models\n",
|
||||||
@ -119,18 +136,63 @@
|
|||||||
"name": "stdout",
|
"name": "stdout",
|
||||||
"output_type": "stream",
|
"output_type": "stream",
|
||||||
"text": [
|
"text": [
|
||||||
"our estimated params: 41.6041M, chinchilla params: 44.0000M, d_model: 512, n_heads: 8, n_layers: 8\n",
|
"our estimated params: 43.7094M, chinchilla params: 44.0000M, d_model: 512, n_heads: 8, n_layers: 8\n",
|
||||||
"our estimated params: 54.3324M, chinchilla params: 57.0000M, d_model: 576, n_heads: 9, n_layers: 9\n",
|
"our estimated params: 57.3287M, chinchilla params: 57.0000M, d_model: 576, n_heads: 9, n_layers: 9\n",
|
||||||
"our estimated params: 69.7165M, chinchilla params: 74.0000M, d_model: 640, n_heads: 10, n_layers: 10\n",
|
"our estimated params: 73.8253M, chinchilla params: 74.0000M, d_model: 640, n_heads: 10, n_layers: 10\n",
|
||||||
"our estimated params: 84.4870M, chinchilla params: 90.0000M, d_model: 640, n_heads: 10, n_layers: 13\n",
|
"our estimated params: 89.8285M, chinchilla params: 90.0000M, d_model: 640, n_heads: 10, n_layers: 13\n",
|
||||||
"our estimated params: 99.2576M, chinchilla params: 106.0000M, d_model: 640, n_heads: 10, n_layers: 16\n"
|
"our estimated params: 105.8317M, chinchilla params: 106.0000M, d_model: 640, n_heads: 10, n_layers: 16\n",
|
||||||
|
"our estimated params: 116.7283M, chinchilla params: 117.0000M, d_model: 768, n_heads: 12, n_layers: 12\n",
|
||||||
|
"our estimated params: 139.7660M, chinchilla params: 140.0000M, d_model: 768, n_heads: 12, n_layers: 15\n",
|
||||||
|
"our estimated params: 162.8037M, chinchilla params: 163.0000M, d_model: 768, n_heads: 12, n_layers: 18\n",
|
||||||
|
"our estimated params: 174.9745M, chinchilla params: 175.0000M, d_model: 896, n_heads: 14, n_layers: 14\n",
|
||||||
|
"our estimated params: 195.8746M, chinchilla params: 196.0000M, d_model: 896, n_heads: 14, n_layers: 16\n",
|
||||||
|
"our estimated params: 216.7747M, chinchilla params: 217.0000M, d_model: 896, n_heads: 14, n_layers: 18\n",
|
||||||
|
"our estimated params: 251.1196M, chinchilla params: 251.0000M, d_model: 1024, n_heads: 16, n_layers: 16\n",
|
||||||
|
"our estimated params: 278.4133M, chinchilla params: 278.0000M, d_model: 1024, n_heads: 16, n_layers: 18\n",
|
||||||
|
"our estimated params: 305.7070M, chinchilla params: 306.0000M, d_model: 1024, n_heads: 16, n_layers: 20\n",
|
||||||
|
"our estimated params: 424.6938M, chinchilla params: 425.0000M, d_model: 1280, n_heads: 10, n_layers: 18\n",
|
||||||
|
"our estimated params: 488.6490M, chinchilla params: 489.0000M, d_model: 1280, n_heads: 10, n_layers: 21\n",
|
||||||
|
"our estimated params: 509.3356M, chinchilla params: 509.0000M, d_model: 1408, n_heads: 11, n_layers: 18\n",
|
||||||
|
"our estimated params: 552.6042M, chinchilla params: 552.0000M, d_model: 1280, n_heads: 10, n_layers: 24\n",
|
||||||
|
"our estimated params: 586.7150M, chinchilla params: 587.0000M, d_model: 1408, n_heads: 11, n_layers: 21\n",
|
||||||
|
"our estimated params: 632.3389M, chinchilla params: 632.0000M, d_model: 1536, n_heads: 12, n_layers: 19\n",
|
||||||
|
"our estimated params: 664.0945M, chinchilla params: 664.0000M, d_model: 1408, n_heads: 11, n_layers: 24\n",
|
||||||
|
"our estimated params: 724.4206M, chinchilla params: 724.0000M, d_model: 1536, n_heads: 12, n_layers: 22\n",
|
||||||
|
"our estimated params: 816.5023M, chinchilla params: 816.0000M, d_model: 1536, n_heads: 12, n_layers: 25\n",
|
||||||
|
"our estimated params: 892.8138M, chinchilla params: 893.0000M, d_model: 1792, n_heads: 14, n_layers: 20\n",
|
||||||
|
"our estimated params: 1018.1338M, chinchilla params: 1018.0000M, d_model: 1792, n_heads: 14, n_layers: 23\n",
|
||||||
|
"our estimated params: 1143.4537M, chinchilla params: 1143.0000M, d_model: 1792, n_heads: 14, n_layers: 26\n",
|
||||||
|
"our estimated params: 1265.7869M, chinchilla params: 1266.0000M, d_model: 2048, n_heads: 16, n_layers: 22\n",
|
||||||
|
"our estimated params: 1424.5576M, chinchilla params: 1424.0000M, d_model: 2176, n_heads: 17, n_layers: 22\n",
|
||||||
|
"our estimated params: 1429.4569M, chinchilla params: 1429.0000M, d_model: 2048, n_heads: 16, n_layers: 25\n",
|
||||||
|
"our estimated params: 1593.1269M, chinchilla params: 1593.0000M, d_model: 2048, n_heads: 16, n_layers: 28\n",
|
||||||
|
"our estimated params: 1609.3196M, chinchilla params: 1609.0000M, d_model: 2176, n_heads: 17, n_layers: 25\n",
|
||||||
|
"our estimated params: 1730.7878M, chinchilla params: 1731.0000M, d_model: 2304, n_heads: 18, n_layers: 24\n",
|
||||||
|
"our estimated params: 1794.0815M, chinchilla params: 1794.0000M, d_model: 2176, n_heads: 17, n_layers: 28\n",
|
||||||
|
"our estimated params: 2006.9637M, chinchilla params: 2007.0000M, d_model: 2304, n_heads: 18, n_layers: 28\n",
|
||||||
|
"our estimated params: 2283.1396M, chinchilla params: 2283.0000M, d_model: 2304, n_heads: 18, n_layers: 32\n",
|
||||||
|
"our estimated params: 2298.0403M, chinchilla params: 2298.0000M, d_model: 2560, n_heads: 20, n_layers: 26\n",
|
||||||
|
"our estimated params: 2638.9811M, chinchilla params: 2639.0000M, d_model: 2560, n_heads: 20, n_layers: 30\n",
|
||||||
|
"our estimated params: 2979.9219M, chinchilla params: 2980.0000M, d_model: 2560, n_heads: 20, n_layers: 34\n",
|
||||||
|
"our estimated params: 3468.9339M, chinchilla params: 3530.0000M, d_model: 2688, n_heads: 22, n_layers: 36\n",
|
||||||
|
"our estimated params: 3802.8109M, chinchilla params: 3802.0000M, d_model: 2816, n_heads: 22, n_layers: 36\n",
|
||||||
|
"our estimated params: 4152.0233M, chinchilla params: 4084.0000M, d_model: 2944, n_heads: 22, n_layers: 36\n",
|
||||||
|
"our estimated params: 4516.5711M, chinchilla params: 4516.0000M, d_model: 3072, n_heads: 24, n_layers: 36\n",
|
||||||
|
"our estimated params: 6796.2747M, chinchilla params: 6796.0000M, d_model: 3584, n_heads: 28, n_layers: 40\n",
|
||||||
|
"our estimated params: 9294.0206M, chinchilla params: 9293.0000M, d_model: 4096, n_heads: 32, n_layers: 42\n",
|
||||||
|
"our estimated params: 11714.6222M, chinchilla params: 11452.0000M, d_model: 4352, n_heads: 32, n_layers: 47\n",
|
||||||
|
"our estimated params: 12296.1623M, chinchilla params: 12295.0000M, d_model: 4608, n_heads: 36, n_layers: 44\n",
|
||||||
|
"our estimated params: 13124.4826M, chinchilla params: 12569.0000M, d_model: 4608, n_heads: 32, n_layers: 47\n",
|
||||||
|
"our estimated params: 14614.4279M, chinchilla params: 13735.0000M, d_model: 4864, n_heads: 32, n_layers: 47\n",
|
||||||
|
"our estimated params: 16037.5039M, chinchilla params: 14940.0000M, d_model: 4992, n_heads: 32, n_layers: 49\n",
|
||||||
|
"our estimated params: 16184.4582M, chinchilla params: 16183.0000M, d_model: 5120, n_heads: 40, n_layers: 47\n"
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"for m in chilchilla_models[:5]: # look at just first 5 models\n",
|
"for m in chilchilla_models:\n",
|
||||||
" p, d, f, k, h, l = m\n",
|
" p, d, f, k, h, l = m\n",
|
||||||
" nparams = params(seq_len = 1024, vocab_size = 32000, d_model = d, num_heads = h, num_layers = l)\n",
|
" nparams = chinchilla_params(seq_len = 1024, vocab_size = 32000, d_model = d, num_heads = h, num_layers = l, ffw_size=f)\n",
|
||||||
" print(f\"our estimated params: {nparams/1e6:.4f}M, chinchilla params: {p/1e6:.4f}M, d_model: {d}, n_heads: {h}, n_layers: {l}\")"
|
" print(f\"our estimated params: {nparams/1e6:.4f}M, chinchilla params: {p/1e6:.4f}M, d_model: {d}, n_heads: {h}, n_layers: {l}\")"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
@ -139,7 +201,7 @@
|
|||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"source": [
|
"source": [
|
||||||
"So we are not able to reproduce the parameter counts for the Chinchilla models. The claimed number of parameters is not the number of parameters we compute above. TODO: resolve... \n",
|
"We are almost able to reproduce the parameter counts for the Chinchilla models. TODO resolve...\n",
|
||||||
"\n",
|
"\n",
|
||||||
"Now turning to FLOPs:"
|
"Now turning to FLOPs:"
|
||||||
]
|
]
|
||||||
@ -156,22 +218,11 @@
|
|||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 6,
|
"execution_count": 6,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [],
|
||||||
{
|
|
||||||
"data": {
|
|
||||||
"text/plain": [
|
|
||||||
"766.006788096"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"execution_count": 6,
|
|
||||||
"metadata": {},
|
|
||||||
"output_type": "execute_result"
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
"source": [
|
||||||
"def flops(seq_len, vocab_size, d_model, num_heads, num_layers, ffw_size=4):\n",
|
"def chinchilla_flops(seq_len, vocab_size, d_model, num_heads, num_layers, ffw_size):\n",
|
||||||
" \"\"\" \n",
|
" \"\"\" \n",
|
||||||
" Given GPT config calculate total number of FLOPs, see Chinchilla \n",
|
" Calculate total number of FLOPs, see Chinchilla \n",
|
||||||
" paper Appendix F as reference: https://arxiv.org/pdf/2203.15556.pdf\n",
|
" paper Appendix F as reference: https://arxiv.org/pdf/2203.15556.pdf\n",
|
||||||
" \"\"\" \n",
|
" \"\"\" \n",
|
||||||
" key_size = d_model // num_heads\n",
|
" key_size = d_model // num_heads\n",
|
||||||
@ -196,14 +247,17 @@
|
|||||||
"\n",
|
"\n",
|
||||||
" # logits\n",
|
" # logits\n",
|
||||||
" logits = 2 * seq_len * d_model * vocab_size\n",
|
" logits = 2 * seq_len * d_model * vocab_size\n",
|
||||||
"\n",
|
" \n",
|
||||||
" forward_flops = embeddings + num_layers * (att + dense) + logits\n",
|
" # this is what you'd expect:\n",
|
||||||
|
" # forward_flops = embeddings + num_layers * (att + dense) + logits\n",
|
||||||
|
" # but:\n",
|
||||||
|
" # per author correspondence apparently there is typo in the paper,\n",
|
||||||
|
" # they do not count embeddings and logits to repro table 4. So instead:\n",
|
||||||
|
" forward_flops = num_layers * (att + dense)\n",
|
||||||
" backward_flops = 2 * forward_flops # as in Kaplan et al. 2020\n",
|
" backward_flops = 2 * forward_flops # as in Kaplan et al. 2020\n",
|
||||||
" total_flops = forward_flops + backward_flops\n",
|
" total_flops = forward_flops + backward_flops\n",
|
||||||
"\n",
|
"\n",
|
||||||
" return total_flops\n",
|
" return total_flops\n"
|
||||||
"\n",
|
|
||||||
"flops(**gpt2)/1e9"
|
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -212,14 +266,149 @@
|
|||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
{
|
{
|
||||||
"name": "stdout",
|
"data": {
|
||||||
"output_type": "stream",
|
"text/html": [
|
||||||
"text": [
|
"<div>\n",
|
||||||
"params: 69.72M\n",
|
"<style scoped>\n",
|
||||||
"approx flops: 428.34B\n",
|
" .dataframe tbody tr th:only-of-type {\n",
|
||||||
"chinchilla flops: 434.11B\n",
|
" vertical-align: middle;\n",
|
||||||
"ratio (chinchilla / approx): 1.01\n"
|
" }\n",
|
||||||
]
|
"\n",
|
||||||
|
" .dataframe tbody tr th {\n",
|
||||||
|
" vertical-align: top;\n",
|
||||||
|
" }\n",
|
||||||
|
"\n",
|
||||||
|
" .dataframe thead th {\n",
|
||||||
|
" text-align: right;\n",
|
||||||
|
" }\n",
|
||||||
|
"</style>\n",
|
||||||
|
"<table border=\"1\" class=\"dataframe\">\n",
|
||||||
|
" <thead>\n",
|
||||||
|
" <tr style=\"text-align: right;\">\n",
|
||||||
|
" <th></th>\n",
|
||||||
|
" <th>seq_len</th>\n",
|
||||||
|
" <th>vocab_size</th>\n",
|
||||||
|
" <th>d_model</th>\n",
|
||||||
|
" <th>num_heads</th>\n",
|
||||||
|
" <th>num_layers</th>\n",
|
||||||
|
" <th>ffw_size</th>\n",
|
||||||
|
" <th>N</th>\n",
|
||||||
|
" <th>F</th>\n",
|
||||||
|
" <th>approx_flops</th>\n",
|
||||||
|
" <th>chinch_flops</th>\n",
|
||||||
|
" <th>ratio</th>\n",
|
||||||
|
" </tr>\n",
|
||||||
|
" </thead>\n",
|
||||||
|
" <tbody>\n",
|
||||||
|
" <tr>\n",
|
||||||
|
" <th>0</th>\n",
|
||||||
|
" <td>2048</td>\n",
|
||||||
|
" <td>32000</td>\n",
|
||||||
|
" <td>640</td>\n",
|
||||||
|
" <td>10</td>\n",
|
||||||
|
" <td>10</td>\n",
|
||||||
|
" <td>2560</td>\n",
|
||||||
|
" <td>73825280</td>\n",
|
||||||
|
" <td>929877196800</td>\n",
|
||||||
|
" <td>907165040640</td>\n",
|
||||||
|
" <td>9.298772e+11</td>\n",
|
||||||
|
" <td>1.025036</td>\n",
|
||||||
|
" </tr>\n",
|
||||||
|
" <tr>\n",
|
||||||
|
" <th>1</th>\n",
|
||||||
|
" <td>2048</td>\n",
|
||||||
|
" <td>32000</td>\n",
|
||||||
|
" <td>1024</td>\n",
|
||||||
|
" <td>16</td>\n",
|
||||||
|
" <td>20</td>\n",
|
||||||
|
" <td>4096</td>\n",
|
||||||
|
" <td>305707008</td>\n",
|
||||||
|
" <td>4135248199680</td>\n",
|
||||||
|
" <td>3756527714304</td>\n",
|
||||||
|
" <td>4.135248e+12</td>\n",
|
||||||
|
" <td>1.100817</td>\n",
|
||||||
|
" </tr>\n",
|
||||||
|
" <tr>\n",
|
||||||
|
" <th>2</th>\n",
|
||||||
|
" <td>2048</td>\n",
|
||||||
|
" <td>32000</td>\n",
|
||||||
|
" <td>1280</td>\n",
|
||||||
|
" <td>10</td>\n",
|
||||||
|
" <td>24</td>\n",
|
||||||
|
" <td>5120</td>\n",
|
||||||
|
" <td>552604160</td>\n",
|
||||||
|
" <td>7353453772800</td>\n",
|
||||||
|
" <td>6790399918080</td>\n",
|
||||||
|
" <td>7.353454e+12</td>\n",
|
||||||
|
" <td>1.082919</td>\n",
|
||||||
|
" </tr>\n",
|
||||||
|
" <tr>\n",
|
||||||
|
" <th>3</th>\n",
|
||||||
|
" <td>2048</td>\n",
|
||||||
|
" <td>32000</td>\n",
|
||||||
|
" <td>1792</td>\n",
|
||||||
|
" <td>14</td>\n",
|
||||||
|
" <td>26</td>\n",
|
||||||
|
" <td>7168</td>\n",
|
||||||
|
" <td>1143453696</td>\n",
|
||||||
|
" <td>14670316437504</td>\n",
|
||||||
|
" <td>14050759016448</td>\n",
|
||||||
|
" <td>1.467032e+13</td>\n",
|
||||||
|
" <td>1.044094</td>\n",
|
||||||
|
" </tr>\n",
|
||||||
|
" <tr>\n",
|
||||||
|
" <th>4</th>\n",
|
||||||
|
" <td>2048</td>\n",
|
||||||
|
" <td>32000</td>\n",
|
||||||
|
" <td>2048</td>\n",
|
||||||
|
" <td>16</td>\n",
|
||||||
|
" <td>28</td>\n",
|
||||||
|
" <td>8192</td>\n",
|
||||||
|
" <td>1593126912</td>\n",
|
||||||
|
" <td>20220437594112</td>\n",
|
||||||
|
" <td>19576343494656</td>\n",
|
||||||
|
" <td>2.022044e+13</td>\n",
|
||||||
|
" <td>1.032902</td>\n",
|
||||||
|
" </tr>\n",
|
||||||
|
" <tr>\n",
|
||||||
|
" <th>5</th>\n",
|
||||||
|
" <td>2048</td>\n",
|
||||||
|
" <td>32000</td>\n",
|
||||||
|
" <td>3584</td>\n",
|
||||||
|
" <td>28</td>\n",
|
||||||
|
" <td>40</td>\n",
|
||||||
|
" <td>14336</td>\n",
|
||||||
|
" <td>6796274688</td>\n",
|
||||||
|
" <td>83021046743040</td>\n",
|
||||||
|
" <td>83512623366144</td>\n",
|
||||||
|
" <td>8.302105e+13</td>\n",
|
||||||
|
" <td>0.994114</td>\n",
|
||||||
|
" </tr>\n",
|
||||||
|
" </tbody>\n",
|
||||||
|
"</table>\n",
|
||||||
|
"</div>"
|
||||||
|
],
|
||||||
|
"text/plain": [
|
||||||
|
" seq_len vocab_size d_model num_heads num_layers ffw_size N \\\n",
|
||||||
|
"0 2048 32000 640 10 10 2560 73825280 \n",
|
||||||
|
"1 2048 32000 1024 16 20 4096 305707008 \n",
|
||||||
|
"2 2048 32000 1280 10 24 5120 552604160 \n",
|
||||||
|
"3 2048 32000 1792 14 26 7168 1143453696 \n",
|
||||||
|
"4 2048 32000 2048 16 28 8192 1593126912 \n",
|
||||||
|
"5 2048 32000 3584 28 40 14336 6796274688 \n",
|
||||||
|
"\n",
|
||||||
|
" F approx_flops chinch_flops ratio \n",
|
||||||
|
"0 929877196800 907165040640 9.298772e+11 1.025036 \n",
|
||||||
|
"1 4135248199680 3756527714304 4.135248e+12 1.100817 \n",
|
||||||
|
"2 7353453772800 6790399918080 7.353454e+12 1.082919 \n",
|
||||||
|
"3 14670316437504 14050759016448 1.467032e+13 1.044094 \n",
|
||||||
|
"4 20220437594112 19576343494656 2.022044e+13 1.032902 \n",
|
||||||
|
"5 83021046743040 83512623366144 8.302105e+13 0.994114 "
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 7,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
@ -227,19 +416,47 @@
|
|||||||
"# comparing accurate flops above to approximate flops F = 6*N*D\n",
|
"# comparing accurate flops above to approximate flops F = 6*N*D\n",
|
||||||
"# note Chinchilla mentions using vocab_size = 32K\n",
|
"# note Chinchilla mentions using vocab_size = 32K\n",
|
||||||
"\n",
|
"\n",
|
||||||
"args = dict(seq_len = 1024, vocab_size = 32000, d_model = 640, num_heads = 10, num_layers = 10) # chinchilla 73M\n",
|
"chilchilla_models_table4 = [\n",
|
||||||
|
" [10, 640, 2560, 10, 64],\n",
|
||||||
|
" [20, 1024, 4096, 16, 64],\n",
|
||||||
|
" [24, 1280, 5120, 10, 128 ],\n",
|
||||||
|
" [26, 1792, 7168, 14, 128 ],\n",
|
||||||
|
" [28, 2048, 8192, 16, 128],\n",
|
||||||
|
" [40, 3584, 14336, 28, 128]\n",
|
||||||
|
"]\n",
|
||||||
"\n",
|
"\n",
|
||||||
"D = 1024 # dataset size (cancels anyway, for the purposes of the ratio calculation below)\n",
|
"rows = []\n",
|
||||||
"N = params(**args)\n",
|
"for num_layers, d_model, ffw_size, num_heads, _ in chilchilla_models_table4:\n",
|
||||||
"F = flops(**args)\n",
|
|
||||||
"\n",
|
"\n",
|
||||||
"approx_flops = 6*D*N # approximate flops\n",
|
" args = dict(seq_len = 2048, vocab_size = 32000, d_model = d_model, \n",
|
||||||
"chinch_flops = F * (float(D) / args['seq_len']) # exact flops according to Chinchilla paper calculations\n",
|
" num_heads = num_heads, num_layers = num_layers, ffw_size=ffw_size)\n",
|
||||||
"\n",
|
"\n",
|
||||||
"print(f\"params: {N/1e6:.2f}M\")\n",
|
" D = args['seq_len'] # dataset size (cancels anyway, for the purposes of the ratio calculation below)\n",
|
||||||
"print(f\"approx flops: {approx_flops/1e9:.2f}B\")\n",
|
" N = chinchilla_params(**args)\n",
|
||||||
"print(f\"chinchilla flops: {chinch_flops/1e9:.2f}B\")\n",
|
" F = chinchilla_flops(**args)\n",
|
||||||
"print(f\"ratio (chinchilla / approx): {chinch_flops / approx_flops:.2f}\")"
|
"\n",
|
||||||
|
" approx_flops = 6*D*N # approximate flops\n",
|
||||||
|
" chinch_flops = F * (float(D) / args['seq_len']) # exact flops according to Chinchilla paper calculations\n",
|
||||||
|
"\n",
|
||||||
|
" # print('---')\n",
|
||||||
|
" # print(f\"params: {N/1e6:.2f}M\")\n",
|
||||||
|
" # print(f\"approx flops: {approx_flops/1e9:.2f}B\")\n",
|
||||||
|
" # print(f\"chinchilla flops: {chinch_flops/1e9:.2f}B\")\n",
|
||||||
|
" # print(f\"ratio (chinchilla / approx): {chinch_flops / approx_flops:.2f}\")\n",
|
||||||
|
"\n",
|
||||||
|
" # first copy all keyvalues from args into out\n",
|
||||||
|
" out = {k:v for k,v in args.items()}\n",
|
||||||
|
" # then add the calculated values\n",
|
||||||
|
" out['N'] = N\n",
|
||||||
|
" out['F'] = F\n",
|
||||||
|
" out['approx_flops'] = approx_flops\n",
|
||||||
|
" out['chinch_flops'] = chinch_flops\n",
|
||||||
|
" out['ratio'] = chinch_flops / approx_flops\n",
|
||||||
|
" rows.append(out)\n",
|
||||||
|
"\n",
|
||||||
|
"# make a pandas dataframe from rows\n",
|
||||||
|
"df = pd.DataFrame(rows)\n",
|
||||||
|
"df"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -247,7 +464,7 @@
|
|||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"source": [
|
"source": [
|
||||||
"Awkward because the ratio is supposed to be 1.03 but we get 1.01, e.g. for the 73M param model. So once again we don't reproduce the numbers. TODO resolve..."
|
"Pretty good match! Except the param counts are still not perfectly accurate."
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -266,7 +483,7 @@
|
|||||||
{
|
{
|
||||||
"data": {
|
"data": {
|
||||||
"text/plain": [
|
"text/plain": [
|
||||||
"<matplotlib.colorbar.Colorbar at 0x7fc09dffc730>"
|
"<matplotlib.colorbar.Colorbar at 0x7f9d2e9ba9e0>"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
"execution_count": 8,
|
"execution_count": 8,
|
||||||
@ -474,13 +691,8 @@
|
|||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"source": [
|
"source": [
|
||||||
"TLDR atm: nothing reproduces."
|
"TLDR atm: nothing reproduces, but progress is being made."
|
||||||
]
|
]
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"metadata": {},
|
|
||||||
"source": []
|
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"metadata": {
|
"metadata": {
|
||||||
|
Loading…
Reference in New Issue
Block a user