Skip to content

Commit

Permalink
Attempt compile time encoding/decoding when using str.encode() on lit…
Browse files Browse the repository at this point in the history
…erals.
  • Loading branch information
scoder committed Aug 30, 2024
1 parent 3469324 commit f83932e
Showing 1 changed file with 24 additions and 15 deletions.
39 changes: 24 additions & 15 deletions Cython/Compiler/Optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -3868,29 +3868,29 @@ def _handle_simple_method_unicode_encode(self, node, function, args, is_unbound_

string_node = args[0]

if len(args) == 1:
null_node = ExprNodes.NullNode(node.pos)
return self._substitute_method_call(
node, function, "PyUnicode_AsEncodedString",
self.PyUnicode_AsEncodedString_func_type,
'encode', is_unbound_method, [string_node, null_node, null_node])

parameters = self._unpack_encoding_and_error_mode(node.pos, args)
if parameters is None:
return node
encoding, encoding_node, error_handling, error_handling_node = parameters

if encoding and isinstance(string_node, ExprNodes.UnicodeNode):
if string_node.has_constant_result():
# constant, so try to do the encoding at compile time
try:
value = string_node.value.encode(encoding, error_handling)
value = string_node.constant_result.encode(encoding, error_handling)
except:
# well, looks like we can't
pass
else:
value = bytes_literal(value, encoding)
value = bytes_literal(value, encoding or 'UTF-8')
return ExprNodes.BytesNode(string_node.pos, value=value, type=Builtin.bytes_type)

if len(args) == 1:
null_node = ExprNodes.NullNode(node.pos)
return self._substitute_method_call(
node, function, "PyUnicode_AsEncodedString",
self.PyUnicode_AsEncodedString_func_type,
'encode', is_unbound_method, [string_node, null_node, null_node])

if encoding and error_handling == 'strict':
# try to find a specific encoder function
codec_name = self._find_special_codec_name(encoding)
Expand Down Expand Up @@ -3944,6 +3944,20 @@ def _handle_simple_method_bytes_decode(self, node, function, args, is_unbound_me
self._error_wrong_arg_count('bytes.decode', node, args, '1-3')
return node

# Try to extract encoding parameters and attempt constant decode.
parameters = self._unpack_encoding_and_error_mode(node.pos, args)
if parameters is None:
return node
encoding, encoding_node, error_handling, error_handling_node = parameters

if args[0].has_constant_result():
try:
constant_result = args[0].constant_result.decode(encoding, error_handling)
except (AttributeError, ValueError, UnicodeDecodeError):
pass
else:
return UnicodeNode(args[0].pos, value=encoded_string(constant_result, encoding))

# normalise input nodes
string_node = args[0]
start = stop = None
Expand Down Expand Up @@ -3971,11 +3985,6 @@ def _handle_simple_method_bytes_decode(self, node, function, args, is_unbound_me
# nothing to optimise here
return node

parameters = self._unpack_encoding_and_error_mode(node.pos, args)
if parameters is None:
return node
encoding, encoding_node, error_handling, error_handling_node = parameters

if not start:
start = ExprNodes.IntNode(node.pos, value='0', constant_result=0)
elif not start.type.is_int:
Expand Down

0 comments on commit f83932e

Please sign in to comment.