Skip to content

Commit

Permalink
Implementing python's global and nonlocal (#1735)
Browse files Browse the repository at this point in the history
  • Loading branch information
oxisto authored Oct 8, 2024
1 parent 3e2db5c commit 1b4602d
Show file tree
Hide file tree
Showing 6 changed files with 263 additions and 61 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -33,20 +33,10 @@ import de.fraunhofer.aisec.cpg.frontends.python.PythonLanguage.Companion.MODIFIE
import de.fraunhofer.aisec.cpg.graph.*
import de.fraunhofer.aisec.cpg.graph.Annotation
import de.fraunhofer.aisec.cpg.graph.declarations.*
import de.fraunhofer.aisec.cpg.graph.statements.AssertStatement
import de.fraunhofer.aisec.cpg.graph.statements.CatchClause
import de.fraunhofer.aisec.cpg.graph.statements.DeclarationStatement
import de.fraunhofer.aisec.cpg.graph.statements.ForEachStatement
import de.fraunhofer.aisec.cpg.graph.statements.Statement
import de.fraunhofer.aisec.cpg.graph.statements.TryStatement
import de.fraunhofer.aisec.cpg.graph.statements.expressions.AssignExpression
import de.fraunhofer.aisec.cpg.graph.statements.expressions.Block
import de.fraunhofer.aisec.cpg.graph.statements.expressions.DeleteExpression
import de.fraunhofer.aisec.cpg.graph.statements.expressions.Expression
import de.fraunhofer.aisec.cpg.graph.statements.expressions.InitializerListExpression
import de.fraunhofer.aisec.cpg.graph.statements.expressions.MemberExpression
import de.fraunhofer.aisec.cpg.graph.statements.expressions.ProblemExpression
import de.fraunhofer.aisec.cpg.graph.statements.expressions.Reference
import de.fraunhofer.aisec.cpg.graph.scopes.BlockScope
import de.fraunhofer.aisec.cpg.graph.scopes.NameScope
import de.fraunhofer.aisec.cpg.graph.statements.*
import de.fraunhofer.aisec.cpg.graph.statements.expressions.*
import de.fraunhofer.aisec.cpg.graph.types.FunctionType
import de.fraunhofer.aisec.cpg.helpers.Util
import kotlin.collections.plusAssign
Expand Down Expand Up @@ -76,9 +66,9 @@ class StatementHandler(frontend: PythonLanguageFrontend) :
is Python.AST.Assert -> handleAssert(node)
is Python.AST.Try -> handleTryStatement(node)
is Python.AST.Delete -> handleDelete(node)
is Python.AST.Global,
is Python.AST.Global -> handleGlobal(node)
is Python.AST.Nonlocal -> handleNonLocal(node)
is Python.AST.Match,
is Python.AST.Nonlocal,
is Python.AST.Raise,
is Python.AST.TryStar,
is Python.AST.With,
Expand Down Expand Up @@ -537,6 +527,43 @@ class StatementHandler(frontend: PythonLanguageFrontend) :
return wrapDeclarationToStatement(result)
}

/**
* Translates a Python [`Global`](https://docs.python.org/3/library/ast.html#ast.Global) into a
* [LookupScopeStatement].
*/
private fun handleGlobal(global: Python.AST.Global): LookupScopeStatement {
// Technically, our global scope is not identical to the python "global" scope. The reason
// behind that is that we wrap each file in a namespace (as defined in the python spec). So
// the "global" scope is actually our current namespace scope.
var pythonGlobalScope =
frontend.scopeManager.globalScope?.children?.firstOrNull { it is NameScope }

return newLookupScopeStatement(
global.names.map { parseName(it).localName },
pythonGlobalScope,
rawNode = global
)
}

