Skip to content

Commit

Permalink
Implement searchList in C (#1264)
Browse files Browse the repository at this point in the history
* Implement searchList in C

* minor formatting
  • Loading branch information
johnynek authored Nov 20, 2024
1 parent 35cbe08 commit 52b8ec6
Show file tree
Hide file tree
Showing 4 changed files with 119 additions and 11 deletions.
111 changes: 105 additions & 6 deletions core/src/main/scala/org/bykn/bosatsu/codegen/clang/ClangGen.scala
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,7 @@ object ClangGen {
def directFn(p: PackageName, b: Bindable): T[Option[Code.Ident]]
def directFn(b: Bindable): T[Option[(Code.Ident, Boolean)]]
def inTop[A](p: PackageName, bn: Bindable)(ta: T[A]): T[A]
def currentTop: T[Option[(PackageName, Bindable)]]
def staticValueName(p: PackageName, b: Bindable): T[Code.Ident]
def constructorFn(p: PackageName, b: Bindable): T[Code.Ident]

Expand Down Expand Up @@ -318,10 +319,11 @@ object ClangGen {
// this is just get_variant(expr) == expect
vl.onExpr { expr => pv(Code.Ident("get_variant")(expr) =:= Code.IntLiteral(expect)) }(newLocalName)
}
case sl @ SearchList(lst, init, check, leftAcc) =>
// TODO: ???
println(s"TODO: implement boolToValue($sl) returning false")
pv(Code.FalseLit)
case SearchList(lst, init, check, leftAcc) =>
(boolToValue(check), innerToValue(init))
.flatMapN { (condV, initV) =>
searchList(lst, initV, condV, leftAcc)
}
case ms @ MatchString(arg, parts, binds) =>
// TODO: ???
println(s"TODO: implement boolToValue($ms) returning false")
Expand All @@ -334,6 +336,100 @@ object ClangGen {
case TrueConst => pv(Code.TrueLit)
}

def searchList(
locMut: LocalAnonMut,
initVL: Code.ValueLike,
checkVL: Code.ValueLike,
optLeft: Option[LocalAnonMut]
): T[Code.ValueLike] = {
import Code.Expression

val emptyList: Expression =
Code.Ident("alloc_enum0")(Code.IntLiteral(0))

def isNonEmptyList(expr: Expression): Expression =
Code.Ident("get_variant")(expr) =:= Code.IntLiteral(1)

def headList(expr: Expression): Expression =
Code.Ident("get_enum_index")(expr, Code.IntLiteral(0))

def tailList(expr: Expression): Expression =
Code.Ident("get_enum_index")(expr, Code.IntLiteral(1))

def consList(head: Expression, tail: Expression): Expression =
Code.Ident("alloc_enum2")(Code.IntLiteral(1), head, tail)
/*
* here is the implementation from MatchlessToValue
*
Dynamic { (scope: Scope) =>
var res = false
var currentList = initF(scope)
var leftList = VList.VNil
while (currentList ne null) {
currentList match {
case nonempty@VList.Cons(head, tail) =>
scope.updateMut(mutV, nonempty)
scope.updateMut(left, leftList)
res = checkF(scope)
if (res) { currentList = null }
else {
currentList = tail
leftList = VList.Cons(head, leftList)
}
case _ =>
currentList = null
// we don't match empty lists
}
}
res
}
*/
for {
currentList <- getAnon(locMut.ident)
optLeft <- optLeft.traverse(lm => getAnon(lm.ident))
res <- newLocalName("result")
tmpList <- newLocalName("tmp_list")
declTmpList <- Code.ValueLike.declareVar(Code.TypeIdent.BValue, tmpList, initVL)(newLocalName)
/*
top <- currentTop
_ = println(s"""in $top: searchList(
$locMut: LocalAnonMut,
$initVL: Code.ValueLike,
$checkVL: Code.ValueLike,
$optLeft: Option[LocalAnonMut]
)""")
*/
} yield
(Code
.Statements(
Code.DeclareVar(Nil, Code.TypeIdent.Bool, res, Some(Code.FalseLit)),
declTmpList
)
.maybeCombine(
optLeft.map(_ := emptyList),
) +
// we don't match empty lists, so if currentList reaches Empty we are done
Code.While(
isNonEmptyList(tmpList),
Code.block(
currentList := tmpList,
res := checkVL,
Code.ifThenElse(res,
{ tmpList := emptyList },
{
(tmpList := tailList(tmpList))
.maybeCombine(
optLeft.map { left =>
left := consList(headList(currentList), left)
}
)
}
)
)
)
) :+ res
}

// We have to lift functions to the top level and not
// create any nesting
def innerFn(fn: FnExpr): T[Code.ValueLike] =
Expand Down Expand Up @@ -509,7 +605,7 @@ object ClangGen {
for {
name <- getBinding(arg)
result <- innerToValue(in)
stmt <- Code.ValueLike.declareVar(name, Code.TypeIdent.BValue, v)(newLocalName)
stmt <- Code.ValueLike.declareVar(Code.TypeIdent.BValue, name, v)(newLocalName)
} yield stmt +: result
}
}
Expand All @@ -521,7 +617,7 @@ object ClangGen {
for {
name <- getAnon(idx)
result <- innerToValue(in)
stmt <- Code.ValueLike.declareVar(name, Code.TypeIdent.BValue, v)(newLocalName)
stmt <- Code.ValueLike.declareVar(Code.TypeIdent.BValue, name, v)(newLocalName)
} yield stmt +: result
}
}
Expand Down Expand Up @@ -941,6 +1037,9 @@ object ClangGen {
_ <- StateT { (s: State) => result(s.copy(currentTop = None), ()) }
} yield a

