Skip to content

Commit

Permalink
Fix a crash when calling the optimised sorted() by using a dedicated …
Browse files Browse the repository at this point in the history
…"SortedListNode" for this simple code injection.

This also avoids needlessly juggling Python temps for the list object.

Closes cython#6496
  • Loading branch information
scoder committed Nov 12, 2024
1 parent a27ff56 commit 10c27c1
Show file tree
Hide file tree
Showing 3 changed files with 141 additions and 46 deletions.
106 changes: 83 additions & 23 deletions Cython/Compiler/ExprNodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1206,6 +1206,63 @@ def get_known_standard_library_import(self):
return None


class _TempModifierNode(ExprNode):
"""Base class for nodes that inherit the result of their temp argument and can modify it.
"""
subexprs = ['arg']
is_temp = False

def __init__(self, pos, arg):
super().__init__(pos, arg=arg)

@property
def type(self):
return self.arg.type

def infer_type(self, env):
return self.arg.infer_type(env)

def analyse_types(self, env):
self.arg = self.arg.analyse_types(env)
return self

def calculate_constant_result(self):
return self.arg.calculate_constant_result()

def may_be_none(self):
return self.arg.may_be_none()

def is_simple(self):
return self.arg.is_simple()

def result_in_temp(self):
return self.arg.result_in_temp()

def nonlocally_immutable(self):
return self.arg.nonlocally_immutable()

def calculate_result_code(self):
return self.arg.result()

def generate_result_code(self, code):
pass

def generate_post_assignment_code(self, code):
self.arg.generate_post_assignment_code(code)

def allocate_temp_result(self, code):
return self.arg.allocate_temp_result(code)

def free_temps(self, code):
self.arg.free_temps(code)


#-------------------------------------------------------------------
#
# Constants
#
#-------------------------------------------------------------------

class AtomicExprNode(ExprNode):
# Abstract base class for expression nodes which have
# no sub-expressions.
Expand All @@ -1218,6 +1275,7 @@ def generate_subexpr_evaluation_code(self, code):
def generate_subexpr_disposal_code(self, code):
pass


class PyConstNode(AtomicExprNode):
# Abstract base class for constant Python values.

Expand Down Expand Up @@ -1895,6 +1953,12 @@ def generate_result_code(self, code):
self.generate_gotref(code)


#-------------------------------------------------------------------
#
# Simple expressions
#
#-------------------------------------------------------------------

class NewExprNode(AtomicExprNode):

# C++ new statement
Expand Down Expand Up @@ -2696,6 +2760,7 @@ def get_known_standard_library_import(self):
return self.entry.known_standard_library_import
return None


class BackquoteNode(ExprNode):
# `expr`
#
Expand Down Expand Up @@ -2725,6 +2790,12 @@ def generate_result_code(self, code):
self.generate_gotref(code)


#-------------------------------------------------------------------
#
# Control-flow related expressions
#
#-------------------------------------------------------------------

class ImportNode(ExprNode):
# Used as part of import statement implementation.
# Implements result =
Expand Down Expand Up @@ -9601,6 +9672,15 @@ def generate_result_code(self, code):
self.pos, 'PyList_Sort(%s)' % self.py_result())


class SortedListNode(_TempModifierNode):
"""Sorts a newly created Python list in place.
"""
type = list_type

def generate_result_code(self, code):
code.putln(code.error_goto_if_neg(f"PyList_Sort({self.arg.result()})", self.pos))


