Skip to content

Commit

Permalink
Fix parallel pgle-tests execution.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 696031645
  • Loading branch information
Google-ML-Automation committed Nov 13, 2024
1 parent f2a25cc commit 9a28b56
Showing 1 changed file with 53 additions and 31 deletions.
84 changes: 53 additions & 31 deletions tests/pgle_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,50 +12,79 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from contextlib import ExitStack
from functools import partial
import glob
import logging
import math
import os
import tempfile
import unittest

from absl.testing import absltest
import jax
from jax._src import api
from jax._src import compilation_cache as cc
from jax._src import config
from jax._src import profiler
from jax._src import pjit
from jax._src import monitoring
from jax._src import pjit
from jax._src import profiler
from jax._src import test_util as jtu
from jax._src import api
from jax.experimental import profiler as exp_profiler
import jax.numpy as jnp
from jax.sharding import NamedSharding, PartitionSpec
from jax._src import compilation_cache as cc
import numpy as np

from jax.experimental.serialize_executable import (
deserialize_and_load,
serialize,
)
import jax.numpy as jnp
from jax.sharding import NamedSharding, PartitionSpec
import numpy as np

jax.config.parse_flags_with_absl()

dump_dir = tempfile.TemporaryDirectory().name
os.environ['XLA_FLAGS'] = (
f'--xla_dump_to={dump_dir}'
' --xla_gpu_experimental_dump_fdo_profiles=true'
' --xla_gpu_enable_latency_hiding_scheduler=true'
)

@jtu.pytest_mark_if_available('multiaccelerator')
class PgleTest(jtu.JaxTestCase):
_dump_exit_stack: ExitStack | None = None

@classmethod
def setUpClass(cls):
super().setUpClass()
cls._dump_exit_stack = ExitStack()

cls.dump_dir = cls._dump_exit_stack.enter_context(tempfile.TemporaryDirectory())
if 'XLA_FLAGS' in os.environ:
cls.old_xla_flags = os.environ['XLA_FLAGS']
else:
cls.old_xla_flags = None

os.environ['XLA_FLAGS'] = (
f'--xla_dump_to={cls.dump_dir}'
' --xla_gpu_experimental_dump_fdo_profiles=true'
' --xla_gpu_enable_latency_hiding_scheduler=true'
# TODO(patrios): Remove this flag once b/376647494 is fixed.
' --xla_gpu_graph_level=0'
)
if cls.old_xla_flags:
os.environ['XLA_FLAGS'] += ' ' + cls.old_xla_flags

@classmethod
def tearDownClass(cls):
if cls.old_xla_flags:
os.environ['XLA_FLAGS'] = cls.old_xla_flags
cls._dump_exit_stack.close()
super().tearDownClass()

def setUp(self):
super().setUp()
cc.set_cache_dir(None)
cc.reset_cache()

def tearDown(self):
# Cleanup dump directory
for file in os.listdir(self.dump_dir):
file_path = os.path.join(self.dump_dir, file)
if os.path.isfile(file_path):
os.remove(file_path)

cc.set_cache_dir(None)
super().tearDown()

Expand Down Expand Up @@ -87,7 +116,6 @@ def f(x, y):
self.assertIsNotNone(fdo_profile)
self.assertIn(b'custom', fdo_profile)

@unittest.skip("Test failing in CI")
def testPGLEProfilerGetFDOProfileLarge(self):
mesh = jtu.create_mesh((2,), ('x',))
its = 500
Expand All @@ -106,14 +134,10 @@ def f(x):
shape = (16, 16)
x = jnp.arange(math.prod(shape)).reshape(shape).astype(np.float32)

with config.pgle_profiling_runs(0):
f_lowered = f.lower(x)
f_compiled = f_lowered.compile()

pgle_profiler = profiler.PGLEProfiler(1, 90)
with config.enable_pgle(False):
with profiler.PGLEProfiler.trace(pgle_profiler):
f_compiled(x)
f(x)
fdo_profile = pgle_profiler.consume_fdo_profile()
self.assertEqual(fdo_profile.count(b'custom'), its)

Expand Down Expand Up @@ -177,7 +201,6 @@ def f(x):
self.assertArraysEqual(compiled(x), expected)
self.assertEqual(cache_miss_count[0], 0)

@unittest.skip("Test failing in CI")
def testAutoPgleWithPersistentCache(self):
its = 50
mesh = jtu.create_mesh((2,), ('x',))
Expand Down Expand Up @@ -206,11 +229,12 @@ def f(x):
config.persistent_cache_min_compile_time_secs(0),
config.pgle_profiling_runs(2),
tempfile.TemporaryDirectory() as cache_dir):
cc.reset_cache()
cc.set_cache_dir(cache_dir)
# Run 1: Module should be compiled without FDO
with jtu.count_cached_compilation_cache_miss() as cache_miss_count:
f(x)
self.assertEqual(cache_miss_count[0], 1)
self.assertGreater(cache_miss_count[0], 0)

# Non-pgle profiled version of module should be saved
non_pgle_profiled_files = os.listdir(cache_dir)
Expand All @@ -221,26 +245,24 @@ def f(x):
f(x)
self.assertEqual(cache_miss_count[0], 0)

module_before_pgle = os.listdir(dump_dir)
print(module_before_pgle)
module_before_pgle = os.listdir(self.dump_dir)
self.assertNotEmpty(module_before_pgle)
# Run 3: Module should be compiled with FDO and stored to persistent cache
with jtu.count_cached_compilation_cache_miss() as cache_miss_count:
# Add xla_dump_to to env flags
f(x)
self.assertEqual(cache_miss_count[0], 1)
self.assertGreater(cache_miss_count[0], 0)

# Check if FDO profile file of the biggest module is not empty
module_after_pgle = [
x
for x in os.listdir(dump_dir)
for x in os.listdir(self.dump_dir)
if x not in module_before_pgle
]
self.assertNotEmpty(module_after_pgle)
biggest_module_after_pgle = max(
module_after_pgle,
key=lambda x: os.path.getsize(
os.path.join(dump_dir, x)
os.path.join(self.dump_dir, x)
),
)
base_module_name = '.'.join(biggest_module_after_pgle.split('.')[0:1])
Expand All @@ -251,7 +273,7 @@ def f(x):
'.fdo_profile'
):
self.assertGreater(
os.path.getsize(os.path.join(dump_dir, module)), 0
os.path.getsize(os.path.join(self.dump_dir, module)), 0
)

for pgle_profiler in profilers_dict.values():
Expand Down Expand Up @@ -283,7 +305,7 @@ def check_if_cache_hit(event):
f(x)
monitoring._unregister_event_listener_by_callback(check_if_cache_hit)

self.assertEqual(cache_hit, 1)
self.assertGreater(cache_hit, 0)

def testPassingFDOProfile(self):
mesh = jtu.create_mesh((2,), ('x',))
Expand Down

0 comments on commit 9a28b56

Please sign in to comment.