2023-01-04 00:59:34 +00:00
{
"cells": [
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
2023-01-06 02:13:04 +00:00
"Trying to reproduce results from [Chinchilla](https://arxiv.org/pdf/2203.15556.pdf):"
2023-01-04 00:59:34 +00:00
]
},
{
"cell_type": "code",
2023-01-06 02:13:04 +00:00
"execution_count": 1,
2023-01-04 00:59:34 +00:00
"metadata": {},
"outputs": [],
"source": [
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
2023-01-07 02:42:30 +00:00
"import pandas as pd\n",
2023-01-04 00:59:34 +00:00
"%matplotlib inline"
]
},
2023-01-06 02:01:08 +00:00
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
2023-01-07 02:42:30 +00:00
"## params\n",
"\n",
"First some parameter calculations:"
2023-01-06 02:01:08 +00:00
]
},
2023-01-04 00:59:34 +00:00
{
"cell_type": "code",
2023-01-07 02:42:30 +00:00
"execution_count": 2,
2023-01-04 00:59:34 +00:00
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"123.653376"
]
},
2023-01-07 02:42:30 +00:00
"execution_count": 2,
2023-01-04 00:59:34 +00:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
2023-01-07 02:42:30 +00:00
"def gpt_params(seq_len, vocab_size, d_model, num_heads, num_layers):\n",
" \"\"\" Given GPT config calculate total number of parameters \"\"\"\n",
" ffw_size = 4*d_model # in GPT the number of intermediate features is always 4*d_model\n",
2023-01-04 00:59:34 +00:00
" # token and position embeddings\n",
" embeddings = d_model * vocab_size + d_model * seq_len\n",
" # transformer blocks\n",
" attention = 3*d_model**2 + 3*d_model # weights and biases\n",
" attproj = d_model**2 + d_model\n",
2023-01-07 02:42:30 +00:00
" ffw = d_model*(ffw_size) + ffw_size\n",
" ffwproj = ffw_size*d_model + d_model\n",
2023-01-04 00:59:34 +00:00
" layernorms = 2*2*d_model\n",
" # dense\n",
" ln_f = 2*d_model\n",
" dense = d_model*vocab_size # note: no bias here\n",
" # note: embeddings are not included in the param count!\n",
" total_params = num_layers*(attention + attproj + ffw + ffwproj + layernorms) + ln_f + dense\n",
" return total_params\n",
"\n",
2023-01-07 02:42:30 +00:00
"gpt2 = dict(seq_len = 1024, vocab_size = 50257, d_model = 768, num_heads = 12, num_layers = 12)\n",
"gpt_params(**gpt2)/1e6"
2023-01-06 02:01:08 +00:00
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
2023-01-07 02:42:30 +00:00
"OpenAI reports gpt2 (small) as having 124M params, so this is a match. Also, loading the OpenAI weights into nanoGPT and then calling `model.parameters()` exactly matches the above number and verifies the implementation. Now Chinchilla parameters:"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"def chinchilla_params(seq_len, vocab_size, d_model, num_heads, num_layers, ffw_size):\n",
" \"\"\" Parameters in the Chinchilla models. Unlike GPT they use relative positional embeddings. \"\"\"\n",
" # token embeddings only\n",
" embeddings = d_model * vocab_size\n",
" # transformer blocks\n",
" attention = 3*d_model**2 + 3*d_model # weights and biases\n",
" relative_pos = d_model**2 + 2*d_model # relative keys, content bias, relative bias\n",
" attproj = d_model**2 + d_model\n",
" ffw = d_model*ffw_size + ffw_size\n",
" ffwproj = ffw_size*d_model + d_model\n",
" layernorms = 2*2*d_model\n",
" # dense\n",
" ln_f = 2*d_model\n",
" dense = d_model*vocab_size # note: no bias here\n",
" # note: embeddings are not included in the param count!\n",
" total_params = num_layers*(attention + relative_pos + attproj + ffw + ffwproj + layernorms) + ln_f + dense\n",
" return total_params\n",
"\n"
2023-01-06 02:01:08 +00:00
]
},
{
"cell_type": "code",
2023-01-06 02:13:04 +00:00
"execution_count": 4,
2023-01-06 02:01:08 +00:00
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[44000000.0, 512, 2048, 64, 8, 8]"
]
},
2023-01-06 02:13:04 +00:00
"execution_count": 4,
2023-01-06 02:01:08 +00:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
2023-01-07 02:42:30 +00:00
"# Load in all the 50 Chinchilla models on the last page of the paper\n",
2023-01-06 02:01:08 +00:00
"import json\n",
"chinchilla_models_txt = '[[44000000.0, 512, 2048, 64, 8, 8], [57000000.0, 576, 2304, 64, 9, 9], [74000000.0, 640, 2560, 64, 10, 10], [90000000.0, 640, 2560, 64, 10, 13], [106000000.0, 640, 2560, 64, 10, 16], [117000000.0, 768, 3072, 64, 12, 12], [140000000.0, 768, 3072, 64, 12, 15], [163000000.0, 768, 3072, 64, 12, 18], [175000000.0, 896, 3584, 64, 14, 14], [196000000.0, 896, 3584, 64, 14, 16], [217000000.0, 896, 3584, 64, 14, 18], [251000000.0, 1024, 4096, 64, 16, 16], [278000000.0, 1024, 4096, 64, 16, 18], [306000000.0, 1024, 4096, 64, 16, 20], [425000000.0, 1280, 5120, 128, 10, 18], [489000000.0, 1280, 5120, 128, 10, 21], [509000000.0, 1408, 5632, 128, 11, 18], [552000000.0, 1280, 5120, 128, 10, 24], [587000000.0, 1408, 5632, 128, 11, 21], [632000000.0, 1536, 6144, 128, 12, 19], [664000000.0, 1408, 5632, 128, 11, 24], [724000000.0, 1536, 6144, 128, 12, 22], [816000000.0, 1536, 6144, 128, 12, 25], [893000000.0, 1792, 7168, 128, 14, 20], [1018000000.0, 1792, 7168, 128, 14, 23], [1143000000.0, 1792, 7168, 128, 14, 26], [1266000000.0, 2048, 8192, 128, 16, 22], [1424000000.0, 2176, 8704, 128, 17, 22], [1429000000.0, 2048, 8192, 128, 16, 25], [1593000000.0, 2048, 8192, 128, 16, 28], [1609000000.0, 2176, 8704, 128, 17, 25], [1731000000.0, 2304, 9216, 128, 18, 24], [1794000000.0, 2176, 8704, 128, 17, 28], [2007000000.0, 2304, 9216, 128, 18, 28], [2283000000.0, 2304, 9216, 128, 18, 32], [2298000000.0, 2560, 10240, 128, 20, 26], [2639000000.0, 2560, 10240, 128, 20, 30], [2980000000.0, 2560, 10240, 128, 20, 34], [3530000000.0, 2688, 10752, 128, 22, 36], [3802000000.0, 2816, 11264, 128, 22, 36], [4084000000.0, 2944, 11776, 128, 22, 36], [4516000000.0, 3072, 12288, 128, 24, 36], [6796000000.0, 3584, 14336, 128, 28, 40], [9293000000.0, 4096, 16384, 128, 32, 42], [11452000000.0, 4352, 17408, 128, 32, 47], [12295000000.0, 4608, 18432, 128, 36, 44], [12569000000.0, 4608, 18432, 128, 32, 47], [13735000000.0, 4864, 19456, 128, 32, 47], [14940000000.0, 4992, 19968, 128, 32, 49], [16183000000.0, 5120, 20480, 128, 40, 47]]'\n",
"chilchilla_models = json.loads(chinchilla_models_txt) # all 50 models\n",
"chilchilla_models[0] # tuples of params, d_model, ffw_size, kv_size, n_heads, n_layers from Table A9"
]
},
{
"cell_type": "code",
2023-01-06 02:13:04 +00:00
"execution_count": 5,
2023-01-06 02:01:08 +00:00
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2023-01-07 02:42:30 +00:00
"our estimated params: 43.7094M, chinchilla params: 44.0000M, d_model: 512, n_heads: 8, n_layers: 8\n",
"our estimated params: 57.3287M, chinchilla params: 57.0000M, d_model: 576, n_heads: 9, n_layers: 9\n",
"our estimated params: 73.8253M, chinchilla params: 74.0000M, d_model: 640, n_heads: 10, n_layers: 10\n",
"our estimated params: 89.8285M, chinchilla params: 90.0000M, d_model: 640, n_heads: 10, n_layers: 13\n",
"our estimated params: 105.8317M, chinchilla params: 106.0000M, d_model: 640, n_heads: 10, n_layers: 16\n",
"our estimated params: 116.7283M, chinchilla params: 117.0000M, d_model: 768, n_heads: 12, n_layers: 12\n",
"our estimated params: 139.7660M, chinchilla params: 140.0000M, d_model: 768, n_heads: 12, n_layers: 15\n",
"our estimated params: 162.8037M, chinchilla params: 163.0000M, d_model: 768, n_heads: 12, n_layers: 18\n",
"our estimated params: 174.9745M, chinchilla params: 175.0000M, d_model: 896, n_heads: 14, n_layers: 14\n",
"our estimated params: 195.8746M, chinchilla params: 196.0000M, d_model: 896, n_heads: 14, n_layers: 16\n",
"our estimated params: 216.7747M, chinchilla params: 217.0000M, d_model: 896, n_heads: 14, n_layers: 18\n",
"our estimated params: 251.1196M, chinchilla params: 251.0000M, d_model: 1024, n_heads: 16, n_layers: 16\n",
"our estimated params: 278.4133M, chinchilla params: 278.0000M, d_model: 1024, n_heads: 16, n_layers: 18\n",
"our estimated params: 305.7070M, chinchilla params: 306.0000M, d_model: 1024, n_heads: 16, n_layers: 20\n",
"our estimated params: 424.6938M, chinchilla params: 425.0000M, d_model: 1280, n_heads: 10, n_layers: 18\n",
"our estimated params: 488.6490M, chinchilla params: 489.0000M, d_model: 1280, n_heads: 10, n_layers: 21\n",
"our estimated params: 509.3356M, chinchilla params: 509.0000M, d_model: 1408, n_heads: 11, n_layers: 18\n",
"our estimated params: 552.6042M, chinchilla params: 552.0000M, d_model: 1280, n_heads: 10, n_layers: 24\n",
"our estimated params: 586.7150M, chinchilla params: 587.0000M, d_model: 1408, n_heads: 11, n_layers: 21\n",
"our estimated params: 632.3389M, chinchilla params: 632.0000M, d_model: 1536, n_heads: 12, n_layers: 19\n",
"our estimated params: 664.0945M, chinchilla params: 664.0000M, d_model: 1408, n_heads: 11, n_layers: 24\n",
"our estimated params: 724.4206M, chinchilla params: 724.0000M, d_model: 1536, n_heads: 12, n_layers: 22\n",
"our estimated params: 816.5023M, chinchilla params: 816.0000M, d_model: 1536, n_heads: 12, n_layers: 25\n",
"our estimated params: 892.8138M, chinchilla params: 893.0000M, d_model: 1792, n_heads: 14, n_layers: 20\n",
"our estimated params: 1018.1338M, chinchilla params: 1018.0000M, d_model: 1792, n_heads: 14, n_layers: 23\n",
"our estimated params: 1143.4537M, chinchilla params: 1143.0000M, d_model: 1792, n_heads: 14, n_layers: 26\n",
"our estimated params: 1265.7869M, chinchilla params: 1266.0000M, d_model: 2048, n_heads: 16, n_layers: 22\n",
"our estimated params: 1424.5576M, chinchilla params: 1424.0000M, d_model: 2176, n_heads: 17, n_layers: 22\n",
"our estimated params: 1429.4569M, chinchilla params: 1429.0000M, d_model: 2048, n_heads: 16, n_layers: 25\n",
"our estimated params: 1593.1269M, chinchilla params: 1593.0000M, d_model: 2048, n_heads: 16, n_layers: 28\n",
"our estimated params: 1609.3196M, chinchilla params: 1609.0000M, d_model: 2176, n_heads: 17, n_layers: 25\n",
"our estimated params: 1730.7878M, chinchilla params: 1731.0000M, d_model: 2304, n_heads: 18, n_layers: 24\n",
"our estimated params: 1794.0815M, chinchilla params: 1794.0000M, d_model: 2176, n_heads: 17, n_layers: 28\n",
"our estimated params: 2006.9637M, chinchilla params: 2007.0000M, d_model: 2304, n_heads: 18, n_layers: 28\n",
"our estimated params: 2283.1396M, chinchilla params: 2283.0000M, d_model: 2304, n_heads: 18, n_layers: 32\n",
"our estimated params: 2298.0403M, chinchilla params: 2298.0000M, d_model: 2560, n_heads: 20, n_layers: 26\n",
"our estimated params: 2638.9811M, chinchilla params: 2639.0000M, d_model: 2560, n_heads: 20, n_layers: 30\n",
"our estimated params: 2979.9219M, chinchilla params: 2980.0000M, d_model: 2560, n_heads: 20, n_layers: 34\n",
"our estimated params: 3468.9339M, chinchilla params: 3530.0000M, d_model: 2688, n_heads: 22, n_layers: 36\n",
"our estimated params: 3802.8109M, chinchilla params: 3802.0000M, d_model: 2816, n_heads: 22, n_layers: 36\n",
"our estimated params: 4152.0233M, chinchilla params: 4084.0000M, d_model: 2944, n_heads: 22, n_layers: 36\n",
"our estimated params: 4516.5711M, chinchilla params: 4516.0000M, d_model: 3072, n_heads: 24, n_layers: 36\n",
"our estimated params: 6796.2747M, chinchilla params: 6796.0000M, d_model: 3584, n_heads: 28, n_layers: 40\n",
"our estimated params: 9294.0206M, chinchilla params: 9293.0000M, d_model: 4096, n_heads: 32, n_layers: 42\n",
"our estimated params: 11714.6222M, chinchilla params: 11452.0000M, d_model: 4352, n_heads: 32, n_layers: 47\n",
"our estimated params: 12296.1623M, chinchilla params: 12295.0000M, d_model: 4608, n_heads: 36, n_layers: 44\n",
"our estimated params: 13124.4826M, chinchilla params: 12569.0000M, d_model: 4608, n_heads: 32, n_layers: 47\n",
"our estimated params: 14614.4279M, chinchilla params: 13735.0000M, d_model: 4864, n_heads: 32, n_layers: 47\n",
"our estimated params: 16037.5039M, chinchilla params: 14940.0000M, d_model: 4992, n_heads: 32, n_layers: 49\n",
"our estimated params: 16184.4582M, chinchilla params: 16183.0000M, d_model: 5120, n_heads: 40, n_layers: 47\n"
2023-01-06 02:01:08 +00:00
]
}
],
"source": [
2023-01-07 02:42:30 +00:00
"for m in chilchilla_models:\n",
2023-01-06 02:01:08 +00:00
" p, d, f, k, h, l = m\n",
2023-01-07 02:42:30 +00:00
" nparams = chinchilla_params(seq_len = 1024, vocab_size = 32000, d_model = d, num_heads = h, num_layers = l, ffw_size=f)\n",
2023-01-06 02:01:08 +00:00
" print(f\"our estimated params: {nparams/1e6:.4f}M, chinchilla params: {p/1e6:.4f}M, d_model: {d}, n_heads: {h}, n_layers: {l}\")"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
2023-01-07 02:42:30 +00:00
"We are almost able to reproduce the parameter counts for the Chinchilla models. TODO resolve...\n",
2023-01-06 02:01:08 +00:00
"\n",
"Now turning to FLOPs:"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## flops"
2023-01-04 00:59:34 +00:00
]
},
{
"cell_type": "code",
2023-01-06 02:13:04 +00:00
"execution_count": 6,
2023-01-04 00:59:34 +00:00
"metadata": {},
2023-01-07 02:42:30 +00:00
"outputs": [],
2023-01-04 00:59:34 +00:00
"source": [
2023-01-07 02:42:30 +00:00
"def chinchilla_flops(seq_len, vocab_size, d_model, num_heads, num_layers, ffw_size):\n",
2023-01-04 00:59:34 +00:00
" \"\"\" \n",
2023-01-07 02:42:30 +00:00
" Calculate total number of FLOPs, see Chinchilla \n",
2023-01-04 00:59:34 +00:00
" paper Appendix F as reference: https://arxiv.org/pdf/2203.15556.pdf\n",
" \"\"\" \n",
" key_size = d_model // num_heads\n",
"\n",
" # embeddings\n",
" embeddings = 2 * seq_len * vocab_size * d_model\n",
"\n",
" # attention\n",
" # key, query, value projections\n",
" attention = 2 * 3 * seq_len * d_model * (key_size * num_heads)\n",
" # key @ query logits\n",
" attlogits = 2 * seq_len * seq_len * (key_size * num_heads)\n",
" # softmax\n",
2023-01-06 02:13:04 +00:00
" attsoftmax = 3 * num_heads * seq_len * seq_len # 3* is for subtract (max), exp, divide (?)\n",
2023-01-04 00:59:34 +00:00
" # softmax @ value reductions\n",
" attvalue = 2 * seq_len * seq_len * (key_size * num_heads)\n",
" # final linear\n",
" attlinear = 2 * seq_len * (key_size * num_heads) * d_model\n",
" att = attention + attlogits + attsoftmax + attvalue + attlinear\n",
" # feed forward\n",
" dense = 2 * seq_len * (d_model * ffw_size + d_model * ffw_size)\n",
"\n",
" # logits\n",
" logits = 2 * seq_len * d_model * vocab_size\n",
2023-01-07 02:42:30 +00:00
" \n",
" # this is what you'd expect:\n",
" # forward_flops = embeddings + num_layers * (att + dense) + logits\n",
" # but:\n",
" # per author correspondence apparently there is typo in the paper,\n",
" # they do not count embeddings and logits to repro table 4. So instead:\n",
" forward_flops = num_layers * (att + dense)\n",
2023-01-04 00:59:34 +00:00
" backward_flops = 2 * forward_flops # as in Kaplan et al. 2020\n",
" total_flops = forward_flops + backward_flops\n",
"\n",
2023-01-07 02:42:30 +00:00
" return total_flops\n"
2023-01-04 00:59:34 +00:00
]
},
{
"cell_type": "code",
2023-01-06 02:13:04 +00:00
"execution_count": 7,
2023-01-04 00:59:34 +00:00
"metadata": {},
"outputs": [
{
2023-01-07 02:42:30 +00:00
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>seq_len</th>\n",
" <th>vocab_size</th>\n",
" <th>d_model</th>\n",
" <th>num_heads</th>\n",
" <th>num_layers</th>\n",
" <th>ffw_size</th>\n",
" <th>N</th>\n",
" <th>F</th>\n",
" <th>approx_flops</th>\n",
" <th>chinch_flops</th>\n",
" <th>ratio</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>2048</td>\n",
" <td>32000</td>\n",
" <td>640</td>\n",
" <td>10</td>\n",
" <td>10</td>\n",
" <td>2560</td>\n",
" <td>73825280</td>\n",
" <td>929877196800</td>\n",
" <td>907165040640</td>\n",
" <td>9.298772e+11</td>\n",
" <td>1.025036</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>2048</td>\n",
" <td>32000</td>\n",
" <td>1024</td>\n",
" <td>16</td>\n",
" <td>20</td>\n",
" <td>4096</td>\n",
" <td>305707008</td>\n",
" <td>4135248199680</td>\n",
" <td>3756527714304</td>\n",
" <td>4.135248e+12</td>\n",
" <td>1.100817</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>2048</td>\n",
" <td>32000</td>\n",
" <td>1280</td>\n",
" <td>10</td>\n",
" <td>24</td>\n",
" <td>5120</td>\n",
" <td>552604160</td>\n",
" <td>7353453772800</td>\n",
" <td>6790399918080</td>\n",
" <td>7.353454e+12</td>\n",
" <td>1.082919</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>2048</td>\n",
" <td>32000</td>\n",
" <td>1792</td>\n",
" <td>14</td>\n",
" <td>26</td>\n",
" <td>7168</td>\n",
" <td>1143453696</td>\n",
" <td>14670316437504</td>\n",
" <td>14050759016448</td>\n",
" <td>1.467032e+13</td>\n",
" <td>1.044094</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>2048</td>\n",
" <td>32000</td>\n",
" <td>2048</td>\n",
" <td>16</td>\n",
" <td>28</td>\n",
" <td>8192</td>\n",
" <td>1593126912</td>\n",
" <td>20220437594112</td>\n",
" <td>19576343494656</td>\n",
" <td>2.022044e+13</td>\n",
" <td>1.032902</td>\n",
" </tr>\n",
" <tr>\n",
" <th>5</th>\n",
" <td>2048</td>\n",
" <td>32000</td>\n",
" <td>3584</td>\n",
" <td>28</td>\n",
" <td>40</td>\n",
" <td>14336</td>\n",
" <td>6796274688</td>\n",
" <td>83021046743040</td>\n",
" <td>83512623366144</td>\n",
" <td>8.302105e+13</td>\n",
" <td>0.994114</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" seq_len vocab_size d_model num_heads num_layers ffw_size N \\\n",
"0 2048 32000 640 10 10 2560 73825280 \n",
"1 2048 32000 1024 16 20 4096 305707008 \n",
"2 2048 32000 1280 10 24 5120 552604160 \n",
"3 2048 32000 1792 14 26 7168 1143453696 \n",
"4 2048 32000 2048 16 28 8192 1593126912 \n",
"5 2048 32000 3584 28 40 14336 6796274688 \n",
"\n",
" F approx_flops chinch_flops ratio \n",
"0 929877196800 907165040640 9.298772e+11 1.025036 \n",
"1 4135248199680 3756527714304 4.135248e+12 1.100817 \n",
"2 7353453772800 6790399918080 7.353454e+12 1.082919 \n",
"3 14670316437504 14050759016448 1.467032e+13 1.044094 \n",
"4 20220437594112 19576343494656 2.022044e+13 1.032902 \n",
"5 83021046743040 83512623366144 8.302105e+13 0.994114 "
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
2023-01-04 00:59:34 +00:00
}
],
"source": [
2023-01-06 02:01:08 +00:00
"# Now try reproduce Table A4 from Chinchilla paper Appendix, \n",
2023-01-04 00:59:34 +00:00
"# comparing accurate flops above to approximate flops F = 6*N*D\n",
2023-01-06 02:01:08 +00:00
"# note Chinchilla mentions using vocab_size = 32K\n",
2023-01-04 00:59:34 +00:00
"\n",
2023-01-07 02:42:30 +00:00
"chilchilla_models_table4 = [\n",
" [10, 640, 2560, 10, 64],\n",
" [20, 1024, 4096, 16, 64],\n",
" [24, 1280, 5120, 10, 128 ],\n",
" [26, 1792, 7168, 14, 128 ],\n",
" [28, 2048, 8192, 16, 128],\n",
" [40, 3584, 14336, 28, 128]\n",
"]\n",
"\n",
"rows = []\n",
"for num_layers, d_model, ffw_size, num_heads, _ in chilchilla_models_table4:\n",
2023-01-04 00:59:34 +00:00
"\n",
2023-01-07 02:42:30 +00:00
" args = dict(seq_len = 2048, vocab_size = 32000, d_model = d_model, \n",
" num_heads = num_heads, num_layers = num_layers, ffw_size=ffw_size)\n",
2023-01-04 00:59:34 +00:00
"\n",
2023-01-07 02:42:30 +00:00
" D = args['seq_len'] # dataset size (cancels anyway, for the purposes of the ratio calculation below)\n",
" N = chinchilla_params(**args)\n",
" F = chinchilla_flops(**args)\n",
2023-01-04 00:59:34 +00:00
"\n",
2023-01-07 02:42:30 +00:00
" approx_flops = 6*D*N # approximate flops\n",
" chinch_flops = F * (float(D) / args['seq_len']) # exact flops according to Chinchilla paper calculations\n",
"\n",
" # print('---')\n",
" # print(f\"params: {N/1e6:.2f}M\")\n",
" # print(f\"approx flops: {approx_flops/1e9:.2f}B\")\n",
" # print(f\"chinchilla flops: {chinch_flops/1e9:.2f}B\")\n",
" # print(f\"ratio (chinchilla / approx): {chinch_flops / approx_flops:.2f}\")\n",
"\n",
" # first copy all keyvalues from args into out\n",
" out = {k:v for k,v in args.items()}\n",
" # then add the calculated values\n",
" out['N'] = N\n",
" out['F'] = F\n",
" out['approx_flops'] = approx_flops\n",
" out['chinch_flops'] = chinch_flops\n",
" out['ratio'] = chinch_flops / approx_flops\n",
" rows.append(out)\n",
"\n",
"# make a pandas dataframe from rows\n",
"df = pd.DataFrame(rows)\n",
"df"
2023-01-04 00:59:34 +00:00
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
2023-01-07 02:42:30 +00:00
"Pretty good match! Except the param counts are still not perfectly accurate."
2023-01-04 00:59:34 +00:00
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## scaling laws"
]
},
{
"cell_type": "code",
2023-01-06 02:13:04 +00:00
"execution_count": 8,
2023-01-04 00:59:34 +00:00
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
2023-01-07 02:42:30 +00:00
"<matplotlib.colorbar.Colorbar at 0x7f9d2e9ba9e0>"
2023-01-04 00:59:34 +00:00
]
},
2023-01-06 02:13:04 +00:00
"execution_count": 8,
2023-01-04 00:59:34 +00:00
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
2023-01-06 02:01:08 +00:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA8UAAAHWCAYAAABe7ytwAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8o6BhiAAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOy9eXwb9bnv//mOdslavMu7HTv74jgLIZAWKCkJUJYWKPTQS+FwoJcboJD7u72Hvs4BSin0dDvcnkPLgVOgO5SwUxpKA4QtqxMncRYn3ld5k619saT5/TGSLNkzskYayUu+b17fVxtp5pnvjO35zjPP83wewrIsCwqFQqFQKBQKhUKhUM5DmNmeAIVCoVAoFAqFQqFQKLMFdYopFAqFQqFQKBQKhXLeQp1iCoVCoVAoFAqFQqGct1CnmEKhUCgUCoVCoVAo5y3UKaZQKBQKhUKhUCgUynkLdYopFAqFQqFQKBQKhXLeQp1iCoVCoVAoFAqFQqGct1CnmEKhUCgUCoVCoVAo5y3UKaZQKBQKhUKhUCgUynkLdYoplAS8+OKLIISgs7NztqdCoVAoFMqCZiGuuT/5yU+waNEiyGQyrF27FgBQXV2N22+/fVbnRaFQ4qFOMYVCoVAoFArlvOFXv/oVbrrpJlRWVoIQktBBHR8fx913343CwkLodDpcdtllOHLkSFLH+dvf/obvfve7uPjii/HCCy/giSeekOgMKBSK1MhnewIUCoVCoVAoFEq2+Ld/+zc4HA5ccMEFGBgYENwuFArh6quvxrFjx/B//s//QUFBAX75y1/i0ksvRWNjIxYvXpzwOB988AEYhsGvf/1rKJVKqU+DQqFICHWKKRQKhUKhUCjnDXv37o1GiXNycgS327VrFz7//HO88soruPHGGwEAX//617FkyRI88sgj+OMf/5jwOENDQ9BoNNQhplDmATR9mkIRyS9/+UusXLkSKpUKpaWl2LFjB8bHx+O2OXfuHG644QaYzWao1WqUl5fjlltugc1mi27z/vvvY8uWLTCZTMjJycHSpUvxve99L8tnQ6FQKBTK3CaZdRcAnn76aSxatAgajQYXXHABPvnkE1x66aW49NJL47arqqoCIWTG4+7atQvFxcX42te+Fv2ssLAQX//61/Hmm2/C5/MJ7ksIwQsvvACXywVCCAghePHFFwW3b29vx0033YS8vDxotVpceOGF+Mtf/hK3zUcffQRCCF5++WV873vfg9lshk6nw7XXXouenp64bZN5DqFQKJPQSDGFIoJHH30U3//+97F161bcc889aGlpwa9+9SscOnQIn332GRQKBfx+P7Zt2wafz4f77rsPZrMZfX19eOeddzA+Pg6j0YiTJ0/iK1/5CtasWYPHHnsMKpUKra2t+Oyzz2b7FCkUCoVCmTMks+4CXJ3wvffeiy984Qt48MEH0dnZieuvvx65ubkoLy9P6dhHjx7FunXrwDDxMaQLLrgAzz77LM6ePYvVq1fz7vu73/0Ozz77LA4ePIj//u//BgBcdNFFvNsODg7ioosugtvtxv3334/8/Hz85je/wbXXXotdu3bhq1/9atz2P/zhD0EIwf/9v/8XQ0NDeOqpp7B161Y0NTVBo9Ek9RxCoVCmwFIoFEFeeOEFFgDb0dHBDg0NsUqlkr3iiivYYDAY3eY///M/WQDs888/z7Isyx49epQFwL7yyiuCdv/93/+dBcAODw9n/BwoFAqFQpkPxK65LMsmve76fD42Pz+f3bhxIzsxMRHd7sUXX2QBsJdccongMXU6Hfutb31L8Lt//Md/nPb5X/7yFxYAu3v37oTn861vfYvV6XTTPq+qqoo75gMPPMACYD/55JPoZw6Hg62pqWGrq6uj5/7hhx+yANiysjLWbrdHt/3zn//MAmD/3//7fyzLJvccQqFQ4qHp0xRKkvz973+H3+/HAw88EPfW+K677oLBYIimOUXewL733ntwu928tkwmEwDgzTffRCgUyuzEKRQKhUKZhyS77h4+fBijo6O46667IJdPJkHeeuutyM3NTfn4Ho8HKpVq2udqtTr6vRS8++67uOCCC7Bly5boZzk5Obj77rvR2dmJU6dOxW1/2223Qa/XR/994403oqSkBO+++y6A5J5DKBRKPNQpplCSpKurCwCwdOnSuM+VSiUWLVoU/b6mpgY7d+7Ef//3f6OgoADbtm3D008/HVfHc/PNN+Piiy/GP/3TP6G4uBi33HIL/vznP1MHmUKhUCiUMMmuu5H/rauri9tOLpejuro65eNrNBreumGv1xv9Xgq6urqmnSMALF++PPp9LFNVrwkhqKuri/Z3TuY5hEKhxEOdYgolA/zsZz/D8ePH8b3vfQ8ejwf3338/Vq5cid7eXgDcQvrxxx/j73//O/7H//gfOH78OG6++WZ8+ctfRjAYnOXZUygUCoVCKSkp4W3ZFPmstLQ021NKmpmeQygUSjzUKaZQkqSqqgoA0NLSEve53+9HR0dH9PsIq1evxr/8y7/g448/xieffIK+vj4888wz0e8ZhsHll1+On//85zh16hR++MMf4oMPPsCHH36Y+ZOhUCgUCmWOk+y6G/nf1tbWuO0CgUA0epoKa9euxZEjR6ZlcR04cABarRZLlixJ2XYsVVVV084RAM6cORP9PpZz587F/ZtlWbS2tk6Lis/0HEKhUCahTjGFkiRbt26FUqnEL37xC7AsG/3817/+NWw2G66++moAgN1uRyAQiNt39erVYBgmmoZltVqn2V+7di0AJGzxQKFQKBTK+UKy6+6GDRuQn5+P5557Lm79/cMf/oCxsbGUj3/jjTdicHAQr732WvSzkZERvPLKK7jmmmt4641T4aqrrsLBgwexb9++6GculwvPPvssqqursWLFirjtf/vb38LhcET/vWvXLgwMDODKK68EkNxzCIVCiYe2ZKJQkqSwsBAPPfQQvv/972P79u249tpr0dLSgl/+8pfYuHEjvvnNbwIAPvjgA9x777246aabsGTJEgQCAfzud7+DTCbDDTfcAAB47LHH8PHHH+Pqq69GVVUVhoaG8Mtf/hLl5eVxQhsUCoVCoZyvJLvuKpVKPProo7jvvvvwpS99CV//+tfR2dmJF198EbW1tdN6Er/99ts4duwYAGBiYgLHjx/H448/DgC49tprsWbNGgCcU3zhhRfijjvuwKlTp1BQUIBf/vKXCAaD+P73vy/Zef7zP/8z/vSnP+HKK6/E/fffj7y8PPzmN79BR0cHXn311WktofLy8rBlyxbccccdGBwcxFNPPYW6ujrcddddAJJ7DqFQKFOYZfVrCmVOM7U9BMtyrSCWLVvGKhQKtri4mL3nnnvYsbGx6Pft7e3sP/7jP7K1tbWsWq1m8/Ly2Msuu4z9+9//Ht1mz5497HXXXceWlpaySqWSLS0tZb/xjW+wZ8+ezeLZUSgUCoUyd+Bbc1l25nU3wi9+8Qu2qqqKValU7AUXXMB+9tln7Pr169nt27fHbfetb32LBcA7XnjhhbhtrVYre+edd7L5+fmsVqtlL7nkEvbQoUNJnU+yLZlYlmXb2trYG2+8kTWZTKxarWYvuOAC9p133onbJtKS6U9/+hP70EMPsUVFRaxGo2GvvvpqtqurK7pdMs8hFAolHsKyMfkoFAqFQqFQKBTKAiAUCqGwsBBf+9rX8Nxzz832dNLmo48+wmWXXYZXXnkFN95442xPh0JZUNCaYgqFQqFQKBTKvMbr9WJqnOe3v/0trFYrLr300tmZFIVCmTfQmmIKhUKhUCgUyrxm//79ePDBB3HTTTchPz8fR44cwa9//WusWrUKN91002xPj0KhzHGoU0yhUCgUCoVCmddUV1ejoqICv/jFL2C1WpGXl4fbbrsNP/rRj6BUKmd7ehQKZY4zq+nTH3/8Ma655hqUlpaCEII33ngj7vvXXnsNV1xxBfLz80EIQVNTU1J2X3nlFSxbtgxqtRqrV6/Gu+++K/3kKRQKhUKhzAhd6ynZoLq6Gm+99RYsFgv8fj8sFguef/55FBUVzfbUJOPSSy8Fy7K0nphCyQCz6hS7XC7U19f
2023-01-04 00:59:34 +00:00
"text/plain": [
2023-01-06 02:01:08 +00:00
"<Figure size 1200x500 with 4 Axes>"
2023-01-04 00:59:34 +00:00
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"def L(N, D):\n",
" \"\"\" \n",
" Approximates loss given N parameters and D dataset size (in tokens),\n",
" per Chinchilla paper.\n",
" \"\"\"\n",
" E = 1.69 # entropy of natural language, limit of infinite model on infinite data\n",
" A = 406.4\n",
" B = 410.7\n",
" alpha = 0.34\n",
" beta = 0.28\n",
" return A / (N ** alpha) + B / (D ** beta) + E\n",
"\n",
"ns = 10 ** np.arange(7, 11, step=2**-4) # model sizes from 10M to 100B\n",
"ds = 10 ** np.arange(9, 12, step=2**-4) # dataset sizes from 1B to 1T\n",
2023-01-06 02:01:08 +00:00
"plt.figure(figsize=(12, 5))\n",
2023-01-04 00:59:34 +00:00
"plt.subplot(121)\n",
2023-01-06 02:01:08 +00:00
"# create a 2D countour plot of loss L as a function of model size and dataset size in ns,ds\n",
"loss2d = np.log10(np.array([[L(n, d) for d in ds] for n in ns]))\n",
"plt.imshow(loss2d, extent=[9, 12, 7, 11], origin='lower', alpha=0.5)\n",
"plt.contour(loss2d, levels=30, extent=[9, 12, 7, 11], origin='lower')\n",
2023-01-04 00:59:34 +00:00
"plt.xlabel('log10(dataset size)')\n",
"plt.ylabel('log10(model size)')\n",
"plt.title('loss')\n",
"plt.colorbar()\n",
"# plot the compute for each point, which is a deterministic function: flops = 6*N*D\n",
"plt.subplot(122)\n",
2023-01-06 02:01:08 +00:00
"compute2d = np.log10(np.array([[6*n*d for d in ds] for n in ns]))\n",
"plt.imshow(compute2d, extent=[9, 12, 7, 11], origin='lower', alpha=0.5)\n",
"plt.contour(compute2d, levels=30, extent=[9, 12, 7, 11], origin='lower')\n",
2023-01-04 00:59:34 +00:00
"plt.xlabel('log10(dataset size)')\n",
"plt.ylabel('log10(model size)')\n",
"plt.title('log10 flops')\n",
"plt.colorbar()"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"Ok so given any N,D we can estimate both: 1) the loss, and 2) the total flops. Now we want to solve the following problem: Given a specific budget of flops C, find: N_opt, D_opt = argmin_{FLOPs(N,D) = C} L(N, D). i.e. how big of a model should we train and for how many tokens?"
]
},
{
"cell_type": "code",
2023-01-06 02:13:04 +00:00
"execution_count": 9,
2023-01-04 00:59:34 +00:00
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"best model size: 316.23M\n",
"best dataset size: 10.12B\n"
]
},
{
"data": {
"text/plain": [
"Text(0, 0.5, 'loss')"
]
},
2023-01-06 02:13:04 +00:00
"execution_count": 9,
2023-01-04 00:59:34 +00:00
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAS8AAAEqCAYAAABEE9ZrAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8o6BhiAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAuW0lEQVR4nO3deViU9f7/8ecMy7Ajm4AspqIoKoqaCpor5m7W6ZfHLKy0Tp0sy7IT1Tl1Wg6ezJZz8tt6ytZjqannWC7pSS3FDTTRVMIFEFkUkWWQbeb+/QFOoiAIw9wzzPtxXXNdcnPPzHuY5tXn/sxn0SiKoiCEEDZGq3YBQgjREhJeQgibJOElhLBJEl5CCJsk4SWEsEkSXkIImyThJYSwSY5qF2BpRqORM2fO4OnpiUajUbscIcQVFEWhtLSUTp06odU23r6yu/A6c+YMYWFhapchhGhCdnY2oaGhjf7e7sLL09MTqP3DeHl5qVxNO6bXQ6dOtf8+cwbc3dWtR9iMkpISwsLCTJ/VxthdeF26VPTy8pLwaksODr/928tLwktct6a6daTDXghhkyS8hBA2ScJLCGGTJLyEEDZJwksIYZMkvIQQNknCSwjR5j7ecZK1B3Ioqag222NKeAkh2lS1wciSTenMX36ArMJysz2uhJcQok0dyL5AWWUNvu7ORAWbb2C4hJcQok39mH4WgLhufmi15lsMQcJLCNGmfsw4B8CI7gFmfVwJLyFEmykur+bn7AsADO/ub9bHlvASQrSZ5BPnMCrQLcCdTh1czfrYEl5CiDaz/dfaS8abzHzJCBJeQog29JMpvMx7yQgSXkKINpJZqCfrfDmOWg1DuvqZ/fElvIQQbeLHulbXgM4+eOjMv+6phJcQok2YLhkjzH/JCBJeQog2UGMwsuN4XXj1MH9nPUh4CSHawMGcYkoravB2daJviHebPIeElxDC7LYdq50SNCzCDwczTgm6nISXEMLsttbNZxzVo2ObPYeElxDCrM7rqzh4+gIAI9qovwskvIQQZvbjr2dRFOgZ5EmQt0ubPY+ElxDCrC71d42MbLtWF0h4CSHMyGhU2GaB/i6Q8BJCmNHhMyUU6qtwd3ZgYGefNn0uCS8hhNlsPVYAwLAIf5wd2zZeJLyEEGZjumSMbNtLRpDwEkKYSXF5NalZRUDbd9aDhJcQwkx+yqhdNbV7Rw9CzLxqakMkvIQQZnGpv2tkGw5MvZyElxCi1YxGhR+OWa6/CyS8hBBmkJZTzLmySjx0jgzu4muR55TwEkK02pajtZeMI3q0/RCJSyS8hBCt9r+j+QCMttAlI0h4CSFaKb+kgkM5JWg0luvvAgkvIUQr/VB3ydgvtAMBnjqLPa+ElxCiVS71d43tablWF0h4CSFaoaLaYNolaEwvCS8hhI3YdaKQi9UGgrxciAr2suhzS3gJIVrsUn/X6J4d0WjaZqONxkh4CSFaRFEU1fq7QMJLCNFCx/JLOV10EZ2jlrgIP4s/v9WE16JFi9BoNDz22GONnrNs2TI0Gk29m4tL2y3wL4Ro3KbDtQNTb+ruj5uzo8Wf3/LP2IC9e/fy3nvvER0d3eS5Xl5eHDt2zPSzpa+zhRC1vv+lNrzGRQWq8vyqt7zKysqYNWsWH3zwAT4+Ta95rdFoCAoKMt0CA9X5wwlhz3KLL5KWU4xGA2N62ml4Pfzww0yePJn4+PhmnV9WVkbnzp0JCwvjlltu4fDhw9c8v7KykpKSkno3IUTrbK5rdQ0I97HoqPrLqRpey5cvJzU1laSkpGadHxkZyUcffcTatWv5/PPPMRqNxMXFcfr06Ubvk5SUhLe3t+kWFhZmrvKFsFub6sLrZpUuGUHF8MrOzmb+/Pl88cUXze50j42NJSEhgf79+zNy5Ei++eYbAgICeO+99xq9T2JiIsXFxaZbdna2uV6CEHappKKaXScKAfX6u0DFDvuUlBQKCgoYMGCA6ZjBYGD79u28/fbbVFZW4uDgcM3HcHJyIiYmhoyMjEbP0el06HTqNGuFaI+2HjtLtUGhW4A7XQM8VKtDtfAaO3YsaWlp9Y7de++99OzZkz/96U9NBhfUhl1aWhqTJk1qqzKFEFf47VvGIFXrUC28PD096dOnT71j7u7u+Pn5mY4nJCQQEhJi6hN78cUXGTp0KBEREVy4cIHFixeTmZnJ3LlzLV6/EPaoqsbI1rpR9WpeMoKVjPNqTFZWFlrtb91yRUVF3H///eTl5eHj48PAgQPZuXMnUVFRKlYphP3YfbKQ0soa/D10xIR1ULUWjaIoiqoVWFhJSQne3t4UFxfj5WXZWfB2Ra8Hj7r+kLIycHdXtx5hFs+uTuOL3VnMHBxG0m1NDypvieZ+RlUf5yWEsA0Go8LGuilBE/oEq1yNhJcQoplSMos4V1aJl4sjsV0tPxH7ShJeQohm2XAoD4D4qECLbW92LepXIISweoqisPFwbXhN6K3uEIlLJLyEEE06eLqYnAsXcXN2YESPALXLASS8hBDNsL7uknF0z464ODU9gNwSJLyEENekKAobDuUCMLGPdVwygoSXEKIJx/JLOVVYjrOj1qI7YjdFwksIcU3r02ovGUd0D8BDZz2TciS8hBDXtN4KLxlBwksIcQ3p+aWk55fh7KAlXuWJ2FeS8BJCNGrdwdpW14ge/ni7OqlcTX0SXkKIBimKwrqDZwCYEt1J5WquJuElhGjQkdxSTpzV4+yoZWwv6/mW8RIJLyFEg75Nq211jY4MwNPFui4ZQcJLCNGA2kvG2v4ua7xkBAkvIUQDDp8pIbOwHBcnLWN6Wt8lI0h4CSEa8N+6jvqxPQNxt6KBqZeT8BJC1KMoCt/WXTJOjlZ/xdTGSHgJIerZn32B00W1y9+MtqK5jFeS8BJC1LN2fw4A43sH4epsHcvfNETCSwhhUmMwmr5lnNbfOr9lvETCSwhh8lPGOQr1Vfi5OzM8wl/tcq5JwksIYfKfA7XfMk6ODsbJwbrjwbqrE0JYzMUqg2mTjVv6h6hcTdMkvIQQAGw+ko++ykCYrysDwjuoXU6TJLyEEACsrbtkvKVfCBqNRuVqmibhJYTgQnkV29ILALjFyr9lvETCSwjBt2m5VBsUooK96B7oqXY5zSLhJYRgVcppAKbH2EarCyS8hLB7J8/pSc26gFYD023gW8ZLJLyEsHPfpNa2ukb0CKCjl4vK1TSfhJcQdsxoVPgmtXYu4+8GhKpczfWR8BLCju06WUjOhYt4ujgyzsq2NmuKhJcQdmxVSm2ra0p0MC5O1ruCREMkvISwU+VVNabdsG3tkhEkvISwWxsO5VFeZeAGPzcGdvZRu5zrJuElhJ1aVfct420DQm1iOtCVJLyEsEPZ58vZkVGIRgO3xtjO2K7LSXgJYYdW7MsGYHiEP2G+bipX0zISXkLYGYNRYUXddKA7BoWpXE3LWU14LVq0CI1Gw2OPPXbN81asWEHPnj1xcXGhb9++fPfdd5YpUIh2YvuvZ8ktrqCDmxM397atsV2Xs4rw2rt3L++99x7R0dHXPG/nzp3MnDmTOXPmsH//fqZPn8706dM5dOiQhSoVwvZ9vbf2knF6/xB0jrY1tutyqodXWVkZs2bN4oMPPsDH59pf17711ltMmDCBhQsX0qtXL1566SUGDBjA22+/baFqhbBthWWVbD6SD8CMG233khGsILwefvhhJk+eTHx8fJPnJicnX3Xe+PHjSU5ObvQ+lZWVlJSU1LsJYa9W78+h2qAQHepNr2AvtctpFUc1n3z58uWkpqayd+/eZp2fl5dHYGD9a/TAwEDy8vIavU9SUhJ//etfW1WnEO2Boigsr7tktPVWF6jY8srOzmb+/Pl88cUXuLi03TIciYmJFBcXm27Z2dlt9lxCWLPUrCIyCspwcdIytZ/tLDrYGNVaXikpKRQUFDBgwADTMYPBwPbt23n77beprKzEwaF+Z2JQUBD5+fn1juXn5xMUFNTo8+h0OnQ6nXmLF8IGfbErC4Cp0Z3wcnFSuZrWU63lNXbsWNLS0jhw4IDpNmjQIGbNmsWBAweuCi6A2NhYtmzZUu/Y999/T2xsrKXKFsI
"text/plain": [
"<Figure size 300x300 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"c = 1.92e19 # target compute budget (usually know this because we know how many GPU for how long go brrr)\n",
"# sweep model sizes from 10M to 100B\n",
"ns = 10 ** np.arange(7, 11, step=2**-4)\n",
"# using C = 6*N*D, solve for D that maintains the compute budget c\n",
"ds = c / (6 * ns)\n",
"# evaluate the loss in each case\n",
"losses = L(ns, ds)\n",
"# find the argmin\n",
"best = np.argmin(losses)\n",
"print(f\"best model size: {ns[best]/1e6:.2f}M\")\n",
"print(f\"best dataset size: {ds[best]/1e9:.2f}B\")\n",
"# plot the loss\n",
"plt.figure(figsize=(3,3))\n",
"plt.plot(ns, losses)\n",
"plt.xscale('log')\n",
"# plot a vertical bar at the best model size\n",
"plt.axvline(ns[best], color='red')\n",
"plt.xlabel('model size')\n",
"plt.ylabel('loss')"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"In the plot above, basically the models on the left of best are too small and trained for too long. The models on the right of best are way too large and trained for too little. The model at the red line is just right.\n",
"\n",
2023-01-06 02:01:08 +00:00
"Now, the Chinchilla paper says that best model size is 400M params and 8B tokens, so this once again disagrees and there is some calculations problem. TODO figure out and fix..."
2023-01-04 00:59:34 +00:00
]
},
{
"cell_type": "code",
2023-01-06 02:13:04 +00:00
"execution_count": 10,
2023-01-04 00:59:34 +00:00
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
2023-01-06 02:01:08 +00:00
"2304"
2023-01-04 00:59:34 +00:00
]
},
2023-01-06 02:13:04 +00:00
"execution_count": 10,
2023-01-04 00:59:34 +00:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
2023-01-06 02:01:08 +00:00
"# Calculate the Chinchilla optimal models for a range of compute budgets\n",
"\n",
2023-01-04 00:59:34 +00:00
"# sweep over compute budgets from 1e17 to 1e26\n",
2023-01-06 02:01:08 +00:00
"cs = 10 ** np.arange(17, 26, step=2**-8)\n",
"models = []\n",
2023-01-04 00:59:34 +00:00
"for c in cs:\n",
2023-01-06 02:01:08 +00:00
" # sweep over model sizes\n",
" ns = 10 ** np.arange(7, 14, step=2**-8)\n",
" # the dataset sizes that would maintain the given compute budget\n",
2023-01-04 00:59:34 +00:00
" ds = c / (6 * ns)\n",
2023-01-06 02:01:08 +00:00
" # losses at each point\n",
2023-01-04 00:59:34 +00:00
" losses = L(ns, ds)\n",
2023-01-06 02:01:08 +00:00
" # n,d for the best model\n",
2023-01-04 00:59:34 +00:00
" best = np.argmin(losses)\n",
2023-01-06 02:01:08 +00:00
" models.append((c, ns[best], ds[best])) # c, n, d tuple log\n",
2023-01-04 00:59:34 +00:00
"\n",
2023-01-06 02:01:08 +00:00
"len(models)"
2023-01-04 00:59:34 +00:00
]
},
{
"cell_type": "code",
2023-01-06 02:13:04 +00:00
"execution_count": 11,
2023-01-06 02:01:08 +00:00
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"closest model found:\n",
"model size: 1298.02M\n",
"dataset size: 60.32B\n",
"flops: 4.697589e+20\n",
"loss: 2.41\n"
]
}
],
"source": [
"query_model_size = 1.3e9 # GPT-3 size\n",
"ns = np.array([n for c, n, d in models])\n",
"ds = np.array([d for c, n, d in models])\n",
"# find the index of the closest model size in ns\n",
"ix = np.argmin(np.abs(ns - query_model_size))\n",
"# retrieve the corresponding params, flops, and data size\n",
"print(\"closest model found:\")\n",
"print(f\"model size: {ns[ix]/1e6:.2f}M\")\n",
"print(f\"dataset size: {ds[ix]/1e9:.2f}B\")\n",
"print(f\"flops: {6*ns[ix]*ds[ix]:e}\")\n",
"print(f\"loss: {L(ns[ix], ds[ix]):.2f}\")"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"So we predict 60B tokens is compute optimal. But e.g. [MosaicML quotes 26B](https://t.co/HyEvCqP70C). So again wrong."
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
2023-01-07 02:42:30 +00:00
"TLDR atm: nothing reproduces, but progress is being made."
2023-01-06 02:01:08 +00:00
]
2023-01-04 00:59:34 +00:00
}
],
"metadata": {
"kernelspec": {
"display_name": "pytorch2",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.8"
},
"orig_nbformat": 4,
"vscode": {
"interpreter": {
"hash": "7f5833218766b48e6e35e4452ee875aac0e2188d05bbe5298f2c62b79f08b222"
}
}
},
"nbformat": 4,
"nbformat_minor": 2
}