From a26195c26ef9c36a0768dfff33d2b8e44d8951c9 Mon Sep 17 00:00:00 2001 From: Maxime Arthaud Date: Wed, 16 Oct 2024 10:48:38 -0700 Subject: [PATCH] Find test annotations in methods and nested functions Reviewed By: tianhan0 Differential Revision: D64469925 fbshipit-source-id: 4815fe23b7c9408e2640a94f1e78ee14203d979d --- tools/pysa_integration_tests/runner_lib.py | 43 +++++++++++++++++++--- 1 file changed, 38 insertions(+), 5 deletions(-) diff --git a/tools/pysa_integration_tests/runner_lib.py b/tools/pysa_integration_tests/runner_lib.py index cf53851aef..55f1c6b327 100644 --- a/tools/pysa_integration_tests/runner_lib.py +++ b/tools/pysa_integration_tests/runner_lib.py @@ -510,16 +510,46 @@ def parse_test_annotation( ) +class FunctionDefinitions: + def __init__(self) -> None: + self.functions: Dict[str, Union[ast.FunctionDef, ast.AsyncFunctionDef]] = {} + + def add_function( + self, name: str, function: Union[ast.FunctionDef, ast.AsyncFunctionDef] + ) -> None: + self.functions[name] = function + + @staticmethod + def from_ast(parsed_ast: ast.AST) -> "FunctionDefinitions": + functions = FunctionDefinitions() + functions.add_from_statements( + parsed_ast.body, # pyre-ignore: _ast.AST has no attribute body + prefix="", + ) + return functions + + def add_from_statements(self, statements: List[ast.stmt], prefix: str) -> None: + for statement in statements: + if isinstance(statement, (ast.FunctionDef, ast.AsyncFunctionDef)): + self.add_function(f"{prefix}{statement.name}", statement) + self.add_from_statements( + statement.body, prefix=f"{prefix}{statement.name}." + ) + elif isinstance(statement, ast.ClassDef): + self.add_from_statements( + statement.body, prefix=f"{prefix}{statement.name}." + ) + + def parse_test_annotations_from_source( source: str, ) -> Dict[str, FunctionTestAnnotations]: parsed_ast = ast.parse(source) - annotated_functions: Dict[str, FunctionTestAnnotations] = {} - for function in parsed_ast.body: - if not isinstance(function, (ast.FunctionDef, ast.AsyncFunctionDef)): - continue + functions = FunctionDefinitions.from_ast(parsed_ast) + annotated_functions: Dict[str, FunctionTestAnnotations] = {} + for qualified_name, function in functions.functions.items(): annotations: List[TestAnnotation] = [] for decorator_expression in function.decorator_list: if not isinstance(decorator_expression, ast.Call): @@ -530,7 +560,7 @@ def parse_test_annotations_from_source( annotations.append(annotation) if len(annotations) > 0: - annotated_functions[function.name] = FunctionTestAnnotations( + annotated_functions[qualified_name] = FunctionTestAnnotations( definition_line=function.lineno, annotations=annotations ) @@ -562,6 +592,9 @@ def parse_test_annotations_from_directory( result = DirectoryTestAnnotations() for path in directory.glob("**/*.py"): + if path.name == "runner_lib.py": + continue + base_module = ".".join(path.relative_to(repository_root).parts) base_module = base_module[:-3] # Remove .py suffix