Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SOT] Mark some APIs can be directly run in simulation mode #70293

Open
wants to merge 4 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
is_break_graph_api,
is_break_graph_tensor_methods,
is_builtin_fn,
is_directly_run_api,
is_not_supported_paddle_layer,
is_paddle_api,
magic_method_builtin_dispatch,
Expand Down Expand Up @@ -699,6 +700,22 @@ def call_function(self, /, *args, **kwargs):
)
return handler(*args, **kwargs)

# If API can be directly called in simulation mode (e.g. user defined native code
# without graph affect), we can directly call it.
if is_directly_run_api(self.value):
from ..function_graph import convert_to_py_value

res = self.value(
*convert_to_py_value(args),
**convert_to_py_value(kwargs),
)

return VariableFactory.from_value(
res,
self.graph,
DummyTracker([self, *list(args), *list(kwargs.values())]),
)

# Try to inline call the magic function
magic_methods = magic_method_builtin_dispatch(self.value)
for magic_method in magic_methods:
Expand Down
1 change: 1 addition & 0 deletions python/paddle/jit/sot/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
from .paddle_api_config import ( # noqa: F401
get_tensor_methods,
is_break_graph_tensor_methods,
is_directly_run_api,
is_inplace_api,
is_not_supported_paddle_layer,
)
Expand Down
25 changes: 25 additions & 0 deletions python/paddle/jit/sot/utils/paddle_api_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,3 +129,28 @@ def is_break_graph_tensor_methods(method_name):

def add_break_graph_apis(apis: list):
break_graph_set.update(apis)


def is_directly_run_api(api):
from .utils import hashable

if not hashable(api):
return False
NATIVE_CODE_PURE_FUNCTIONS = {
paddle.base.libpaddle.is_compiled_with_avx,
paddle.base.libpaddle.is_compiled_with_cuda,
paddle.base.libpaddle.is_compiled_with_cudnn_frontend,
paddle.base.libpaddle.is_compiled_with_rocm,
paddle.base.libpaddle.is_compiled_with_custom_device,
paddle.base.libpaddle.is_compiled_with_ipu,
paddle.base.libpaddle.is_compiled_with_xpu,
paddle.base.libpaddle.is_compiled_with_mkldnn,
paddle.base.libpaddle.is_compiled_with_nccl,
paddle.base.libpaddle.is_compiled_with_mpi,
paddle.base.libpaddle.is_compiled_with_mpi_aware,
paddle.base.libpaddle.is_compiled_with_cinn,
paddle.base.libpaddle.is_compiled_with_distribute,
paddle.base.libpaddle.is_compiled_with_brpc,
paddle.base.libpaddle.is_compiled_with_dist,
}
return api in NATIVE_CODE_PURE_FUNCTIONS
41 changes: 41 additions & 0 deletions test/sot/test_builtin_dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,5 +388,46 @@ def test_builtin_type_conversion_breakgraph(self):
)


@check_no_breakgraph
def test_native_code_function():
res1 = paddle.base.libpaddle.is_compiled_with_avx()
res2 = paddle.base.libpaddle.is_compiled_with_cuda()
res3 = paddle.base.libpaddle.is_compiled_with_cudnn_frontend()
res4 = paddle.base.libpaddle.is_compiled_with_rocm()
res5 = paddle.base.libpaddle.is_compiled_with_custom_device("npu")
res6 = paddle.base.libpaddle.is_compiled_with_ipu()
res7 = paddle.base.libpaddle.is_compiled_with_xpu()
res8 = paddle.base.libpaddle.is_compiled_with_mkldnn()
res9 = paddle.base.libpaddle.is_compiled_with_nccl()
res10 = paddle.base.libpaddle.is_compiled_with_mpi()
res11 = paddle.base.libpaddle.is_compiled_with_mpi_aware()
res12 = paddle.base.libpaddle.is_compiled_with_cinn()
res13 = paddle.base.libpaddle.is_compiled_with_distribute()
res14 = paddle.base.libpaddle.is_compiled_with_brpc()
res15 = paddle.base.libpaddle.is_compiled_with_dist()
return (
res1,
res2,
res3,
res4,
res5,
res6,
res7,
res8,
res9,
res10,
res11,
res12,
res13,
res14,
res15,
)


class TestNativeCodeFunction(TestCaseBase):
def test_native_code_function(self):
self.assert_results(test_native_code_function)


if __name__ == "__main__":
unittest.main()