1
0
mirror of https://github.com/osmarks/nanogpt-experiments.git synced 2024-11-14 05:44:51 +00:00
nanogpt-experiments/scaling_laws.ipynb

793 lines
262 KiB
Plaintext
Raw Normal View History

{
"cells": [
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"Reproducing some scaling laws results from [Chinchilla](https://arxiv.org/pdf/2203.15556.pdf). Can't get the numbers to match exactly, but can still be used as a rough guide to help determine compute-optimal models. Also contains related utilities for calculating flops and param counts."
]
},
{
"cell_type": "code",
2023-01-06 02:13:04 +00:00
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"import pandas as pd\n",
"%matplotlib inline"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## params\n",
"\n",
"First some parameter calculations:"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"123.653376"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"def gpt_params(seq_len, vocab_size, d_model, num_heads, num_layers):\n",
" \"\"\" Given GPT config calculate total number of parameters \"\"\"\n",
" ffw_size = 4*d_model # in GPT the number of intermediate features is always 4*d_model\n",
" # token and position embeddings\n",
" embeddings = d_model * vocab_size + d_model * seq_len\n",
" # transformer blocks\n",
" attention = 3*d_model**2 + 3*d_model # weights and biases\n",
" attproj = d_model**2 + d_model\n",
" ffw = d_model*(ffw_size) + ffw_size\n",
" ffwproj = ffw_size*d_model + d_model\n",
" layernorms = 2*2*d_model\n",
" # dense\n",
" ln_f = 2*d_model\n",
" dense = d_model*vocab_size # note: no bias here\n",
" # note: embeddings are not included in the param count!\n",
" total_params = num_layers*(attention + attproj + ffw + ffwproj + layernorms) + ln_f + dense\n",
" return total_params\n",
"\n",
"gpt2 = dict(seq_len = 1024, vocab_size = 50257, d_model = 768, num_heads = 12, num_layers = 12)\n",
"gpt_params(**gpt2)/1e6"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"OpenAI reports gpt2 (small) as having 124M params, so this is a match. Also, loading the OpenAI weights into nanoGPT and then calling `model.parameters()` exactly matches the above number and verifies the implementation. Now Chinchilla parameters:"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"def chinchilla_params(seq_len, vocab_size, d_model, num_heads, num_layers, ffw_size):\n",
" \"\"\" Parameters in the Chinchilla models. Unlike GPT they use relative positional embeddings. \"\"\"\n",
" # token embeddings only\n",
" embeddings = d_model * vocab_size\n",
" # transformer blocks\n",
" attention = 3*d_model**2 + 3*d_model # weights and biases\n",
" relative_pos = d_model**2 + 2*d_model # relative keys, content bias, relative bias\n",
" attproj = d_model**2 + d_model\n",
" ffw = d_model*ffw_size + ffw_size\n",
" ffwproj = ffw_size*d_model + d_model\n",
" layernorms = 2*2*d_model\n",
" # dense\n",
" ln_f = 2*d_model\n",
" dense = d_model*vocab_size # note: no bias here\n",
" # note: embeddings are not included in the param count!\n",
" total_params = num_layers*(attention + relative_pos + attproj + ffw + ffwproj + layernorms) + ln_f + dense\n",
" return total_params\n",
"\n"
]
},
{
"cell_type": "code",
2023-01-06 02:13:04 +00:00
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[44000000.0, 512, 2048, 64, 8, 8]"
]
},
2023-01-06 02:13:04 +00:00
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Load in all the 50 Chinchilla models on the last page of the paper\n",
"import json\n",
"chinchilla_models_txt = '[[44000000.0, 512, 2048, 64, 8, 8], [57000000.0, 576, 2304, 64, 9, 9], [74000000.0, 640, 2560, 64, 10, 10], [90000000.0, 640, 2560, 64, 10, 13], [106000000.0, 640, 2560, 64, 10, 16], [117000000.0, 768, 3072, 64, 12, 12], [140000000.0, 768, 3072, 64, 12, 15], [163000000.0, 768, 3072, 64, 12, 18], [175000000.0, 896, 3584, 64, 14, 14], [196000000.0, 896, 3584, 64, 14, 16], [217000000.0, 896, 3584, 64, 14, 18], [251000000.0, 1024, 4096, 64, 16, 16], [278000000.0, 1024, 4096, 64, 16, 18], [306000000.0, 1024, 4096, 64, 16, 20], [425000000.0, 1280, 5120, 128, 10, 18], [489000000.0, 1280, 5120, 128, 10, 21], [509000000.0, 1408, 5632, 128, 11, 18], [552000000.0, 1280, 5120, 128, 10, 24], [587000000.0, 1408, 5632, 128, 11, 21], [632000000.0, 1536, 6144, 128, 12, 19], [664000000.0, 1408, 5632, 128, 11, 24], [724000000.0, 1536, 6144, 128, 12, 22], [816000000.0, 1536, 6144, 128, 12, 25], [893000000.0, 1792, 7168, 128, 14, 20], [1018000000.0, 1792, 7168, 128, 14, 23], [1143000000.0, 1792, 7168, 128, 14, 26], [1266000000.0, 2048, 8192, 128, 16, 22], [1424000000.0, 2176, 8704, 128, 17, 22], [1429000000.0, 2048, 8192, 128, 16, 25], [1593000000.0, 2048, 8192, 128, 16, 28], [1609000000.0, 2176, 8704, 128, 17, 25], [1731000000.0, 2304, 9216, 128, 18, 24], [1794000000.0, 2176, 8704, 128, 17, 28], [2007000000.0, 2304, 9216, 128, 18, 28], [2283000000.0, 2304, 9216, 128, 18, 32], [2298000000.0, 2560, 10240, 128, 20, 26], [2639000000.0, 2560, 10240, 128, 20, 30], [2980000000.0, 2560, 10240, 128, 20, 34], [3530000000.0, 2688, 10752, 128, 22, 36], [3802000000.0, 2816, 11264, 128, 22, 36], [4084000000.0, 2944, 11776, 128, 22, 36], [4516000000.0, 3072, 12288, 128, 24, 36], [6796000000.0, 3584, 14336, 128, 28, 40], [9293000000.0, 4096, 16384, 128, 32, 42], [11452000000.0, 4352, 17408, 128, 32, 47], [12295000000.0, 4608, 18432, 128, 36, 44], [12569000000.0, 4608, 18432, 128, 32, 47], [13735000000.0, 4864, 19456, 128, 32, 47], [14940000000.0, 4992, 19968, 128, 32, 49], [16183000000.0, 5120, 20480, 128, 40, 47]]'\n",
"chilchilla_models = json.loads(chinchilla_models_txt) # all 50 models\n",
"chilchilla_models[0] # tuples of params, d_model, ffw_size, kv_size, n_heads, n_layers from Table A9"
]
},
{
"cell_type": "code",
2023-01-06 02:13:04 +00:00
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"our estimated params: 12296.1623M, chinchilla params: 12295.0000M, d_model: 4608, n_heads: 36, n_layers: 44\n",
"our estimated params: 13124.4826M, chinchilla params: 12569.0000M, d_model: 4608, n_heads: 32, n_layers: 47\n",
"our estimated params: 14614.4279M, chinchilla params: 13735.0000M, d_model: 4864, n_heads: 32, n_layers: 47\n",
"our estimated params: 16037.5039M, chinchilla params: 14940.0000M, d_model: 4992, n_heads: 32, n_layers: 49\n",
"our estimated params: 16184.4582M, chinchilla params: 16183.0000M, d_model: 5120, n_heads: 40, n_layers: 47\n"
]
}
],
"source": [
"for m in chilchilla_models[-5:]: # only print last 5 models of the table\n",
" p, d, f, k, h, l = m\n",
" nparams = chinchilla_params(seq_len = 1024, vocab_size = 32000, d_model = d, num_heads = h, num_layers = l, ffw_size=f)\n",
" print(f\"our estimated params: {nparams/1e6:.4f}M, chinchilla params: {p/1e6:.4f}M, d_model: {d}, n_heads: {h}, n_layers: {l}\")"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"We are almost able to reproduce the parameter counts for the Chinchilla models.\n",
"\n",
"Now turning to FLOPs:"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## flops"
]
},
{
"cell_type": "code",
2023-01-06 02:13:04 +00:00
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"def chinchilla_flops(seq_len, vocab_size, d_model, num_heads, num_layers, ffw_size):\n",
" \"\"\" \n",
" Calculate total number of FLOPs, see Chinchilla \n",
" paper Appendix F as reference: https://arxiv.org/pdf/2203.15556.pdf\n",
" \"\"\" \n",
" key_size = d_model // num_heads\n",
"\n",
" # embeddings\n",
" embeddings = 2 * seq_len * vocab_size * d_model\n",
"\n",
" # attention\n",
" # key, query, value projections\n",
" attention = 2 * 3 * seq_len * d_model * (key_size * num_heads)\n",
" # key @ query logits\n",
" attlogits = 2 * seq_len * seq_len * (key_size * num_heads)\n",
" # softmax\n",
2023-01-06 02:13:04 +00:00
" attsoftmax = 3 * num_heads * seq_len * seq_len # 3* is for subtract (max), exp, divide (?)\n",
" # softmax @ value reductions\n",
" attvalue = 2 * seq_len * seq_len * (key_size * num_heads)\n",
" # final linear\n",
" attlinear = 2 * seq_len * (key_size * num_heads) * d_model\n",
" att = attention + attlogits + attsoftmax + attvalue + attlinear\n",
" # feed forward\n",
" dense = 2 * seq_len * (d_model * ffw_size + d_model * ffw_size)\n",
"\n",
" # logits\n",
" logits = 2 * seq_len * d_model * vocab_size\n",
" \n",
" # this is what you'd expect:\n",
" # forward_flops = embeddings + num_layers * (att + dense) + logits\n",
" # but:\n",
" # per author correspondence apparently there is typo in the paper,\n",
" # they do not count embeddings and logits to repro table 4. So instead:\n",
" forward_flops = num_layers * (att + dense)\n",
" backward_flops = 2 * forward_flops # as in Kaplan et al. 2020\n",
" total_flops = forward_flops + backward_flops\n",
"\n",
" return total_flops\n"
]
},
{
"cell_type": "code",
2023-01-06 02:13:04 +00:00
"execution_count": 7,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>seq_len</th>\n",
" <th>vocab_size</th>\n",
" <th>d_model</th>\n",
" <th>num_heads</th>\n",
" <th>num_layers</th>\n",
" <th>ffw_size</th>\n",
" <th>N</th>\n",
" <th>F</th>\n",
" <th>approx_flops</th>\n",
" <th>chinch_flops</th>\n",
" <th>ratio</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>2048</td>\n",
" <td>32000</td>\n",
" <td>640</td>\n",
" <td>10</td>\n",
" <td>10</td>\n",
" <td>2560</td>\n",
" <td>73825280</td>\n",
" <td>929877196800</td>\n",
" <td>907165040640</td>\n",
" <td>9.298772e+11</td>\n",
" <td>1.025036</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>2048</td>\n",
" <td>32000</td>\n",
" <td>1024</td>\n",
" <td>16</td>\n",
" <td>20</td>\n",
" <td>4096</td>\n",
" <td>305707008</td>\n",
" <td>4135248199680</td>\n",
" <td>3756527714304</td>\n",
" <td>4.135248e+12</td>\n",
" <td>1.100817</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>2048</td>\n",
" <td>32000</td>\n",
" <td>1280</td>\n",
" <td>10</td>\n",
" <td>24</td>\n",
" <td>5120</td>\n",
" <td>552604160</td>\n",
" <td>7353453772800</td>\n",
" <td>6790399918080</td>\n",
" <td>7.353454e+12</td>\n",
" <td>1.082919</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>2048</td>\n",
" <td>32000</td>\n",
" <td>1792</td>\n",
" <td>14</td>\n",
" <td>26</td>\n",
" <td>7168</td>\n",
" <td>1143453696</td>\n",
" <td>14670316437504</td>\n",
" <td>14050759016448</td>\n",
" <td>1.467032e+13</td>\n",
" <td>1.044094</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>2048</td>\n",
" <td>32000</td>\n",
" <td>2048</td>\n",
" <td>16</td>\n",
" <td>28</td>\n",
" <td>8192</td>\n",
" <td>1593126912</td>\n",
" <td>20220437594112</td>\n",
" <td>19576343494656</td>\n",
" <td>2.022044e+13</td>\n",
" <td>1.032902</td>\n",
" </tr>\n",
" <tr>\n",
" <th>5</th>\n",
" <td>2048</td>\n",
" <td>32000</td>\n",
" <td>3584</td>\n",
" <td>28</td>\n",
" <td>40</td>\n",
" <td>14336</td>\n",
" <td>6796274688</td>\n",
" <td>83021046743040</td>\n",
" <td>83512623366144</td>\n",
" <td>8.302105e+13</td>\n",
" <td>0.994114</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" seq_len vocab_size d_model num_heads num_layers ffw_size N \\\n",
"0 2048 32000 640 10 10 2560 73825280 \n",
"1 2048 32000 1024 16 20 4096 305707008 \n",
"2 2048 32000 1280 10 24 5120 552604160 \n",
"3 2048 32000 1792 14 26 7168 1143453696 \n",
"4 2048 32000 2048 16 28 8192 1593126912 \n",
"5 2048 32000 3584 28 40 14336 6796274688 \n",
"\n",
" F approx_flops chinch_flops ratio \n",
"0 929877196800 907165040640 9.298772e+11 1.025036 \n",
"1 4135248199680 3756527714304 4.135248e+12 1.100817 \n",
"2 7353453772800 6790399918080 7.353454e+12 1.082919 \n",
"3 14670316437504 14050759016448 1.467032e+13 1.044094 \n",
"4 20220437594112 19576343494656 2.022044e+13 1.032902 \n",
"5 83021046743040 83512623366144 8.302105e+13 0.994114 "
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Now try reproduce Table A4 from Chinchilla paper Appendix, \n",
"# comparing accurate flops above to approximate flops F = 6*N*D\n",
"# note Chinchilla mentions using vocab_size = 32K\n",
"\n",
"chilchilla_models_table4 = [\n",
" [10, 640, 2560, 10, 64],\n",
" [20, 1024, 4096, 16, 64],\n",
" [24, 1280, 5120, 10, 128 ],\n",
" [26, 1792, 7168, 14, 128 ],\n",
" [28, 2048, 8192, 16, 128],\n",
" [40, 3584, 14336, 28, 128]\n",
"]\n",
"\n",
"rows = []\n",
"for num_layers, d_model, ffw_size, num_heads, _ in chilchilla_models_table4:\n",
"\n",
" args = dict(seq_len = 2048, vocab_size = 32000, d_model = d_model, \n",
" num_heads = num_heads, num_layers = num_layers, ffw_size=ffw_size)\n",
"\n",
" D = args['seq_len'] # dataset size (cancels anyway, for the purposes of the ratio calculation below)\n",
" N = chinchilla_params(**args)\n",
" F = chinchilla_flops(**args)\n",
"\n",
" approx_flops = 6*D*N # approximate flops\n",
" chinch_flops = F * (float(D) / args['seq_len']) # exact flops according to Chinchilla paper calculations\n",
"\n",
" # print('---')\n",
" # print(f\"params: {N/1e6:.2f}M\")\n",
" # print(f\"approx flops: {approx_flops/1e9:.2f}B\")\n",
" # print(f\"chinchilla flops: {chinch_flops/1e9:.2f}B\")\n",
" # print(f\"ratio (chinchilla / approx): {chinch_flops / approx_flops:.2f}\")\n",
"\n",
" # first copy all keyvalues from args into out\n",
" out = {k:v for k,v in args.items()}\n",
" # then add the calculated values\n",
" out['N'] = N\n",
" out['F'] = F\n",
" out['approx_flops'] = approx_flops\n",
" out['chinch_flops'] = chinch_flops\n",
" out['ratio'] = chinch_flops / approx_flops\n",
" rows.append(out)\n",
"\n",
"# make a pandas dataframe from rows\n",
"df = pd.DataFrame(rows)\n",
"df"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"Pretty good match! Except the param counts are still not perfectly accurate."
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## Scaling Laws: Approach 3\n",
"\n",
"In their \"Aproach 3\", Chinchilla paper fits a function L(N,D) to approximate the final loss gives the model size and the data size. Here is the final fit:"
]
},
{
"cell_type": "code",
2023-01-06 02:13:04 +00:00
"execution_count": 8,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<matplotlib.colorbar.Colorbar at 0x7f1bd262a9e0>"
]
},
2023-01-06 02:13:04 +00:00
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA8UAAAHWCAYAAABe7ytwAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8o6BhiAAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOy9eXwb9bnv//mOdslavMu7HTv74jgLIZAWKCkJUJYWKPTQS+FwoJcboJD7u72Hvs4BSin0dDvcnkPLgVOgO5SwUxpKA4QtqxMncRYn3ld5k619saT5/TGSLNkzskYayUu+b17fVxtp5pnvjO35zjPP83wewrIsCwqFQqFQKBQKhUKhUM5DmNmeAIVCoVAoFAqFQqFQKLMFdYopFAqFQqFQKBQKhXLeQp1iCoVCoVAoFAqFQqGct1CnmEKhUCgUCoVCoVAo5y3UKaZQKBQKhUKhUCgUynkLdYopFAqFQqFQKBQKhXLeQp1iCoVCoVAoFAqFQqGct1CnmEKhUCgUCoVCoVAo5y3UKaZQKBQKhUKhUCgUynkLdYoplAS8+OKLIISgs7NztqdCoVAoFMqCZiGuuT/5yU+waNEiyGQyrF27FgBQXV2N22+/fVbnRaFQ4qFOMYVCoVAoFArlvOFXv/oVbrrpJlRWVoIQktBBHR8fx913343CwkLodDpcdtllOHLkSFLH+dvf/obvfve7uPjii/HCCy/giSeekOgMKBSK1MhnewIUCoVCoVAoFEq2+Ld/+zc4HA5ccMEFGBgYENwuFArh6quvxrFjx/B//s//QUFBAX75y1/i0ksvRWNjIxYvXpzwOB988AEYhsGvf/1rKJVKqU+DQqFICHWKKRQKhUKhUCjnDXv37o1GiXNycgS327VrFz7//HO88soruPHGGwEAX//617FkyRI88sgj+OMf/5jwOENDQ9BoNNQhplDmATR9mkIRyS9/+UusXLkSKpUKpaWl2LFjB8bHx+O2OXfuHG644QaYzWao1WqUl5fjlltugc1mi27z/vvvY8uWLTCZTMjJycHSpUvxve99L8tnQ6FQKBTK3CaZdRcAnn76aSxatAgajQYXXHABPvnkE1x66aW49NJL47arqqoCIWTG4+7atQvFxcX42te+Fv2ssLAQX//61/Hmm2/C5/MJ7ksIwQsvvACXywVCCAghePHFFwW3b29vx0033YS8vDxotVpceOGF+Mtf/hK3zUcffQRCCF5++WV873vfg9lshk6nw7XXXouenp64bZN5DqFQKJPQSDGFIoJHH30U3//+97F161bcc889aGlpwa9+9SscOnQIn332GRQKBfx+P7Zt2wafz4f77rsPZrMZfX19eOeddzA+Pg6j0YiTJ0/iK1/5CtasWYPHHnsMKpUKra2t+Oyzz2b7FCkUCoVCmTMks+4CXJ3wvffeiy984Qt48MEH0dnZieuvvx65ubkoLy9P6dhHjx7FunXrwDDxMaQLLrgAzz77LM6ePYvVq1fz7vu73/0Ozz77LA4ePIj//u//BgBcdNFFvNsODg7ioosugtvtxv3334/8/Hz85je/wbXXXotdu3bhq1/9atz2P/zhD0EIwf/9v/8XQ0NDeOqpp7B161Y0NTVBo9Ek9RxCoVCmwFIoFEFeeOEFFgDb0dHBDg0NsUqlkr3iiivYYDAY3eY///M/WQDs888/z7Isyx49epQFwL7yyiuCdv/93/+dBcAODw9n/BwoFAqFQpkPxK65LMsmve76fD42Pz+f3bhxIzsxMRHd7sUXX2QBsJdccongMXU6Hfutb31L8Lt//Md/nPb5X/7yFxYAu3v37oTn861vfYvV6XTTPq+qqoo75gMPPMACYD/55JPoZw6Hg62pqWGrq6uj5/7hhx+yANiysjLWbrdHt/3zn//MAmD/3//7fyzLJvccQqFQ4qHp0xRKkvz973+H3+/HAw88EPfW+K677oLBYIimOUXewL733ntwu928tkwmEwDgzTffRCgUyuzEKRQKhUKZhyS77h4+fBijo6O46667IJdPJkHeeuutyM3NTfn4Ho8HKpVq2udqtTr6vRS8++67uOCCC7Bly5boZzk5Obj77rvR2dmJU6dOxW1/2223Qa/XR/994403oqSkBO+++y6A5J5DKBRKPNQpplCSpKurCwCwdOnSuM+VSiUWLVoU/b6mpgY7d+7Ef//3f6OgoADbtm3D008/HVfHc/PNN+Piiy/GP/3TP6G4uBi33HIL/vznP1MHmUKhUCiUMMmuu5H/rauri9tOLpejuro65eNrNBreumGv1xv9Xgq6urqmnSMALF++PPp9LFNVrwkhqKuri/Z3TuY5hEKhxEOdYgolA/zsZz/D8ePH8b3vfQ8ejwf3338/Vq5cid7eXgDcQvrxxx/j73//O/7H//gfOH78OG6++WZ8+ctfRjAYnOXZUygUCoVCKSkp4W3ZFPmstLQ021NKmpmeQygUSjzUKaZQkqSqqgoA0NLSEve53+9HR0dH9PsIq1evxr/8y7/g448/xieffIK+vj4888wz0e8ZhsHll1+On//85zh16hR++MMf4oMPPsCHH36Y+ZOhUCgUCmWOk+y6G/nf1tbWuO0CgUA0epoKa9euxZEjR6ZlcR04cABarRZLlixJ2XYsVVVV084RAM6cORP9PpZz587F/ZtlWbS2tk6Lis/0HEKhUCahTjGFkiRbt26FUqnEL37xC7AsG/3817/+NWw2G66++moAgN1uRyAQiNt39erVYBgmmoZltVqn2V+7di0AJGzxQKFQKBTK+UKy6+6GDRuQn5+P5557Lm79/cMf/oCxsbGUj3/jjTdicHAQr732WvSzkZERvPLKK7jmmmt4641T4aqrrsLBgwexb9++6GculwvPPvssqqursWLFirjtf/vb38LhcET/vWvXLgwMDODKK68EkNxzCIVCiYe2ZKJQkqSwsBAPPfQQvv/972P79u249tpr0dLSgl/+8pfYuHEjvvnNbwIAPvjgA9x777246aabsGTJEgQCAfzud7+DTCbDDTfcAAB47LHH8PHHH+Pqq69GVVUVhoaG8Mtf/hLl5eVxQhsUCoVCoZyvJLvuKpVKPProo7jvvvvwpS99CV//+tfR2dmJF198EbW1tdN6Er/99ts4duwYAGBiYgLHjx/H448/DgC49tprsWbNGgCcU3zhhRfijjvuwKlTp1BQUIBf/vKXCAaD+P73vy/Zef7zP/8z/vSnP+HKK6/E/fffj7y8PPzmN79BR0cHXn311WktofLy8rBlyxbccccdGBwcxFNPPYW6ujrcddddAJJ7DqFQKFOYZfVrCmVOM7U9BMtyrSCWLVvGKhQKtri4mL3nnnvYsbGx6Pft7e3sP/7jP7K1tbWsWq1m8/Ly2Msuu4z9+9//Ht1mz5497HXXXceWlpaySqWSLS0tZb/xjW+wZ8+ezeLZUSgUCoUyd+Bbc1l25nU3wi9+8Qu2qqqKValU7AUXXMB+9tln7Pr169nt27fHbfetb32LBcA7XnjhhbhtrVYre+edd7L5+fmsVqtlL7nkEvbQoUNJnU+yLZlYlmXb2trYG2+8kTWZTKxarWYvuOAC9p133onbJtKS6U9/+hP70EMPsUVFRaxGo2GvvvpqtqurK7pdMs8hFAolHsKyMfkoFAqFQqFQKBTKAiAUCqGwsBBf+9rX8Nxzz832dNLmo48+wmWXXYZXXnkFN95442xPh0JZUNCaYgqFQqFQKBTKvMbr9WJqnOe3v/0trFYrLr300tmZFIVCmTfQmmIKhUKhUCgUyrxm//79ePDBB3HTTTchPz8fR44cwa9//WusWrUKN91002xPj0KhzHGoU0yhUCgUCoVCmddUV1ejoqICv/jFL2C1WpGXl4fbbrsNP/rRj6BUKmd7ehQKZY4zq+nTH3/8Ma655hqUlpaCEII33ngj7vvXXnsNV1xxBfLz80EIQVNTU1J2X3nlFSxbtgxqtRqrV6/Gu+++K/3kKRQKhUKhzAhd6ynZoLq6Gm+99RYsFgv8fj8sFguef/55FBUVzfbUJOPSSy8Fy7K0nphCyQCz6hS7XC7U19f
"text/plain": [
"<Figure size 1200x500 with 4 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"def L(N, D):\n",
" \"\"\" \n",
" Approximates loss given N parameters and D dataset size (in tokens),\n",
" per Chinchilla paper.\n",
" \"\"\"\n",
" E = 1.69 # entropy of natural language, limit of infinite model on infinite data\n",
" A = 406.4\n",
" B = 410.7\n",
" alpha = 0.34\n",
" beta = 0.28\n",
" return A / (N ** alpha) + B / (D ** beta) + E\n",
"\n",
"ns = 10 ** np.arange(7, 11, step=2**-4) # model sizes from 10M to 100B\n",
"ds = 10 ** np.arange(9, 12, step=2**-4) # dataset sizes from 1B to 1T\n",
"plt.figure(figsize=(12, 5))\n",
"plt.subplot(121)\n",
"# create a 2D countour plot of loss L as a function of model size and dataset size in ns,ds\n",
"loss2d = np.log10(np.array([[L(n, d) for d in ds] for n in ns]))\n",
"plt.imshow(loss2d, extent=[9, 12, 7, 11], origin='lower', alpha=0.5)\n",
"plt.contour(loss2d, levels=30, extent=[9, 12, 7, 11], origin='lower')\n",
"plt.xlabel('log10(dataset size)')\n",
"plt.ylabel('log10(model size)')\n",
"plt.title('loss')\n",
"plt.colorbar()\n",
"# plot the compute for each point, which is a deterministic function: flops = 6*N*D\n",
"plt.subplot(122)\n",
"compute2d = np.log10(np.array([[6*n*d for d in ds] for n in ns]))\n",
"plt.imshow(compute2d, extent=[9, 12, 7, 11], origin='lower', alpha=0.5)\n",
"plt.contour(compute2d, levels=30, extent=[9, 12, 7, 11], origin='lower')\n",
"plt.xlabel('log10(dataset size)')\n",
"plt.ylabel('log10(model size)')\n",
"plt.title('log10 flops')\n",
"plt.colorbar()"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"Ok so given any N,D we can estimate both: 1) the loss, and 2) the total flops. Now we want to solve the following problem: Given a specific budget of flops C, find: N_opt, D_opt = argmin_{FLOPs(N,D) = C} L(N, D). i.e. how big of a model should we train and for how many tokens?"
]
},
{
"cell_type": "code",
2023-01-06 02:13:04 +00:00
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"best model size: 316.23M\n",
"best dataset size: 11.65B\n"
]
},
{
"data": {
"text/plain": [
"Text(0, 0.5, 'loss')"
]
},
2023-01-06 02:13:04 +00:00
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAS8AAAEqCAYAAABEE9ZrAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8o6BhiAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAuRElEQVR4nO3dd3xUVf7/8dekTdpMGpCEFEACCQETQhRIQEAFqVJ2XV0WBV3Erx1FccV1v7qiht8itoW14E/UXVlAFNgfRdpSFEJvoQiElkIKhJDJJKTN3N8fQ0aiBEKYzJ3yeT4e9/FI7tyZ+SRh3px77rnnaBRFURBCCCfjoXYBQgjRHBJeQginJOElhHBKEl5CCKck4SWEcEoSXkIIpyThJYRwSl5qF2BvZrOZs2fPotPp0Gg0apcjhPgFRVEoLy+nbdu2eHg03r5yu/A6e/YsMTExapchhLiO3NxcoqOjG33c7cJLp9MBll+MXq9XuRoXVlEBbdtavj57FgIC1K1HOA2DwUBMTIz1s9oYtwuv+lNFvV4v4dWSPD1//lqvl/ASN+x63TrSYS+EcEoSXkIIpyThJYRwShJeQginJOElhHBKEl5CCKck4SWEaHGf/3iKZfvyMVTV2uw1JbyEEC2q1mRm1pqjTF6wj5ySSpu9roSXEKJF7cu9SEWNidAAHxIjbTcwXMJLCNGifjh+HoD0jmF4eNhuMgQJLyFEi/rx+DkA7ujUyqavK+ElhGgxhqpa9ueVAdC3U2ubvraElxCixWSeKMFkVrilVQBRwX42fW0JLyFEi/nxcn9XXxufMoKElxCiBf2YbQmvPnESXkIIJ5FXWsmp8xV4emhI6xhm89eX8BJCtIgtl1tdydFB6H29bf76El5CiBbxg7W/y7ZXGetJeAkhbM5sVth6ogSw/fiuehJeQgibO1xg4EJFDYFaL7rHBLfIe0h4CSFsbtMxy6j63reE4u3ZMjEj4SWEsLn68Oof36bF3kPCSwhhU+VVtew5UwpA/xbqrAcJLyGEjW09UULd5VuCYsP8W+x9JLyEEDZVf8rYr3PLtbpAwksIYUOKorDpaH1/l4SXEMJJnDhXQf7FS/h4edC7g+1vCbqShJcQwmbqTxl7dQjFz8ezRd9LwksIYTPWIRIt3N8FEl5CCBupqjWx/aTllqABLdzfBRJeQggb2XayhOo6M1HBfnRsHdji7yfhJYSwiSuHSGg0tlslqDESXkIIm9jwUzFgn/4ucKDwmjFjBhqNhueee67RY7744gs0Gk2DzdfX135FCiGu6uQ5I6dLKvH21LTIfPVX42WXd7mOnTt38sknn5CUlHTdY/V6PUePHrV+b4/mqRDi2v57udXV+5YwArX2iRXVW15Go5Fx48Yxd+5cQkJCrnu8RqMhIiLCuoWHh9uhSiHEtdSH150tOIvEL6keXk899RTDhw9n4MCBTTreaDTSrl07YmJiGDVqFIcOHbrm8dXV1RgMhgabEMJ2yqtq2XHqAgB3JbhJeC1YsIA9e/aQkZHRpOPj4+P5/PPPWbZsGf/6178wm82kp6eTl5fX6HMyMjIICgqybjExMbYqXwiBZW3G+lkk2rcKsNv7qhZeubm5TJ48ma+//rrJne5paWmMHz+e7t27079/f7777jtat27NJ5980uhzpk2bRllZmXXLzc211Y8ghODnU0Z7trpAxQ773bt3U1xcTI8ePaz7TCYTmzdvZvbs2VRXV+Ppee17o7y9vUlJSSE7O7vRY7RaLVqt1mZ1CyF+ZjYrbDjqZuF19913k5WV1WDfI488QkJCAn/605+uG1xgCbusrCyGDRvWUmUKIa4hK7+M80bLQhu3tQ+163urFl46nY5u3bo12BcQEEBYWJh1//jx44mKirL2ib3xxhv07t2buLg4Ll68yMyZMzlz5gyPPvqo3esXQvx8ynhHp1b4eNm3F8ohxnk1JicnBw+Pn38hpaWlTJo0icLCQkJCQkhNTWXr1q0kJiaqWKUQ7kutU0YAjaIoit3fVUUGg4GgoCDKysrQ6/Vql+O6Kiog8PLNuUYjBNjvKpSwj2JDFT3fXo9GAzteGUhrnW36lpv6GVV9nJcQwjmtO2JpdXWPCbZZcN0ICS8hRLOsPVwIwMAu6tzlIuElhLhhFdV1bDlhmXjwnkQJLyGEk/jh+Dlq6sy0C/Mnrk3LTzx4NRJeQogbtuZwEQCDuoSrNrOLhJcQ4obUmczWiQcHqnTKCBJeQogbtPtMKaWVtQT7e3Nbu+tPY9VSJLyEEDdk7eVTxrvi2+DlqV6ESHgJIZpMURTWHrnc36XiKSNIeAkhbkB2sZEzJZX4eHpwh50W2miMhJcQosnqrzKmdbTfXPWNkfASQjTZ9wcto+qHdItQuRIJLyFEE+WVVpKVX4aHRv3+LpDwEkI00epDllPG29uH0ipQ/dmJJbyEEE3y/cECwDFOGUHCSwjRBMXlVew6UwrA4K4SXkIIJ7H2cBGKAskxwbQN9lO7HEDCSwjRBNarjA7S6gIJLyHEdZRV1pJ5ee6uwV3Vv8pYT8JLCHFN644UUWdWiA/XcUtrdebuuhoJLyHENa26fMo42EGuMtaT8BJCNKq8qpbNx84BMFTCSwjhLNYdKaLGZKZj6wASInRql9OAhJcQolErDlgGpg5PaqvadM+NkfASQlxV2aVaNh87D8CIpEiVq/k1CS8hxFWtO2w5ZezUJpDO4Y51yggSXkKIRqzIqj9ldLxWF0h4CSGuoqyylh+OW64yDr9VwksI4STWHC6k1mQZmNrJAU8ZQcJLCHEV9aeMwxy01QUSXkKIX7hYWcOPxy1XGYcnOdbA1CtJeAkhGliZVUidWSEhQkdcG8c8ZQQJLyHELyzblw/A6JQolSu5NgkvIYTV2YuX2HH6AgD3JrdVuZprk/ASQlgtP3AWRYGe7UOJcpAZUxsj4SWEsPrP/rMAjOzu2K0ukPASQlyWXWzkYL4BLw+NQw+RqCfhJYQAfm519evcmtAAH5WruT4JLyEEiqLwn8tXGUc5wSkjSHgJIYADeWWcLqnEz9uTgV0cZ5GNa5HwEkKwZK+l1TUoMZwArZfK1TSNhJcQbq7WZLb2d43p4dgDU68k4SWEm9t49BwXKmpordNyR1wrtctpMgkvIdzcd3vyABjdvS1ens4TCQ5T6YwZM9BoNDz33HPXPO6bb74hISEBX19fbr31VlauXGmfAoVwQRcra1h/pBiA3/SIVrmaG+MQ4bVz504++eQTkpKSrnnc1q1bGTt2LBMnTmTv3r2MHj2a0aNHc/DgQTtVKoRrWX6ggBqTmS6RerpE6tUu54aoHl5Go5Fx48Yxd+5cQkJCrnnsBx98wJAhQ5g6dSpdunRh+vTp9OjRg9mzZzf6nOrqagwGQ4NNCGFRf8r4WyfqqK+neng99dRTDB8+nIEDB1732MzMzF8dN3jwYDIzMxt9TkZGBkFBQdYtJibmpmsWwhWcPGdkT85FPD00TnEv4y+pGl4LFixgz549ZGRkNOn4wsJCwsMbDqALDw+nsLCw0edMmzaNsrIy65abm3tTNQvhKurHdvXr1Io2Ol+Vq7lxqo1Gy83NZfLkyaxduxZf35b7xWm1WrRabYu9vhDOyGRW+G6PJbzGOFlHfT3Vwmv37t0UFxfTo0cP6z6TycTmzZuZPXs21dXVeHp6NnhOREQERUVFDfYVFRUREeG482wL4Yi2ZJ8n/+Il9L5e3JPoHLcD/ZJqp4133303WVlZ7Nu3z7rddtttjBs3jn379v0quADS0tJYv359g31r164lLS3NXmUL4RIW7bJ0n4xOicLX+9efNWegWstLp9PRrVu3BvsCAgIICwuz7h8/fjxRUVHWPrHJkyfTv39/Zs2axfDhw1mwYAG7du3i008/tXv9Qjir0ooa1hyynMHcf5vzXsBS/WrjteTk5FBQUGD9Pj09nfnz5/Ppp5+SnJzM4sWLWbp06a9CUAjRuGX78qkxmUmM1NMtKkjtcppNoyiKonYR9mQwGAgKCqKsrAy93rkG5TmVigoIDLR8bTRCQIC69QjAMm/XsA9/5EiBgb+O7MqE9PZql/Q
"text/plain": [
"<Figure size 300x300 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"c = 2.21e19 # target compute budget (usually know this because we know how many GPU for how long go brrr)\n",
"# (I got this flop number from row 1 of Table A3)\n",
"# sweep model sizes from 10M to 100B\n",
"ns = 10 ** np.arange(7, 11, step=2**-4)\n",
"# using C = 6*N*D, solve for D that maintains the compute budget c\n",
"ds = c / (6 * ns)\n",
"# evaluate the loss in each case\n",
"losses = L(ns, ds)\n",
"# find the argmin\n",
"best = np.argmin(losses)\n",
"print(f\"best model size: {ns[best]/1e6:.2f}M\")\n",
"print(f\"best dataset size: {ds[best]/1e9:.2f}B\")\n",
"# plot the loss\n",
"plt.figure(figsize=(3,3))\n",
"plt.plot(ns, losses)\n",
"plt.xscale('log')\n",
"# plot a vertical bar at the best model size\n",
"plt.axvline(ns[best], color='red')\n",
"plt.xlabel('model size')\n",
"plt.ylabel('loss')"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"In the plot above, basically the models on the left of best are too small and trained for too long. The models on the right of best are way too large and trained for too little. The model at the red line is just right.\n",
"\n",
"Now, the Chinchilla paper says that best model size for this flop budget is 400M params and 9.2B tokens (instead of 316M params 11.65B params) so there is some unresolved disagreement here too..."
]
},
{
"cell_type": "code",
2023-01-06 02:13:04 +00:00
"execution_count": 10,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"2304"
]
},
2023-01-06 02:13:04 +00:00
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Calculate the Chinchilla optimal models for a range of compute budgets\n",
"\n",
"# sweep over compute budgets from 1e17 to 1e26\n",
"cs = 10 ** np.arange(17, 26, step=2**-8)\n",
"models = []\n",
"for c in cs:\n",
" # sweep over model sizes\n",
" ns = 10 ** np.arange(7, 14, step=2**-8)\n",
" # the dataset sizes that would maintain the given compute budget\n",
" ds = c / (6 * ns)\n",
" # losses at each point\n",
" losses = L(ns, ds)\n",
" # n,d for the best model\n",
" best = np.argmin(losses)\n",
" models.append((c, ns[best], ds[best])) # c, n, d tuple log\n",
"\n",
"len(models)"
]
},
{
"cell_type": "code",
2023-01-06 02:13:04 +00:00
"execution_count": 11,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"closest model found:\n",
"model size: 399.54M\n",
"dataset size: 14.43B\n",
"flops: 3.459892e+19\n",
"loss: 2.76\n"
]
}
],
"source": [
"query_model_size = 400e6\n",
"ns = np.array([n for c, n, d in models])\n",
"ds = np.array([d for c, n, d in models])\n",
"# find the index of the closest model size in ns\n",
"ix = np.argmin(np.abs(ns - query_model_size))\n",
"# retrieve the corresponding params, flops, and data size\n",
"print(\"closest model found:\")\n",
"print(f\"model size: {ns[ix]/1e6:.2f}M\")\n",
"print(f\"dataset size: {ds[ix]/1e9:.2f}B\")\n",
"print(f\"flops: {6*ns[ix]*ds[ix]:e}\")\n",
"print(f\"loss: {L(ns[ix], ds[ix]):.2f}\")"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"This should have come out as 9.2B according to Table A3 in Chinchilla paper, per my understanding of it."
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## Scaling Laws: Approach 2\n",
"\n",
"Approach 2 is probably my favorite one because it fixes a flop budget and runs a number of model/dataset sizes, measures the loss, fits a parabolla, and gets the minimum. So it's a fairly direct measurement of what we're after. The best way to then calculate the compute-optimal number of tokens for any given model size, as an example, is via simple interpolation."
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
"# Approach 1 numbers\n",
"# # parameters, tokens\n",
"# raw = [\n",
"# [400e6, 8e9],\n",
"# [1e9, 20.2e9],\n",
"# [10e9, 205.1e9],\n",
"# [67e9, 1.5e12],\n",
"# [175e9, 3.7e12],\n",
"# [280e9, 5.9e12],\n",
"# [520e9, 11e12],\n",
"# [1e12, 21.2e12],\n",
"# [10e12, 216.2e12],\n",
"# ]\n",
"\n",
"# Approach 2 numbers\n",
"# parameters, tokens\n",
"raw = [\n",
" [400e6, 7.7e9],\n",
" [1e9, 20.0e9],\n",
" [10e9, 219.5e9],\n",
" [67e9, 1.7e12],\n",
" [175e9, 4.3e12],\n",
" [280e9, 7.1e12],\n",
" [520e9, 13.4e12],\n",
" [1e12, 26.5e12],\n",
" [10e12, 292.0e12],\n",
"]\n"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"y = 1.0409573169995892x + 0.9353887152390791\n"
]
}
],
"source": [
"# fit a line by linear regression to the raw data\n",
"import numpy as np\n",
"x = np.array([np.log10(x[0]) for x in raw])\n",
"y = np.array([np.log10(x[1]) for x in raw])\n",
"A = np.vstack([x, np.ones(len(x))]).T\n",
"m, c = np.linalg.lstsq(A, y, rcond=None)[0]\n",
"print(f\"y = {m}x + {c}\")"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAATkAAAFBCAYAAAAMkNhdAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8o6BhiAAAACXBIWXMAAA9hAAAPYQGoP6dpAABCEElEQVR4nO3deVxU5f4H8M8MwozIJrHL5q6IgIIQWVdUFslISculEs2fdQ3SIm9JpYQt2oaWUmY3Nc2VNNyJRUwtlFwgvYpbaIoIIsLAINvM8/vjNKPDADIwcIaZ7/v18gXnmTNnvucBPp71OQLGGAMhhOgpId8FEEJIR6KQI4ToNQo5Qoheo5AjhOg1CjlCiF6jkCOE6DUKOUKIXqOQI4ToNQo5Qoheo5AjBisoKAhBQUEG99mamDlzJtzd3dv0Xl1ZRwo50iYff/wxUlJS+C7joc6dO4f3338fV69e5bsUwhMKOdImXSnkEhISmgy5tLQ0pKWldX5RpFN147sAQvhiYmLCdwmkE9CWXAcpLCzE7Nmz4eTkBJFIhN69e2Pu3Lmoq6tTzvPXX3/h2WefhbW1NUxNTfHoo49i3759Kss5dOgQBAIBtm/fjoSEBPTq1Qvm5uaYPHkyKioqUFtbi9dffx12dnYwMzPDrFmzUFtbq7IMgUCAmJgYbNq0CQMHDoRYLIavry8OHz6sMl9zx1/ef/99CAQCleVJpVL88MMPEAgEEAgEmDlzpsq6v/TSS7C3t4dIJMKQIUOwdu3aVvVbQ0MDPvjgA/Tt2xcikQju7u5455131NbJ3d0dTz31FNLS0uDj4wOxWAwPDw/s3LlTOc/69evx7LPPAgBGjx6trPXQoUMA1I8ZaaOv161bhzFjxsDOzg4ikQgeHh745ptvWrXuTVH87JKTk+Hh4YHu3bsjMDAQZ86cAQB8++236NevH8RiMYKCgprcYk1OToavry+6d+8OGxsbvPDCCygsLFSbLyUlBZ6enhCLxfD09MTPP//cZE1yuRwrVqzAkCFDIBaLYW9vj1deeQV379596PqsXLkSQ4YMgampKXr27Ak/Pz9s3rxZs07RFCNaV1hYyJycnJipqSl7/fXX2erVq9miRYvY4MGD2d27dxljjN26dYvZ29szc3Nz9u6777LExETm7e3NhEIh27lzp3JZWVlZDADz8fFhgYGB7KuvvmLz5s1jAoGATZ06lU2fPp2Fh4ezpKQk9uKLLzIALCEhQaUeAMzT05PZ2NiwJUuWsE8++YS5ubmx7t27szNnzijni4qKYm5ubmrrEx8fzx78Vdm4cSMTiUTsiSeeYBs3bmQbN25kv//+u3K9nJ2dmYuLC1uyZAn75ptv2NNPP80AsOXLlz+076KiohgANnnyZJaUlMRmzJjBALCJEyeqzOfm5sYGDBjArKys2MKFC1liYiIbOnQoEwqFLC0tjTHG2JUrV9i8efMYAPbOO+8oa7116xZjjLFRo0axUaNGabWvR4wYwWbOnMmWL1/OVq5cyUJDQxkAtmrVKpX5Gn92cwAwLy8v5uLiwpYtW8aWLVvGLC0tmaurK1u1ahXz8PBgX3zxBXvvvfeYiYkJGz16tMr7161bxwCwESNGsOXLl7OFCxey7t27M3d3d+XvImOM/fLLL0woFDJPT0+WmJjI3n33XWZpacmGDBmi9jvxf//3f6xbt25szpw5bPXq1eztt99mPXr0YCNGjGB1dXXNruOaNWuUP9tvv/2Wffnll2z27Nls3rx5D+2H9qCQ6wAzZsxgQqGQ/fHHH2qvyeVyxhhjr7/+OgPAjhw5onytsrKS9e7dm7m7uzOZTMYYu/+H5+npqfILNG3aNCYQCFh4eLjK8gMDA9V+KQEwAOzEiRPKtmvXrjGxWMwiIyOVba0NOcYY69GjB4uKilKbd/bs2czR0ZGVlpaqtE+dOpVZWlqy6upqtfco5ObmMgDs//7v/1TaFyxYwACwgwcPKtvc3NwYALZjxw5lW0VFBXN0dGTDhg1TtiUnJzMALCsrS+3zmgu59vR1U+sXFhbG+vTp0+JnNwcAE4lErKCgQNn27bffMgDMwcGBSSQSZXtcXBwDoJy3rq6O2dnZMU9PT3bv3j3lfHv37mUA2OLFi5VtPj4+zNHRkZWXlyvb0tLSGACVdTxy5AgDwDZt2qRSZ2pqqlp743WcMGECGzJkyEPXWdtod1XL5HI5UlJSEBERAT8/P7XXFbt9+/fvh7+/Px5//HHla2ZmZnj55Zdx9epVnDt3TuV9M2bMgLGxsXI6ICAAjDG89NJLKvMFBATg+vXraGhoUGkPDAyEr6+vctrV1RUTJkzAL7/8AplM1vYVfgBjDDt27EBERAQYYygtLVX+CwsLQ0VFBU6dOtXs+/fv3w8AiI2NVWl/8803AUBtV97JyQmRkZHKaQsLC8yYMQOnT5/GrVu32rwe7enr7t27K7+vqKhAaWkpRo0ahb/++gsVFRVtqmfs2LEqhxECAgIAAJMmTYK5ubla+19//QUAOHHiBEpKSvDqq69CLBYr5xs/fjwGDRqk7M+ioiLk5uYiKioKlpaWyvlCQkLg4eGhUktycjIsLS0REhKi8vP19fWFmZkZsrKyml0PKysr3LhxA3/88Ueb+qGtKOS07Pbt25BIJPD09GxxvmvXrmHgwIFq7YMHD1a+/iBXV1eVacUvo4uLi1q7XC5X+4Pq37+/2mcNGDAA1dXVuH37dou1ttbt27dRXl6ONWvWwNbWVuXfrFmzAAAlJSXNvv/atWsQCoXo16+fSruDgwOsrKzU+qRfv34qxwoV6wSgXZeMtKevf/vtNwQHB6NHjx6wsrKCra0t3nnnHQBoc8hpUg8A5bExRX819Xs2aNAg5euKr039jjR+76VLl1BRUQE7Ozu1n3FVVVWLP9+3334bZmZm8Pf3R//+/REdHY3ffvut+RXXEjq72kUYGRlp1M7aMKp948BQaO2WnlwuBwC88MILiIqKanIeLy+vNtfRWdra11euXMHYsWMxaNAgJCYmwsXFBSYmJti/fz+WL1+u7J/OqqcjyOVy2NnZYdOmTU2+bmtr2+x7Bw8ejAsXLmDv3r1ITU3Fjh078PXXX2Px4sVISEjoqJIp5LTN1tYWFhYWOHv2bIvzubm54cKFC2rt+fn5yte16dKlS2ptFy9ehKmpqfIXs2fPnigvL1ebr/EWFNB0ENna2sLc3BwymQzBwcEa1+jm5ga5XI5Lly4pt2gBoLi4GOXl5Wp9cvnyZTDGVGq5ePEiACh37zozMPfs2YPa2lrs3r1bZeurpV24jqTorwsXLmDMmDEqr124cEH5uuJrU78jjX9H+/bti4yMDIwcOVJl17y1evTogSlTpmDKlCmoq6vDM888g48++ghxcXEqu9TaRLurWiYUCjFx4kTs2bMHJ06cUHtd8b/sk08+iZycHGRnZytfk0qlWLNmDdzd3dWOhbRXdna2yvGw69evY9euXQgNDVVuEfTt2xcVFRX4888/lfMVFRU1eSlBjx491ALRyMgIkyZNwo4dO5oM+YftFj/55JMAgBUrVqi0JyYmAuCOJT3o5s2bKrVJJBJs2LABPj4+cHBwUNYJoMnw1jZFPz64JVVRUYF169Z1+Gc3xc/PD3Z2dli9erXKpS4HDhzA+fPnlf3p6OgIHx8f/PDDDyq71Onp6WrHhp977jnIZDJ88MEHap/X0NDQYj/fuXNHZdrExAQeHh5gjKG+vr4tq9gqtCXXAT7++GOkpaVh1KhRePnllzF48GAUFRUhOTkZR48ehZWVFRYuXIgtW7YgPDwc8+bNg7W1NX744QcUFBRgx44dEAq1+/+Pp6cnwsLCMG/ePIhEInz99dcAoLKbMHXqVLz99tuIjIzEvHnzUF1djW+++QYDBgxQO2Hg6+uLjIwMJCYmwsnJCb1790ZAQACWLVuGrKwsBAQEYM6cOfDw8EBZWRlOnTqFjIwMlJWVNVujt7c3oqKisGbNGpSXl2PUqFHIycnBDz/8gIkTJ2L06NEq8w8YMACzZ8/GH3/8AXt7e6xduxbFxcUqoeLj4wMjIyN88sknqKiogEgkUl7
"text/plain": [
"<Figure size 300x300 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"plt.figure(figsize=(3, 3))\n",
"# plot the line\n",
"plt.plot([q[0] for q in raw], [10**(m*np.log10(q[0]) + c) for q in raw], label='linear regression', color='r')\n",
"# plot the raw data\n",
"plt.scatter([q[0] for q in raw], [q[1] for q in raw], label='raw data')\n",
"plt.xscale('log')\n",
"plt.yscale('log')\n",
"plt.xlabel('parameters')\n",
"plt.ylabel('tokens')\n",
"plt.title('compute optimal models')\n",
"plt.grid()\n"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"predicted parameters for 1.240000e+08 tokens: 2.292426e+09\n"
]
}
],
"source": [
"xquery = 124e6 # query model size here (e.g. GPT-2 small is 124M)\n",
"yquery = 10**(m*np.log10(xquery) + c)\n",
"print(f\"predicted parameters for {xquery:e} tokens: {yquery:e}\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "pytorch2",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.8"
},
"orig_nbformat": 4,
"vscode": {
"interpreter": {
"hash": "7f5833218766b48e6e35e4452ee875aac0e2188d05bbe5298f2c62b79f08b222"
}
}
},
"nbformat": 4,
"nbformat_minor": 2
}