Skip to content

Commit

Permalink
Test c generation of all predef (#1259)
Browse files Browse the repository at this point in the history
* Test c generation of all predef

* make node tests pass
  • Loading branch information
johnynek authored Nov 18, 2024
1 parent 0c5ece3 commit 6c9f52e
Show file tree
Hide file tree
Showing 10 changed files with 215 additions and 72 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
package org.bykn.bosatsu.codegen.clang

import cats.data.NonEmptyList
import org.bykn.bosatsu.{PackageName, PackageMap, TestUtils, Identifier, Predef}
import Identifier.Name
import org.bykn.bosatsu.MatchlessFromTypedExpr

import org.bykn.bosatsu.DirectEC.directEC

class ClangGenTest extends munit.FunSuite {
val predef_c = Code.Include(true, "bosatsu_predef.h")

def predef(s: String, arity: Int) =
(PackageName.PredefName -> Name(s)) -> (predef_c,
ClangGen.generatedName(PackageName.PredefName, Name(s)),
arity)

val jvmExternals = {
val ext = Predef.jvmExternals.toMap.iterator.map { case ((_, n), ffi) => predef(n, ffi.arity) }
.toMap[(PackageName, Identifier), (Code.Include, Code.Ident, Int)]

{ (pn: (PackageName, Identifier)) => ext.get(pn) }
}

def md5HashToHex(content: String): String = {
val md = java.security.MessageDigest.getInstance("MD5")
val digest = md.digest(content.getBytes("UTF-8"))
digest.map("%02x".format(_)).mkString
}
def testFilesCompilesToHash(path0: String, paths: String*)(hashHex: String)(implicit loc: munit.Location) = {
val pm: PackageMap.Typed[Any] = TestUtils.compileFile(path0, paths*)
/*
val exCode = ClangGen.generateExternalsStub(pm)
println(exCode.render(80))
sys.error("stop")
*/
val matchlessMap = MatchlessFromTypedExpr.compile(pm)
val topoSort = pm.topoSort.toSuccess.get
val sortedEnv = cats.Functor[Vector].compose[NonEmptyList].map(topoSort) { pn =>
(pn, matchlessMap(pn))
}

val res = ClangGen.renderMain(
sortedEnv = sortedEnv,
externals = jvmExternals,
value = (PackageName.PredefName, Identifier.Name("ignored")),
evaluator = (Code.Include(true, "eval.h"), Code.Ident("evaluator_run"))
)

res match {
case Right(d) =>
val everything = d.render(80)
val hashed = md5HashToHex(everything)
assertEquals(hashed, hashHex, s"compilation didn't match. Compiled code:\n\n${"//" * 40}\n\n$everything")
case Left(e) => fail(e.toString)
}
}

test("test_workspace/Ackermann.bosatsu") {
/*
To inspect the code, change the hash, and it will print the code out
*/
testFilesCompilesToHash("test_workspace/Ackermann.bosatsu")(
"46716ef3c97cf2a79bf17d4033d55854"
)
}
}
Original file line number Diff line number Diff line change
@@ -1,17 +1,12 @@
package org.bykn.bosatsu.codegen.python

import cats.Show
import cats.data.NonEmptyList
import java.io.{ByteArrayInputStream, InputStream}
import java.nio.file.{Paths, Files}
import java.util.concurrent.Semaphore
import org.bykn.bosatsu.{
PackageMap,
MatchlessFromTypedExpr,
Parser,
Package,
LocationMap,
PackageName
PackageName,
TestUtils
}
import org.scalacheck.Gen
import org.scalatestplus.scalacheck.ScalaCheckPropertyChecks.{
Expand All @@ -24,6 +19,8 @@ import org.python.core.{PyInteger, PyFunction, PyObject, PyTuple}
import org.bykn.bosatsu.DirectEC.directEC
import org.scalatest.funsuite.AnyFunSuite

import TestUtils.compileFile

// Jython seems to have some thread safety issues
object JythonBarrier {
private val sem = new Semaphore(1)
Expand Down Expand Up @@ -87,27 +84,6 @@ class PythonGenTest extends AnyFunSuite {
}
}

def compileFile(path: String, rest: String*): PackageMap.Typed[Any] = {
def toS(s: String): String =
new String(Files.readAllBytes(Paths.get(s)), "UTF-8")

val packNEL =
NonEmptyList(path, rest.toList)
.map { s =>
val str = toS(s)
val pack = Parser.unsafeParse(Package.parser(None), str)
(("", LocationMap(str)), pack)
}

val res = PackageMap.typeCheckParsed(packNEL, Nil, "")
res.left match {
case Some(err) => sys.error(err.toString)
case None => ()
}

res.right.get
}

def isfromString(s: String): InputStream =
new ByteArrayInputStream(s.getBytes("UTF-8"))

Expand Down
12 changes: 4 additions & 8 deletions core/src/main/scala/org/bykn/bosatsu/FfiCall.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@ package org.bykn.bosatsu

import cats.data.NonEmptyList

sealed abstract class FfiCall {
sealed abstract class FfiCall(val arity: Int) {
def call(t: rankn.Type): Value
}

object FfiCall {
final case class Fn1(fn: Value => Value) extends FfiCall {
final case class Fn1(fn: Value => Value) extends FfiCall(1) {
import Value.FnValue

private[this] val evalFn: FnValue = FnValue { case NonEmptyList(a, _) =>
Expand All @@ -16,7 +16,7 @@ object FfiCall {

def call(t: rankn.Type): Value = evalFn
}
final case class Fn2(fn: (Value, Value) => Value) extends FfiCall {
final case class Fn2(fn: (Value, Value) => Value) extends FfiCall(2) {
import Value.FnValue

private[this] val evalFn: FnValue =
Expand All @@ -26,7 +26,7 @@ object FfiCall {

def call(t: rankn.Type): Value = evalFn
}
final case class Fn3(fn: (Value, Value, Value) => Value) extends FfiCall {
final case class Fn3(fn: (Value, Value, Value) => Value) extends FfiCall(3) {
import Value.FnValue

private[this] val evalFn: FnValue =
Expand All @@ -37,10 +37,6 @@ object FfiCall {
def call(t: rankn.Type): Value = evalFn
}

final case class FromFn(callFn: rankn.Type => Value) extends FfiCall {
def call(t: rankn.Type): Value = callFn(t)
}

def getJavaType(t: rankn.Type): List[Class[_]] = {
def one(t: rankn.Type): Option[Class[_]] =
loop(t, false) match {
Expand Down
4 changes: 2 additions & 2 deletions core/src/main/scala/org/bykn/bosatsu/MainModule.scala
Original file line number Diff line number Diff line change
Expand Up @@ -630,7 +630,7 @@ abstract class MainModule[IO[_]](implicit
val intrinsic = PythonGen.intrinsicValues
val missingExternals =
allExternals.iterator.flatMap { case (p, names) =>
val missing = names.filterNot { case n =>
val missing = names.filterNot { case (n, _) =>
exts((p, n)) || intrinsic.get(p).exists(_(n))
}

Expand Down Expand Up @@ -703,7 +703,7 @@ abstract class MainModule[IO[_]](implicit
Doc.char('[') +
Doc.intercalate(
Doc.comma + Doc.lineOrSpace,
names.map(b => Doc.text(b.sourceCodeRepr))
names.map { case (b, _) => Doc.text(b.sourceCodeRepr) }
) + Doc.char(']')).nested(4)
}

Expand Down
10 changes: 7 additions & 3 deletions core/src/main/scala/org/bykn/bosatsu/PackageMap.scala
Original file line number Diff line number Diff line change
Expand Up @@ -45,13 +45,17 @@ case class PackageMap[A, B, C, +D](

def allExternals(implicit
ev: Package[A, B, C, D] <:< Package.Typed[Any]
): Map[PackageName, List[Identifier.Bindable]] =
): Map[PackageName, List[(Identifier.Bindable, rankn.Type)]] =
toMap.iterator.map { case (name, pack) =>
(name, ev(pack).externalDefs)
val tpack = ev(pack)
(name, tpack.externalDefs.map { n =>
(n, tpack.types.getExternalValue(name, n)
.getOrElse(sys.error(s"invariant violation, unknown type: $name $n")) )
})
}.toMap

def topoSort(
ev: Package[A, B, C, D] <:< Package.Typed[Any]
implicit ev: Package[A, B, C, D] <:< Package.Typed[Any]
): Toposort.Result[PackageName] = {

val packNames = toMap.keys.iterator.toList.sorted
Expand Down
97 changes: 76 additions & 21 deletions core/src/main/scala/org/bykn/bosatsu/codegen/clang/ClangGen.scala
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@ import cats.data.{StateT, EitherT, NonEmptyList, Chain}
import java.math.BigInteger
import java.nio.charset.StandardCharsets
import org.bykn.bosatsu.codegen.Idents
import org.bykn.bosatsu.rankn.DataRepr
import org.bykn.bosatsu.{Identifier, Lit, Matchless, PackageName}
import org.bykn.bosatsu.rankn.{DataRepr, Type}
import org.bykn.bosatsu.{Identifier, Lit, Matchless, PackageName, PackageMap}
import org.bykn.bosatsu.Matchless.Expr
import org.bykn.bosatsu.Identifier.Bindable
import org.typelevel.paiges.Doc
Expand All @@ -21,9 +21,43 @@ object ClangGen {
case class Unbound(bn: Bindable, inside: Option[(PackageName, Bindable)]) extends Error
}

def generateExternalsStub(pm: PackageMap.Typed[Any]): Doc = {
val includes = Code.Include(true, "bosatsu_runtime.h") :: Nil

def toStmt(pn: PackageName, ident: Identifier.Bindable, arity: Int): Code.Statement = {
val cIdent = generatedName(pn, ident)
val args = Idents.allSimpleIdents.take(arity).map { nm =>
Code.Param(Code.TypeIdent.BValue, Code.Ident(nm))
}
Code.DeclareFn(Nil, Code.TypeIdent.BValue, cIdent, args.toList, Some(
Code.block(Code.Return(Some(Code.IntLiteral.Zero)))
))
}

def tpeArity(t: Type): Int =
t match {
case Type.Fun.MaybeQuant(_, args, _) => args.length
case _ => 0
}

val fns = pm.allExternals
.iterator
.flatMap { case (p, vs) =>
vs.iterator.map { case (n, tpe) =>
Code.toDoc(toStmt(p, n, tpeArity(tpe)))
}
}
.toList

val line2 = Doc.hardLine + Doc.hardLine

Doc.intercalate(Doc.hardLine, includes.map(Code.toDoc)) + line2 +
Doc.intercalate(line2, fns)
}

def renderMain(
sortedEnv: Vector[NonEmptyList[(PackageName, List[(Bindable, Expr)])]],
externals: Map[(PackageName, Bindable), (Code.Include, Code.Ident)],
externals: ((PackageName, Bindable)) => Option[(Code.Include, Code.Ident, Int)],
value: (PackageName, Bindable),
evaluator: (Code.Include, Code.Ident)
): Either[Error, Doc] = {
Expand All @@ -44,23 +78,23 @@ object ClangGen {
.iterator.flatMap(_.iterator)
.flatMap { case (p, vs) =>
vs.iterator.map { case (b, e) =>
(p, b) -> (e, Impl.generatedName(p, b))
(p, b) -> (e, generatedName(p, b))
}
}
.toMap

run(allValues, externals, res)
}

private object Impl {
type AllValues = Map[(PackageName, Bindable), (Expr, Code.Ident)]
type Externals = Map[(PackageName, Bindable), (Code.Include, Code.Ident)]
private def fullName(p: PackageName, b: Bindable): String =
p.asString + "/" + b.asString

def fullName(p: PackageName, b: Bindable): String =
p.asString + "/" + b.asString
def generatedName(p: PackageName, b: Bindable): Code.Ident =
Code.Ident(Idents.escape("___bsts_g_", fullName(p, b)))

def generatedName(p: PackageName, b: Bindable): Code.Ident =
Code.Ident(Idents.escape("___bsts_g_", fullName(p, b)))
private object Impl {
type AllValues = Map[(PackageName, Bindable), (Expr, Code.Ident)]
type Externals = Function1[(PackageName, Bindable), Option[(Code.Include, Code.Ident, Int)]]

trait Env {
import Matchless._
Expand Down Expand Up @@ -410,11 +444,7 @@ object ClangGen {
case Some(nm) =>
pv(Code.Ident("STATIC_PUREFN")(nm))
case None =>
// read_or_build(&__bvalue_foo, make_foo);
for {
value <- staticValueName(pack, name)
consFn <- constructorFn(pack, name)
} yield Code.Ident("read_or_build")(value.addr, consFn): Code.ValueLike
globalIdent(pack, name).map { nm => nm() }
}
case Local(arg) =>
directFn(arg)
Expand Down Expand Up @@ -494,7 +524,7 @@ object ClangGen {
case ZeroNat =>
pv(Code.Ident("BSTS_NAT_0"))
case SuccNat =>
val arg = Identifier.Name("arg0")
val arg = Identifier.Name("nat")
// This relies on optimizing App(SuccNat, _) otherwise
// it creates an infinite loop.
// Also, this we should cache creation of Lambda/Closure values
Expand Down Expand Up @@ -567,25 +597,37 @@ object ClangGen {
_ <- appendStatement(stmt)
} yield ()
case someValue =>
// TODO: if we can create the value statically, we don't
// need the read_or_build trick
//
// we materialize an Atomic value to hold the static data
// then we generate a function to populate the value
for {
vl <- innerToValue(someValue)
value <- staticValueName(p, b)
consFn <- constructorFn(p, b)
_ <- appendStatement(Code.DeclareVar(
Code.Attr.Static :: Nil,
Code.TypeIdent.AtomicBValue,
value,
Some(Code.IntLiteral.Zero)
))
consFn <- constructorFn(p, b)
_ <- appendStatement(Code.DeclareFn(
Code.Attr.Static :: Nil,
Code.TypeIdent.BValue,
consFn,
Nil,
Some(Code.block(Code.returnValue(vl)))
))
readFn <- globalIdent(p, b)
res = Code.Ident("read_or_build")(value.addr, consFn)
_ <- appendStatement(Code.DeclareFn(
Code.Attr.Static :: Nil,
Code.TypeIdent.BValue,
readFn,
Nil,
Some(Code.block(Code.returnValue(res)))
))
} yield ()
}
}
Expand Down Expand Up @@ -652,8 +694,9 @@ object ClangGen {
def globalIdent(pn: PackageName, bn: Bindable): T[Code.Ident] =
StateT { s =>
val key = (pn, bn)
s.externals.get(key) match {
case Some((incl, ident)) =>
s.externals(key) match {
case Some((incl, ident, _)) =>
// TODO: suspect that we are ignoring arity here
val withIncl =
if (s.includeSet(incl)) s
else s.copy(includeSet = s.includeSet + incl, includes = s.includes :+ incl)
Expand Down Expand Up @@ -775,9 +818,21 @@ object ClangGen {
// record that this name is a top level function, so applying it can be direct
def directFn(pack: PackageName, b: Bindable): T[Option[Code.Ident]] =
StateT { s =>
s.allValues.get((pack, b)) match {
val key = (pack, b)
s.allValues.get(key) match {
case Some((_: Matchless.FnExpr, ident)) =>
result(s, Some(ident))
case None =>
// this is external
s.externals(key) match {
case Some((incl, ident, arity)) if arity > 0 =>
// TODO: suspect that we are ignoring arity here
val withIncl =
if (s.includeSet(incl)) s
else s.copy(includeSet = s.includeSet + incl, includes = s.includes :+ incl)
result(withIncl, Some(ident))
case _ => result(s, None)
}
case _ => result(s, None)
}
}
Expand Down
Loading

0 comments on commit 6c9f52e

Please sign in to comment.