Skip to content

Commit

Permalink
Change dataclasses Cython code generation to use the PyxCodeWriter fo…
Browse files Browse the repository at this point in the history
…r intentation.
  • Loading branch information
scoder committed Aug 25, 2024
1 parent 1945ffc commit 1413f7d
Showing 1 changed file with 85 additions and 77 deletions.
162 changes: 85 additions & 77 deletions Cython/Compiler/Dataclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,9 +112,8 @@ def __init__(self, writer=None, placeholders=None, extra_stats=None):
def add_code_line(self, code_line):
self.writer.putln(code_line)

def add_code_lines(self, code_lines):
for line in code_lines:
self.writer.putln(line)
def add_code_chunk(self, code_chunk):
self.writer.put_chunk(code_chunk)

def reset(self):
# don't attempt to reset placeholders - it really doesn't matter if
Expand All @@ -124,8 +123,14 @@ def reset(self):
def empty(self):
return self.writer.empty()

def indenter(self):
return self.writer.indenter()
def indent(self):
self.writer.indent()

def dedent(self):
self.writer.dedent()

def indenter(self, block_opener_line):
return self.writer.indenter(block_opener_line)

def new_placeholder(self, field_names, value):
name = self._new_placeholder_name(field_names)
Expand All @@ -139,7 +144,7 @@ def add_extra_statements(self, statements):

def _new_placeholder_name(self, field_names):
while True:
name = "DATACLASS_PLACEHOLDER_%d" % self._placeholder_count
name = f"DATACLASS_PLACEHOLDER_{self._placeholder_count:d}"
if (name not in self.placeholders
and name not in field_names):
# make sure name isn't already used and doesn't
Expand Down Expand Up @@ -399,6 +404,7 @@ def generate_init_code(code, init, node, fields, kw_only):

function_start_point = code.insertion_point()
code = code.insertion_point()
code.indent()

# create a temp to get _HAS_DEFAULT_FACTORY
dataclass_module = make_dataclasses_module_callnode(node.pos)
Expand All @@ -414,7 +420,7 @@ def generate_init_code(code, init, node, fields, kw_only):
for name, field in fields.items():
entry = node.scope.lookup(name)
if entry.annotation:
annotation = ": %s" % entry.annotation.string.value
annotation = f": {entry.annotation.string.value}"
else:
annotation = ""
assignment = ''
Expand All @@ -425,47 +431,47 @@ def generate_init_code(code, init, node, fields, kw_only):
ph_name = default_factory_placeholder
else:
ph_name = code.new_placeholder(fields, field.default) # 'default' should be a node
assignment = " = %s" % ph_name
assignment = f" = {ph_name}"
elif seen_default and not kw_only and field.init.value:
error(entry.pos, ("non-default argument '%s' follows default argument "
"in dataclass __init__") % name)
code.reset()
return

if field.init.value:
args.append("%s%s%s" % (name, annotation, assignment))
args.append(f"{name}{annotation}{assignment}")

if field.is_initvar:
continue
elif field.default_factory is MISSING:
if field.init.value:
code.add_code_line(" %s.%s = %s" % (selfname, name, name))
code.add_code_line(f"{selfname}.{name} = {name}")
elif assignment:
# not an argument to the function, but is still initialized
code.add_code_line(" %s.%s%s" % (selfname, name, assignment))
code.add_code_line(f"{selfname}.{name}{assignment}")
else:
ph_name = code.new_placeholder(fields, field.default_factory)
if field.init.value:
# close to:
# def __init__(self, name=_PLACEHOLDER_VALUE):
# self.name = name_default_factory() if name is _PLACEHOLDER_VALUE else name
code.add_code_line(" %s.%s = %s() if %s is %s else %s" % (
selfname, name, ph_name, name, default_factory_placeholder, name))
code.add_code_line(
f"{selfname}.{name} = {ph_name}() if {name} is {default_factory_placeholder} else {name}"
)
else:
# still need to use the default factory to initialize
code.add_code_line(" %s.%s = %s()" % (
selfname, name, ph_name))
code.add_code_line(f"{selfname}.{name} = {ph_name}()")

if node.scope.lookup("__post_init__"):
post_init_vars = ", ".join(name for name, field in fields.items()
if field.is_initvar)
code.add_code_line(" %s.__post_init__(%s)" % (selfname, post_init_vars))
code.add_code_line(f"{selfname}.__post_init__({post_init_vars})")

if code.empty():
code.add_code_line(" pass")
code.add_code_line("pass")

args = ", ".join(args)
function_start_point.add_code_line("def __init__(%s):" % args)
function_start_point.add_code_line(f"def __init__({args}):")


def generate_match_args(code, match_args, node, fields, global_kw_only):
Expand Down Expand Up @@ -521,25 +527,35 @@ def generate_repr_code(code, repr, node, fields):
break

if needs_recursive_guard:
code.add_code_line("__pyx_recursive_repr_guard = __import__('threading').local()")
code.add_code_line("__pyx_recursive_repr_guard.running = set()")
code.add_code_line("def __repr__(self):")
if needs_recursive_guard:
code.add_code_line(" key = id(self)")
code.add_code_line(" guard_set = self.__pyx_recursive_repr_guard.running")
code.add_code_line(" if key in guard_set: return '...'")
code.add_code_line(" guard_set.add(key)")
code.add_code_line(" try:")
strs = ["%s={self.%s!r}" % (name, name)
for name, field in fields.items()
if field.repr.value and not field.is_initvar]
format_string = ", ".join(strs)

