Skip to content

Commit

Permalink
implement a lazy evaluation / thunk compiler primitive (#25387)
Browse files Browse the repository at this point in the history
# The Problem

I would like to be able to write module code for "built-in" Chapel
features. That's because some code is much easier to express in plain
Chapel than in AST manipulation. For instance, I would like to be able
to use `on` clauses and `forall` loops to implement language features,
rather than having to build lowered forms of them in the AST. It's also
significantly easier to change the Chapel module implementation than
having to modify the AST-building compiler code. For instance, consider
the following code that I would like to use to implement remote variable
declarations:

```Chapel
    var c: owned _remoteVarContainer(inType)?;
    on loc do c = new _remoteVarContainer(expr);
    return new _remoteVarWrapper(try! c : owned _remoteVarContainer(inType));
```

There are a lot of AST nodes here, and this is just one of the three
versions of this code (it handles the case where both the type and the
value of the remote variable are specified). Writing it as C++
`buildWhateverStatement` is error prone.

And yet, __I can't just turn that code into a Chapel function and call
it__:

```Chapel
proc myFn(loc: locale, type inType, value: inType) {
    var c: owned _remoteVarContainer(inType)?;
    on loc do c = new _remoteVarContainer(value);
    return new _remoteVarWrapper(try! c : owned _remoteVarContainer(inType));
}

var myWrapper = myFn(loc, inType, expr); // not the same!
```

The reason is that arguments to functions are evaluated before they are
given to function bodies. In my example, `expr` will get evaluated
before the call to `myFn`, and thus the computation will not occur on
`loc`. This is currently causing a problem for remote variables (my
second snippet matches my actual implementation). Static variables
suffer from this problem because the compiler-generated C++ code is
brittle and long.

# The Solution

What I'd like to be able to do is to pass in an "expression" into Chapel
code, to be evaluated by the function when needed, be it on a different
locale or conditionally. To this end, what I need is to be able to defer
computations, and that's what this PR adds:
[thunks](https://en.wikipedia.org/wiki/Thunk). Then, I can write code
like this:

```Chapel
proc myFn(loc: locale, type inType, in thunk: _thunkRecord) {
    var c: owned _remoteVarContainer(inType)?;
    on loc do c = new _remoteVarContainer(__primitive("force thunk", thunk));
    return new _remoteVarWrapper(try! c : owned _remoteVarContainer(inType));
}

var myWrapper = myFn(loc, inType, __primitive("create thunk", expr)); // same as the "inlined" form!
```

Another example (from a test file) is the following:

```Chapel
proc executeIfTrue(cond: bool, in thunk: _thunkRecord) {
  var temp: thunkToReturnType(thunk.type);
  if cond {
    temp = __primitive("force thunk", thunk);
  }
  return temp;
}

writeln(executeIfTrue(true, __primitive("create thunk", new C?(42)))); // prints "calling C.init", then "{42}"
writeln(executeIfTrue(false, __primitive("create thunk", new C?(42)))); // doesn't print, then prints "nil"
```

# Implementation
This works by doing pretty much what we do for iterators and their
records/classes: for each thunked expression (the `create thunk`
primitive), it creates a new builder function `chpl__thunkN` and,
eventually, a record `_tr_chpl__thunkN`. It re-uses the same logic for
capturing outer variables etc, so that the expression being captured can
refer to global or local variables, fields, etc. In detail, the process
is as follows:

1. __Pass 8 - Normalize__: lifts `create thunk` primitives into builder
functions, named `chpl__thunkN`. These functions accept all the captured
variables as formals. Eventually, these functions are marked to return
the thunk record; however, during normalize, the thunk record doesn't
yet exist. The thunk body -- the expression being deferred -- is copied
into these functions.
2. __Pass 12 - Resolve__: after resolving the types of the
`chpl__thunkN` function, the function resolution code creates the thunk
record (but doesn't yet populate its fields; this matches how iterator
records are handled). It also creates the `invoke` method, which is
called by `force thunk` primitives. The 'invoke' method is empty for the
time being (again, like the prototype methods for iterators) and is
populated once the thunk record's fields have been created. The `force
thunk` primitive is simply rewritten to the `invoke` method.
3. __Pass 20 - Lower Iterators__: at this time, the thunk record is
populated with the fields that correspond to the captured variables, and
the body of the `invoke` method is filled in with the original code that
the user provided; references to outer variables are replaced with field
references to the thunk record. The builder function is made to create a
new thunk record and populate its fields with the formals, preparing it
for invocation.

Much of the logic is shared with lowering iterators, and I have factored
out code where appropriate to ensure that little-to-no code duplication
is involved.

# Next steps
Since this PR is relatively large, I didn't want to also include into it
changes to remote or static varibles. However, I believe that the next
step is to switch the remote variable implementation to use thunks,
which will allow it to consistently execute the initialization
expression on the target locale
(#25298).

Reviewed by @e-kayrakli -- thanks!

# Testing
- [x] paratest
- [x] paratest gasnet
- [x] GPU test
  • Loading branch information
DanilaFe authored Jul 3, 2024
2 parents 355b758 + 48022ea commit a32cc42
Show file tree
Hide file tree
Showing 44 changed files with 893 additions and 103 deletions.
1 change: 1 addition & 0 deletions compiler/AST/AggregateType.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ AggregateType::AggregateType(AggregateTag initTag) :
builtReaderInit = false;
initializerResolved = false;
iteratorInfo = NULL;
thunkInvoke = NULL;
doc = NULL;

instantiatedFrom = NULL;
Expand Down
1 change: 1 addition & 0 deletions compiler/AST/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ set(SRCS
stmt.cpp
symbol.cpp
TemporaryConversionThunk.cpp
thunks.cpp
TransformLogicalShortCircuit.cpp
TryStmt.cpp
type.cpp
Expand Down
14 changes: 7 additions & 7 deletions compiler/AST/LoopExpr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -623,7 +623,7 @@ static void findOuterVars(LoopExpr* loopExpr,
}
}

static ArgSymbol* newOuterVarArg(Symbol* ovar) {
ArgSymbol* newOuterVarArg(Symbol* ovar) {
Type* argType = ovar->type;
if (argType == dtUnknown)
argType = dtAny;
Expand Down Expand Up @@ -715,7 +715,7 @@ static void adjustIndexDefPoints(FnSymbol* xifn, AList* indexDefs) {
forLoop->insertAtHead(expr->remove());
}*/

static void scopeResolveAndNormalize(FnSymbol* fn) {
void scopeResolveAndNormalizeGeneratedLoweringFn(FnSymbol* fn) {
TransformLogicalShortCircuit vis;
addToSymbolTable(fn);
fn->accept(&vis);
Expand Down Expand Up @@ -887,20 +887,20 @@ static CallExpr* buildLoopExprFunctions(LoopExpr* loopExpr) {
fn->insertAtHead(new DefExpr(fifn));
}

scopeResolveAndNormalize(fn);
scopeResolveAndNormalizeGeneratedLoweringFn(fn);
} else {
fn->defPoint->insertBefore(new DefExpr(sifn));
scopeResolveAndNormalize(sifn);
scopeResolveAndNormalizeGeneratedLoweringFn(sifn);

if (forall) {
fn->defPoint->insertBefore(new DefExpr(lifn));
scopeResolveAndNormalize(lifn);
scopeResolveAndNormalizeGeneratedLoweringFn(lifn);

fn->defPoint->insertBefore(new DefExpr(fifn));
scopeResolveAndNormalize(fifn);
scopeResolveAndNormalizeGeneratedLoweringFn(fifn);
}

scopeResolveAndNormalize(fn);
scopeResolveAndNormalizeGeneratedLoweringFn(fn);
}


Expand Down
1 change: 1 addition & 0 deletions compiler/AST/Makefile.share
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ AST_SRCS = \
primitive.cpp \
stmt.cpp \
symbol.cpp \
thunks.cpp \
TryStmt.cpp \
type.cpp \
UseStmt.cpp \
Expand Down
181 changes: 99 additions & 82 deletions compiler/AST/iterator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -401,13 +401,7 @@ static inline CallExpr* parentYieldExpr(SymExpr* se) {
}


//
// Now that we have localized yield symbols, the return symbol
// and the PRIM_RETURN CallExpr are not needed and would cause trouble.
// Returns the type yielded by the iterator. (fn->retType is not it.)
//
static Type*
removeRetSymbolAndUses(FnSymbol* fn) {
void removeRetSymbolAndUses(FnSymbol* fn) {
// follows getReturnSymbol()
CallExpr* ret = toCallExpr(fn->body->body.last());
INT_ASSERT(ret && ret->isPrimitive(PRIM_RETURN));
Expand All @@ -425,11 +419,6 @@ removeRetSymbolAndUses(FnSymbol* fn) {

// We cannot remove rsym's definition, because rsym
// may also be referenced in an autoDestroy call.

INT_ASSERT(fn->iteratorInfo != NULL);
Type* yieldedType = fn->iteratorInfo->yieldedType;

return yieldedType;
}


Expand Down Expand Up @@ -862,6 +851,39 @@ static void replaceLocalWithFieldTemp(SymExpr* se,
// E.g. 'yield localvar' is converted to ic.value = ic.FNN_localvar.
//

void replaceLocalUseOrDefWithFieldRef(SymExpr* se,
Symbol* classOrRecord,
std::vector<BaseAST*>& asts,
SymbolMap& local2field,
Vec<SymExpr*>& defSet,
Vec<SymExpr*>& useSet) {
if (useSet.set_in(se) || defSet.set_in(se)) {
// SymExpr is among those we are interested in: def or use of a live local.

// Get the corresponding field in the iterator class.
Symbol* field = local2field.get(se->symbol());

// Get the expression that sets or uses the symexpr.
CallExpr* call = toCallExpr(se->parentExpr);

if (call && call->isPrimitive(PRIM_ADDR_OF)) {

// Convert (addr of var) to (. _ic field).
// Note, GET_MEMBER is not valid on a ref field;
// in that event, GET_MEMBER_VALUE returns the ref.
if (field->isRef())
call->primitive = primitives[PRIM_GET_MEMBER_VALUE];
else
call->primitive = primitives[PRIM_GET_MEMBER];

call->insertAtHead(classOrRecord);
se->setSymbol(field);
} else {
replaceLocalWithFieldTemp(se, classOrRecord, field,
defSet.set_in(se), useSet.set_in(se), asts);
}
}
}

// In the body of an iterator function, replace references to local variables
// with references to fields in the iterator class instead.
Expand Down Expand Up @@ -915,31 +937,8 @@ replaceLocalsWithFields(FnSymbol* fn, // the iterator function
}
}
}
} else if (useSet.set_in(se) || defSet.set_in(se)) {
// SymExpr is among those we are interested in: def or use of a live local.

// Get the corresponding field in the iterator class.
Symbol* field = local2field.get(se->symbol());

// Get the expression that sets or uses the symexpr.
CallExpr* call = toCallExpr(se->parentExpr);

if (call && call->isPrimitive(PRIM_ADDR_OF)) {

// Convert (addr of var) to (. _ic field).
// Note, GET_MEMBER is not valid on a ref field;
// in that event, GET_MEMBER_VALUE returns the ref.
if (field->isRef())
call->primitive = primitives[PRIM_GET_MEMBER_VALUE];
else
call->primitive = primitives[PRIM_GET_MEMBER];

call->insertAtHead(ic);
se->setSymbol(field);
} else {
replaceLocalWithFieldTemp(se, ic, field,
defSet.set_in(se), useSet.set_in(se), asts);
}
} else {
replaceLocalUseOrDefWithFieldRef(se, ic, asts, local2field, defSet, useSet);
}
}
}
Expand Down Expand Up @@ -1808,6 +1807,50 @@ addAllLocalVariables(Vec<Symbol*>& syms, std::vector<BaseAST*>& asts) {
}
}

void insertReturn(FnSymbol* fn, Symbol* toReturn) {
if (fn->hasFlag(FLAG_FN_RETARG)) {
ArgSymbol* retArg = NULL;
for_formals(formal, fn) {
if (formal->hasFlag(FLAG_RETARG))
retArg = formal;
}
fn->insertAtTail(new CallExpr(PRIM_ASSIGN, retArg, toReturn));
fn->insertAtTail(new CallExpr(PRIM_RETURN, gVoid));
} else {
fn->insertAtTail(new CallExpr(PRIM_RETURN, toReturn));
}
}

void initializeRecordFieldWithArgLocals(FnSymbol* fn,
Symbol* rec,
Vec<Symbol*>& locals,
SymbolMap& local2field) {
// For each live argument
forv_Vec(Symbol, local, locals) {
if (!toArgSymbol(local))
continue;

// Get the corresponding field in the iterator class
Symbol* field = local2field.get(local);
Symbol* value = local;

if (local->type == field->type->refType) {
// If a ref var, load the local in to a temp and
// then set the value of the corresponding field.
Symbol* tmp = newTemp(field->type);

fn->insertAtTail(new DefExpr(tmp));

fn->insertAtTail(new CallExpr(PRIM_MOVE,
tmp,
new CallExpr(PRIM_DEREF, local)));

value = tmp;
}

fn->insertAtTail(new CallExpr(PRIM_SET_MEMBER, rec, field, value));
}
}

// Preceding calls to the various build...() functions have copied out
// interesting parts of the iterator function.
Expand Down Expand Up @@ -1853,44 +1896,11 @@ rebuildIterator(IteratorInfo* ii,
fn->insertAtTail(new CallExpr(PRIM_ZERO_VARIABLE, new SymExpr(iterator)));
}

// For each live argument
forv_Vec(Symbol, local, locals) {
if (!toArgSymbol(local))
continue;

// Get the corresponding field in the iterator class
Symbol* field = local2field.get(local);
Symbol* value = local;

if (local->type == field->type->refType) {
// If a ref var, load the local in to a temp and
// then set the value of the corresponding field.
Symbol* tmp = newTemp(field->type);

fn->insertAtTail(new DefExpr(tmp));

fn->insertAtTail(new CallExpr(PRIM_MOVE,
tmp,
new CallExpr(PRIM_DEREF, local)));

value = tmp;
}

fn->insertAtTail(new CallExpr(PRIM_SET_MEMBER, iterator, field, value));
}
// Initialize the iterator record with the live arguments.
initializeRecordFieldWithArgLocals(fn, iterator, locals, local2field);

// Return the filled-in iterator record.
if (fn->hasFlag(FLAG_FN_RETARG)) {
ArgSymbol* retArg = NULL;
for_formals(formal, fn) {
if (formal->hasFlag(FLAG_RETARG))
retArg = formal;
}
fn->insertAtTail(new CallExpr(PRIM_ASSIGN, retArg, iterator));
fn->insertAtTail(new CallExpr(PRIM_RETURN, gVoid));
} else {
fn->insertAtTail(new CallExpr(PRIM_RETURN, iterator));
}
insertReturn(fn, iterator);

ii->getValue->defPoint->insertAfter(new DefExpr(fn));

Expand Down Expand Up @@ -1990,10 +2000,8 @@ removeLocals(Vec<Symbol*>& locals, std::vector<BaseAST*>& asts, Vec<Symbol*>& yl
}


// Creates (and returns) an iterator class field.
// 'type' is used if local==NULL.
static inline Symbol* createICField(int& i, Symbol* local, Type* type,
bool isValueField, FnSymbol* fn) {
Symbol* createICField(int& i, Symbol* local, Type* type,
bool isValueField, FnSymbol* fn) {
// The field name is "value" for the return value of the iterator,
// or F<int>_<local->name> otherwise.
const char* fieldName = isValueField
Expand All @@ -2012,10 +2020,15 @@ static inline Symbol* createICField(int& i, Symbol* local, Type* type,
qt = qt.refToRefType();

INT_ASSERT(qt.type() != dtUnknown);
Symbol* field = new VarSymbol(fieldName, qt);
return new VarSymbol(fieldName, qt);
}

fn->iteratorInfo->iclass->fields.insertAtTail(new DefExpr(field));
// Same as createAndInsertICField, but inserts into the iclass.
static inline Symbol* createAndInsertICField(int& i, Symbol* local, Type* type,
bool isValueField, FnSymbol* fn) {

auto field = createICField(i, local, type, isValueField, fn);
fn->iteratorInfo->iclass->fields.insertAtTail(new DefExpr(field));
return field;
}

Expand Down Expand Up @@ -2094,7 +2107,7 @@ static void addLocalsToClassAndRecord(Vec<Symbol*>& locals, FnSymbol* fn,
int i = 0; // This numbers the fields.
forv_Vec(Symbol, local, locals) {
bool isYieldSym = yldSymSet.set_in(local);
Symbol* field = createICField(i, local, NULL, isYieldSym && oneLocalYS, fn);
Symbol* field = createAndInsertICField(i, local, NULL, isYieldSym && oneLocalYS, fn);
local2field.put(local, field);
if (isYieldSym) {
INT_ASSERT(local->type == yieldedType);
Expand Down Expand Up @@ -2128,7 +2141,7 @@ static void addLocalsToClassAndRecord(Vec<Symbol*>& locals, FnSymbol* fn,
}

if (!valField) {
valField = createICField(i, NULL, yieldedType, true, fn);
valField = createAndInsertICField(i, NULL, yieldedType, true, fn);
}
*valFieldRef = valField;
}
Expand All @@ -2143,7 +2156,11 @@ void lowerIterator(FnSymbol* fn) {
INT_ASSERT(! iteratorsLowered); // ensure formalToPrimMap is valid
SET_LINENO(fn);
std::vector<BaseAST*> asts;
Type* yieldedType = removeRetSymbolAndUses(fn);
removeRetSymbolAndUses(fn);

INT_ASSERT(fn->iteratorInfo != NULL);
Type* yieldedType = fn->iteratorInfo->yieldedType;

collect_asts_postorder(fn, asts);

BlockStmt* singleLoop = NULL;
Expand Down
16 changes: 16 additions & 0 deletions compiler/AST/primitive.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,17 @@ returnInfoScalarPromotionType(CallExpr* call) {
return QualifiedType(type, QUAL_VAL);
}

static QualifiedType
returnInfoThunkResultType(CallExpr* call) {
QualifiedType tmp = call->get(1)->qualType();
AggregateType* type = toAggregateType(tmp.type()->getValType());

INT_ASSERT(type);
INT_ASSERT(type->thunkInvoke);

return QualifiedType(type->thunkInvoke->retType, QUAL_VAL);
}

static QualifiedType
returnInfoStaticFieldType(CallExpr* call) {
// The code below is not very general. It can be extended as needed.
Expand Down Expand Up @@ -730,6 +741,11 @@ initPrimitive() {
prim_def(PRIM_OUTER_CONTEXT, "outer context", returnInfoFirst);
prim_def(PRIM_HOIST_TO_CONTEXT, "hoist to context", returnInfoVoid);

prim_def(PRIM_CREATE_THUNK, "create thunk", returnInfoUnknown);
prim_def(PRIM_THUNK_RESULT, "thunk result", returnInfoFirst);
prim_def(PRIM_FORCE_THUNK, "force thunk", returnInfoUnknown);
prim_def(PRIM_THUNK_RESULT_TYPE, "thunk result type", returnInfoThunkResultType);

prim_def(PRIM_ACTUALS_LIST, "actuals list", returnInfoVoid);
prim_def(PRIM_NOOP, "noop", returnInfoVoid);
// dst, src. PRIM_MOVE can set a reference.
Expand Down
Loading

0 comments on commit a32cc42

Please sign in to comment.