diff --git a/transformer_sizing.ipynb b/transformer_sizing.ipynb index 2e528de..43262ad 100644 --- a/transformer_sizing.ipynb +++ b/transformer_sizing.ipynb @@ -267,7 +267,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "palm_flops: 879894724608, flops: 874944921600, ratio: 1.0057\n" + "palm_flops: 875062886400, flops: 874944921600, ratio: 1.0001\n" ] } ], @@ -276,7 +276,9 @@ "# this formula is often used to calculate MFU (model flops utilization)\n", "def palm_flops():\n", " \"\"\"estimate of the model flops following PaLM paper formula\"\"\"\n", - " N = params()['total']\n", + " # non-embedding model parameters. note that we do not subtract the\n", + " # embedding/token params because those are tied and get used in the last layer.\n", + " N = params()['total'] - params()['emebedding/position']\n", " L, H, Q, T = n_layer, n_head, n_embd//n_head, block_size\n", " mf_per_token = 6*N + 12*L*H*Q*T\n", " mf = mf_per_token * block_size\n",