From 5d5caca2b1410d58e5e2f354069039d8258a7123 Mon Sep 17 00:00:00 2001 From: ebsmothers Date: Sun, 8 Sep 2024 08:35:29 -0700 Subject: [PATCH] Make our compile tests actually work (#1522) --- tests/recipes/test_full_finetune_single_device.py | 8 +++++--- tests/recipes/test_lora_finetune_single_device.py | 15 ++++++++------- torchtune/training/_compile.py | 2 +- 3 files changed, 14 insertions(+), 11 deletions(-) diff --git a/tests/recipes/test_full_finetune_single_device.py b/tests/recipes/test_full_finetune_single_device.py index 646e53382c..9b8a75ceb9 100644 --- a/tests/recipes/test_full_finetune_single_device.py +++ b/tests/recipes/test_full_finetune_single_device.py @@ -74,9 +74,6 @@ def test_loss(self, compile, config, model_type, ckpt_type, tmpdir, monkeypatch) ckpt_dir = ckpt_path.parent log_file = gen_log_file_name(tmpdir) - # To workaround https://github.com/pytorch/torchtune/issues/676 - if compile: - os.environ["TORCH_COMPILE_BACKEND"] = "aot_eager" cmd = f""" tune run full_finetune_single_device \ --config {config} \ @@ -99,8 +96,13 @@ def test_loss(self, compile, config, model_type, ckpt_type, tmpdir, monkeypatch) with pytest.raises(SystemExit, match=""): runpy.run_path(TUNE_PATH, run_name="__main__") + # Make sure to clear compile state in between tests + if compile: + torch._dynamo.reset() + loss_values = get_loss_values_from_metric_logger(log_file) expected_loss_values = self._fetch_expected_loss_values(model_type) + torch.testing.assert_close( loss_values, expected_loss_values, rtol=1e-4, atol=1e-4 ) diff --git a/tests/recipes/test_lora_finetune_single_device.py b/tests/recipes/test_lora_finetune_single_device.py index d0609949d4..5253c10346 100644 --- a/tests/recipes/test_lora_finetune_single_device.py +++ b/tests/recipes/test_lora_finetune_single_device.py @@ -75,9 +75,6 @@ def test_loss(self, compile, config, model_type, ckpt_type, tmpdir, monkeypatch) ckpt_dir = ckpt_path.parent log_file = gen_log_file_name(tmpdir) - # To workaround https://github.com/pytorch/torchtune/issues/676 - if compile: - os.environ["TORCH_COMPILE_BACKEND"] = "aot_eager" cmd = f""" tune run lora_finetune_single_device \ --config {config} \ @@ -100,6 +97,10 @@ def test_loss(self, compile, config, model_type, ckpt_type, tmpdir, monkeypatch) with pytest.raises(SystemExit, match=""): runpy.run_path(TUNE_PATH, run_name="__main__") + # Make sure to clear compile state in between tests + if compile: + torch._dynamo.reset() + loss_values = get_loss_values_from_metric_logger(log_file) expected_loss_values = self._fetch_expected_loss_values(model_type) torch.testing.assert_close( @@ -119,10 +120,6 @@ def test_loss_qlora(self, compile, dtype, tmpdir, monkeypatch): ckpt_dir = ckpt_path.parent log_file = gen_log_file_name(tmpdir) - # To workaround https://github.com/pytorch/torchtune/issues/676 - if compile: - os.environ["TORCH_COMPILE_BACKEND"] = "aot_eager" - cmd = f""" tune run lora_finetune_single_device --config llama2/7B_qlora_single_device \ @@ -145,6 +142,10 @@ def test_loss_qlora(self, compile, dtype, tmpdir, monkeypatch): with pytest.raises(SystemExit): runpy.run_path(TUNE_PATH, run_name="__main__") + # Make sure to clear compile state in between tests + if compile: + torch._dynamo.reset() + loss_values = get_loss_values_from_metric_logger(log_file) expected_loss_values = self._fetch_qlora_expected_loss_values(dtype=dtype) torch.testing.assert_close( diff --git a/torchtune/training/_compile.py b/torchtune/training/_compile.py index 133c3f2c02..3f8d8c279e 100644 --- a/torchtune/training/_compile.py +++ b/torchtune/training/_compile.py @@ -34,6 +34,7 @@ def compile_model( verbose (bool): Whether to log compile info. Default: True Returns: None + """ backend = os.environ.get("TORCH_COMPILE_BACKEND", "inductor") if torch_version_ge("2.5.0"): @@ -65,7 +66,6 @@ def compile_loss(loss: nn.Module, verbose: bool = True) -> None: Returns: loss (nn.Module): loss with either entire module compiled or (in the case of CEWithChunkedOutputLoss) only the upcast and cross-entropy calculation compiled. - """ backend = os.environ.get("TORCH_COMPILE_BACKEND", "inductor") if verbose: