Skip to content

Commit

Permalink
[lang] Feat: using magic number 32 to limit loop unrolling (#8169)
Browse files Browse the repository at this point in the history
Issue: #8151

### Brief Summary

<!--
copilot:summary
-->
### <samp>🤖 Generated by Copilot at 9662871</samp>

Add a new option `ti_unroll_limit` to limit the unrolling of static for
loops in `ASTTransformer`. This can improve performance and memory usage
in some cases.

### Walkthrough

<!--
copilot:walkthrough
-->
### <samp>🤖 Generated by Copilot at 9662871</samp>

* Introduce a new class attribute `ti_unroll_limit` to store the maximum
number of iterations that can be unrolled in a static for loop
([link](https://github.com/taichi-dev/taichi/pull/8169/files?diff=unified&w=0#diff-3e22417ffade4af0564893b98dc5101d714b8ba6fd4423ab5bc5129e360fee8fR1130))
* Modify the `build_static_for` method to check the loop iterator
against the `ti_unroll_limit` and raise a syntax error if it exceeds the
limit
([link](https://github.com/taichi-dev/taichi/pull/8169/files?diff=unified&w=0#diff-3e22417ffade4af0564893b98dc5101d714b8ba6fd4423ab5bc5129e360fee8fL1139-R1144),
[link](https://github.com/taichi-dev/taichi/pull/8169/files?diff=unified&w=0#diff-3e22417ffade4af0564893b98dc5101d714b8ba6fd4423ab5bc5129e360fee8fR1156-R1157))

---------

Co-authored-by: Lin Jiang <linjiang@taichi.graphics>
  • Loading branch information
lgyStoic and lin-hitonami authored Jul 12, 2023
1 parent b8ca066 commit 1eed2b9
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 0 deletions.
36 changes: 36 additions & 0 deletions python/taichi/lang/ast/ast_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1140,6 +1140,7 @@ def get_for_loop_targets(node):

@staticmethod
def build_static_for(ctx, node, is_grouped):
ti_unroll_limit = impl.get_runtime().unrolling_limit
if is_grouped:
assert len(node.iter.args[0].args) == 1
ndrange_arg = build_stmt(ctx, node.iter.args[0].args[0])
Expand All @@ -1149,7 +1150,24 @@ def build_static_for(ctx, node, is_grouped):
if len(targets) != 1:
raise TaichiSyntaxError(f"Group for should have 1 loop target, found {len(targets)}")
target = targets[0]
iter_time = 0
alert_already = False

for value in impl.grouped(ndrange_arg):
iter_time += 1
if not alert_already and ti_unroll_limit and iter_time > ti_unroll_limit:
alert_already = True
warnings.warn_explicit(
f"""You are unrolling more than
{ti_unroll_limit} iterations, so the compile time may be extremely long.
You can use a non-static for loop if you want to decrease the compile time.
You can disable this warning by setting ti.init(unrolling_limit=0).""",
SyntaxWarning,
ctx.file,
node.lineno + ctx.lineno_offset,
module="taichi",
)

with ctx.variable_scope_guard():
ctx.create_variable(target, value)
build_stmts(ctx, node.body)
Expand All @@ -1161,9 +1179,27 @@ def build_static_for(ctx, node, is_grouped):
else:
build_stmt(ctx, node.iter)
targets = ASTTransformer.get_for_loop_targets(node)

iter_time = 0
alert_already = False
for target_values in node.iter.ptr:
if not isinstance(target_values, collections.abc.Sequence) or len(targets) == 1:
target_values = [target_values]

iter_time += 1
if not alert_already and ti_unroll_limit and iter_time > ti_unroll_limit:
alert_already = True
warnings.warn_explicit(
f"""You are unrolling more than
{ti_unroll_limit} iterations, so the compile time may be extremely long.
You can use a non-static for loop if you want to decrease the compile time.
You can disable this warning by setting ti.init(unrolling_limit=0).""",
SyntaxWarning,
ctx.file,
node.lineno + ctx.lineno_offset,
module="taichi",
)

with ctx.variable_scope_guard():
for target, target_value in zip(targets, target_values):
ctx.create_variable(target, target_value)
Expand Down
3 changes: 3 additions & 0 deletions python/taichi/lang/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,7 @@ def __init__(self):
self.gdb_trigger = False
self.short_circuit_operators = True
self.print_full_traceback = False
self.unrolling_limit = 32


def prepare_sandbox():
Expand Down Expand Up @@ -416,6 +417,7 @@ def init(
env_spec.add("gdb_trigger")
env_spec.add("short_circuit_operators")
env_spec.add("print_full_traceback")
env_spec.add("unrolling_limit")

# compiler configurations (ti.cfg):
for key in dir(cfg):
Expand All @@ -436,6 +438,7 @@ def init(
_ti_core.set_core_trigger_gdb_when_crash(spec_cfg.gdb_trigger)
impl.get_runtime().short_circuit_operators = spec_cfg.short_circuit_operators
impl.get_runtime().print_full_traceback = spec_cfg.print_full_traceback
impl.get_runtime().unrolling_limit = spec_cfg.unrolling_limit
_logging.set_logging_level(spec_cfg.log_level.lower())

# select arch (backend):
Expand Down

0 comments on commit 1eed2b9

Please sign in to comment.