code.add_code_line(' name = getattr(type(self), "__qualname__", type(self).__name__)')
code.add_code_line(" return f'{name}(%s)'" % format_string)
if needs_recursive_guard:
code.add_code_line(" finally:")
code.add_code_line(" guard_set.remove(key)")
code.add_code_chunk("""
__pyx_recursive_repr_guard = __import__('threading').local()
__pyx_recursive_repr_guard.running = set()
""")

with code.indenter("def __repr__(self):"):
if needs_recursive_guard:
code.add_code_chunk("""
key = id(self)
guard_set = self.__pyx_recursive_repr_guard.running
if key in guard_set: return '...'
guard_set.add(key)
try:
""")
code.indent()

strs = ["%s={self.%s!r}" % (name, name)
for name, field in fields.items()
if field.repr.value and not field.is_initvar]
format_string = ", ".join(strs)

code.add_code_chunk(f'''
name = getattr(type(self), "__qualname__", None) or type(self).__name__
return f'{{name}}({format_string})'
''')
if needs_recursive_guard:
code.dedent()
with code.indenter("finally:"):
code.add_code_line("guard_set.remove(key)")


def generate_cmp_code(code, op, funcname, node, fields):
Expand All @@ -548,39 +564,33 @@ def generate_cmp_code(code, op, funcname, node, fields):

names = [name for name, field in fields.items() if (field.compare.value and not field.is_initvar)]

code.add_code_lines([
"def %s(self, other):" % funcname,
" if other.__class__ is not self.__class__:"
" return NotImplemented",
with code.indenter(f"def {funcname}(self, other):"):
code.add_code_chunk(f"""
if other.__class__ is not self.__class__: return NotImplemented
cdef {node.class_name} other_cast
other_cast = <{node.class_name}>other
""")

# The Python implementation of dataclasses.py does a tuple comparison
# (roughly):
# return self._attributes_to_tuple() {op} other._attributes_to_tuple()
#
# For the Cython implementation a tuple comparison isn't an option because
# not all attributes can be converted to Python objects and stored in a tuple
#
" cdef %s other_cast" % node.class_name,
" other_cast = <%s>other" % node.class_name,
])
# TODO - better diagnostics of whether the types support comparison before
# generating the code. Plus, do we want to convert C structs to dicts and
# compare them that way (I think not, but it might be in demand)?
checks = []
op_without_equals = op.replace('=', '')

# The Python implementation of dataclasses.py does a tuple comparison
# (roughly):
# return self._attributes_to_tuple() {op} other._attributes_to_tuple()
#
# For the Cython implementation a tuple comparison isn't an option because
# not all attributes can be converted to Python objects and stored in a tuple
#
# TODO - better diagnostics of whether the types support comparison before
# generating the code. Plus, do we want to convert C structs to dicts and
# compare them that way (I think not, but it might be in demand)?
checks = []
op_without_equals = op.replace('=', '')

for name in names:
if op != '==':
# tuple comparison rules - early elements take precedence
code.add_code_line(" if self.%s %s other_cast.%s: return True" % (
name, op_without_equals, name))
code.add_code_line(" if self.%s != other_cast.%s: return False" % (
name, name))
if "=" in op:
code.add_code_line(" return True") # "() == ()" is True
else:
code.add_code_line(" return False")
for name in names:
if op != '==':
# tuple comparison rules - early elements take precedence
code.add_code_line(f"if self.{name} {op_without_equals} other_cast.{name}: return True")
code.add_code_line(f"if self.{name} != other_cast.{name}: return False")
code.add_code_line(f"return {'True' if '=' in op else 'False'}") # "() == ()" is True


def generate_eq_code(code, eq, node, fields):
Expand Down Expand Up @@ -672,10 +682,8 @@ def generate_hash_code(code, unsafe_hash, eq, frozen, node, fields):
hash_tuple_items += "," # ensure that one arg form is a tuple

# if we're here we want to generate a hash
code.add_code_lines([
"def __hash__(self):",
" return hash((%s))" % hash_tuple_items,
])
with code.indenter("def __hash__(self):"):
code.add_code_line(f"return hash(({hash_tuple_items}))")


def get_field_type(pos, entry):
Expand Down Expand Up @@ -837,11 +845,11 @@ def _set_up_dataclass_fields(node, fields, dataclass_module):
key=ExprNodes.IdentifierStringNode(node.pos, value=EncodedString(name)),
value=dc_field_call))
dc_fields_namevalue_assignments.append(
dedent("""\
__dataclass_fields__[{0!r}].name = {0!r}
__dataclass_fields__[{0!r}].type = {1}
__dataclass_fields__[{0!r}]._field_type = {2}
""").format(name, type_placeholder_name, field_type_placeholder_name))
dedent(f"""\
__dataclass_fields__[{name!r}].name = {name!r}
__dataclass_fields__[{name!r}].type = {type_placeholder_name}
__dataclass_fields__[{name!r}]._field_type = {field_type_placeholder_name}
"""))

dataclass_fields_assignment = \
Nodes.SingleAssignmentNode(node.pos,
Expand Down

0 comments on commit 1413f7d

Please sign in to comment.