Skip to content

Commit

Permalink
replace create_seed_checkpoint.md with a note in docs/checkpoint.md
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
tianyu-l committed Dec 13, 2024
1 parent e846b69 commit 86c8e55
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 58 deletions.
34 changes: 0 additions & 34 deletions create_seed_checkpoint.sh

This file was deleted.

17 changes: 14 additions & 3 deletions docs/checkpoint.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
## How to convert a Llama3 checkpoint for use in torchtitan
## How to convert a Llama 3 checkpoint for use in torchtitan

If you want to continue training from an existing model checkpoint, the checkpoint must be in the DCP format expected by the checkpoint manager.
An example script for converting the original Llama3 checkpoints into the expected DCP format can be found in `scripts/convert_llama_to_dcp.py`.
Expand All @@ -9,8 +9,7 @@ python3 scripts/convert_llama_to_dcp.py <input_dir> <output_dir>
```



## How to Convert a torchtitan Checkpoint for Use in torchtune
## How to convert a torchtitan checkpoint for use in torchtune

This guide will walk you through the steps required to convert a checkpoint from torchtitan so that it can be loaded into torchtune.

Expand Down Expand Up @@ -66,3 +65,15 @@ python -m torch.distributed.checkpoint.format_utils dcp_to_torch torchtitan/outp
```

That's it. You have now successfully converted a sharded torchtitan checkpoint for use in torchtune.


## How to create a seed checkpoint
Sometimes one needs to create a seed checkpoint to initialize a model from step 0.
E.g. it is hard, if not impossible, for meta initialization on multiple devices to reproduce the initialization on a single device.
A seed checkpoint does initialization of the model on a single CPU, and can be loaded from another job on an arbitrary number of GPUs via DCP resharding.

To create a seed checkpoint, use the same model config as you use for training.
e.g.
```bash
NGPU=1 CONFIG=<path_to_model_config> ./run_llama_train.sh --checkpoint.enable_checkpoint --checkpoint.create_seed_checkpoint --training.data_parallel_shard_degree 1
```
22 changes: 2 additions & 20 deletions test_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ class OverrideDefinitions:
override_args: Sequence[Sequence[str]] = tuple(tuple(" "))
test_descr: str = "default"
test_name: str = "default"
requires_seed_checkpoint: bool = False
ngpu: int = 4
model_flavor: str = "debugmodel"

Expand Down Expand Up @@ -146,7 +145,6 @@ def build_test_list():
],
"PP looped zero bubble test",
"pp_looped_zero_bubble",
requires_seed_checkpoint=True,
ngpu=4,
),
OverrideDefinitions(
Expand All @@ -160,7 +158,6 @@ def build_test_list():
],
"PP 1D test 1F1B",
"pp_1f1b",
requires_seed_checkpoint=True,
ngpu=2,
),
OverrideDefinitions(
Expand All @@ -174,7 +171,6 @@ def build_test_list():
],
"PP 1D test GPipe",
"pp_gpipe",
requires_seed_checkpoint=True,
ngpu=2,
),
OverrideDefinitions(
Expand All @@ -188,7 +184,6 @@ def build_test_list():
],
"PP+DP 1F1B 2D test",
"pp_dp_1f1b",
requires_seed_checkpoint=True,
),
OverrideDefinitions(
[
Expand All @@ -201,7 +196,6 @@ def build_test_list():
],
"PP+DP GPipe 2D test",
"pp_dp_gpipe",
requires_seed_checkpoint=True,
),
OverrideDefinitions(
[
Expand All @@ -213,7 +207,6 @@ def build_test_list():
],
"PP+TP 2D test",
"pp_tp",
requires_seed_checkpoint=True,
),
OverrideDefinitions(
[
Expand All @@ -233,7 +226,6 @@ def build_test_list():
],
"PP+DP+TP 3D test with save/load resume ckpt",
"pp_dp_tp",
requires_seed_checkpoint=True,
ngpu=8,
),
OverrideDefinitions(
Expand All @@ -247,7 +239,6 @@ def build_test_list():
],
"PP+DP+TP 3D test with torch.compile",
"3d_compile",
requires_seed_checkpoint=True,
ngpu=8,
),
OverrideDefinitions(
Expand All @@ -260,7 +251,6 @@ def build_test_list():
],
"PP looped 1F1B test",
"pp_looped_1f1b",
requires_seed_checkpoint=True,
ngpu=4,
),
OverrideDefinitions(
Expand Down Expand Up @@ -384,7 +374,7 @@ def build_test_list():
]
],
"FSDP2 Memory Tracking and Estimation",
"fsdp2_mem_tracker",
"fsdp2_memory_estimation",
ngpu=2,
),
OverrideDefinitions(
Expand Down Expand Up @@ -421,17 +411,9 @@ def run_test(test_flavor: OverrideDefinitions, full_path: str, output_dir: str):
model_flavor_arg = f"--model.flavor {test_flavor.model_flavor}"
all_ranks = ",".join(map(str, range(test_flavor.ngpu)))

if test_flavor.requires_seed_checkpoint:
cmd = f"CONFIG_FILE={full_path} ./create_seed_checkpoint.sh {dump_folder_arg} {model_flavor_arg}"
logger.info(
f"=====Integration test, flavor : {test_flavor.test_descr}, command : {cmd}====="
)
result = _run_cmd(cmd)
logger.info(result.stdout)

for override_arg in test_flavor.override_args:
cmd = f"CONFIG_FILE={full_path} NGPU={test_flavor.ngpu} LOG_RANK={all_ranks} ./run_llama_train.sh"
if test_name == "fsdp2_mem_tracker":
if test_name == "fsdp2_memory_estimation":
cmd = (
f"CONFIG_FILE={full_path} NGPU={test_flavor.ngpu} LOG_RANK={all_ranks} "
"./scripts/estimate/run_memory_estimation.sh"
Expand Down
2 changes: 1 addition & 1 deletion torchtitan/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def set_determinism(

# As long as we are not in the 1-D (PP-only) case, we will have a seed to use for all ranks of the SPMD mesh.
# IF PP is also used, this seed is unique per PP rank.
if spmd_mesh:
if spmd_mesh and spmd_mesh.get_coordinate() is not None:
torch.distributed.tensor._random.manual_seed(seed, spmd_mesh)


Expand Down

0 comments on commit 86c8e55

Please sign in to comment.