/**
* Translates a Python [`Nonlocal`](https://docs.python.org/3/library/ast.html#ast.Nonlocal)
* into a [LookupScopeStatement].
*/
private fun handleNonLocal(global: Python.AST.Nonlocal): LookupScopeStatement {
// We need to find the first outer function scope, or rather the block scope belonging to
// the function
var outerFunctionScope =
frontend.scopeManager.firstScopeOrNull {
it is BlockScope && it != frontend.scopeManager.currentScope
}

return newLookupScopeStatement(
global.names.map { parseName(it).localName },
outerFunctionScope,
rawNode = global
)
}

/** Adds the arguments to [result] which might be located in a [recordDeclaration]. */
private fun handleArguments(
args: Python.AST.arguments,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,10 @@ import de.fraunhofer.aisec.cpg.frontends.python.PythonLanguageFrontend
import de.fraunhofer.aisec.cpg.graph.*
import de.fraunhofer.aisec.cpg.graph.declarations.Declaration
import de.fraunhofer.aisec.cpg.graph.declarations.FieldDeclaration
import de.fraunhofer.aisec.cpg.graph.declarations.MethodDeclaration
import de.fraunhofer.aisec.cpg.graph.declarations.VariableDeclaration
import de.fraunhofer.aisec.cpg.graph.scopes.RecordScope
import de.fraunhofer.aisec.cpg.graph.statements.ForEachStatement
import de.fraunhofer.aisec.cpg.graph.statements.expressions.AssignExpression
import de.fraunhofer.aisec.cpg.graph.statements.expressions.CallExpression
import de.fraunhofer.aisec.cpg.graph.statements.expressions.MemberExpression
import de.fraunhofer.aisec.cpg.graph.statements.expressions.Reference
import de.fraunhofer.aisec.cpg.graph.types.InitializerTypePropagation
Expand Down Expand Up @@ -79,59 +77,105 @@ class PythonAddDeclarationsPass(ctx: TranslationContext) : ComponentPass(ctx) {
}
}

