1
0
mirror of https://github.com/osmarks/nanogpt-experiments.git synced 2024-09-21 11:49:46 +00:00

small tweaks to notebook

This commit is contained in:
Andrej Karpathy 2023-01-06 02:13:04 +00:00
parent 69d1a5f1af
commit 27fc6a4112

View File

@ -5,12 +5,12 @@
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
"source": [ "source": [
"Reproducing some results from [Chinchilla](https://arxiv.org/pdf/2203.15556.pdf). I can't get the numbers to match exactly, please open an Issue if you understand why the results are off by a bit. Current running hypothesis is that this is because I am using the FLOPs = 6\\*N\\*D formula, instead of taking all the Gopher sizes and calculating their FLOPs individually, and using that to interpolate?" "Trying to reproduce results from [Chinchilla](https://arxiv.org/pdf/2203.15556.pdf):"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 117, "execution_count": 1,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -21,7 +21,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 118, "execution_count": 2,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -44,7 +44,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 119, "execution_count": 3,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
@ -53,7 +53,7 @@
"123.653376" "123.653376"
] ]
}, },
"execution_count": 119, "execution_count": 3,
"metadata": {}, "metadata": {},
"output_type": "execute_result" "output_type": "execute_result"
} }
@ -89,7 +89,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 123, "execution_count": 4,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
@ -98,7 +98,7 @@
"[44000000.0, 512, 2048, 64, 8, 8]" "[44000000.0, 512, 2048, 64, 8, 8]"
] ]
}, },
"execution_count": 123, "execution_count": 4,
"metadata": {}, "metadata": {},
"output_type": "execute_result" "output_type": "execute_result"
} }
@ -112,7 +112,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 126, "execution_count": 5,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
@ -154,7 +154,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 127, "execution_count": 6,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
@ -163,7 +163,7 @@
"766.006788096" "766.006788096"
] ]
}, },
"execution_count": 127, "execution_count": 6,
"metadata": {}, "metadata": {},
"output_type": "execute_result" "output_type": "execute_result"
} }
@ -185,7 +185,7 @@
" # key @ query logits\n", " # key @ query logits\n",
" attlogits = 2 * seq_len * seq_len * (key_size * num_heads)\n", " attlogits = 2 * seq_len * seq_len * (key_size * num_heads)\n",
" # softmax\n", " # softmax\n",
" attsoftmax = 3 * num_heads * seq_len * seq_len # TODO why?\n", " attsoftmax = 3 * num_heads * seq_len * seq_len # 3* is for subtract (max), exp, divide (?)\n",
" # softmax @ value reductions\n", " # softmax @ value reductions\n",
" attvalue = 2 * seq_len * seq_len * (key_size * num_heads)\n", " attvalue = 2 * seq_len * seq_len * (key_size * num_heads)\n",
" # final linear\n", " # final linear\n",
@ -208,7 +208,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 128, "execution_count": 7,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
@ -260,16 +260,16 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 129, "execution_count": 8,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
"data": { "data": {
"text/plain": [ "text/plain": [
"<matplotlib.colorbar.Colorbar at 0x7f64c8d4b3d0>" "<matplotlib.colorbar.Colorbar at 0x7fc09dffc730>"
] ]
}, },
"execution_count": 129, "execution_count": 8,
"metadata": {}, "metadata": {},
"output_type": "execute_result" "output_type": "execute_result"
}, },
@ -330,7 +330,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 130, "execution_count": 9,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
@ -347,7 +347,7 @@
"Text(0, 0.5, 'loss')" "Text(0, 0.5, 'loss')"
] ]
}, },
"execution_count": 130, "execution_count": 9,
"metadata": {}, "metadata": {},
"output_type": "execute_result" "output_type": "execute_result"
}, },
@ -396,7 +396,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 131, "execution_count": 10,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
@ -405,7 +405,7 @@
"2304" "2304"
] ]
}, },
"execution_count": 131, "execution_count": 10,
"metadata": {}, "metadata": {},
"output_type": "execute_result" "output_type": "execute_result"
} }
@ -432,7 +432,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 132, "execution_count": 11,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {