Skip to content

Commit

Permalink
Add --test option to specify test to run
Browse files Browse the repository at this point in the history
  • Loading branch information
kwen2501 committed May 29, 2024
1 parent 1ceaa4e commit 2f712cc
Showing 1 changed file with 26 additions and 28 deletions.
54 changes: 26 additions & 28 deletions test_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,12 @@ class OverrideDefinitions:

override_args: Sequence[Sequence[str]] = tuple(tuple(" "))
test_descr: str = "default"
test_id: str = "default"
requires_seed_checkpoint: bool = False
ngpu: int = 4


def build_test_list(args):
def build_test_list():
"""
key is the config file name and value is a list of OverrideDefinitions
that is used to generate variations of integration tests based on the
Expand All @@ -45,156 +46,154 @@ def build_test_list(args):
[
[
"--checkpoint.enable_checkpoint",
f"--job.dump_folder {args.output_dir}/pp_1f1b/",
"--experimental.pipeline_parallel_degree 2",
"--experimental.pipeline_parallel_split_points layers.1",
"--experimental.pipeline_parallel_schedule 1f1b",
"--training.data_parallel_degree 1",
],
],
"PP 1D test 1f1b",
"pp_1f1b",
requires_seed_checkpoint=True,
ngpu=2,
),
OverrideDefinitions(
[
[
"--checkpoint.enable_checkpoint",
f"--job.dump_folder {args.output_dir}/pp_gpipe/",
"--experimental.pipeline_parallel_degree 2",
"--experimental.pipeline_parallel_split_points layers.1",
"--experimental.pipeline_parallel_schedule gpipe",
"--training.data_parallel_degree 1",
],
],
"PP 1D test gpipe",
"pp_gpipe",
requires_seed_checkpoint=True,
ngpu=2,
),
OverrideDefinitions(
[
[
"--checkpoint.enable_checkpoint",
f"--job.dump_folder {args.output_dir}/pp_dp_1f1b/",
"--experimental.pipeline_parallel_degree 2",
"--experimental.pipeline_parallel_split_points layers.1",
"--experimental.pipeline_parallel_schedule 1f1b",
"--training.data_parallel_degree 2",
],
],
"PP+DP 1f1b 2D test",
"pp_dp_1f1b",
requires_seed_checkpoint=True,
),
OverrideDefinitions(
[
[
"--checkpoint.enable_checkpoint",
f"--job.dump_folder {args.output_dir}/pp_dp_gpipe/",
"--experimental.pipeline_parallel_degree 2",
"--experimental.pipeline_parallel_split_points layers.1",
"--experimental.pipeline_parallel_schedule gpipe",
"--training.data_parallel_degree 2",
],
],
"PP+DP gpipe 2D test",
"pp_dp_gpipe",
requires_seed_checkpoint=True,
),
OverrideDefinitions(
[
[
"--checkpoint.enable_checkpoint",
f"--job.dump_folder {args.output_dir}/pp_tp/",
"--experimental.pipeline_parallel_degree 2",
"--experimental.pipeline_parallel_split_points layers.1",
"--training.tensor_parallel_degree 2",
"--model.norm_type rmsnorm", # fused_rmsnorm not yet compatible with TP
],
],
"PP+TP 2D test",
"pp_tp",
requires_seed_checkpoint=True,
),
OverrideDefinitions(
[
[
"--checkpoint.enable_checkpoint",
f"--job.dump_folder {args.output_dir}/pp_tracer/",
"--experimental.pipeline_parallel_degree 2",
"--experimental.pipeline_parallel_split_points layers.1",
"--model.norm_type rmsnorm", # fused_rmsnorm not yet compatible with tracer
],
],
"PP tracer frontend test",
"pp_tracer",
requires_seed_checkpoint=True,
),
OverrideDefinitions(
[
[
f"--job.dump_folder {args.output_dir}/default/",
],
[],
],
"Default",
"default",
"default",
),
OverrideDefinitions(
[
[
"--training.compile --model.norm_type=rmsnorm",
f"--job.dump_folder {args.output_dir}/1d_compile/",
],
],
"1D compile",
"1d_compile",
),
OverrideDefinitions(
[
[
"--training.compile --training.tensor_parallel_degree 2 --model.norm_type=rmsnorm",
f"--job.dump_folder {args.output_dir}/2d_compile/",
],
],
"2D compile",
"2d_compile",
),
OverrideDefinitions(
[
[
"--training.tensor_parallel_degree 2 --model.norm_type=rmsnorm",
f"--job.dump_folder {args.output_dir}/eager_2d/",
],
],
"Eager mode 2DParallel",
"eager_2d",
),
OverrideDefinitions(
[
[
"--checkpoint.enable_checkpoint",
f"--job.dump_folder {args.output_dir}/full_checkpoint/",
],
[
"--checkpoint.enable_checkpoint",
f"--job.dump_folder {args.output_dir}/full_checkpoint/",
"--training.steps 20",
],
],
"Checkpoint Integration Test - Save Load Full Checkpoint",
"full_checkpoint",
),
OverrideDefinitions(
[
[
"--checkpoint.enable_checkpoint",
f"--job.dump_folder {args.output_dir}/model_weights_only_fp32/",
"--checkpoint.model_weights_only",
],
],
"Checkpoint Integration Test - Save Model Weights Only fp32",
"model_weights_only_fp32",
),
OverrideDefinitions(
[
[
"--checkpoint.enable_checkpoint",
f"--job.dump_folder {args.output_dir}/model_weights_only_bf16/",
"--checkpoint.model_weights_only",
"--checkpoint.export_dtype bfloat16",
],
],
"Checkpoint Integration Test - Save Model Weights Only bf16",
"model_weights_only_bf16",
),
]
return integration_tests_flavors
Expand All @@ -210,25 +209,22 @@ def _run_cmd(cmd):
)


def run_test(test_flavor: OverrideDefinitions, full_path: str):
def run_test(test_flavor: OverrideDefinitions, full_path: str, output_dir: str):
# run_test supports sequence of tests.
for override_arg in test_flavor.override_args:
test_id = test_flavor.test_id
dump_folder_arg = f"--job.dump_folder {output_dir}/{test_id}"

cmd = f"CONFIG_FILE={full_path} NGPU={test_flavor.ngpu} LOG_RANK=0,1,2,3 ./run_llama_train.sh"
cmd += " " + dump_folder_arg

if override_arg:
cmd += " " + " ".join(override_arg)
logger.info(
f"=====Integration test, flavor : {test_flavor.test_descr}, command : {cmd}====="
)

if test_flavor.requires_seed_checkpoint:
dump_folder_arg = None
for arg in override_arg:
if "--job.dump_folder" in arg:
dump_folder_arg = arg
assert (
dump_folder_arg is not None
), "Can't use seed checkpoint if folder is not specified"
logger.info("Creating seed checkpoint")
result = _run_cmd(
f"CONFIG_FILE={full_path} ./create_seed_checkpoint.sh {dump_folder_arg}"
Expand All @@ -244,7 +240,7 @@ def run_test(test_flavor: OverrideDefinitions, full_path: str):


def run_tests(args):
integration_tests_flavors = build_test_list(args)
integration_tests_flavors = build_test_list()
for config_file in os.listdir(args.config_dir):
if config_file.endswith(".toml"):
full_path = os.path.join(args.config_dir, config_file)
Expand All @@ -255,13 +251,15 @@ def run_tests(args):
)
if is_integration_test:
for test_flavor in integration_tests_flavors[config_file]:
run_test(test_flavor, full_path)
if args.test == "all" or test_flavor.test_id == args.test:
run_test(test_flavor, full_path, args.output_dir)


def main():
parser = argparse.ArgumentParser()
parser.add_argument("output_dir")
parser.add_argument("--config_dir", default="./train_configs")
parser.add_argument("--test", default="all", help="test to run, acceptable values: `test_id` in `build_test_list` (default: all)")
args = parser.parse_args()

if not os.path.exists(args.output_dir):
Expand Down

0 comments on commit 2f712cc

Please sign in to comment.