Skip to content

Commit

Permalink
Find test annotations in methods and nested functions
Browse files Browse the repository at this point in the history
Reviewed By: tianhan0

Differential Revision: D64469925

fbshipit-source-id: 4815fe23b7c9408e2640a94f1e78ee14203d979d
  • Loading branch information
arthaud authored and facebook-github-bot committed Oct 16, 2024
1 parent aaf733f commit a26195c
Showing 1 changed file with 38 additions and 5 deletions.
43 changes: 38 additions & 5 deletions tools/pysa_integration_tests/runner_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -510,16 +510,46 @@ def parse_test_annotation(
)


class FunctionDefinitions:
def __init__(self) -> None:
self.functions: Dict[str, Union[ast.FunctionDef, ast.AsyncFunctionDef]] = {}

def add_function(
self, name: str, function: Union[ast.FunctionDef, ast.AsyncFunctionDef]
) -> None:
self.functions[name] = function

@staticmethod
def from_ast(parsed_ast: ast.AST) -> "FunctionDefinitions":
functions = FunctionDefinitions()
functions.add_from_statements(
parsed_ast.body, # pyre-ignore: _ast.AST has no attribute body
prefix="",
)
return functions

def add_from_statements(self, statements: List[ast.stmt], prefix: str) -> None:
for statement in statements:
if isinstance(statement, (ast.FunctionDef, ast.AsyncFunctionDef)):
self.add_function(f"{prefix}{statement.name}", statement)
self.add_from_statements(
statement.body, prefix=f"{prefix}{statement.name}."
)
elif isinstance(statement, ast.ClassDef):
self.add_from_statements(
statement.body, prefix=f"{prefix}{statement.name}."
)


def parse_test_annotations_from_source(
source: str,
) -> Dict[str, FunctionTestAnnotations]:
parsed_ast = ast.parse(source)

annotated_functions: Dict[str, FunctionTestAnnotations] = {}
for function in parsed_ast.body:
if not isinstance(function, (ast.FunctionDef, ast.AsyncFunctionDef)):
continue
functions = FunctionDefinitions.from_ast(parsed_ast)

annotated_functions: Dict[str, FunctionTestAnnotations] = {}
for qualified_name, function in functions.functions.items():
annotations: List[TestAnnotation] = []
for decorator_expression in function.decorator_list:
if not isinstance(decorator_expression, ast.Call):
Expand All @@ -530,7 +560,7 @@ def parse_test_annotations_from_source(
annotations.append(annotation)

if len(annotations) > 0:
annotated_functions[function.name] = FunctionTestAnnotations(
annotated_functions[qualified_name] = FunctionTestAnnotations(
definition_line=function.lineno, annotations=annotations
)

Expand Down Expand Up @@ -562,6 +592,9 @@ def parse_test_annotations_from_directory(

result = DirectoryTestAnnotations()
for path in directory.glob("**/*.py"):
if path.name == "runner_lib.py":
continue

base_module = ".".join(path.relative_to(repository_root).parts)
base_module = base_module[:-3] # Remove .py suffix

Expand Down

0 comments on commit a26195c

Please sign in to comment.