1
0
mirror of https://github.com/osmarks/nanogpt-experiments.git synced 2024-12-18 14:10:28 +00:00

new notebook with a bunch of calculations related to flops and memory of Transformer

This commit is contained in:
Andrej Karpathy 2023-02-04 22:02:53 +00:00
parent a74e8363a2
commit eae986c2d2

400
transformer_sizing.ipynb generated Normal file
View File

@ -0,0 +1,400 @@
{
"cells": [
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"### Transformer Theoretical Model\n",
"\n",
"This notebook stores a bunch of analysis about a Transformer, e.g. estimates the number of FLOPs, parameters, peak memory footprint, checkpoint size, etc."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"from collections import OrderedDict"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"# config_args = {\n",
"# 'gpt2': dict(n_layer=12, n_head=12, n_embd=768), # 124M params\n",
"# 'gpt2-medium': dict(n_layer=24, n_head=16, n_embd=1024), # 350M params\n",
"# 'gpt2-large': dict(n_layer=36, n_head=20, n_embd=1280), # 774M params\n",
"# 'gpt2-xl': dict(n_layer=48, n_head=25, n_embd=1600), # 1558M params\n",
"# }[model_type]\n",
"\n",
"block_size = 1024\n",
"vocab_size = 50257\n",
"n_layer = 12\n",
"n_head = 12\n",
"n_embd = 768\n",
"bias = False\n",
"assert not bias, \"this notebook assumes bias=False just for simplicity\""
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"we see: 124337664, expected: 124337664, match: True\n",
"name params ratio (%) \n",
"emebedding/position 786432 0.6325\n",
"embedding/token 38597376 31.0424\n",
"embedding 39383808 31.6749\n",
"attention/ln 768 0.0006\n",
"attention/kqv 1769472 1.4231\n",
"attention/proj 589824 0.4744\n",
"attention 2360064 1.8981\n",
"mlp/ln 768 0.0006\n",
"mlp/ffw 2359296 1.8975\n",
"mlp/proj 2359296 1.8975\n",
"mlp 4719360 3.7956\n",
"block 7079424 5.6937\n",
"transformer 84953088 68.3245\n",
"ln_f 768 0.0006\n",
"dense 0 0.0000\n",
"total 124337664 100.0000\n"
]
}
],
"source": [
"def params():\n",
" \"\"\" estimates the number of parameters in the model\"\"\"\n",
" out = OrderedDict()\n",
"\n",
" # token and position embeddings\n",
" out['emebedding/position'] = n_embd * block_size\n",
" out['embedding/token'] = n_embd * vocab_size\n",
" out['embedding'] = out['emebedding/position'] + out['embedding/token']\n",
"\n",
" # attention blocks\n",
" out['attention/ln'] = n_embd # note, bias=False in our LN\n",
" out['attention/kqv'] = n_embd * 3*n_embd\n",
" out['attention/proj'] = n_embd**2\n",
" out['attention'] = out['attention/ln'] + out['attention/kqv'] + out['attention/proj']\n",
"\n",
" # MLP blocks\n",
" ffw_size = 4*n_embd # feed forward size\n",
" out['mlp/ln'] = n_embd\n",
" out['mlp/ffw'] = n_embd * ffw_size\n",
" out['mlp/proj'] = ffw_size * n_embd\n",
" out['mlp'] = out['mlp/ln'] + out['mlp/ffw'] + out['mlp/proj']\n",
" \n",
" # the transformer and the rest of it\n",
" out['block'] = out['attention'] + out['mlp']\n",
" out['transformer'] = n_layer * out['block']\n",
" out['ln_f'] = n_embd # final layernorm\n",
" out['dense'] = 0 # 0 because of parameter sharing. This layer uses the weights from the embedding layer\n",
"\n",
" # total\n",
" out['total'] = out['embedding'] + out['transformer'] + out['ln_f'] + out['dense']\n",
"\n",
" return out\n",
"\n",
"# compare our param count to that reported by PyTorch\n",
"p = params()\n",
"params_total = p['total']\n",
"print(f\"we see: {params_total}, expected: {124337664}, match: {params_total == 124337664}\")\n",
"# create a header\n",
"print(f\"{'name':20s} {'params':10s} {'ratio (%)':10s}\")\n",
"for k,v in p.items():\n",
" print(f\"{k:20s} {v:10d} {v/params_total*100:10.4f}\")\n",
" "
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"est checkpoint size: 1.49 GB\n",
"measured with wc -c ckpt.pt: 1542470366\n",
"fluff ratio: 103.38%\n"
]
}
],
"source": [
"# we can now calculate the size of each checkpoint\n",
"# params are stored in fp32, and the AdamW optimizer has 2 additional buffers per param for statistics\n",
"params_bytes = params_total*4\n",
"params_and_buffers_bytes = params_bytes + 2*params_bytes\n",
"print(f\"est checkpoint size: {params_and_buffers_bytes/1e9:.2f} GB\")\n",
"measured_bytes = 1542470366 # from wc -c ckpt.pt\n",
"print(f\"measured with wc -c ckpt.pt: {measured_bytes}\")\n",
"print(f\"fluff ratio: {measured_bytes/params_and_buffers_bytes*100:.2f}%\")"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"We can also estimate the ratio of our GPU memory that will be taken up just by the weights and the buffers inside the AdamW optimizer"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"memory ratio taken up just for parameters: 3.73%\n"
]
}
],
"source": [
"gpu_memory = 40e9 # 40 GB A100 GPU, roughly\n",
"print(f\"memory ratio taken up just for parameters: {params_and_buffers_bytes / gpu_memory * 100:.2f}%\")"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"i.e. not that much of the memory for this tiny model, most of the memory is activations (forward and backward). This of course changes dramatically for larger and larger models."
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's estimate FLOPs for a single forward pass."
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"name flops ratio (%) \n",
"attention/kqv 3623878656 1.2426\n",
"attention/scores 1610612736 0.5522\n",
"attention/reduce 1610612736 0.5522\n",
"attention/proj 1207959552 0.4142\n",
"attention 8053063680 2.7612\n",
"mlp/ffw1 4831838208 1.6567\n",
"mlp/ffw2 4831838208 1.6567\n",
"mlp 9663676416 3.3135\n",
"block 17716740096 6.0747\n",
"transformer 212600881152 72.8963\n",
"dense 79047426048 27.1037\n",
"forward_total 291648307200 100.0000\n",
"backward_total 583296614400 200.0000\n",
"total 874944921600 300.0000\n"
]
}
],
"source": [
"def flops():\n",
" # we only count Weight FLOPs, all other layers (LayerNorm, Softmax, etc) are effectively irrelevant\n",
" # we count actual FLOPs, not MACs. Hence 2* all over the place\n",
" # basically for any matrix multiply A (BxC) @ B (CxD) -> (BxD) flops are 2*B*C*D\n",
"\n",
" out = OrderedDict()\n",
" head_size = n_embd // n_head\n",
"\n",
" # attention blocks\n",
" # 1) the projection to key, query, values\n",
" out['attention/kqv'] = 2 * block_size * (n_embd * 3*n_embd)\n",
" # 2) calculating the attention scores\n",
" out['attention/scores'] = 2 * block_size * block_size * n_embd\n",
" # 3) the reduction of the values (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)\n",
" out['attention/reduce'] = 2 * n_head * (block_size * block_size * head_size)\n",
" # 4) the final linear projection\n",
" out['attention/proj'] = 2 * block_size * (n_embd * n_embd)\n",
" out['attention'] = sum(out['attention/'+k] for k in ['kqv', 'scores', 'reduce', 'proj'])\n",
"\n",
" # MLP blocks\n",
" ffw_size = 4*n_embd # feed forward size\n",
" out['mlp/ffw1'] = 2 * block_size * (n_embd * ffw_size)\n",
" out['mlp/ffw2'] = 2 * block_size * (ffw_size * n_embd)\n",
" out['mlp'] = out['mlp/ffw1'] + out['mlp/ffw2']\n",
"\n",
" # the transformer and the rest of it\n",
" out['block'] = out['attention'] + out['mlp']\n",
" out['transformer'] = n_layer * out['block']\n",
" out['dense'] = 2 * block_size * (n_embd * vocab_size)\n",
"\n",
" # forward,backward,total\n",
" out['forward_total'] = out['transformer'] + out['dense']\n",
" out['backward_total'] = 2 * out['forward_total'] # use common estimate of bwd = 2*fwd\n",
" out['total'] = out['forward_total'] + out['backward_total']\n",
"\n",
" return out\n",
" \n",
"# compare our param count to that reported by PyTorch\n",
"f = flops()\n",
"flops_total = f['forward_total']\n",
"print(f\"{'name':20s} {'flops':14s} {'ratio (%)':10s}\")\n",
"for k,v in f.items():\n",
" print(f\"{k:20s} {v:14d} {v/flops_total*100:10.4f}\")\n",
" "
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"palm_flops: 879894724608, flops: 874944921600, ratio: 1.0057\n"
]
}
],
"source": [
"# now here is an estimate copy pasted from the PaLM paper\n",
"# this formula is often used to calculate MFU (model flops utilization)\n",
"def palm_flops():\n",
" \"\"\"estimate of the model flops following PaLM paper formula\"\"\"\n",
" N = params()['total']\n",
" L, H, Q, T = n_layer, n_head, n_embd//n_head, block_size\n",
" mf_per_token = 6*N + 12*L*H*Q*T\n",
" mf = mf_per_token * block_size\n",
" return mf\n",
"\n",
"print(f\"palm_flops: {palm_flops():d}, flops: {flops()['total']:d}, ratio: {palm_flops()/flops()['total']:.4f}\")"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"Ok they are quite similar, giving some confidence that my math in flops() function was ~ok. Now, A100 is cited at 312TFLOPS bfloat16 on tensor cores. So what is our model flops utilization (MFU)? I trained the model above with a batch_size of 20 and grad_accum of 5, which runs in about 755ms on a single A100 GPU. We get:"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"fraction of A100 used: 37.14%\n"
]
}
],
"source": [
"# here is what we currently roughly measure\n",
"batch_size = 20 * 5 # 5 is grad_accum, so total batch size is 100\n",
"measured_time = 0.755 # in seconds per iteration\n",
"measured_throughput = batch_size / measured_time\n",
"flops_achieved = f['total'] * measured_throughput\n",
"\n",
"# A100 is cited to be 312 TFLOPS of bloat16 running on tensor cores\n",
"a100_flops_promised = 312e12\n",
"\n",
"# the fraction of the A100 that we are using:\n",
"print(f\"fraction of A100 used: {flops_achieved / a100_flops_promised * 100:.2f}%\")"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"For reference, we'd prefer to be somewhere around 50%+, and not just for a single GPU but for an entire DDP run. So we still have some work to do, but at least we're within a factor of ~2X of what is achievable with this GPU."
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"time needed to train the model: 3.46 days\n"
]
}
],
"source": [
"# Finally let's check out the 6ND approximation as total cost of training in FLOPs\n",
"model_size = params()['total'] # this is number of parameters, N\n",
"tokens_num = 300e9 # 300B tokens, this is dataset size in tokens, D\n",
"a100_flops = 312e12 # 312 TFLOPS\n",
"assumed_mfu = 0.3 # assume this model flops utilization (take the current 37% from above and add some DDP overhead)\n",
"flops_throughput = a100_flops * 8 * assumed_mfu # assume an 8XA100 node at 50% utilization\n",
"flops_needed = 6 * model_size * tokens_num # 6ND\n",
"time_needed_s = flops_needed / flops_throughput # in seconds\n",
"print(f\"time needed to train the model: {time_needed_s/3600/24:.2f} days\")"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"This is not a bad estimate at all. I trained this model and it converged in roughly 4 days."
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"Now, FLOPs are just one constraint, the other that we have to keep a close track of is the memory bandwidth. TODO estimate LOAD/STORE costs of our model later."
]
}
],
"metadata": {
"kernelspec": {
"display_name": "pytorch2",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.8"
},
"orig_nbformat": 4,
"vscode": {
"interpreter": {
"hash": "7f5833218766b48e6e35e4452ee875aac0e2188d05bbe5298f2c62b79f08b222"
}
}
},
"nbformat": 4,
"nbformat_minor": 2
}