experiments
56
c4gzparse_nanogpt.py
Normal file
@ -0,0 +1,56 @@
|
|||||||
|
import os
|
||||||
|
from tqdm import tqdm
|
||||||
|
import numpy as np
|
||||||
|
import tiktoken
|
||||||
|
import json
|
||||||
|
import gzip
|
||||||
|
|
||||||
|
enc = tiktoken.get_encoding("gpt2")
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
# takes 54GB in huggingface .cache dir, about 8M documents (8,013,769)
|
||||||
|
dataset = []
|
||||||
|
|
||||||
|
with gzip.open("c4-train.00000-of-01024.json.gz", "r") as file:
|
||||||
|
while line := file.readline():
|
||||||
|
try:
|
||||||
|
dataset.append(json.loads(line))
|
||||||
|
except EOFError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# we now want to tokenize the dataset. first define the encoding function (gpt2 bpe)
|
||||||
|
def process(example):
|
||||||
|
ids = enc.encode_ordinary(example['text']) # encode_ordinary ignores any special tokens
|
||||||
|
ids.insert(0, enc.eot_token) # add the end of text token, e.g. 50256 for gpt2 bpe
|
||||||
|
# note: I think eot should be prepended not appended... hmm. it's called "eot" though...
|
||||||
|
out = {"ids": ids, "len": len(ids)}
|
||||||
|
return out
|
||||||
|
|
||||||
|
# tokenize the dataset
|
||||||
|
tokenized = [ process(x) for x in dataset ]
|
||||||
|
divider = len(tokenized) // 100
|
||||||
|
tokenized = {
|
||||||
|
"val": tokenized[:divider],
|
||||||
|
"train": tokenized[divider:]
|
||||||
|
}
|
||||||
|
|
||||||
|
# concatenate all the ids in each dataset into one large file we can use for training
|
||||||
|
for split, dset in tokenized.items():
|
||||||
|
arr_len = sum((d['len'] for d in dset))
|
||||||
|
filename = os.path.join(os.path.dirname(__file__), f'{split}.bin')
|
||||||
|
dtype = np.uint16 # (can do since enc.max_token_value == 50256 is < 2**16)
|
||||||
|
arr = np.memmap(filename, dtype=dtype, mode='w+', shape=(arr_len,))
|
||||||
|
total_batches = 1024
|
||||||
|
|
||||||
|
idx = 0
|
||||||
|
for d in tqdm(dset, desc=f'writing {filename}'):
|
||||||
|
arr[idx : idx + d["len"]] = d["ids"]
|
||||||
|
idx += d["len"]
|
||||||
|
arr.flush()
|
||||||
|
|
||||||
|
# train.bin is ~17GB, val.bin ~8.5MB
|
||||||
|
# train has ~9B tokens (9,035,582,198)
|
||||||
|
# val has ~4M tokens (4,434,897)
|
||||||
|
|
||||||
|
# to read the bin files later, e.g. with numpy:
|
||||||
|
# m = np.memmap('train.bin', dtype=np.uint16, mode='r')
|
69
compare_ckpts.py
Normal file
@ -0,0 +1,69 @@
|
|||||||
|
import torch
|
||||||
|
from pathlib import Path
|
||||||
|
from collections import defaultdict
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
def compute_differences(m1, m2):
|
||||||
|
groups = {"mlp.c_fc.weight": defaultdict(lambda: 0.0), "attn.c_attn.weight": defaultdict(lambda: 0.0)}
|
||||||
|
for k, v1 in m1["model"].items():
|
||||||
|
for cat in groups.keys():
|
||||||
|
if cat in k:
|
||||||
|
diff = torch.flatten(v1 - m2["model"][k])
|
||||||
|
#groups[cat]["l1"] += torch.linalg.norm(diff, dim=None, ord=1).item()
|
||||||
|
groups[cat]["l2"] += torch.linalg.norm(diff, dim=None, ord=2).item()
|
||||||
|
#groups[cat]["cosine"] += F.cosine_similarity(v1.flatten(), m2["model"][k].flatten(), dim=-1).item()
|
||||||
|
return groups
|
||||||
|
|
||||||
|
def gradnorm(m):
|
||||||
|
x = 0
|
||||||
|
for key, state in m["optimizer"]["state"].items():
|
||||||
|
#x += torch.linalg.norm(state["exp_avg"], dim=None, ord=2).item()
|
||||||
|
x += torch.mean(state["exp_avg_sq"]).item()
|
||||||
|
return x
|
||||||
|
|
||||||
|
def flatten(xs, out=None, prefix=""):
|
||||||
|
if out is None: out = {}
|
||||||
|
for k, v in xs.items():
|
||||||
|
longk = (prefix + " " + k).strip()
|
||||||
|
if isinstance(v, dict): flatten(v, out, longk)
|
||||||
|
else:
|
||||||
|
out[longk] = v
|
||||||
|
return out
|
||||||
|
|
||||||
|
xs = []
|
||||||
|
ys = defaultdict(list)
|
||||||
|
for step in range(500, 3500, 500):
|
||||||
|
file = f"ckpt{step}.pt"
|
||||||
|
m_baseline = torch.load(Path("fixed-seed1") / file)
|
||||||
|
m_sameseed = torch.load(Path("fixed-seed1-1") / file)
|
||||||
|
m_sameseed2 = torch.load(Path("fixed-seed1-2") / file)
|
||||||
|
m_other = torch.load(Path("fixed-seed2") / file)
|
||||||
|
m_baseline_resumed = torch.load(Path("fixed-seed1-res1500") / file)
|
||||||
|
xs.append(step)
|
||||||
|
|
||||||
|
"""
|
||||||
|
comparisons = {
|
||||||
|
"same seed": compute_differences(m_baseline, m_sameseed),
|
||||||
|
"same seed 2": compute_differences(m_baseline, m_sameseed2),
|
||||||
|
"same seed resume at 1500": compute_differences(m_baseline, m_baseline_resumed),
|
||||||
|
"other seed": compute_differences(m_baseline, m_other),
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
comparisons = {
|
||||||
|
"baseline": gradnorm(m_baseline),
|
||||||
|
"same seed": gradnorm(m_sameseed),
|
||||||
|
"same seed 2": gradnorm(m_sameseed2),
|
||||||
|
"other seed": gradnorm(m_other)
|
||||||
|
}
|
||||||
|
|
||||||
|
for k, v in flatten(comparisons).items():
|
||||||
|
ys[k].append(v)
|
||||||
|
|
||||||
|
plt.figure(figsize=(12, 10))
|
||||||
|
plt.xlabel("step")
|
||||||
|
plt.ylabel("gradnorm")
|
||||||
|
for k, v in ys.items():
|
||||||
|
plt.plot(xs, v, label=k)
|
||||||
|
plt.legend()
|
||||||
|
plt.savefig("x.png")
|
@ -2,7 +2,7 @@
|
|||||||
# launch as the following (e.g. in a screen session) and wait ~5 days:
|
# launch as the following (e.g. in a screen session) and wait ~5 days:
|
||||||
# $ torchrun --standalone --nproc_per_node=8 train.py config/train_gpt2.py
|
# $ torchrun --standalone --nproc_per_node=8 train.py config/train_gpt2.py
|
||||||
|
|
||||||
wandb_log = True
|
wandb_log = False
|
||||||
wandb_project = 'owt'
|
wandb_project = 'owt'
|
||||||
wandb_run_name='gpt2-124M'
|
wandb_run_name='gpt2-124M'
|
||||||
|
|
||||||
|
@ -35,3 +35,4 @@ warmup_iters = 100 # not super necessary potentially
|
|||||||
# on macbook also add
|
# on macbook also add
|
||||||
# device = 'cpu' # run on cpu only
|
# device = 'cpu' # run on cpu only
|
||||||
# compile = False # do not torch compile the model
|
# compile = False # do not torch compile the model
|
||||||
|
compile = False
|
||||||
|
BIN
fixed-cosine-plot.png
Normal file
After Width: | Height: | Size: 72 KiB |
BIN
fixed-l1norm-plot.png
Normal file
After Width: | Height: | Size: 95 KiB |
BIN
fixed-l2norm-plot.png
Normal file
After Width: | Height: | Size: 93 KiB |
BIN
grad-norms-l1-what.png
Normal file
After Width: | Height: | Size: 88 KiB |
BIN
grad-norms-mean-avgsq.png
Normal file
After Width: | Height: | Size: 75 KiB |
BIN
grad-norms.png
Normal file
After Width: | Height: | Size: 86 KiB |
BIN
l1-new.png
Normal file
After Width: | Height: | Size: 42 KiB |
BIN
l1norm.png
Normal file
After Width: | Height: | Size: 38 KiB |
BIN
l2-new.png
Normal file
After Width: | Height: | Size: 42 KiB |
BIN
l2norm.png
Normal file
After Width: | Height: | Size: 38 KiB |
2
model.py
@ -298,7 +298,7 @@ class GPT(nn.Module):
|
|||||||
flops_per_iter = flops_per_fwdbwd * fwdbwd_per_iter
|
flops_per_iter = flops_per_fwdbwd * fwdbwd_per_iter
|
||||||
# express our flops throughput as ratio of A100 bfloat16 peak flops
|
# express our flops throughput as ratio of A100 bfloat16 peak flops
|
||||||
flops_achieved = flops_per_iter * (1.0/dt) # per second
|
flops_achieved = flops_per_iter * (1.0/dt) # per second
|
||||||
flops_promised = 312e12 # A100 GPU bfloat16 peak flops is 312 TFLOPS
|
flops_promised = 70e12 # RTX 3090 BF16 FP32 accumulate
|
||||||
mfu = flops_achieved / flops_promised
|
mfu = flops_achieved / flops_promised
|
||||||
return mfu
|
return mfu
|
||||||
|
|
||||||
|
93
train.py
@ -28,6 +28,12 @@ from torch.nn.parallel import DistributedDataParallel as DDP
|
|||||||
from torch.distributed import init_process_group, destroy_process_group
|
from torch.distributed import init_process_group, destroy_process_group
|
||||||
|
|
||||||
from model import GPTConfig, GPT
|
from model import GPTConfig, GPT
|
||||||
|
import random
|
||||||
|
|
||||||
|
seed = 1
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
random.seed(seed)
|
||||||
|
np.random.seed(seed)
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
# default config values designed to train a gpt2 (124M) on OpenWebText
|
# default config values designed to train a gpt2 (124M) on OpenWebText
|
||||||
@ -38,34 +44,49 @@ log_interval = 1
|
|||||||
eval_iters = 200
|
eval_iters = 200
|
||||||
eval_only = False # if True, script exits right after the first eval
|
eval_only = False # if True, script exits right after the first eval
|
||||||
always_save_checkpoint = True # if True, always save a checkpoint after each eval
|
always_save_checkpoint = True # if True, always save a checkpoint after each eval
|
||||||
init_from = 'scratch' # 'scratch' or 'resume' or 'gpt2*'
|
init_from = 'resume' # 'scratch' or 'resume' or 'gpt2*'
|
||||||
# wandb logging
|
# wandb logging
|
||||||
wandb_log = False # disabled by default
|
|
||||||
wandb_project = 'owt'
|
|
||||||
wandb_run_name = 'gpt2' # 'run' + str(time.time())
|
|
||||||
# data
|
# data
|
||||||
dataset = 'openwebtext'
|
dataset = 'openwebtext'
|
||||||
gradient_accumulation_steps = 5 * 8 # used to simulate larger batch sizes
|
|
||||||
batch_size = 12 # if gradient_accumulation_steps > 1, this is the micro-batch size
|
wandb_log = False
|
||||||
|
wandb_project = 'owt'
|
||||||
|
wandb_run_name='gpt2'
|
||||||
|
|
||||||
|
# these make the total batch size be ~0.5M
|
||||||
|
# 12 batch size * 1024 block size * 5 gradaccum * 8 GPUs = 491,520
|
||||||
|
batch_size = 8
|
||||||
|
block_size = 1024
|
||||||
|
gradient_accumulation_steps = 8 * 8
|
||||||
|
|
||||||
|
# this makes total number of tokens be 300B
|
||||||
|
max_iters = 3000
|
||||||
|
lr_decay_iters = 3000
|
||||||
|
|
||||||
|
# eval stuff
|
||||||
|
eval_interval = 500
|
||||||
|
eval_iters = 200
|
||||||
|
log_interval = 10
|
||||||
|
|
||||||
|
# weight decay
|
||||||
|
weight_decay = 1e-1
|
||||||
|
|
||||||
block_size = 1024
|
block_size = 1024
|
||||||
# model
|
# model
|
||||||
n_layer = 12
|
n_layer = 6
|
||||||
n_head = 12
|
n_head = 8
|
||||||
n_embd = 768
|
n_embd = 512
|
||||||
dropout = 0.0 # for pretraining 0 is good, for finetuning try 0.1+
|
dropout = 0.0 # for pretraining 0 is good, for finetuning try 0.1+
|
||||||
bias = False # do we use bias inside LayerNorm and Linear layers?
|
bias = False # do we use bias inside LayerNorm and Linear layers?
|
||||||
# adamw optimizer
|
# adamw optimizer
|
||||||
learning_rate = 6e-4 # max learning rate
|
learning_rate = 6e-4 # max learning rate
|
||||||
max_iters = 600000 # total number of training iterations
|
|
||||||
weight_decay = 1e-1
|
|
||||||
beta1 = 0.9
|
beta1 = 0.9
|
||||||
beta2 = 0.95
|
beta2 = 0.95
|
||||||
grad_clip = 1.0 # clip gradients at this value, or disable if == 0.0
|
grad_clip = 1.0 # clip gradients at this value, or disable if == 0.0
|
||||||
# learning rate decay settings
|
# learning rate decay settings
|
||||||
decay_lr = True # whether to decay the learning rate
|
decay_lr = True # whether to decay the learning rate
|
||||||
warmup_iters = 2000 # how many steps to warm up for
|
warmup_iters = 500 # how many steps to warm up for
|
||||||
lr_decay_iters = 600000 # should be ~= max_iters per Chinchilla
|
min_lr = 6e-4 # minimum learning rate, should be ~= learning_rate/10 per Chinchilla
|
||||||
min_lr = 6e-5 # minimum learning rate, should be ~= learning_rate/10 per Chinchilla
|
|
||||||
# DDP settings
|
# DDP settings
|
||||||
backend = 'nccl' # 'nccl', 'gloo', etc.
|
backend = 'nccl' # 'nccl', 'gloo', etc.
|
||||||
# system
|
# system
|
||||||
@ -74,7 +95,6 @@ dtype = 'bfloat16' if torch.cuda.is_available() and torch.cuda.is_bf16_supported
|
|||||||
compile = True # use PyTorch 2.0 to compile the model to be faster
|
compile = True # use PyTorch 2.0 to compile the model to be faster
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
config_keys = [k for k,v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))]
|
config_keys = [k for k,v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))]
|
||||||
exec(open('configurator.py').read()) # overrides from command line or config file
|
|
||||||
config = {k: globals()[k] for k in config_keys} # will be useful for logging
|
config = {k: globals()[k] for k in config_keys} # will be useful for logging
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
|
|
||||||
@ -103,7 +123,6 @@ print(f"tokens per iteration will be: {tokens_per_iter:,}")
|
|||||||
|
|
||||||
if master_process:
|
if master_process:
|
||||||
os.makedirs(out_dir, exist_ok=True)
|
os.makedirs(out_dir, exist_ok=True)
|
||||||
torch.manual_seed(1337 + seed_offset)
|
|
||||||
torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul
|
torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul
|
||||||
torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn
|
torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn
|
||||||
device_type = 'cuda' if 'cuda' in device else 'cpu' # for later use in torch.autocast
|
device_type = 'cuda' if 'cuda' in device else 'cpu' # for later use in torch.autocast
|
||||||
@ -112,15 +131,16 @@ ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torc
|
|||||||
ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast(device_type=device_type, dtype=ptdtype)
|
ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast(device_type=device_type, dtype=ptdtype)
|
||||||
|
|
||||||
# poor man's data loader
|
# poor man's data loader
|
||||||
data_dir = os.path.join('data', dataset)
|
data_dir = "."
|
||||||
def get_batch(split):
|
def get_batch(split, step):
|
||||||
# We recreate np.memmap every batch to avoid a memory leak, as per
|
# We recreate np.memmap every batch to avoid a memory leak, as per
|
||||||
# https://stackoverflow.com/questions/45132940/numpy-memmap-memory-usage-want-to-iterate-once/61472122#61472122
|
# https://stackoverflow.com/questions/45132940/numpy-memmap-memory-usage-want-to-iterate-once/61472122#61472122
|
||||||
if split == 'train':
|
if split == 'train':
|
||||||
data = np.memmap(os.path.join(data_dir, 'train.bin'), dtype=np.uint16, mode='r')
|
data = np.memmap(os.path.join(data_dir, 'train.bin'), dtype=np.uint16, mode='r')
|
||||||
else:
|
else:
|
||||||
data = np.memmap(os.path.join(data_dir, 'val.bin'), dtype=np.uint16, mode='r')
|
data = np.memmap(os.path.join(data_dir, 'val.bin'), dtype=np.uint16, mode='r')
|
||||||
ix = torch.randint(len(data) - block_size, (batch_size,))
|
d_rng = random.Random(f"{split}-{step}-{seed}")
|
||||||
|
ix = [ d_rng.randint(0, len(data) - block_size) for _ in range(batch_size) ]
|
||||||
x = torch.stack([torch.from_numpy((data[i:i+block_size]).astype(np.int64)) for i in ix])
|
x = torch.stack([torch.from_numpy((data[i:i+block_size]).astype(np.int64)) for i in ix])
|
||||||
y = torch.stack([torch.from_numpy((data[i+1:i+1+block_size]).astype(np.int64)) for i in ix])
|
y = torch.stack([torch.from_numpy((data[i+1:i+1+block_size]).astype(np.int64)) for i in ix])
|
||||||
if device_type == 'cuda':
|
if device_type == 'cuda':
|
||||||
@ -158,7 +178,7 @@ if init_from == 'scratch':
|
|||||||
elif init_from == 'resume':
|
elif init_from == 'resume':
|
||||||
print(f"Resuming training from {out_dir}")
|
print(f"Resuming training from {out_dir}")
|
||||||
# resume training from a checkpoint.
|
# resume training from a checkpoint.
|
||||||
ckpt_path = os.path.join(out_dir, 'ckpt.pt')
|
ckpt_path = os.path.join(out_dir, 'ckpt1500.pt')
|
||||||
checkpoint = torch.load(ckpt_path, map_location=device)
|
checkpoint = torch.load(ckpt_path, map_location=device)
|
||||||
checkpoint_model_args = checkpoint['model_args']
|
checkpoint_model_args = checkpoint['model_args']
|
||||||
# force these config attributes to be equal otherwise we can't even resume training
|
# force these config attributes to be equal otherwise we can't even resume training
|
||||||
@ -213,13 +233,13 @@ if ddp:
|
|||||||
|
|
||||||
# helps estimate an arbitrarily accurate loss over either split using many batches
|
# helps estimate an arbitrarily accurate loss over either split using many batches
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def estimate_loss():
|
def estimate_loss(step):
|
||||||
out = {}
|
out = {}
|
||||||
model.eval()
|
model.eval()
|
||||||
for split in ['train', 'val']:
|
for split in ['train', 'val']:
|
||||||
losses = torch.zeros(eval_iters)
|
losses = torch.zeros(eval_iters)
|
||||||
for k in range(eval_iters):
|
for k in range(eval_iters):
|
||||||
X, Y = get_batch(split)
|
X, Y = get_batch(split, step)
|
||||||
with ctx:
|
with ctx:
|
||||||
logits, loss = model(X, Y)
|
logits, loss = model(X, Y)
|
||||||
losses[k] = loss.item()
|
losses[k] = loss.item()
|
||||||
@ -247,9 +267,9 @@ if wandb_log and master_process:
|
|||||||
wandb.init(project=wandb_project, name=wandb_run_name, config=config)
|
wandb.init(project=wandb_project, name=wandb_run_name, config=config)
|
||||||
|
|
||||||
# training loop
|
# training loop
|
||||||
X, Y = get_batch('train') # fetch the very first batch
|
X, Y = get_batch('train', f"{iter_num}-{0}") # fetch the very first batch
|
||||||
t0 = time.time()
|
|
||||||
local_iter_num = 0 # number of iterations in the lifetime of this process
|
local_iter_num = 0 # number of iterations in the lifetime of this process
|
||||||
|
t0 = time.time()
|
||||||
raw_model = model.module if ddp else model # unwrap DDP container if needed
|
raw_model = model.module if ddp else model # unwrap DDP container if needed
|
||||||
running_mfu = -1.0
|
running_mfu = -1.0
|
||||||
while True:
|
while True:
|
||||||
@ -261,7 +281,7 @@ while True:
|
|||||||
|
|
||||||
# evaluate the loss on train/val sets and write checkpoints
|
# evaluate the loss on train/val sets and write checkpoints
|
||||||
if iter_num % eval_interval == 0 and master_process:
|
if iter_num % eval_interval == 0 and master_process:
|
||||||
losses = estimate_loss()
|
losses = estimate_loss(iter_num)
|
||||||
print(f"step {iter_num}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")
|
print(f"step {iter_num}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")
|
||||||
if wandb_log:
|
if wandb_log:
|
||||||
wandb.log({
|
wandb.log({
|
||||||
@ -273,17 +293,16 @@ while True:
|
|||||||
})
|
})
|
||||||
if losses['val'] < best_val_loss or always_save_checkpoint:
|
if losses['val'] < best_val_loss or always_save_checkpoint:
|
||||||
best_val_loss = losses['val']
|
best_val_loss = losses['val']
|
||||||
if iter_num > 0:
|
checkpoint = {
|
||||||
checkpoint = {
|
'model': raw_model.state_dict(),
|
||||||
'model': raw_model.state_dict(),
|
'optimizer': optimizer.state_dict(),
|
||||||
'optimizer': optimizer.state_dict(),
|
'model_args': model_args,
|
||||||
'model_args': model_args,
|
'iter_num': iter_num,
|
||||||
'iter_num': iter_num,
|
'best_val_loss': best_val_loss,
|
||||||
'best_val_loss': best_val_loss,
|
'config': config,
|
||||||
'config': config,
|
}
|
||||||
}
|
print(f"saving checkpoint to {out_dir}")
|
||||||
print(f"saving checkpoint to {out_dir}")
|
torch.save(checkpoint, os.path.join(out_dir, f'ckpt{iter_num}.pt'))
|
||||||
torch.save(checkpoint, os.path.join(out_dir, 'ckpt.pt'))
|
|
||||||
if iter_num == 0 and eval_only:
|
if iter_num == 0 and eval_only:
|
||||||
break
|
break
|
||||||
|
|
||||||
@ -300,7 +319,7 @@ while True:
|
|||||||
logits, loss = model(X, Y)
|
logits, loss = model(X, Y)
|
||||||
loss = loss / gradient_accumulation_steps # scale the loss to account for gradient accumulation
|
loss = loss / gradient_accumulation_steps # scale the loss to account for gradient accumulation
|
||||||
# immediately async prefetch next batch while model is doing the forward pass on the GPU
|
# immediately async prefetch next batch while model is doing the forward pass on the GPU
|
||||||
X, Y = get_batch('train')
|
X, Y = get_batch('train', f"{iter_num}-{micro_step + 1}")
|
||||||
# backward pass, with gradient scaling if training in fp16
|
# backward pass, with gradient scaling if training in fp16
|
||||||
scaler.scale(loss).backward()
|
scaler.scale(loss).backward()
|
||||||
# clip the gradient
|
# clip the gradient
|
||||||
|