-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain_zero.py
592 lines (470 loc) · 18.6 KB
/
main_zero.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
import argparse
import gc
import logging
import random as pyrandom
from functools import partial
from typing import Any, Callable, Tuple, Union
import flax
import flax.linen as nn
import jax
import jax.numpy as jnp
import numpy as np
import optax
import torch
import webdataset as wds
from flax.training import checkpoints, train_state
from jax.experimental.maps import xmap
from jax.experimental.pjit import pjit
from jax.sharding import Mesh
from omegaconf import OmegaConf
from torch.utils.data import DataLoader
from tqdm import tqdm
import wandb
from src.models.GPT import model_getter
from src.partitioning.partition import create_opt_spec, set_partitions_zero
from src.partitioning.xmap_train_functions import (
eval_step,
train_step,
update_opt_state,
)
from src.training.training_utils import compute_tokens_seen, initialized
from src.utils.configs import flatten_dict
from src.utils.dataloader import numpy_collate
logging.basicConfig()
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
def parse():
parser = argparse.ArgumentParser(description="Transformer Training")
parser.add_argument("--cfg", default="conf/config.yaml", type=str)
parser.add_argument("--model-cfg", default="conf/model_config.yaml", type=str)
parser.add_argument(
"--resume",
default=False,
action="store_true",
)
args = parser.parse_args()
return args
def save_checkpoint_params(params: Any, step: int, workdir: str) -> None:
"""
Save a copy of params.
TODO: Add async manager to do this in a background process
"""
if jax.process_index() == 0:
params = jax.device_get(params)
faux_state = train_state.TrainState(
step=step, apply_fn=None, params=params, tx=None, opt_state=None
)
checkpoints.save_checkpoint(
workdir, faux_state, step, keep=5, overwrite=True, prefix="params_"
)
def save_checkpoint_optimizer(opt_state: Any, step: int, workdir: str) -> None:
"""
Function to gather and save the sharded optimizer state.
TODO: Add async manager to do this in a background process
"""
if jax.process_index() == 0:
# print(type(opt_state))
# def grab_shards(tree):
# return jax.experimental.multihost_utils.process_allgather(tree)
# opt_state = grab_shards(opt_state)
opt_state = jax.device_get(opt_state)
faux_state = train_state.TrainState(
step=step, apply_fn=None, params=None, tx=None, opt_state=opt_state
)
checkpoints.save_checkpoint(
workdir, faux_state, step, keep=5, overwrite=True, prefix="optimizer_"
)
def restore_param_checkpoint(workdir: str) -> Any:
"""
Restores the most recent parameter dict
"""
params = checkpoints.restore_checkpoint(workdir, target=None, prefix="params_")
return flax.core.freeze(params["params"])
def restore_opt_checkpoint(workdir: str) -> Tuple[Any, int]:
"""
Function to restore optimizer state from a sequence of serialized Flax
state dicts. By default, restoring a flax state dict to an optax state
doesn't work so we manually recreate the optimizer state and return it.
"""
opt_state_restored = checkpoints.restore_checkpoint(
workdir, target=None, prefix="optimizer_"
)
mu_pytree = jax.tree_util.tree_map(
lambda x: jnp.array(x), opt_state_restored["opt_state"]["1"]["0"]["mu"]
)
nu_pytree = jax.tree_util.tree_map(
lambda x: jnp.array(x), opt_state_restored["opt_state"]["1"]["0"]["nu"]
)
count_pytree = jax.tree_util.tree_map(
lambda x: jnp.array(x), opt_state_restored["opt_state"]["1"]["0"]["count"]
)
restoredadamstate = optax.ScaleByAdamState(
count_pytree, flax.core.FrozenDict(mu_pytree), flax.core.FrozenDict(nu_pytree)
)
restored_state = (
optax.EmptyState(),
(
restoredadamstate,
optax.MaskedState(inner_state=optax.EmptyState()),
optax.ScaleByScheduleState(count=jnp.array(opt_state_restored["step"])),
),
)
return restored_state, opt_state_restored["step"]
def create_zero_train_state(
rng: jax.random.PRNGKey,
learning_rate_fn: Union[float, Callable],
weight_decay: float,
model: nn.Module,
) -> Tuple[train_state.TrainState, Any, optax.GradientTransformation]:
"""
Initializes model parameters, optimizer state and returns a simplified flax
TrainState object.
"""
params = initialized(rng, model, input_shape=(1, model.block_size))
# This mask turns off weight decay for bias terms, LN terms and position embeddings
mask = jax.tree_map(
lambda x: x.ndim != 1 and x.shape != (model.block_size, model.embedding_dim),
params,
)
tx = optax.chain(
optax.clip(1.0),
optax.adamw(
learning_rate=learning_rate_fn,
weight_decay=weight_decay,
mask=mask,
b2=0.95,
),
)
init_batch = jnp.ones((1, model.block_size), dtype=jnp.int32)
param_shape = jax.eval_shape(model.init, rng, init_batch)
return params, param_shape, tx
def main():
args = parse()
cfg = OmegaConf.load(args.cfg)
# getting system information
num_devices = jax.device_count()
num_local_devices = jax.local_device_count()
num_host = num_devices // num_local_devices
platform = jax.local_devices()[0].platform
# setting up GCP bucket/client info if training on TPU
save_to_bucket = False
client = None
if platform == "tpu":
if cfg.data.bucket_path is not None:
# use GCP
from google.cloud import storage
from google.cloud.exceptions import NotFound
client = storage.Client()
save_to_bucket = True
train_shards = open(cfg.data.index_path_train).read().splitlines()
validation_shards = open(cfg.data.index_path_validation).read().splitlines()
else:
raise NotImplementedError("Training not currently supported on GPU.")
model, model_config = model_getter(
cfg.model.size, config_path=args.model_cfg, return_cfg=True, dtype=jnp.float32
)
learning_rate_fn = optax.warmup_cosine_decay_schedule(
init_value=0,
peak_value=cfg.training.peak_learning_rate,
warmup_steps=cfg.training.warmup_steps,
decay_steps=143000,
end_value=cfg.training.end_learning_rate,
)
rng = jax.random.PRNGKey(0)
rng, init_rng = jax.random.split(rng)
resume_step = 0
params, param_shape, tx = create_zero_train_state(
init_rng,
learning_rate_fn,
weight_decay=cfg.training.weight_decay,
model=model,
)
devices = np.asarray(jax.devices())
mesh = Mesh(devices, ("dp",))
# axis_list_params = jax.tree_map(lambda x: [...], params)
axis_list_params = [...]
in_axes = (
axis_list_params,
["batch", ...],
[...],
)
out_axes = (axis_list_params, [...])
# standard data parallel training step with xmap!
train_step_xmap = xmap(
partial(
train_step,
model=model,
accum_steps=cfg.training.gradient_accumulation_steps,
),
in_axes=in_axes,
out_axes=out_axes,
axis_resources={"batch": "dp"},
)
eval_axes = (
axis_list_params,
["batch", ...],
)
eval_step_xmap = xmap(
partial(eval_step, model=model),
in_axes=eval_axes,
out_axes=[...],
axis_resources={"batch": "dp"},
)
opt_state_shapes = jax.eval_shape(tx.init, params)
grad_param_spec = set_partitions_zero(param_shape)
opt_state_spec = create_opt_spec(grad_param_spec, opt_state_shapes)
if (cfg.model.warm_init) and not (args.resume):
# only start from warm init params @ beginning of training run
del params
if save_to_bucket:
opt_state, step = restore_opt_checkpoint(
workdir=f"gs://{cfg.data.bucket_path}/{cfg.model.warm_init_dir}/optimizer"
)
opt_state = jax.device_get(opt_state) # copy to CPU
params = restore_param_checkpoint(
workdir=f"gs://{cfg.data.bucket_path}/{cfg.model.warm_init_dir}/params"
)
params = jax.device_get(params) # copy to CPU
else:
raise NotImplementedError(
"Checkpointing not currently implemented for GPU."
)
if jax.process_index() == 0:
logger.debug(f"Warm starting training for pretrained checkpoint.")
if args.resume:
del params
if save_to_bucket:
opt_state, step = restore_opt_checkpoint(
workdir=f"gs://{cfg.data.bucket_path}/{cfg.data.checkpoint_directory}/optimizer"
)
opt_state = jax.device_get(opt_state) # copy to CPU
params = restore_param_checkpoint(
workdir=f"gs://{cfg.data.bucket_path}/{cfg.data.checkpoint_directory}/params"
)
params = jax.device_get(params) # copy to CPU
resume_step = int(step)
else:
raise NotImplementedError(
"Checkpointing not currently implemented for GPU."
)
if jax.process_index() == 0:
logger.debug(f"Resuming training from step {resume_step}")
params = jax.device_get(params) # copy params to VM CPU
if jax.process_index() == 0:
logger.debug(f"VM setup with {num_devices} devices.")
logger.debug(f"Host setup with {num_local_devices} devices.")
logger.debug(f"Using platform: {platform}.")
logger.debug(
f"Performing data parallel training. Model parameters are replicated across all devices. Optimizer state is sharded across {num_devices} devices"
)
if not args.resume:
if cfg.data.bucket_path is not None:
# clear bucket
client = storage.Client()
if jax.process_index() == 0:
bucket = storage.Bucket(client, f"{cfg.data.bucket_path}")
blobs = bucket.list_blobs(
prefix=f"{cfg.data.checkpoint_directory}/optimizer"
)
for blob in blobs:
blob.delete()
blobs = bucket.list_blobs(
prefix=f"{cfg.data.checkpoint_directory}/params"
)
for blob in blobs:
blob.delete()
local_batch_size = cfg.training.batch_size // (jax.local_device_count())
total_tokens = num_host * (
cfg.training.batch_size
* compute_tokens_seen(
cfg.training.total_steps,
max_context=cfg.data.max_context,
)
)
if jax.process_index() == 0:
id = wandb.util.generate_id()
wandb.init(id=id, resume="allow", project=cfg.data.wandb_project)
flat_dict = flatten_dict(cfg)
for key in model_config.keys():
flat_dict[f"model.{key}"] = model_config[key]
flat_dict["training.local_batch_size"] = local_batch_size
flat_dict["runtime"] = platform
flat_dict["Total Training Tokens"] = total_tokens / 1e9
flat_dict["Total Devices"] = num_devices
wandb.config.update(flat_dict)
def preprocess(batch):
x = batch["input_id.pth"][: cfg.data.max_context]
if type(x) == torch.tensor:
return jnp.array(x.long(), dtype=jnp.int32)
else:
return jnp.array(x, dtype=jnp.int32)
from itertools import islice
def split_by_jax_process(src):
host_id, num_process = (
jax.process_index(),
num_host,
)
if num_process > 1:
for s in islice(src, host_id, None, num_process):
yield s
else:
for s in src:
yield s
train_dataset = wds.DataPipeline(
wds.SimpleShardList(train_shards),
split_by_jax_process,
wds.tarfile_to_samples(handler=wds.warn_and_continue),
wds.shuffle(1e7, initial=1e7, rng=pyrandom.Random(23 + resume_step)),
wds.decode(handler=wds.warn_and_continue),
wds.map(preprocess),
).repeat(nepochs=cfg.training.max_epochs)
validation_dataset = wds.DataPipeline(
wds.SimpleShardList(validation_shards),
split_by_jax_process,
wds.tarfile_to_samples(handler=wds.warn_and_continue),
wds.shuffle(1e6, initial=1e6, rng=pyrandom.Random(23 + resume_step)),
wds.decode(handler=wds.warn_and_continue),
wds.map(preprocess),
)
tl = DataLoader(
dataset=train_dataset,
batch_size=cfg.training.batch_size,
collate_fn=numpy_collate,
drop_last=True,
num_workers=0,
)
vl = DataLoader(
dataset=validation_dataset,
batch_size=cfg.training.batch_size // 4,
collate_fn=numpy_collate,
drop_last=True,
num_workers=0,
)
running_metrics = []
if cfg.training.train_context < cfg.data.max_context:
seq_len = cfg.training.train_context
else:
seq_len = cfg.data.max_context
accum_steps = cfg.training.gradient_accumulation_steps
rng = jax.random.fold_in(rng, resume_step) # fold in resume step to create new rng
# quick way to track global step count when resuming a run
new_steps = 0
iterator_resume_step = int(resume_step % cfg.data.steps_per_epoch)
with mesh:
params = jax.device_get(params)
if args.resume:
opt_state = pjit(
lambda x: x, in_axis_resources=None, out_axis_resources=opt_state_spec
)(opt_state)
else:
opt_state = pjit(
tx.init, in_axis_resources=None, out_axis_resources=opt_state_spec
)(params)
update_opt_state_pjit = pjit(
partial(update_opt_state, optimizer=tx, grad_spec=grad_param_spec),
in_shardings=(grad_param_spec, opt_state_spec, grad_param_spec),
out_shardings=(None, opt_state_spec),
)
grad_shard = pjit(
lambda x: x, in_axis_resources=None, out_axis_resources=grad_param_spec
)
for i, text in enumerate(tqdm(tl, disable=not jax.process_index() == 0)):
if (resume_step + new_steps) > cfg.training.total_steps:
if jax.process_index() == 0:
logger.debug(f"Training has completed.")
return True
if i < iterator_resume_step:
continue
rng, dropout_rng = jax.random.split(rng, 2)
gradient_accumulation_steps = accum_steps
if seq_len < cfg.data.max_context:
text = text.reshape(-1, seq_len)
# we add a 'grad_accum' batch dimension which we then iterate through in train_step
text = text.reshape(
gradient_accumulation_steps,
text.shape[0] // gradient_accumulation_steps,
seq_len,
).transpose(1, 0, 2)
text = text.reshape(
jax.device_count(),
cfg.training.batch_size
* (cfg.data.max_context // cfg.training.train_context)
// (jax.device_count() * gradient_accumulation_steps),
gradient_accumulation_steps,
seq_len,
) # (8, 4, 2, 2048) -> (32, 1, 2, 2048)
grads, metrics = train_step_xmap(params, text, dropout_rng)
grads = grad_shard(grads)
params = grad_shard(params)
params, opt_state = update_opt_state_pjit(grads, opt_state, params)
del grads # manually free grad mem
metrics["Train Sequence Length"] = seq_len
metrics["Learning Rate"] = learning_rate_fn(resume_step + new_steps)
running_metrics.append(metrics)
train_metrics_np = {
k: np.mean([metrics[k] for metrics in running_metrics])
for k in running_metrics[0]
}
running_metrics = []
validation_metrics = []
absolute_step = resume_step + new_steps
train_metrics_np["Tokens Seen (B)"] = (
num_host
* (
cfg.training.batch_size
* compute_tokens_seen(
absolute_step,
max_context=cfg.data.max_context,
)
)
/ 1e9
)
new_steps += 1
if (i) % (cfg.training.evaluation_frequency) == 0:
for val_it, val_text in enumerate(
tqdm(vl, disable=not jax.process_index() == 0)
):
val_text = val_text.reshape(-1, seq_len)
val_text = val_text.reshape(
jax.device_count(),
val_text.shape[0] // (jax.device_count()),
seq_len,
)
if val_it < cfg.training.maximum_evaluation_steps:
metrics = eval_step_xmap(params, val_text)
validation_metrics.append(metrics)
else:
break
validation_metrics_np = {
k: np.mean([metrics[k] for metrics in validation_metrics])
for k in validation_metrics[0]
}
def grab_shards(tree):
return jax.experimental.multihost_utils.process_allgather(tree)
opt_state_cpu = grab_shards(opt_state)
if jax.process_index() == 0:
train_metrics_np.update(validation_metrics_np)
wandb.log(train_metrics_np)
if save_to_bucket:
save_checkpoint_params(
params,
absolute_step,
workdir=f"gs://{cfg.data.bucket_path}/{cfg.data.checkpoint_directory}/params",
)
save_checkpoint_optimizer(
opt_state_cpu,
absolute_step,
workdir=f"gs://{cfg.data.bucket_path}/{cfg.data.checkpoint_directory}/optimizer",
)
else:
raise NotImplementedError(
"Checkpointing not currently implemented for GPU/CPU"
)
# TODO: Call gc.collect here?
# del opt_state_cpu
# gc.collect()
else:
if jax.process_index() == 0:
wandb.log(train_metrics_np)
if __name__ == "__main__":
main()