mirror of
https://github.com/osmarks/nanogpt-experiments.git
synced 2024-11-10 20:09:58 +00:00
small tweaks to notebook
This commit is contained in:
parent
69d1a5f1af
commit
27fc6a4112
@ -5,12 +5,12 @@
|
|||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"source": [
|
"source": [
|
||||||
"Reproducing some results from [Chinchilla](https://arxiv.org/pdf/2203.15556.pdf). I can't get the numbers to match exactly, please open an Issue if you understand why the results are off by a bit. Current running hypothesis is that this is because I am using the FLOPs = 6\\*N\\*D formula, instead of taking all the Gopher sizes and calculating their FLOPs individually, and using that to interpolate?"
|
"Trying to reproduce results from [Chinchilla](https://arxiv.org/pdf/2203.15556.pdf):"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 117,
|
"execution_count": 1,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
@ -21,7 +21,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 118,
|
"execution_count": 2,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
@ -44,7 +44,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 119,
|
"execution_count": 3,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
{
|
{
|
||||||
@ -53,7 +53,7 @@
|
|||||||
"123.653376"
|
"123.653376"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
"execution_count": 119,
|
"execution_count": 3,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"output_type": "execute_result"
|
"output_type": "execute_result"
|
||||||
}
|
}
|
||||||
@ -89,7 +89,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 123,
|
"execution_count": 4,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
{
|
{
|
||||||
@ -98,7 +98,7 @@
|
|||||||
"[44000000.0, 512, 2048, 64, 8, 8]"
|
"[44000000.0, 512, 2048, 64, 8, 8]"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
"execution_count": 123,
|
"execution_count": 4,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"output_type": "execute_result"
|
"output_type": "execute_result"
|
||||||
}
|
}
|
||||||
@ -112,7 +112,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 126,
|
"execution_count": 5,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
{
|
{
|
||||||
@ -154,7 +154,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 127,
|
"execution_count": 6,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
{
|
{
|
||||||
@ -163,7 +163,7 @@
|
|||||||
"766.006788096"
|
"766.006788096"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
"execution_count": 127,
|
"execution_count": 6,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"output_type": "execute_result"
|
"output_type": "execute_result"
|
||||||
}
|
}
|
||||||
@ -185,7 +185,7 @@
|
|||||||
" # key @ query logits\n",
|
" # key @ query logits\n",
|
||||||
" attlogits = 2 * seq_len * seq_len * (key_size * num_heads)\n",
|
" attlogits = 2 * seq_len * seq_len * (key_size * num_heads)\n",
|
||||||
" # softmax\n",
|
" # softmax\n",
|
||||||
" attsoftmax = 3 * num_heads * seq_len * seq_len # TODO why?\n",
|
" attsoftmax = 3 * num_heads * seq_len * seq_len # 3* is for subtract (max), exp, divide (?)\n",
|
||||||
" # softmax @ value reductions\n",
|
" # softmax @ value reductions\n",
|
||||||
" attvalue = 2 * seq_len * seq_len * (key_size * num_heads)\n",
|
" attvalue = 2 * seq_len * seq_len * (key_size * num_heads)\n",
|
||||||
" # final linear\n",
|
" # final linear\n",
|
||||||
@ -208,7 +208,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 128,
|
"execution_count": 7,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
{
|
{
|
||||||
@ -260,16 +260,16 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 129,
|
"execution_count": 8,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
{
|
{
|
||||||
"data": {
|
"data": {
|
||||||
"text/plain": [
|
"text/plain": [
|
||||||
"<matplotlib.colorbar.Colorbar at 0x7f64c8d4b3d0>"
|
"<matplotlib.colorbar.Colorbar at 0x7fc09dffc730>"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
"execution_count": 129,
|
"execution_count": 8,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"output_type": "execute_result"
|
"output_type": "execute_result"
|
||||||
},
|
},
|
||||||
@ -330,7 +330,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 130,
|
"execution_count": 9,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
{
|
{
|
||||||
@ -347,7 +347,7 @@
|
|||||||
"Text(0, 0.5, 'loss')"
|
"Text(0, 0.5, 'loss')"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
"execution_count": 130,
|
"execution_count": 9,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"output_type": "execute_result"
|
"output_type": "execute_result"
|
||||||
},
|
},
|
||||||
@ -396,7 +396,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 131,
|
"execution_count": 10,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
{
|
{
|
||||||
@ -405,7 +405,7 @@
|
|||||||
"2304"
|
"2304"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
"execution_count": 131,
|
"execution_count": 10,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"output_type": "execute_result"
|
"output_type": "execute_result"
|
||||||
}
|
}
|
||||||
@ -432,7 +432,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 132,
|
"execution_count": 11,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
{
|
{
|
||||||
|
Loading…
Reference in New Issue
Block a user