/*
* Return null when not creating a new decl
/**
* This function will create a new dynamic [VariableDeclaration] if there is a write access to
* the [ref].
*/
private fun handleReference(node: Reference): VariableDeclaration? {
if (node.resolutionHelper is CallExpression) {
private fun handleWriteToReference(ref: Reference): VariableDeclaration? {
if (ref.access != AccessValues.WRITE) {
return null
}

// TODO(oxisto): Actually the logic here is far more complex in reality, taking into account
// local and global variables, but for now we just resolve it using the scope manager
val resolved = scopeManager.resolveReference(node)
// Look for a potential scope modifier for this reference
// lookupScope
var targetScope =
scopeManager.currentScope?.predefinedLookupScopes[ref.name.toString()]?.targetScope

// There are a couple of things to consider now
var symbol =
// Since this is a WRITE access, we need
// - to look for a local symbol, unless
// - a global keyword is present for this symbol and scope
if (targetScope != null) {
scopeManager.lookupSymbolByName(ref.name, ref.location, targetScope)
} else {
scopeManager.lookupSymbolByName(ref.name, ref.location) {
it.scope == scopeManager.currentScope
}
}

// Nothing to create
if (resolved != null) return null

val decl =
if (scopeManager.isInRecord) {
if (scopeManager.isInFunction) {
if (
node is MemberExpression &&
node.base.name ==
(scopeManager.currentFunction as? MethodDeclaration)?.receiver?.name
) {
// We need to temporarily jump into the scope of the current record to
// add the field. These are instance attributes
scopeManager.withScope(
scopeManager.firstScopeIsInstanceOrNull<RecordScope>()
) {
newFieldDeclaration(node.name)
}
} else {
newVariableDeclaration(node.name)
if (symbol.isNotEmpty()) return null

// First, check if we need to create a field
var field: FieldDeclaration? =
when {
// Check, whether we are referring to a "self.X", which would create a field
scopeManager.isInRecord && scopeManager.isInFunction && ref.refersToReceiver -> {
// We need to temporarily jump into the scope of the current record to
// add the field. These are instance attributes
scopeManager.withScope(scopeManager.firstScopeIsInstanceOrNull<RecordScope>()) {
newFieldDeclaration(ref.name)
}
} else {
newFieldDeclaration(node.name)
}
scopeManager.isInRecord && scopeManager.isInFunction && ref is MemberExpression -> {
// If this is any other member expression, we are usually not interested in
// creating fields, except if this is a receiver
return null
}
scopeManager.isInRecord -> {
// We end up here for fields declared directly in the class body. These are
// class attributes; more or less static fields.
newFieldDeclaration(ref.name)
}
else -> {
null
}
}

// If we didn't create any field up to this point and if we are still have not returned, we
// can create a normal variable. We need to take scope modifications into account.
var decl =
if (field == null && targetScope != null) {
scopeManager.withScope(targetScope) { newVariableDeclaration(ref.name) }
} else if (field == null) {
newVariableDeclaration(ref.name)
} else {
newVariableDeclaration(node.name)
field
}

decl.code = node.code
decl.location = node.location
decl.code = ref.code
decl.location = ref.location
decl.isImplicit = true

log.debug(
"Creating dynamic {} {} in {}",
if (decl is FieldDeclaration) {
"field"
} else {
"variable"
},
decl.name,
decl.scope
)

// Make sure we add the declaration at the correct place, i.e. with the scope we set at the
// creation time
scopeManager.withScope(decl.scope) { scopeManager.addDeclaration(decl) }

return decl
}

private val Reference.refersToReceiver: Boolean
get() {
return this is MemberExpression &&
this.base.name == scopeManager.currentMethod?.receiver?.name
}

private fun handleAssignExpression(assignExpression: AssignExpression) {
for (target in assignExpression.lhs) {
(target as? Reference)?.let {
val handled = handleReference(target)
val handled = handleWriteToReference(target)
if (handled is Declaration) {
// We cannot assign an initializer here because this will lead to duplicate
// DFG edges, but we need to propagate the type information from our value
Expand All @@ -153,8 +197,10 @@ class PythonAddDeclarationsPass(ctx: TranslationContext) : ComponentPass(ctx) {
private fun handleForEach(node: ForEachStatement) {
when (val forVar = node.variable) {
is Reference -> {
val handled = handleReference(forVar)
(handled as? Declaration)?.let { scopeManager.addDeclaration(it) }
val handled = handleWriteToReference(forVar)
if (handled is Declaration) {
handled.let { node.addDeclaration(it) }
}
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -688,10 +688,16 @@ class PythonFrontendTest : BaseTest() {

val classFieldDeclaredInFunction = clsFoo.fields["classFieldDeclaredInFunction"]
assertNotNull(classFieldDeclaredInFunction)
// assertEquals(3, clsFoo.fields.size) // TODO should "self" be considered a field here?

assertNull(classFieldNoInitializer.initializer)
assertNotNull(classFieldWithInit)

val localClassFieldNoInitializer = methBar.variables["classFieldNoInitializer"]
assertNotNull(localClassFieldNoInitializer)

val localClassFieldWithInit = methBar.variables["classFieldWithInit"]
assertNotNull(localClassFieldNoInitializer)

val localClassFieldDeclaredInFunction = methBar.variables["classFieldDeclaredInFunction"]
assertNotNull(localClassFieldNoInitializer)

// classFieldNoInitializer = classFieldWithInit
val assignClsFieldOutsideFunc = clsFoo.statements[2]
Expand Down Expand Up @@ -725,22 +731,22 @@ class PythonFrontendTest : BaseTest() {
val barStmt3 = barBody.statements[3]
assertIs<AssignExpression>(barStmt3)
assertEquals("=", barStmt3.operatorCode)
assertRefersTo((barStmt3.lhs<Reference>()), classFieldNoInitializer)
assertEquals("shadowed", (barStmt3.rhs<Literal<*>>())?.value)
assertRefersTo(barStmt3.lhs<Reference>(), localClassFieldNoInitializer)
assertLiteralValue("shadowed", barStmt3.rhs<Literal<String>>())

// classFieldWithInit = "shadowed"
val barStmt4 = barBody.statements[4]
assertIs<AssignExpression>(barStmt4)
assertEquals("=", barStmt4.operatorCode)
assertRefersTo((barStmt4.lhs<Reference>()), classFieldWithInit)
assertEquals("shadowed", (barStmt4.rhs<Literal<*>>())?.value)
assertRefersTo(barStmt4.lhs<Reference>(), localClassFieldWithInit)
assertLiteralValue("shadowed", (barStmt4.rhs<Literal<String>>()))

// classFieldDeclaredInFunction = "shadowed"
val barStmt5 = barBody.statements[5]
assertIs<AssignExpression>(barStmt5)
assertEquals("=", barStmt5.operatorCode)
assertRefersTo((barStmt5.lhs<Reference>()), classFieldDeclaredInFunction)
assertEquals("shadowed", (barStmt5.rhs<Literal<*>>())?.value)
assertRefersTo((barStmt5.lhs<Reference>()), localClassFieldDeclaredInFunction)
assertLiteralValue("shadowed", barStmt5.rhs<Literal<String>>())

/* TODO:
foo = Foo()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -231,4 +231,82 @@ class StatementHandlerTest : BaseTest() {
assertEquals(assertResolvedType("str"), b.type)
}
}

@Test
fun testGlobal() {
var file = topLevel.resolve("global.py").toFile()
val result = analyze(listOf(file), topLevel, true) { it.registerLanguage<PythonLanguage>() }
assertNotNull(result)

// There should be three variable declarations, two local and one global
var cVariables = result.variables("c")
assertEquals(3, cVariables.size)

// Our scopes do not match 1:1 to python scopes, but rather the python "global" scope is a
// name space with the name of the file and the function scope is a block scope of the
// function body
var pythonGlobalScope = result.finalCtx.scopeManager.lookupScope(file.nameWithoutExtension)

var globalC = cVariables.firstOrNull { it.scope == pythonGlobalScope }
assertNotNull(globalC)

var localC1 =
cVariables.firstOrNull {
it.scope?.astNode?.astParent?.name?.localName == "local_write"
}
assertNotNull(localC1)

var localC2 =
cVariables.firstOrNull {
it.scope?.astNode?.astParent?.name?.localName == "error_write"
}
assertNotNull(localC2)

// In global_write, all references should point to global c
var cRefs = result.functions["global_write"]?.refs("c")
assertNotNull(cRefs)
cRefs.forEach { assertRefersTo(it, globalC) }

// In global_read, all references should point to global c
cRefs = result.functions["global_read"]?.refs("c")
assertNotNull(cRefs)
cRefs.forEach { assertRefersTo(it, globalC) }

// In local_write, all references should point to local c
cRefs = result.functions["local_write"]?.refs("c")
assertNotNull(cRefs)
cRefs.forEach { assertRefersTo(it, localC1) }

// In error_write, all references will point to local c; even though the c on the right side
// SHOULD be unresolved - but this a general shortcoming because the resolving will not take
// the EOG into consideration (yet)
cRefs = result.functions["error_write"]?.refs("c")
assertNotNull(cRefs)
cRefs.forEach { assertRefersTo(it, localC2) }
}

@Test
fun testNonlocal() {
var file = topLevel.resolve("nonlocal.py").toFile()
val result = analyze(listOf(file), topLevel, true) { it.registerLanguage<PythonLanguage>() }
assertNotNull(result)

// There will be only 1 variable declarations because we are currently not adding nested
// functions to the AST properly :(
var cVariables = result.variables("c")
assertEquals(1, cVariables.size)
}

// TODO(oxisto): Re-renable this once we parse nested functions
@Ignore
@Test
fun testNonLocal() {
var file = topLevel.resolve("nonlocal.py").toFile()
val result = analyze(listOf(file), topLevel, true) { it.registerLanguage<PythonLanguage>() }
assertNotNull(result)

// There should be three variable declarations, two local and one global
var cVariables = result.variables("c")
assertEquals(3, cVariables.size)
}
}
Loading

0 comments on commit 1b4602d

Please sign in to comment.