Skip to content

Commit

Permalink
evaluate static string matches in TypedExprNormalization (#1262)
Browse files Browse the repository at this point in the history
* evaluate static string matches in matchless

* refactor to try to move to TypedExprNormalization...

* fix tests

* close TODO by normalizing Lit Str and Chr patterns

* simplify conversion to regex

* clean up toRegex more
  • Loading branch information
johnynek authored Nov 20, 2024
1 parent 84fb8da commit 35cbe08
Show file tree
Hide file tree
Showing 17 changed files with 570 additions and 330 deletions.
9 changes: 7 additions & 2 deletions core/src/main/scala/org/bykn/bosatsu/Lit.scala
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,15 @@ object Lit {
} else new Integer(BigInteger.valueOf(l))
}

case class Str(toStr: String) extends Lit {
// Means this lit could be the result of a string match
sealed abstract class StringMatchResult extends Lit {
def asStr: String
}
case class Str(toStr: String) extends StringMatchResult {
def unboxToAny: Any = toStr
def asStr = toStr
}
case class Chr(asStr: String) extends Lit {
case class Chr(asStr: String) extends StringMatchResult {
def toCodePoint: Int = asStr.codePointAt(0)
def unboxToAny: Any = asStr
}
Expand Down
73 changes: 6 additions & 67 deletions core/src/main/scala/org/bykn/bosatsu/Matchless.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package org.bykn.bosatsu

import cats.{Monad, Monoid}
import cats.data.{Chain, NonEmptyList, WriterT}
import org.bykn.bosatsu.pattern.StrPart
import org.bykn.bosatsu.rankn.{DataRepr, Type, RefSpace}

import Identifier.{Bindable, Constructor}
Expand All @@ -24,66 +25,6 @@ object Matchless {
def body: Expr
}

sealed abstract class StrPart
object StrPart {
sealed abstract class Glob(val capture: Boolean) extends StrPart
sealed abstract class CharPart(val capture: Boolean) extends StrPart
case object WildStr extends Glob(false)
case object IndexStr extends Glob(true)
case object WildChar extends CharPart(false)
case object IndexChar extends CharPart(true)
case class LitStr(asString: String) extends StrPart

sealed abstract class MatchSize(val isExact: Boolean) {
def charCount: Int
def canMatch(cp: Int): Boolean
// we know chars/2 <= cpCount <= chars for utf16
def canMatchUtf16Count(chars: Int): Boolean
}
object MatchSize {
case class Exactly(charCount: Int) extends MatchSize(true) {
def canMatch(cp: Int): Boolean = cp == charCount
def canMatchUtf16Count(chars: Int): Boolean = {
val cpmin = chars / 2
val cpmax = chars
(cpmin <= charCount) && (charCount <= cpmax)
}
}
case class AtLeast(charCount: Int) extends MatchSize(false) {
def canMatch(cp: Int): Boolean = charCount <= cp
def canMatchUtf16Count(chars: Int): Boolean = {
val cpmax = chars
// we have any cp in [cpmin, cpmax]
// but we require charCount <= cp
(charCount <= cpmax)
}
}

private val atLeast0 = AtLeast(0)
private val exactly0 = Exactly(0)
private val exactly1 = Exactly(1)

def from(sp: StrPart): MatchSize =
sp match {
case _: Glob => atLeast0
case _: CharPart => exactly1
case LitStr(str) =>
Exactly(str.codePointCount(0, str.length))
}

def apply[F[_]: cats.Foldable](f: F[StrPart]): MatchSize =
cats.Foldable[F].foldMap(f)(from)

implicit val monoidMatchSize: Monoid[MatchSize] =
new Monoid[MatchSize] {
def empty: MatchSize = exactly0
def combine(l: MatchSize, r: MatchSize) =
if (l.isExact && r.isExact) Exactly(l.charCount + r.charCount)
else AtLeast(l.charCount + r.charCount)
}
}
}

// name is set for recursive (but not tail recursive) methods
case class Lambda(
captures: List[Expr],
Expand Down Expand Up @@ -527,10 +468,6 @@ object Matchless {
case Pattern.StrPart.NamedChar(n) => n
}

val muts = sbinds.traverse { b =>
makeAnon.map(LocalAnonMut(_)).map((b, _))
}

val pat = items.toList.map {
case Pattern.StrPart.NamedStr(_) => StrPart.IndexStr
case Pattern.StrPart.NamedChar(_) => StrPart.IndexChar
Expand All @@ -539,10 +476,13 @@ object Matchless {
case Pattern.StrPart.LitStr(s) => StrPart.LitStr(s)
}

muts.map { binds =>
sbinds.traverse { b =>
makeAnon.map(LocalAnonMut(_)).map((b, _))
}
.map { binds =>
val ms = binds.map(_._2)

NonEmptyList.of((ms, MatchString(arg, pat, ms), binds))
NonEmptyList.one((ms, MatchString(arg, pat, ms), binds))
}
case lp @ Pattern.ListPat(_) =>
lp.toPositionalStruct(empty, cons) match {
Expand Down Expand Up @@ -817,7 +757,6 @@ object Matchless {
(Pattern[(PackageName, Constructor), Type], Expr)
]
): F[Expr] = {

def recur(
arg: CheapExpr,
branches: NonEmptyList[
Expand Down
125 changes: 3 additions & 122 deletions core/src/main/scala/org/bykn/bosatsu/MatchlessToValue.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,13 @@ import cats.{Eval, Functor, Applicative}
import cats.data.NonEmptyList
import cats.evidence.Is
import java.math.BigInteger
import org.bykn.bosatsu.pattern.StrPart
import scala.collection.immutable.LongMap
import scala.collection.mutable.{LongMap => MLongMap}

import Identifier.Bindable
import Value._

import cats.implicits._

object MatchlessToValue {
import Matchless._

Expand Down Expand Up @@ -209,7 +208,7 @@ object MatchlessToValue {
// we have nothing to bind
loop(str).map { strV =>
val arg = strV.asExternal.toAny.asInstanceOf[String]
matchString(arg, pat, 0) != null
StrPart.matchString(arg, pat, 0) != null
}
case _ =>
val bary = binds.iterator.collect { case LocalAnonMut(id) =>
Expand All @@ -219,7 +218,7 @@ object MatchlessToValue {
// this may be static
val matchScope = loop(str).map { str =>
val arg = str.asExternal.toAny.asInstanceOf[String]
matchString(arg, pat, bary.length)
StrPart.matchString(arg, pat, bary.length)
}
// if we mutate scope, it has to be dynamic
Dynamic { scope =>
Expand Down Expand Up @@ -547,123 +546,5 @@ object MatchlessToValue {
}

}

private[this] val emptyStringArray: Array[String] = new Array[String](0)
def matchString(
str: String,
pat: List[StrPart],
binds: Int
): Array[String] = {
import Matchless.StrPart._

val strLen = str.length()
val results =
if (binds > 0) new Array[String](binds) else emptyStringArray

def loop(offset: Int, pat: List[StrPart], next: Int): Boolean =
pat match {
case Nil => offset == strLen
case LitStr(expect) :: tail =>
val len = expect.length
str.regionMatches(offset, expect, 0, len) && loop(
offset + len,
tail,
next
)
case (c: CharPart) :: tail =>
try {
val nextOffset = str.offsetByCodePoints(offset, 1)
val n =
if (c.capture) {
results(next) = str.substring(offset, nextOffset)
next + 1
} else next

loop(nextOffset, tail, n)
} catch {
case _: IndexOutOfBoundsException => false
}
case (h: Glob) :: tail =>
tail match {
case Nil =>
// we capture all the rest
if (h.capture) {
results(next) = str.substring(offset)
}
true
case rest @ ((_: CharPart) :: _) =>
val matchableSizes = MatchSize[List](rest)

def canMatch(off: Int): Boolean =
matchableSizes.canMatch(str.codePointCount(off, strLen))

// (.*)(.)tail2
// this is a naive algorithm that just
// checks at all possible later offsets
// a smarter algorithm could see if there
// are Lit parts that can match or not
var matched = false
var off1 = offset
val n1 = if (h.capture) (next + 1) else next
while (!matched && (off1 < strLen)) {
matched = canMatch(off1) && loop(off1, rest, n1)
if (!matched) {
off1 = off1 + Character.charCount(str.codePointAt(off1))
}
}

matched && {
if (h.capture) {
results(next) = str.substring(offset, off1)
}
true
}
case LitStr(expect) :: tail2 =>
val next1 = if (h.capture) next + 1 else next

val matchableSizes = MatchSize(tail2)

def canMatch(off: Int): Boolean =
matchableSizes.canMatchUtf16Count(strLen - off)

var start = offset
var result = false
while (start >= 0) {
val candidate = str.indexOf(expect, start)
if (candidate >= 0) {
// we have to skip the current expect string
val nextOff = candidate + expect.length
val check1 =
canMatch(nextOff) && loop(nextOff, tail2, next1)
if (check1) {
// this was a match, write into next if needed
if (h.capture) {
results(next) = str.substring(offset, candidate)
}
result = true
start = -1
} else {
// we couldn't match here, try just after candidate
start = candidate + Character.charCount(
str.codePointAt(candidate)
)
}
} else {
// no more candidates
start = -1
}
}
result
// $COVERAGE-OFF$
case (_: Glob) :: _ =>
// this should be an error at compile time since it
// is never meaningful to have two adjacent globs
sys.error(s"invariant violation, adjacent globs: $pat")
// $COVERAGE-ON$
}
}

if (loop(0, pat, 0)) results else null
}
}
}
56 changes: 43 additions & 13 deletions core/src/main/scala/org/bykn/bosatsu/Pattern.scala
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import cats.data.NonEmptyList
import cats.parse.{Parser0 => P0, Parser => P}
import org.typelevel.paiges.{Doc, Document}
import org.bykn.bosatsu.pattern.{NamedSeqPattern, SeqPattern, SeqPart}
import java.util.regex.{Pattern => RegexPattern}

import Parser.{Combinators, maybeSpace, MaybeTupleOrParens}
import cats.implicits._
Expand Down Expand Up @@ -464,10 +465,10 @@ object Pattern {
}
}

lazy val toNamedSeqPattern: NamedSeqPattern[Char] =
lazy val toNamedSeqPattern: NamedSeqPattern[Int] =
StrPat.toNamedSeqPattern(this)

lazy val toSeqPattern: SeqPattern[Char] = toNamedSeqPattern.unname
lazy val toSeqPattern: SeqPattern[Int] = toNamedSeqPattern.unname

lazy val toLiteralString: Option[String] =
toSeqPattern.toLiteralSeq.map(_.mkString)
Expand All @@ -476,6 +477,30 @@ object Pattern {

def matches(str: String): Boolean =
isTotal || matcher(str).isDefined

/**
* Convert to a regular expression matching this pattern, which
* uses reluctant modifiers
*/
def toRegex: RegexPattern = {
def mapPart(p: StrPart): String =
p match {
case StrPart.NamedStr(_) => "(.*?)"
case StrPart.WildStr => ".*?"
case StrPart.NamedChar(_) => "(.)"
case StrPart.WildChar => "."
case StrPart.LitStr(s) =>
// we need to escape any characters that may be in regex
RegexPattern.quote(s)
}
RegexPattern.compile(
parts
.iterator
.map(mapPart(_))
.mkString,
RegexPattern.DOTALL
)
}
}

/** Patterns like Some(_) as foo as binds tighter than |, so use ( ) with
Expand Down Expand Up @@ -600,14 +625,19 @@ object Pattern {
val Empty: StrPat = fromLitStr("")
val Wild: StrPat = StrPat(NonEmptyList.one(StrPart.WildStr))

def fromSeqPattern(sp: SeqPattern[Char]): StrPat = {
def lit(rev: List[Char]): List[StrPart.LitStr] =
def fromSeqPattern(sp: SeqPattern[Int]): StrPat = {
def lit(rev: List[Int]): List[StrPart.LitStr] =
if (rev.isEmpty) Nil
else StrPart.LitStr(rev.reverse.mkString) :: Nil
else {
val cps = rev.reverse
val bldr = new java.lang.StringBuilder
cps.foreach(bldr.appendCodePoint(_))
StrPart.LitStr(bldr.toString) :: Nil
}

def loop(
ps: List[SeqPart[Char]],
front: List[Char]
ps: List[SeqPart[Int]],
front: List[Int]
): NonEmptyList[StrPart] =
ps match {
case Nil => NonEmptyList.fromList(lit(front)).getOrElse(Empty.parts)
Expand Down Expand Up @@ -638,10 +668,10 @@ object Pattern {
StrPat(loop(sp.toList, Nil))
}

def toNamedSeqPattern(sp: StrPat): NamedSeqPattern[Char] = {
val empty: NamedSeqPattern[Char] = NamedSeqPattern.NEmpty
def toNamedSeqPattern(sp: StrPat): NamedSeqPattern[Int] = {
val empty: NamedSeqPattern[Int] = NamedSeqPattern.NEmpty

def partToNsp(s: StrPart): NamedSeqPattern[Char] =
def partToNsp(s: StrPart): NamedSeqPattern[Int] =
s match {
case StrPart.NamedStr(n) =>
NamedSeqPattern.Bind(n.sourceCodeRepr, NamedSeqPattern.Wild)
Expand All @@ -650,9 +680,9 @@ object Pattern {
case StrPart.WildStr => NamedSeqPattern.Wild
case StrPart.WildChar => NamedSeqPattern.Any
case StrPart.LitStr(s) =>
if (s.isEmpty) empty
else
s.toList.foldRight(empty) { (c, tail) =>
StringUtil
.codePoints(s)
.foldRight(empty) { (c, tail) =>
NamedSeqPattern.NCat(NamedSeqPattern.fromLit(c), tail)
}
}
Expand Down
Loading

0 comments on commit 35cbe08

Please sign in to comment.