Skip to content

Commit

Permalink
Join repeated simple f-string expressions.
Browse files Browse the repository at this point in the history
  • Loading branch information
scoder committed Sep 3, 2024
1 parent 8cdcbd1 commit 2be1855
Show file tree
Hide file tree
Showing 3 changed files with 128 additions and 4 deletions.
25 changes: 21 additions & 4 deletions Cython/Compiler/ExprNodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -3547,8 +3547,10 @@ def generate_result_code(self, code):
class JoinedStrNode(ExprNode):
# F-strings
#
# values [UnicodeNode|FormattedValueNode] Substrings of the f-string
# values [UnicodeNode|FormattedValueNode|CloneNode] Substrings of the f-string
#
# CloneNodes for repeated substrings are only inserted right before the code generation phase.

type = unicode_type
is_temp = True
gil_message = "String concatenation"
Expand Down Expand Up @@ -3576,21 +3578,31 @@ def generate_evaluation_code(self, code):
elif isinstance(node, FormattedValueNode) and node.value.type.is_numeric:
# formatted C numbers are always ASCII
pass
elif isinstance(node, CloneNode):
# we already know the result
pass
else:
unknown_nodes.add(node)

length_parts = []
counts = {}
charval_parts = [str(max_char_value)]
for node in self.values:
node.generate_evaluation_code(code)

if isinstance(node, UnicodeNode):
length_parts.append(str(len(node.value)))
length_part = str(len(node.value))
else:
# TODO: add exception handling for these macro calls if not ASSUME_SAFE_SIZE/MACROS
length_parts.append("__Pyx_PyUnicode_GET_LENGTH(%s)" % node.py_result())
length_part = f"__Pyx_PyUnicode_GET_LENGTH({node.py_result()})"
if node in unknown_nodes:
charval_parts.append("__Pyx_PyUnicode_MAX_CHAR_VALUE(%s)" % node.py_result())
charval_parts.append(f"__Pyx_PyUnicode_MAX_CHAR_VALUE({node.py_result()})")

if length_part in counts:
counts[length_part] += 1
else:
length_parts.append(length_part)
counts[length_part] = 1

if use_stack_memory:
values_array = code.funcstate.allocate_temp(
Expand All @@ -3606,6 +3618,11 @@ def generate_evaluation_code(self, code):
for i, node in enumerate(self.values):
code.putln('%s[%d] = %s;' % (values_array, i, node.py_result()))

length_parts = [
f"{part} * {counts[part]}" if counts[part] > 1 else part
for part in length_parts
]

code.mark_pos(self.pos)
self.allocate_temp_result(code)
code.globalstate.use_utility_code(UtilityCode.load_cached("JoinPyUnicode", "StringTools.c"))
Expand Down
55 changes: 55 additions & 0 deletions Cython/Compiler/Optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -5027,6 +5027,7 @@ class FinalOptimizePhase(Visitor.EnvTransform, Visitor.NodeRefCleanupMixin):
- eliminate useless string formatting steps
- inject branch hints for unlikely if-cases that only raise exceptions
- replace Python function calls that look like method calls by a faster PyMethodCallNode
- replace duplicate FormattedValueNodes in f-strings with CloneNodes
"""
in_loop = False

Expand Down Expand Up @@ -5178,6 +5179,60 @@ def _set_ifclause_branch_hint(self, clause, statements_node, inverse=False):
clause.branch_hint = 'likely' if inverse else 'unlikely'
break

def visit_JoinedStrNode(self, node: ExprNodes.JoinedStrNode):
"""
Deduplicate repeatedly formatted (C) values by replacing them with CloneNodes.
It's not uncommon for a formatting expression to appear multiple times in an f-string.
Note that this is somewhat handwavy since it's potentially possible even for simple
expressions to change their value while processing an f-string, e.g. by modifying the
world in a ".__format__" method. However, this seems unlikely enough to appear in
real-world code that we ignore the case here.
"""
FormattedValueNode = ExprNodes.FormattedValueNode
CoerceToPyTypeNode = ExprNodes.CoerceToPyTypeNode

seen = {}
values = node.values[:]
for i, fnode in enumerate(node.values):
if not isinstance(fnode, FormattedValueNode):
# Unicode string constants are deduplicated already.
continue
fnode_value_node = fnode.value
if isinstance(fnode.value, CoerceToPyTypeNode):
# Coerced C values are probably safe.
fnode_value_node = fnode_value_node.arg
elif fnode.c_format_spec is not None:
# Simple formatted C values are safe.
pass
elif fnode_value_node.type.is_builtin_type:
# Most builtin Python types are probably safe as well.
# FIXME: Except when a container type formats user defined values...
# Thus, we might want to be more specific and allow only simple Python types.
pass
else:
# Other Python objects are not safe as they can change their formatting on each acces.

Check failure on line 5214 in Cython/Compiler/Optimize.py

View workflow job for this annotation

GitHub Actions / Check for spelling errors

acces ==> access
continue

if not (fnode_value_node.is_name and fnode_value_node.is_simple()):
# Everything but simple (local) names risks changing on access.
# NOTE: Potentially, any non-trivial operation might re-assign values,
# e.g. with the walrus operator, but we ignore this here since it's really unusual.
# Otherwise, we'd have to stop with 'break' instead of allowing 'continue'.
continue

key = (fnode_value_node.name, fnode.c_format_spec, fnode.format_spec, fnode.conversion_char or 's')
seen_fnode = seen.setdefault(key, fnode)
if seen_fnode is fnode:
continue

dedup_fnode = ExprNodes.CloneNode(seen_fnode)
dedup_fnode.pos = fnode.pos
values[i] = dedup_fnode

node.values[:] = values
return node


class ConsolidateOverflowCheck(Visitor.CythonTransform):
"""
Expand Down
52 changes: 52 additions & 0 deletions tests/run/fstring.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -615,6 +615,58 @@ def sideeffect(l):
return list(l)


@cython.test_assert_path_exists(
"//JoinedStrNode",
"//JoinedStrNode/CloneNode",
)
def dedup_same(x, int i, float f):
"""
>>> dedup_same('abc', 5, 5.5)
'xabci5f5.5xabci5f5.5'
"""
return f"x{x}i{i}f{f}x{x}i{i}f{f}"


@cython.test_assert_path_exists(
"//JoinedStrNode",
"//JoinedStrNode/CloneNode",
)
def dedup_same_kind(x, int i, float f):
"""
>>> dedup_same_kind('abc', 5, 5.5)
'xabci5f5.5xabci5f5.5'
"""
return f"x{x}i{i}f{f}x{x!s}i{i!s}f{f!s}"


@cython.test_fail_if_path_exists(
"//JoinedStrNode//CloneNode",
)
@cython.test_assert_path_exists(
"//JoinedStrNode",
)
def dedup_different_format_char(x, int i, float f):
"""
>>> dedup_different_format_char('abc', 5, 5.5)
"xabci5f5.5x'abc'i5f5.5"
"""
return f"x{x}i{i}f{f}x{x!r}i{i:d}f{f!a}"


@cython.test_fail_if_path_exists(
"//JoinedStrNode//CloneNode",
)
@cython.test_assert_path_exists(
"//JoinedStrNode",
)
def dedup_non_simple(x, int i, float f):
"""
>>> dedup_non_simple('abc', 5, 5.5)
'xabci6f6.5xabci6f6.5'
"""
return f"x{x+''}i{i+1}f{f+1}x{x}i{i+1}f{f+1}"


########################################
# await inside f-string

Expand Down

0 comments on commit 2be1855

Please sign in to comment.