Skip to content

Commit

Permalink
[BE] dump compile trace to CI output for debugging, and reduce CI wor…
Browse files Browse the repository at this point in the history
…kload

ghstack-source-id: fe1076b088646dd1d3f3da4bb646c17bd3fa555d
Pull Request resolved: #739
  • Loading branch information
tianyu-l committed Dec 17, 2024
1 parent 5ce8a0c commit 00f5302
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 12 deletions.
15 changes: 6 additions & 9 deletions test_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,10 @@ def build_test_list():
integration_tests_flavors["debug_model.toml"] = [
OverrideDefinitions(
[
[],
[
"--profiling.enable_profiling",
"--metrics.enable_tensorboard",
],
],
"default",
"default",
Expand Down Expand Up @@ -138,7 +141,6 @@ def build_test_list():
OverrideDefinitions(
[
[
"--checkpoint.enable_checkpoint",
"--experimental.pipeline_parallel_degree 4",
"--experimental.pipeline_parallel_schedule InterleavedZeroBubble",
],
Expand All @@ -150,7 +152,6 @@ def build_test_list():
OverrideDefinitions(
[
[
"--checkpoint.enable_checkpoint",
"--experimental.pipeline_parallel_degree 2",
"--experimental.pipeline_parallel_schedule 1F1B",
"--training.data_parallel_shard_degree 1",
Expand All @@ -163,7 +164,6 @@ def build_test_list():
OverrideDefinitions(
[
[
"--checkpoint.enable_checkpoint",
"--experimental.pipeline_parallel_degree 2",
"--experimental.pipeline_parallel_schedule GPipe",
"--training.data_parallel_shard_degree 1",
Expand All @@ -176,7 +176,6 @@ def build_test_list():
OverrideDefinitions(
[
[
"--checkpoint.enable_checkpoint",
"--experimental.pipeline_parallel_degree 2",
"--experimental.pipeline_parallel_schedule 1F1B",
"--training.data_parallel_shard_degree 2",
Expand All @@ -188,7 +187,6 @@ def build_test_list():
OverrideDefinitions(
[
[
"--checkpoint.enable_checkpoint",
"--experimental.pipeline_parallel_degree 2",
"--experimental.pipeline_parallel_schedule GPipe",
"--training.data_parallel_shard_degree 2",
Expand All @@ -200,7 +198,6 @@ def build_test_list():
OverrideDefinitions(
[
[
"--checkpoint.enable_checkpoint",
"--experimental.pipeline_parallel_degree 2",
"--training.tensor_parallel_degree 2",
],
Expand Down Expand Up @@ -244,7 +241,6 @@ def build_test_list():
OverrideDefinitions(
[
[
"--checkpoint.enable_checkpoint",
"--experimental.pipeline_parallel_degree 4",
"--experimental.pipeline_parallel_schedule Interleaved1F1B",
],
Expand All @@ -256,7 +252,6 @@ def build_test_list():
OverrideDefinitions(
[
[
"--checkpoint.enable_checkpoint",
"--experimental.pipeline_parallel_degree 2",
"--experimental.pipeline_parallel_schedule PipelineScheduleMulti",
"--experimental.pipeline_parallel_schedule_csv ./test/assets/custom_schedule.csv",
Expand Down Expand Up @@ -413,6 +408,8 @@ def run_test(test_flavor: OverrideDefinitions, full_path: str, output_dir: str):

for override_arg in test_flavor.override_args:
cmd = f"CONFIG_FILE={full_path} NGPU={test_flavor.ngpu} LOG_RANK={all_ranks} ./run_llama_train.sh"
# dump compile trace for debugging purpose
cmd = f'TORCH_TRACE="{output_dir}/{test_name}/compile_trace" ' + cmd
if test_name == "fsdp2_memory_estimation":
cmd = (
f"CONFIG_FILE={full_path} NGPU={test_flavor.ngpu} LOG_RANK={all_ranks} "
Expand Down
6 changes: 3 additions & 3 deletions train_configs/debug_model.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ description = "Llama 3 debug training"
use_for_integration_test = true

[profiling]
enable_profiling = true
enable_profiling = false
save_traces_folder = "profile_trace"
profile_freq = 10
enable_memory_snapshot = false
Expand All @@ -15,7 +15,7 @@ save_memory_snapshot_folder = "memory_snapshot"
[metrics]
log_freq = 1
enable_color_printing = true
enable_tensorboard = true
enable_tensorboard = false
save_tb_folder = "tb"
enable_wandb = false

Expand Down Expand Up @@ -51,7 +51,7 @@ enable_async_tensor_parallel = false
enable_checkpoint = false
folder = "checkpoint"
interval_type = "steps"
interval = 5
interval = 10
model_weights_only = false
export_dtype = "float32"
async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"]
Expand Down

0 comments on commit 00f5302

Please sign in to comment.