diff --git a/test_runner.py b/test_runner.py index 59bc49a4..31f53034 100755 --- a/test_runner.py +++ b/test_runner.py @@ -29,11 +29,12 @@ class OverrideDefinitions: override_args: Sequence[Sequence[str]] = tuple(tuple(" ")) test_descr: str = "default" + test_name: str = "default" requires_seed_checkpoint: bool = False ngpu: int = 4 -def build_test_list(args): +def build_test_list(): """ key is the config file name and value is a list of OverrideDefinitions that is used to generate variations of integration tests based on the @@ -45,7 +46,6 @@ def build_test_list(args): [ [ "--checkpoint.enable_checkpoint", - f"--job.dump_folder {args.output_dir}/pp_1f1b/", "--experimental.pipeline_parallel_degree 2", "--experimental.pipeline_parallel_split_points layers.1", "--experimental.pipeline_parallel_schedule 1f1b", @@ -53,6 +53,7 @@ def build_test_list(args): ], ], "PP 1D test 1f1b", + "pp_1f1b", requires_seed_checkpoint=True, ngpu=2, ), @@ -60,7 +61,6 @@ def build_test_list(args): [ [ "--checkpoint.enable_checkpoint", - f"--job.dump_folder {args.output_dir}/pp_gpipe/", "--experimental.pipeline_parallel_degree 2", "--experimental.pipeline_parallel_split_points layers.1", "--experimental.pipeline_parallel_schedule gpipe", @@ -68,6 +68,7 @@ def build_test_list(args): ], ], "PP 1D test gpipe", + "pp_gpipe", requires_seed_checkpoint=True, ngpu=2, ), @@ -75,7 +76,6 @@ def build_test_list(args): [ [ "--checkpoint.enable_checkpoint", - f"--job.dump_folder {args.output_dir}/pp_dp_1f1b/", "--experimental.pipeline_parallel_degree 2", "--experimental.pipeline_parallel_split_points layers.1", "--experimental.pipeline_parallel_schedule 1f1b", @@ -83,13 +83,13 @@ def build_test_list(args): ], ], "PP+DP 1f1b 2D test", + "pp_dp_1f1b", requires_seed_checkpoint=True, ), OverrideDefinitions( [ [ "--checkpoint.enable_checkpoint", - f"--job.dump_folder {args.output_dir}/pp_dp_gpipe/", "--experimental.pipeline_parallel_degree 2", "--experimental.pipeline_parallel_split_points layers.1", "--experimental.pipeline_parallel_schedule gpipe", @@ -97,13 +97,13 @@ def build_test_list(args): ], ], "PP+DP gpipe 2D test", + "pp_dp_gpipe", requires_seed_checkpoint=True, ), OverrideDefinitions( [ [ "--checkpoint.enable_checkpoint", - f"--job.dump_folder {args.output_dir}/pp_tp/", "--experimental.pipeline_parallel_degree 2", "--experimental.pipeline_parallel_split_points layers.1", "--training.tensor_parallel_degree 2", @@ -111,90 +111,89 @@ def build_test_list(args): ], ], "PP+TP 2D test", + "pp_tp", requires_seed_checkpoint=True, ), OverrideDefinitions( [ [ "--checkpoint.enable_checkpoint", - f"--job.dump_folder {args.output_dir}/pp_tracer/", "--experimental.pipeline_parallel_degree 2", "--experimental.pipeline_parallel_split_points layers.1", "--model.norm_type rmsnorm", # fused_rmsnorm not yet compatible with tracer ], ], "PP tracer frontend test", + "pp_tracer", requires_seed_checkpoint=True, ), OverrideDefinitions( [ - [ - f"--job.dump_folder {args.output_dir}/default/", - ], + [], ], - "Default", + "default", + "default", ), OverrideDefinitions( [ [ "--training.compile --model.norm_type=rmsnorm", - f"--job.dump_folder {args.output_dir}/1d_compile/", ], ], "1D compile", + "1d_compile", ), OverrideDefinitions( [ [ "--training.compile --training.tensor_parallel_degree 2 --model.norm_type=rmsnorm", - f"--job.dump_folder {args.output_dir}/2d_compile/", ], ], "2D compile", + "2d_compile", ), OverrideDefinitions( [ [ "--training.tensor_parallel_degree 2 --model.norm_type=rmsnorm", - f"--job.dump_folder {args.output_dir}/eager_2d/", ], ], "Eager mode 2DParallel", + "eager_2d", ), OverrideDefinitions( [ [ "--checkpoint.enable_checkpoint", - f"--job.dump_folder {args.output_dir}/full_checkpoint/", ], [ "--checkpoint.enable_checkpoint", - f"--job.dump_folder {args.output_dir}/full_checkpoint/", "--training.steps 20", ], ], "Checkpoint Integration Test - Save Load Full Checkpoint", + "full_checkpoint", ), OverrideDefinitions( [ [ "--checkpoint.enable_checkpoint", - f"--job.dump_folder {args.output_dir}/model_weights_only_fp32/", "--checkpoint.model_weights_only", ], ], "Checkpoint Integration Test - Save Model Weights Only fp32", + "model_weights_only_fp32", ), OverrideDefinitions( [ [ "--checkpoint.enable_checkpoint", - f"--job.dump_folder {args.output_dir}/model_weights_only_bf16/", "--checkpoint.model_weights_only", "--checkpoint.export_dtype bfloat16", ], ], "Checkpoint Integration Test - Save Model Weights Only bf16", + "model_weights_only_bf16", ), ] return integration_tests_flavors @@ -210,11 +209,15 @@ def _run_cmd(cmd): ) -def run_test(test_flavor: OverrideDefinitions, full_path: str): +def run_test(test_flavor: OverrideDefinitions, full_path: str, output_dir: str): # run_test supports sequence of tests. for override_arg in test_flavor.override_args: + test_name = test_flavor.test_name + dump_folder_arg = f"--job.dump_folder {output_dir}/{test_name}" cmd = f"CONFIG_FILE={full_path} NGPU={test_flavor.ngpu} LOG_RANK=0,1,2,3 ./run_llama_train.sh" + cmd += " " + dump_folder_arg + if override_arg: cmd += " " + " ".join(override_arg) logger.info( @@ -222,13 +225,6 @@ def run_test(test_flavor: OverrideDefinitions, full_path: str): ) if test_flavor.requires_seed_checkpoint: - dump_folder_arg = None - for arg in override_arg: - if "--job.dump_folder" in arg: - dump_folder_arg = arg - assert ( - dump_folder_arg is not None - ), "Can't use seed checkpoint if folder is not specified" logger.info("Creating seed checkpoint") result = _run_cmd( f"CONFIG_FILE={full_path} ./create_seed_checkpoint.sh {dump_folder_arg}" @@ -244,7 +240,7 @@ def run_test(test_flavor: OverrideDefinitions, full_path: str): def run_tests(args): - integration_tests_flavors = build_test_list(args) + integration_tests_flavors = build_test_list() for config_file in os.listdir(args.config_dir): if config_file.endswith(".toml"): full_path = os.path.join(args.config_dir, config_file) @@ -255,13 +251,19 @@ def run_tests(args): ) if is_integration_test: for test_flavor in integration_tests_flavors[config_file]: - run_test(test_flavor, full_path) + if args.test == "all" or test_flavor.test_name == args.test: + run_test(test_flavor, full_path, args.output_dir) def main(): parser = argparse.ArgumentParser() parser.add_argument("output_dir") parser.add_argument("--config_dir", default="./train_configs") + parser.add_argument( + "--test", + default="all", + help="test to run, acceptable values: `test_name` in `build_test_list` (default: all)", + ) args = parser.parse_args() if not os.path.exists(args.output_dir):