class ModuleNameMixin:
def get_py_mod_name(self, code):
return code.get_py_string_const(
Expand Down Expand Up @@ -14216,16 +14296,17 @@ def free_subexpr_temps(self, code):
self.arg.free_subexpr_temps(code)


class NoneCheckNode(CoercionNode):
class NoneCheckNode(_TempModifierNode):
# This node is used to check that a Python object is not None and
# raises an appropriate exception (as specified by the creating
# transform).

is_nonecheck = True
type = None

def __init__(self, arg, exception_type_cname, exception_message,
exception_format_args=()):
CoercionNode.__init__(self, arg)
super().__init__(arg.pos, arg)
self.type = arg.type
self.result_ctype = arg.ctype()
self.exception_type_cname = exception_type_cname
Expand All @@ -14234,24 +14315,9 @@ def __init__(self, arg, exception_type_cname, exception_message,

nogil_check = None # this node only guards an operation that would fail already

def analyse_types(self, env):
return self

def may_be_none(self):
return False

def is_simple(self):
return self.arg.is_simple()

def result_in_temp(self):
return self.arg.result_in_temp()

def nonlocally_immutable(self):
return self.arg.nonlocally_immutable()

def calculate_result_code(self):
return self.arg.result()

def condition(self):
if self.type.is_pyobject:
return self.arg.py_result()
Expand Down Expand Up @@ -14302,12 +14368,6 @@ def put_nonecheck(self, code):
def generate_result_code(self, code):
self.put_nonecheck(code)

def generate_post_assignment_code(self, code):
self.arg.generate_post_assignment_code(code)

def free_temps(self, code):
self.arg.free_temps(code)


class CoerceToPyTypeNode(CoercionNode):
# This node is used to convert a C data type
Expand Down
27 changes: 4 additions & 23 deletions Cython/Compiler/Optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -1797,12 +1797,10 @@ def _handle_simple_function_sorted(self, node, pos_args):
arg = pos_args[0]
if isinstance(arg, ExprNodes.ComprehensionNode) and arg.type is Builtin.list_type:
list_node = arg
loop_node = list_node.loop

elif isinstance(arg, ExprNodes.GeneratorExpressionNode):
gen_expr_node = arg
loop_node = gen_expr_node.loop
yield_statements = _find_yield_statements(loop_node)
yield_statements = _find_yield_statements(gen_expr_node.loop)
if not yield_statements:
return node

Expand All @@ -1820,37 +1818,20 @@ def _handle_simple_function_sorted(self, node, pos_args):
elif arg.is_sequence_constructor:
# sorted([a, b, c]) or sorted((a, b, c)). The result is always a list,
# so starting off with a fresh one is more efficient.
list_node = loop_node = arg.as_list()
list_node = arg.as_list()

else:
# Interestingly, PySequence_List works on a lot of non-sequence
# things as well.
list_node = loop_node = ExprNodes.PythonCapiCallNode(
list_node = ExprNodes.PythonCapiCallNode(
node.pos,
"__Pyx_PySequence_ListKeepNew"
if arg.is_temp and arg.type in (PyrexTypes.py_object_type, Builtin.list_type)
else "PySequence_List",
self.PySequence_List_func_type,
args=pos_args, is_temp=True)

result_node = UtilNodes.ResultRefNode(
pos=loop_node.pos, type=Builtin.list_type, may_hold_none=False)
list_assign_node = Nodes.SingleAssignmentNode(
node.pos, lhs=result_node, rhs=list_node, first=True)

sort_method = ExprNodes.AttributeNode(
node.pos, obj=result_node, attribute=EncodedString('sort'),
# entry ? type ?
needs_none_check=False)
sort_node = Nodes.ExprStatNode(
node.pos, expr=ExprNodes.SimpleCallNode(
node.pos, function=sort_method, args=[]))

sort_node.analyse_declarations(self.current_env())

return UtilNodes.TempResultFromStatNode(
result_node,
Nodes.StatListNode(node.pos, stats=[list_assign_node, sort_node]))
return ExprNodes.SortedListNode(node.pos, list_node)

def __handle_simple_function_sum(self, node, pos_args):
"""Transform sum(genexpr) into an equivalent inlined aggregation loop.
Expand Down
54 changes: 54 additions & 0 deletions tests/run/builtin_sorted.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,17 @@ def sorted_genexp():
return sorted(i*i for i in range(10,0,-1))


@cython.test_fail_if_path_exists("//YieldExprNode",
"//NoneCheckNode")
@cython.test_assert_path_exists("//ComprehensionNode")
def sorted_listcomp():
"""
>>> sorted_listcomp()
[1, 4, 9, 16, 25, 36, 49, 64, 81, 100]
"""
return sorted([i*i for i in range(10,0,-1)])


@cython.test_fail_if_path_exists("//SimpleCallNode//SimpleCallNode")
@cython.test_assert_path_exists("//SimpleCallNode/NameNode[@name = 'range']")
def sorted_list_of_range():
Expand All @@ -92,3 +103,46 @@ def sorted_tuple_literal():
[1, 1, 2, 2, 3, 3]
"""
return sorted((1, 3, 2) * 2)


@cython.test_fail_if_path_exists("//SimpleCallNode")
def sorted_in_loop(L: list, repeat: cython.int, raise_at: cython.int = -1):
# See https://github.com/cython/cython/issues/6496
"""
>>> L = [3, 1, 2]
>>> sorted_in_loop(L, 3)
OK: [1, 2, 3] [1, 2, 3]
OK: [1, 2, 3] [1, 2, 3]
OK: [1, 2, 3] [1, 2, 3]
[1, 2, 3]
>>> L
[3, 1, 2]
>>> L = [3, 1, 2]
>>> sorted_in_loop(L, 1)
OK: [1, 2, 3] [1, 2, 3]
[1, 2, 3]
>>> L
[3, 1, 2]
>>> L = [3, 1, 2]
>>> sorted_in_loop(L, 5, raise_at=2)
OK: [1, 2, 3] [1, 2, 3]
OK: [1, 2, 3] [1, 2, 3]
EX: [1, 2, 3] [1, 2, 3]
OK: [1, 2, 3] [1, 2, 3]
OK: [1, 2, 3] [1, 2, 3]
[1, 2, 3]
>>> L
[3, 1, 2]
"""
for i in range(repeat):
try:
if i == raise_at:
raise ValueError
L = sorted(L)
print(f"OK: {sorted(L)} {L}")
except ValueError:
L = sorted(L)
print(f"EX: {sorted(L)} {L}")
return L

0 comments on commit 10c27c1

Please sign in to comment.