diff --git a/scaling_laws.ipynb b/scaling_laws.ipynb index 3483d14..9ba644a 100644 --- a/scaling_laws.ipynb +++ b/scaling_laws.ipynb @@ -5,12 +5,12 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "Reproducing some results from [Chinchilla](https://arxiv.org/pdf/2203.15556.pdf). I can't get the numbers to match exactly, please open an Issue if you understand why the results are off by a bit. Current running hypothesis is that this is because I am using the FLOPs = 6\\*N\\*D formula, instead of taking all the Gopher sizes and calculating their FLOPs individually, and using that to interpolate?" + "Trying to reproduce results from [Chinchilla](https://arxiv.org/pdf/2203.15556.pdf):" ] }, { "cell_type": "code", - "execution_count": 117, + "execution_count": 1, "metadata": {}, "outputs": [], "source": [ @@ -21,7 +21,7 @@ }, { "cell_type": "code", - "execution_count": 118, + "execution_count": 2, "metadata": {}, "outputs": [], "source": [ @@ -44,7 +44,7 @@ }, { "cell_type": "code", - "execution_count": 119, + "execution_count": 3, "metadata": {}, "outputs": [ { @@ -53,7 +53,7 @@ "123.653376" ] }, - "execution_count": 119, + "execution_count": 3, "metadata": {}, "output_type": "execute_result" } @@ -89,7 +89,7 @@ }, { "cell_type": "code", - "execution_count": 123, + "execution_count": 4, "metadata": {}, "outputs": [ { @@ -98,7 +98,7 @@ "[44000000.0, 512, 2048, 64, 8, 8]" ] }, - "execution_count": 123, + "execution_count": 4, "metadata": {}, "output_type": "execute_result" } @@ -112,7 +112,7 @@ }, { "cell_type": "code", - "execution_count": 126, + "execution_count": 5, "metadata": {}, "outputs": [ { @@ -154,7 +154,7 @@ }, { "cell_type": "code", - "execution_count": 127, + "execution_count": 6, "metadata": {}, "outputs": [ { @@ -163,7 +163,7 @@ "766.006788096" ] }, - "execution_count": 127, + "execution_count": 6, "metadata": {}, "output_type": "execute_result" } @@ -185,7 +185,7 @@ " # key @ query logits\n", " attlogits = 2 * seq_len * seq_len * (key_size * num_heads)\n", " # softmax\n", - " attsoftmax = 3 * num_heads * seq_len * seq_len # TODO why?\n", + " attsoftmax = 3 * num_heads * seq_len * seq_len # 3* is for subtract (max), exp, divide (?)\n", " # softmax @ value reductions\n", " attvalue = 2 * seq_len * seq_len * (key_size * num_heads)\n", " # final linear\n", @@ -208,7 +208,7 @@ }, { "cell_type": "code", - "execution_count": 128, + "execution_count": 7, "metadata": {}, "outputs": [ { @@ -260,16 +260,16 @@ }, { "cell_type": "code", - "execution_count": 129, + "execution_count": 8, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "" + "" ] }, - "execution_count": 129, + "execution_count": 8, "metadata": {}, "output_type": "execute_result" }, @@ -330,7 +330,7 @@ }, { "cell_type": "code", - "execution_count": 130, + "execution_count": 9, "metadata": {}, "outputs": [ { @@ -347,7 +347,7 @@ "Text(0, 0.5, 'loss')" ] }, - "execution_count": 130, + "execution_count": 9, "metadata": {}, "output_type": "execute_result" }, @@ -396,7 +396,7 @@ }, { "cell_type": "code", - "execution_count": 131, + "execution_count": 10, "metadata": {}, "outputs": [ { @@ -405,7 +405,7 @@ "2304" ] }, - "execution_count": 131, + "execution_count": 10, "metadata": {}, "output_type": "execute_result" } @@ -432,7 +432,7 @@ }, { "cell_type": "code", - "execution_count": 132, + "execution_count": 11, "metadata": {}, "outputs": [ {