val currentTop: T[Option[(PackageName, Bindable)]] =
StateT { (s: State) => result(s, s.currentTop) }

def staticValueName(p: PackageName, b: Bindable): T[Code.Ident] =
monadImpl.pure(Code.Ident(Idents.escape("___bsts_s_", fullName(p, b))))
def constructorFn(p: PackageName, b: Bindable): T[Code.Ident] =
Expand Down
12 changes: 11 additions & 1 deletion core/src/main/scala/org/bykn/bosatsu/codegen/clang/Code.scala
Original file line number Diff line number Diff line change
Expand Up @@ -192,8 +192,8 @@ object Code {
}

def declareVar[F[_]: Monad](
ident: Ident,
tpe: TypeIdent,
ident: Ident,
value: ValueLike)(newLocalName: String => F[Code.Ident]): F[Statement] =
value.exprToStatement[F] { expr =>
Monad[F].pure(DeclareVar(Nil, tpe, ident, Some(expr)))
Expand Down Expand Up @@ -329,8 +329,15 @@ object Code {

sealed trait Statement extends Code {
def +(stmt: Statement): Statement = Statements.combine(this, stmt)
def maybeCombine(that: Option[Statement]): Statement =
that match {
case Some(t) => Statements.combine(this, t)
case None => this

}
def :+(vl: ValueLike): ValueLike = (this +: vl)
}

case class Assignment(target: Expression, value: Expression) extends Statement
case class DeclareArray(tpe: TypeIdent, ident: Ident, values: Either[Int, List[Expression]]) extends Statement
case class DeclareVar(attrs: List[Attr], tpe: TypeIdent, ident: Ident, value: Option[Expression]) extends Statement
Expand All @@ -347,6 +354,9 @@ object Code {
def apply(nel: NonEmptyList[Statement]): Statements =
Statements(NonEmptyChain.fromNonEmptyList(nel))

def apply(first: Statement, rest: Statement*): Statements =
Statements(NonEmptyChain.of(first, rest: _*))

def combine(first: Statement, last: Statement): Statement =
first match {
case Statements(items) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1267,10 +1267,10 @@ object PythonGen {
case SearchList(locMut, init, check, optLeft) =>
// check to see if we can find a non-empty
// list that matches check
(loop(init, slotName), boolExpr(check, slotName)).mapN {
(loop(init, slotName), boolExpr(check, slotName)).flatMapN {
(initVL, checkVL) =>
searchList(locMut, initVL, checkVL, optLeft)
}.flatten
}
}

def matchString(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
package org.bykn.bosatsu.codegen.clang

import cats.data.NonEmptyList
import org.bykn.bosatsu.{PackageName, TestUtils, Identifier, Predef}
import Identifier.Name
import org.bykn.bosatsu.{PackageName, TestUtils, Identifier}

class ClangGenTest extends munit.FunSuite {
def assertPredefFns(fns: String*)(matches: String)(implicit loc: munit.Location) =
Expand Down

0 comments on commit 52b8ec6

Please sign in to comment.