Skip to content

Commit

Permalink
Merge branch 'refactor/no_folding' into refactor/folding_alt2
Browse files Browse the repository at this point in the history
  • Loading branch information
tserg committed Dec 28, 2023
2 parents 4df618f + fe925d6 commit 8ddd2ca
Show file tree
Hide file tree
Showing 10 changed files with 82 additions and 44 deletions.
48 changes: 47 additions & 1 deletion vyper/ast/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -777,7 +777,12 @@ class Constant(ExprNode):

def __init__(self, parent: Optional["VyperNode"] = None, **kwargs: dict):
super().__init__(parent, **kwargs)
self._metadata["folded_value"] = self

def get_folded_value_throwing(self) -> "VyperNode":
return self

def get_folded_value_maybe(self) -> Optional["VyperNode"]:
return self


class Num(Constant):
Expand Down Expand Up @@ -911,6 +916,18 @@ def s(self):
return self.value


def check_literal(node: VyperNode) -> bool:
"""
Check if the given node is a literal value.
"""
if isinstance(node, Constant):
return True
elif isinstance(node, (Tuple, List)):
return all(check_literal(item) for item in node.elements)

return False


class List(ExprNode):
__slots__ = ("elements",)
_is_prefoldable = True
Expand All @@ -920,15 +937,44 @@ def fold(self) -> Optional[ExprNode]:
elements = [e.get_folded_value_throwing() for e in self.elements]
return type(self).from_node(self, elements=elements)

def get_folded_value_throwing(self) -> "VyperNode":
if check_literal(self):
return self

return super().get_folded_value_throwing()

def get_folded_value_maybe(self) -> Optional["VyperNode"]:
if check_literal(self):
return self

return super().get_folded_value_maybe()


class Tuple(ExprNode):
__slots__ = ("elements",)
_is_prefoldable = True
_translated_fields = {"elts": "elements"}

def validate(self):
if not self.elements:
raise InvalidLiteral("Cannot have an empty tuple", self)

def fold(self) -> Optional[ExprNode]:
elements = [e.get_folded_value_throwing() for e in self.elements]
return type(self).from_node(self, elements=elements)

def get_folded_value_throwing(self) -> "VyperNode":
if check_literal(self):
return self

return super().get_folded_value_throwing()

def get_folded_value_maybe(self) -> Optional["VyperNode"]:
if check_literal(self):
return self

return super().get_folded_value_maybe()


class NameConstant(Constant):
__slots__ = ()
Expand Down
2 changes: 1 addition & 1 deletion vyper/builtins/_signatures.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def process_arg(arg, expected_arg_type, context):

def process_kwarg(kwarg_node, kwarg_settings, expected_kwarg_type, context):
if kwarg_settings.require_literal:
return kwarg_node.value
return kwarg_node.get_folded_value_throwing().value

return process_arg(kwarg_node, expected_kwarg_type, context)

Expand Down
11 changes: 11 additions & 0 deletions vyper/codegen/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,9 @@ class Expr:
# TODO: Once other refactors are made reevaluate all inline imports

def __init__(self, node, context):
if isinstance(node, vy_ast.VyperNode):
node = node._metadata.get("folded_value", node)

self.expr = node
self.context = context

Expand Down Expand Up @@ -185,6 +188,14 @@ def parse_Name(self):
# TODO: use self.expr._expr_info
elif self.expr.id in self.context.globals:
varinfo = self.context.globals[self.expr.id]
if varinfo.modifiability == Modifiability.ALWAYS_CONSTANT:
# non-struct constants should have been dispatched via the `Expr` ctor
# using the folded value metadata
assert isinstance(varinfo.typ, StructT)
value_node = varinfo.decl_node.value
value_node = value_node._metadata.get("folded_value", value_node)
return Expr.parse_value_expr(value_node, self.context)

assert varinfo.modifiability == Modifiability.IMMUTABLE, "not an immutable!"

