-
Notifications
You must be signed in to change notification settings - Fork 228
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add --test option to specify test to run #368
Changes from 2 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -29,11 +29,12 @@ class OverrideDefinitions: | |
|
||
override_args: Sequence[Sequence[str]] = tuple(tuple(" ")) | ||
test_descr: str = "default" | ||
test_id: str = "default" | ||
requires_seed_checkpoint: bool = False | ||
ngpu: int = 4 | ||
|
||
|
||
def build_test_list(args): | ||
def build_test_list(): | ||
""" | ||
key is the config file name and value is a list of OverrideDefinitions | ||
that is used to generate variations of integration tests based on the | ||
|
@@ -45,156 +46,154 @@ def build_test_list(args): | |
[ | ||
[ | ||
"--checkpoint.enable_checkpoint", | ||
f"--job.dump_folder {args.output_dir}/pp_1f1b/", | ||
"--experimental.pipeline_parallel_degree 2", | ||
"--experimental.pipeline_parallel_split_points layers.1", | ||
"--experimental.pipeline_parallel_schedule 1f1b", | ||
"--training.data_parallel_degree 1", | ||
], | ||
], | ||
"PP 1D test 1f1b", | ||
"pp_1f1b", | ||
requires_seed_checkpoint=True, | ||
ngpu=2, | ||
), | ||
OverrideDefinitions( | ||
[ | ||
[ | ||
"--checkpoint.enable_checkpoint", | ||
f"--job.dump_folder {args.output_dir}/pp_gpipe/", | ||
"--experimental.pipeline_parallel_degree 2", | ||
"--experimental.pipeline_parallel_split_points layers.1", | ||
"--experimental.pipeline_parallel_schedule gpipe", | ||
"--training.data_parallel_degree 1", | ||
], | ||
], | ||
"PP 1D test gpipe", | ||
"pp_gpipe", | ||
requires_seed_checkpoint=True, | ||
ngpu=2, | ||
), | ||
OverrideDefinitions( | ||
[ | ||
[ | ||
"--checkpoint.enable_checkpoint", | ||
f"--job.dump_folder {args.output_dir}/pp_dp_1f1b/", | ||
"--experimental.pipeline_parallel_degree 2", | ||
"--experimental.pipeline_parallel_split_points layers.1", | ||
"--experimental.pipeline_parallel_schedule 1f1b", | ||
"--training.data_parallel_degree 2", | ||
], | ||
], | ||
"PP+DP 1f1b 2D test", | ||
"pp_dp_1f1b", | ||
requires_seed_checkpoint=True, | ||
), | ||
OverrideDefinitions( | ||
[ | ||
[ | ||
"--checkpoint.enable_checkpoint", | ||
f"--job.dump_folder {args.output_dir}/pp_dp_gpipe/", | ||
"--experimental.pipeline_parallel_degree 2", | ||
"--experimental.pipeline_parallel_split_points layers.1", | ||
"--experimental.pipeline_parallel_schedule gpipe", | ||
"--training.data_parallel_degree 2", | ||
], | ||
], | ||
"PP+DP gpipe 2D test", | ||
"pp_dp_gpipe", | ||
requires_seed_checkpoint=True, | ||
), | ||
OverrideDefinitions( | ||
[ | ||
[ | ||
"--checkpoint.enable_checkpoint", | ||
f"--job.dump_folder {args.output_dir}/pp_tp/", | ||
"--experimental.pipeline_parallel_degree 2", | ||
"--experimental.pipeline_parallel_split_points layers.1", | ||
"--training.tensor_parallel_degree 2", | ||
"--model.norm_type rmsnorm", # fused_rmsnorm not yet compatible with TP | ||
], | ||
], | ||
"PP+TP 2D test", | ||
"pp_tp", | ||
requires_seed_checkpoint=True, | ||
), | ||
OverrideDefinitions( | ||
[ | ||
[ | ||
"--checkpoint.enable_checkpoint", | ||
f"--job.dump_folder {args.output_dir}/pp_tracer/", | ||
"--experimental.pipeline_parallel_degree 2", | ||
"--experimental.pipeline_parallel_split_points layers.1", | ||
"--model.norm_type rmsnorm", # fused_rmsnorm not yet compatible with tracer | ||
], | ||
], | ||
"PP tracer frontend test", | ||
"pp_tracer", | ||
requires_seed_checkpoint=True, | ||
), | ||
OverrideDefinitions( | ||
[ | ||
[ | ||
f"--job.dump_folder {args.output_dir}/default/", | ||
], | ||
[], | ||
], | ||
"Default", | ||
"default", | ||
"default", | ||
), | ||
OverrideDefinitions( | ||
[ | ||
[ | ||
"--training.compile --model.norm_type=rmsnorm", | ||
f"--job.dump_folder {args.output_dir}/1d_compile/", | ||
], | ||
], | ||
"1D compile", | ||
"1d_compile", | ||
), | ||
OverrideDefinitions( | ||
[ | ||
[ | ||
"--training.compile --training.tensor_parallel_degree 2 --model.norm_type=rmsnorm", | ||
f"--job.dump_folder {args.output_dir}/2d_compile/", | ||
], | ||
], | ||
"2D compile", | ||
"2d_compile", | ||
), | ||
OverrideDefinitions( | ||
[ | ||
[ | ||
"--training.tensor_parallel_degree 2 --model.norm_type=rmsnorm", | ||
f"--job.dump_folder {args.output_dir}/eager_2d/", | ||
], | ||
], | ||
"Eager mode 2DParallel", | ||
"eager_2d", | ||
), | ||
OverrideDefinitions( | ||
[ | ||
[ | ||
"--checkpoint.enable_checkpoint", | ||
f"--job.dump_folder {args.output_dir}/full_checkpoint/", | ||
], | ||
[ | ||
"--checkpoint.enable_checkpoint", | ||
f"--job.dump_folder {args.output_dir}/full_checkpoint/", | ||
"--training.steps 20", | ||
], | ||
], | ||
"Checkpoint Integration Test - Save Load Full Checkpoint", | ||
"full_checkpoint", | ||
), | ||
OverrideDefinitions( | ||
[ | ||
[ | ||
"--checkpoint.enable_checkpoint", | ||
f"--job.dump_folder {args.output_dir}/model_weights_only_fp32/", | ||
"--checkpoint.model_weights_only", | ||
], | ||
], | ||
"Checkpoint Integration Test - Save Model Weights Only fp32", | ||
"model_weights_only_fp32", | ||
), | ||
OverrideDefinitions( | ||
[ | ||
[ | ||
"--checkpoint.enable_checkpoint", | ||
f"--job.dump_folder {args.output_dir}/model_weights_only_bf16/", | ||
"--checkpoint.model_weights_only", | ||
"--checkpoint.export_dtype bfloat16", | ||
], | ||
], | ||
"Checkpoint Integration Test - Save Model Weights Only bf16", | ||
"model_weights_only_bf16", | ||
), | ||
] | ||
return integration_tests_flavors | ||
|
@@ -210,25 +209,22 @@ def _run_cmd(cmd): | |
) | ||
|
||
|
||
def run_test(test_flavor: OverrideDefinitions, full_path: str): | ||
def run_test(test_flavor: OverrideDefinitions, full_path: str, output_dir: str): | ||
# run_test supports sequence of tests. | ||
for override_arg in test_flavor.override_args: | ||
test_id = test_flavor.test_id | ||
dump_folder_arg = f"--job.dump_folder {output_dir}/{test_id}" | ||
|
||
cmd = f"CONFIG_FILE={full_path} NGPU={test_flavor.ngpu} LOG_RANK=0,1,2,3 ./run_llama_train.sh" | ||
cmd += " " + dump_folder_arg | ||
|
||
if override_arg: | ||
cmd += " " + " ".join(override_arg) | ||
logger.info( | ||
f"=====Integration test, flavor : {test_flavor.test_descr}, command : {cmd}=====" | ||
) | ||
|
||
if test_flavor.requires_seed_checkpoint: | ||
dump_folder_arg = None | ||
for arg in override_arg: | ||
if "--job.dump_folder" in arg: | ||
dump_folder_arg = arg | ||
assert ( | ||
dump_folder_arg is not None | ||
), "Can't use seed checkpoint if folder is not specified" | ||
logger.info("Creating seed checkpoint") | ||
result = _run_cmd( | ||
f"CONFIG_FILE={full_path} ./create_seed_checkpoint.sh {dump_folder_arg}" | ||
|
@@ -244,7 +240,7 @@ def run_test(test_flavor: OverrideDefinitions, full_path: str): | |
|
||
|
||
def run_tests(args): | ||
integration_tests_flavors = build_test_list(args) | ||
integration_tests_flavors = build_test_list() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why we remove the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. that's right. |
||
for config_file in os.listdir(args.config_dir): | ||
if config_file.endswith(".toml"): | ||
full_path = os.path.join(args.config_dir, config_file) | ||
|
@@ -255,13 +251,19 @@ def run_tests(args): | |
) | ||
if is_integration_test: | ||
for test_flavor in integration_tests_flavors[config_file]: | ||
run_test(test_flavor, full_path) | ||
if args.test == "all" or test_flavor.test_id == args.test: | ||
run_test(test_flavor, full_path, args.output_dir) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we should still assert the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
If not provided, the program entry would fail:
|
||
|
||
|
||
def main(): | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument("output_dir") | ||
parser.add_argument("--config_dir", default="./train_configs") | ||
parser.add_argument( | ||
"--test", | ||
default="all", | ||
help="test to run, acceptable values: `test_id` in `build_test_list` (default: all)", | ||
) | ||
args = parser.parse_args() | ||
|
||
if not os.path.exists(args.output_dir): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This feels like a
test_name
rather thantest_id
? could we rename totest_name
?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changed 👍