ofst = varinfo.position.offset
Expand Down
29 changes: 11 additions & 18 deletions vyper/codegen/function_definitions/external_function.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from vyper import ast as vy_ast
from vyper.codegen.abi_encoder import abi_encoding_matches_vyper
from vyper.codegen.context import Context, VariableRecord
from vyper.codegen.core import get_element_ptr, getpos, make_setter, needs_clamp
Expand Down Expand Up @@ -51,7 +50,7 @@ def _register_function_args(func_t: ContractFunctionT, context: Context) -> list


def _generate_kwarg_handlers(
func_t: ContractFunctionT, context: Context, code: vy_ast.FunctionDef
func_t: ContractFunctionT, context: Context
) -> dict[str, tuple[int, IRnode]]:
# generate kwarg handlers.
# since they might come in thru calldata or be default,
Expand All @@ -63,7 +62,7 @@ def _generate_kwarg_handlers(
# write default args to memory
# goto external_function_common_ir

def handler_for(calldata_kwargs, original_default_kwargs, folded_default_kwargs):
def handler_for(calldata_kwargs, default_kwargs):
calldata_args = func_t.positional_args + calldata_kwargs
# create a fake type so that get_element_ptr works
calldata_args_t = TupleT(list(arg.typ for arg in calldata_args))
Expand All @@ -82,7 +81,7 @@ def handler_for(calldata_kwargs, original_default_kwargs, folded_default_kwargs)
calldata_min_size = args_abi_t.min_size() + 4

# TODO optimize make_setter by using
# TupleT(list(arg.typ for arg in calldata_kwargs + folded_default_kwargs))
# TupleT(list(arg.typ for arg in calldata_kwargs + default_kwargs))
# (must ensure memory area is contiguous)

for i, arg_meta in enumerate(calldata_kwargs):
Expand All @@ -98,15 +97,15 @@ def handler_for(calldata_kwargs, original_default_kwargs, folded_default_kwargs)
copy_arg.source_pos = getpos(arg_meta.ast_source)
ret.append(copy_arg)

for x, y in zip(original_default_kwargs, folded_default_kwargs):
for x in default_kwargs:
dst = context.lookup_var(x.name).pos
lhs = IRnode(dst, location=MEMORY, typ=x.typ)
lhs.source_pos = getpos(y)
kw_ast_val = y
lhs.source_pos = getpos(x.ast_source)
kw_ast_val = func_t.default_values[x.name] # e.g. `3` in x: int = 3
rhs = Expr(kw_ast_val, context).ir_node

copy_arg = make_setter(lhs, rhs)
copy_arg.source_pos = getpos(y)
copy_arg.source_pos = getpos(x.ast_source)
ret.append(copy_arg)

ret.append(["goto", func_t._ir_info.external_function_base_entry_label])
Expand All @@ -117,25 +116,19 @@ def handler_for(calldata_kwargs, original_default_kwargs, folded_default_kwargs)
ret = {}

keyword_args = func_t.keyword_args
folded_keyword_args = code.args.defaults

# allocate variable slots in memory
for arg in keyword_args:
context.new_variable(arg.name, arg.typ, is_mutable=False)

for i, _ in enumerate(keyword_args):
calldata_kwargs = keyword_args[:i]
# folded ast
original_default_kwargs = keyword_args[i:]
# unfolded ast
folded_default_kwargs = folded_keyword_args[i:]
default_kwargs = keyword_args[i:]

sig, calldata_min_size, ir_node = handler_for(
calldata_kwargs, original_default_kwargs, folded_default_kwargs
)
sig, calldata_min_size, ir_node = handler_for(calldata_kwargs, default_kwargs)
ret[sig] = calldata_min_size, ir_node

sig, calldata_min_size, ir_node = handler_for(keyword_args, [], [])
sig, calldata_min_size, ir_node = handler_for(keyword_args, [])

ret[sig] = calldata_min_size, ir_node

Expand All @@ -160,7 +153,7 @@ def generate_ir_for_external_function(code, func_t, context):
handle_base_args = _register_function_args(func_t, context)

# generate handlers for kwargs and register the variable records
kwarg_handlers = _generate_kwarg_handlers(func_t, context, code)
kwarg_handlers = _generate_kwarg_handlers(func_t, context)

body = ["seq"]
# once optional args have been handled,
Expand Down
2 changes: 1 addition & 1 deletion vyper/compiler/phases.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,7 @@ def generate_folded_ast(
symbol_tables = set_data_positions(vyper_module, storage_layout_overrides)

vyper_module_folded = copy.deepcopy(vyper_module)
vy_ast.folding.fold(vyper_module_folded)
# vy_ast.folding.fold(vyper_module_folded)

return vyper_module_folded, symbol_tables

Expand Down
2 changes: 1 addition & 1 deletion vyper/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def __str__(self):
if isinstance(node, vy_ast.VyperNode):
module_node = node.get_ancestor(vy_ast.Module)

if module_node.get("path") not in (None, "<unknown>"):
if module_node and module_node.get("path") not in (None, "<unknown>"):
node_msg = f'{node_msg}contract "{module_node.path}:{node.lineno}", '

fn_node = node.get_ancestor(vy_ast.FunctionDef)
Expand Down
10 changes: 6 additions & 4 deletions vyper/semantics/analysis/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -545,13 +545,15 @@ def visit(self, node, typ):
# can happen.
super().visit(node, typ)

folded_value = node.get_folded_value_maybe()
if isinstance(folded_value, vy_ast.Constant):
validate_expected_type(folded_value, typ)

# annotate
node._metadata["type"] = typ

# validate and annotate folded value
folded_value = node._metadata.get("folded_value")
if folded_value:
validate_expected_type(folded_value, typ)
folded_value._metadata["type"] = typ

def visit_Attribute(self, node: vy_ast.Attribute, typ: VyperType) -> None:
_validate_msg_data_attribute(node)

Expand Down
2 changes: 1 addition & 1 deletion vyper/semantics/analysis/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -495,7 +495,7 @@ def _parse_and_fold_ast(file: FileInput) -> vy_ast.VyperNode:
resolved_path=str(file.resolved_path),
)
vy_ast.validation.validate_literal_nodes(ret)
vy_ast.folding.fold(ret)
# vy_ast.folding.fold(ret)

return ret

Expand Down
4 changes: 2 additions & 2 deletions vyper/semantics/analysis/pre_typecheck.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from vyper import ast as vy_ast
from vyper.exceptions import UnfoldableNode, VyperException
from vyper.exceptions import UnfoldableNode


def get_constants(node: vy_ast.Module) -> dict:
Expand Down Expand Up @@ -66,7 +66,7 @@ def prefold(node: vy_ast.VyperNode, constants: dict[str, vy_ast.VyperNode]):
try:
node._metadata["folded_value"] = call_type.fold(node)
return
except (UnfoldableNode, VyperException):
except UnfoldableNode:
pass

if getattr(node, "_is_prefoldable", None):
Expand Down
16 changes: 1 addition & 15 deletions vyper/semantics/analysis/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -624,25 +624,11 @@ def validate_unique_method_ids(functions: List) -> None:
seen.add(method_id)


def _check_literal(node: vy_ast.VyperNode) -> bool:
"""
Check if the given node is a literal value.
"""
if isinstance(node, vy_ast.Constant):
return True
elif isinstance(node, (vy_ast.Tuple, vy_ast.List)):
return all(_check_literal(item) for item in node.elements)

if node.get_folded_value_maybe():
return True
return False


def check_modifiability(node: vy_ast.VyperNode, modifiability: Modifiability) -> bool:
"""
Check if the given node is not more modifiable than the given modifiability.
"""
if _check_literal(node):
if node.get_folded_value_maybe():
return True

if isinstance(node, (vy_ast.BinOp, vy_ast.Compare)):
Expand Down

0 comments on commit 8ddd2ca

Please sign in to comment.