Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/more simplification #30

Open
wants to merge 14 commits into
base: master
Choose a base branch
from
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,8 @@ See the [youtube presentation](https://youtu.be/cESdgot_ZxY) for more details ab

[This lecture](https://youtu.be/v=KcfD3Iv--UM) is a pedagogical explanation of the Curry-Howard correspondence in the context of functional programming.

See also a [recent presentation at the Haskell User's Group meetup](https://youtu.be/OFBwrMo1ESk).

# Unit tests

`sbt test`
Expand All @@ -156,6 +158,7 @@ Build the tutorial (thanks to the [tut plugin](https://github.com/tpolecat/tut))

# Revision history

- 0.3.8 Support Scala 2.13 (keep supporting Scala 2.11 and 2.12)
- 0.3.7 Implement the `typeExpr` macro instead of the old test-only API. Detect and use `val`s from the immediately enclosing class. Minor performance improvements and bug fixes (alpha-conversion for STLC terms). Tests for automatic discovery of some monads.
- 0.3.6 STLC terms are now emitted for `implement` as well; the JVM bytecode limit is obviated; fixed bug with recognizing `Function10`.
- 0.3.5 Added `:@@` and `@@:` operations to the STLC interpreter. Fixed a bug whereby `Tuple2(x._1, x._2)` was not simplified to `x`. Fixed other bugs in alpha-conversion of type parameters.
Expand Down
8 changes: 4 additions & 4 deletions build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ lazy val curryhoward: Project = (project in file("."))
.settings(common)
.settings(
organization := "io.chymyst",
version := "0.3.8",
version := "0.3.9",

licenses := Seq("Apache License, Version 2.0" -> url("https://www.apache.org/licenses/LICENSE-2.0.txt")),
homepage := Some(url("https://github.com/Chymyst/curryhoward")),
Expand Down Expand Up @@ -157,14 +157,14 @@ lazy val curryhoward: Project = (project in file("."))
/////////////////////////////////////////////////////////////////////////////////////////////////////
// Publishing to Sonatype Maven repository
publishMavenStyle := true
publishTo := sonatypePublishToBundle.value
/*{
publishTo := //sonatypePublishToBundle.value
{
val nexus = "https://oss.sonatype.org/"
if (isSnapshot.value)
Some("snapshots" at nexus + "content/repositories/snapshots")
else
Some("releases" at nexus + "service/local/staging/deploy/maven2")
}*/
}
//
publishArtifact in Test := false
//
Expand Down
4 changes: 2 additions & 2 deletions docs/Tutorial.md
Original file line number Diff line number Diff line change
Expand Up @@ -888,8 +888,8 @@ res26: Boolean = true
| `a.substTypeVar(b, c)` | `TermExpr ⇒ (TermExpr, TermExpr) ⇒ TermExpr` | replace a type variable in `a`; the type variable is specified as the type of `b`, and the replacement type is specified as the type of `c` |
| `a.substTypeVars(s)` | `TermExpr ⇒ Map[TP, TypeExpr] ⇒ TermExpr` | replace all type variables in `a` according to the given substitution map `s` -- all type variables are substituted at once |
| `u()` | `TermExpr ⇒ () ⇒ TermExpr` and `TypeExpr ⇒ () ⇒ TermExpr` | create a "named Unit" term of type `u.t` -- the type of `u` must be a named unit type, e.g. `None.type` or a case class with no constructors |
| `c(x...)` | `TermExpr ⇒ TermExpr* ⇒ TermExpr` and `TypeExpr ⇒ TermExpr* ⇒ TermExpr` | create a named conjunction term of type `c.t` -- the type of `c` must be a conjunction whose parts match the types of the arguments `x...` |
| `d(x)` | `TermExpr ⇒ TermExpr ⇒ TermExpr` and `TypeExpr ⇒ TermExpr ⇒ TermExpr` | create a disjunction term of type `d.t` using term `x` -- the type of `x` must match one of the disjunction parts in the type `d`, which must be a disjunction type |
| `c(x...)` | `TypeExpr ⇒ TermExpr* ⇒ TermExpr` and `TypeExpr ⇒ TermExpr* ⇒ TermExpr` | create a named conjunction term of type `c` -- the type `c` must be a conjunction whose parts match the types of the arguments `x...` |
| `d(x)` | `TypeExpr ⇒ TermExpr ⇒ TermExpr` and `TypeExpr ⇒ TermExpr ⇒ TermExpr` | create a disjunction term of type `d` using term `x` -- the type of `x` must match one of the disjunction parts in the type `d`, which must be a disjunction type |
| `c(i)` | `TermExpr ⇒ Int ⇒ TermExpr` | project a conjunction term onto part with given zero-based index -- the type of `c` must be a conjunction with sufficiently many parts |
| `c("id")` | `TermExpr ⇒ String ⇒ TermExpr` | project a conjunction term onto part with given accessor name -- the type of `c` must be a named conjunction that supports this accessor |
| `d.cases(x =>: ..., y =>: ..., ...)` | `TermExpr ⇒ TermExpr* ⇒ TermExpr` | create a term that pattern-matches on the given disjunction term -- the type of `d` must be a disjunction whose arguments match the arguments `x`, `y`, ... of the given case clauses |
103 changes: 91 additions & 12 deletions src/main/scala/io/chymyst/ch/TermExpr.scala
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ object TermExpr {
thisVar.name === otherVar.name && (thisVar.t === otherVar.t || TypeExpr.isDisjunctionPart(thisVar.t, otherVar.t))
}

/** Replace all non-free occurrences of variable `replaceVar` by expression `byExpr` in `origExpr`.
/** Replace all free occurrences of variable `replaceVar` by expression `byExpr` in `origExpr`.
*
* @param replaceVar A variable that may occur freely in `origExpr`.
* @param byExpr A new expression to replace all free occurrences of that variable.
Expand All @@ -156,7 +156,7 @@ object TermExpr {
// Check that all instances of replaceVar in origExpr have the correct type.
val badVars = origExpr.freeVars.filter(_.name === replaceVar.name).filterNot(varMatchesType(_, replaceVar))
if (badVars.nonEmpty) {
throw new Exception(s"In subst($replaceVar, $byExpr, $origExpr), found variable(s) ${badVars.map(v ⇒ s"(${v.name}:${v.t.prettyPrint})").mkString(", ")} with incorrect type(s), expected variable type ${replaceVar.t.prettyPrint}")
throw new Exception(s"In subst($replaceVar:${replaceVar.t.prettyPrint}, $byExpr, $origExpr), found variable(s) ${badVars.map(v ⇒ s"(${v.name}:${v.t.prettyPrint})").mkString(", ")} with incorrect type(s), expected variable type ${replaceVar.t.prettyPrint}")
}
// Do we need an alpha-conversion? Better be safe than sorry.
val (convertedReplaceVar, convertedOrigExpr) = if (byExpr.usedVarNames contains replaceVar.name) {
Expand All @@ -167,7 +167,7 @@ object TermExpr {

substMap(convertedOrigExpr) {
case c@CurriedE(heads, _) if heads.exists(_.name === convertedReplaceVar.name) ⇒ c
// If a variable from `heads` collides with `convertedReplaceVar`, we do not replace anything in the body.
// If a variable from `heads` collides with `convertedReplaceVar`, we do not replace anything in the body because the variable occurs as non-free.

case v@VarE(_, _) if varMatchesType(v, convertedReplaceVar) ⇒ byExpr
}
Expand Down Expand Up @@ -342,6 +342,43 @@ object TermExpr {
}

private[ch] def roundFactor(x: Double): Int = math.round(x * 10000).toInt

/** Generate all necessary fresh variables for equality checking of functions that consume disjunction types.
*
* @param typeExpr The type of the argument expression.
* @return A sequence of [[TermExpr]] values containing the necessary fresh variables.
*/
def subtypeVars(typeExpr: TypeExpr): Seq[TermExpr] = typeExpr match {
case dt@DisjunctT(_, _, terms) ⇒ terms.zipWithIndex.flatMap { case (t, i) ⇒ subtypeVars(t).map(v ⇒ DisjunctE(i, terms.length, v, dt)) }
case nct@NamedConjunctT(_, _, _, wrapped) ⇒
TheoremProver.explode(wrapped.map(subtypeVars)).map(NamedConjunctE(_, nct))
case _ ⇒ Seq(VarE(freshIdents(), typeExpr))
}

/** Extensional equality check. If the term expressions are functions, fresh variables are substituted as arguments and the results are compared with `equiv`.
*
* @param termExpr1 The first term.
* @param termExpr2 The second term.
* @return `true` if the terms are extensionally equal.
*/
def extEqual(termExpr1: TermExpr, termExpr2: TermExpr): Boolean = {
val t1 = termExpr1.simplify
val t2 = termExpr2.simplify
(t1.t === t2.t) && (
(t1 equiv t2) || {
println(s"DEBUG: checking extensional equality of ${t1.prettyPrint} and ${t2.prettyPrint}")
(t1, t2) match {
case (CurriedE(h1 :: _, _), CurriedE(_ :: _, _)) ⇒
subtypeVars(h1.t).forall { term ⇒
val result = extEqual(t1(term), t2(term))
if (!result) println(s"DEBUG: found inequality after substituting term ${term.prettyPrint}")
result
}
case _ ⇒ false
}
}
)
}
}

sealed trait TermExpr {
Expand Down Expand Up @@ -507,12 +544,33 @@ sealed trait TermExpr {
"(" + leftZeros.mkString(" + ") + leftZerosString + term.prettyPrintWithParentheses(0) + rightZerosString + rightZeros.mkString(" + ") + ")"
}

lazy val printScala: String = printScalaWithTypes()

private[ch] def printScalaWithTypes(withTypes: Boolean = false): String = this match {
case VarE(name, _) ⇒ name + (if (withTypes) ": " + t.prettyPrint else "")
case AppE(head, arg) ⇒
val h = head.printScalaWithTypes(true)
val b = arg.printScalaWithTypes()
s"$h($b)"
case CurriedE(heads, body) ⇒
s"${heads.map(_.printScalaWithTypes(true)).mkString(" ⇒ ")} ⇒ ${body.printScalaWithTypes()}"
case UnitE(_) ⇒ "()"
case ConjunctE(terms) ⇒ "(" + terms.map(_.printScalaWithTypes()).mkString(", ") + ")"
case NamedConjunctE(terms, tExpr) ⇒ if (tExpr.wrapped.isEmpty) tExpr.constructor.toString
else s"${tExpr.constructor.toString}(${terms.map(_.printScalaWithTypes()).mkString(", ")})"
case ProjectE(index, term) ⇒ term.printScalaWithTypes() + "." + term.accessor(index)
case MatchE(term, cases) ⇒
term.printScalaWithTypes() + " match { case " + cases.map(_.printScalaWithTypes(true)).mkString("; case ") + " }"
case DisjunctE(index, total, term, _) ⇒
term.printScalaWithTypes()
}

private def prettyVars: Iterator[String] = for {
number ← Iterator.single("") ++ Iterator.from(1).map(_.toString)
letter ← ('a' to 'z').toIterator
} yield s"$letter$number"

private lazy val renameBoundVars: TermExpr = TermExpr.substMap(this) {
private[ch] lazy val renameBoundVars: TermExpr = TermExpr.substMap(this) {
case CurriedE(heads, body) ⇒
val oldAndNewVars = heads.map { v ⇒ (v, VarE(TermExpr.freshIdents(), v.t)) }
val renamedBody = oldAndNewVars.foldLeft(body.renameBoundVars) { case (prev, (oldVar, newVar)) ⇒
Expand Down Expand Up @@ -803,36 +861,57 @@ final case class MatchE(term: TermExpr, cases: List[TermExpr]) extends TermExpr
}

private[ch] override def simplifyOnceInternal(withEta: Boolean): TermExpr = {
lazy val casesSimplified = cases.map(_.simplifyOnce(withEta))
val ncases = cases.length
term.simplifyOnce(withEta) match {
// Match a fixed part of the disjunction; can be simplified to just one clause.
// Example: Left(a) match { case Left(x) => f(x); case Right(y) => ... } can be simplified to just f(a).
case DisjunctE(index, total, termInjected, _) ⇒
if (total === cases.length) {
if (total === ncases) {
AppE(cases(index).simplifyOnce(withEta), termInjected).simplifyOnce(withEta)
} else throw new Exception(s"Internal error: MatchE with ${cases.length} cases applied to DisjunctE with $total parts, but must be of equal size")
} else throw new Exception(s"Internal error: MatchE with $ncases cases applied to DisjunctE with $total parts, but must be of equal size")

// Match of an inner match, can be simplified to a single match.
// Example: (Left(a) match { case Left(x) ⇒ ...; case Right(y) ⇒ ... }) match { case ... ⇒ ... }
// can be simplified to Left(a) match { case Left(x) ⇒ ... match { case ... ⇒ ...}; case Right(y) ⇒ ... match { case ... ⇒ ... } }
// Example: (q match { case Left(x) ⇒ ...; case Right(y) ⇒ ... }) match { case ... ⇒ ... }
// can be simplified to q match { case Left(x) ⇒ ... match { case ... ⇒ ...}; case Right(y) ⇒ ... match { case ... ⇒ ... } }
case MatchE(innerTerm, innerCases) ⇒
MatchE(innerTerm, innerCases map { case CurriedE(List(head), body) ⇒ CurriedE(List(head), MatchE(body, cases)) })
.simplifyOnce(withEta)

// Detect the identity patterns:
// MatchE(_, List(a ⇒ DisjunctE(0, total, a, _), a ⇒ DisjunctE(1, total, a, _), ...))
// MatchE(_, a: T1 ⇒ DisjunctE(i, total, NamedConjunctE(List(ProjectE(0, a), Project(1, a), ...), T1), ...), _)
case termSimplified ⇒
if (cases.nonEmpty && {

// Replace redundant matches on the same term, can be simplified by eliminating one match subexpresssion.
// Example: q match { case x ⇒ q match { case y ⇒ b; case other ⇒ ... } ... }
// We already know that q was matched as Left(x). Therefore, we can replace y by x in b and remove the `case other` clause altogether.
// Doing a .renameBoundVars on the cases leads to infinite loops somewhere due to incorrect alpha-conversion.
val casesSimplified = cases.map(_.simplifyOnce(withEta))
/*
.zipWithIndex.map { case (c@CurriedE(List(headVar), _), i) ⇒
TermExpr.substMap(c) {
case MatchE(otherTerm, otherCases) if otherTerm === termSimplified ⇒
// We already matched `otherTerm`, and we are now in case `c`, which is `case x ⇒ ...`.
// Therefore we can discard any of the `otherCases` except the one corresponding to `c`.
// We can replace the `q match { case y ⇒ b; ...}` by `b` after replacing `x` by `y` in `b`.
val remainingCase = otherCases(i)
val result = AppE(remainingCase, headVar).simplifyOnce(withEta)
// println(s"DEBUG: replacing ${MatchE(otherTerm, otherCases)} by $result in ${c.simplifyOnce(withEta)}")
result
}
}
*/
if (casesSimplified.nonEmpty && {
casesSimplified.zipWithIndex.forall {
// Detect a ⇒ a pattern
case (CurriedE(List(head@VarE(_, _)), body@VarE(_, _)), _)
if head.name === body.name
⇒ true
case (CurriedE(List(head@VarE(_, _)), DisjunctE(i, len, x, _)), ind)
if x === head && len === cases.length && ind === i
if x === head && len === ncases && ind === i
⇒ true
case (CurriedE(List(head@VarE(_, headT)), DisjunctE(i, len, NamedConjunctE(projectionTerms, conjT), _)), ind) ⇒
len === cases.length && ind === i && headT === conjT &&
len === ncases && ind === i && headT === conjT &&
projectionTerms.zipWithIndex.forall {
case (ProjectE(k, head1), j) if k === j && head1 === head ⇒ true
case _ ⇒ false
Expand Down
3 changes: 2 additions & 1 deletion src/main/scala/io/chymyst/ch/TheoremProver.scala
Original file line number Diff line number Diff line change
Expand Up @@ -93,9 +93,10 @@ object TheoremProver {
val transformedProofs = explodedNewProofs.map(ruleResult.backTransform)
val t1 = System.currentTimeMillis()

val result = transformedProofs.sortBy(_.informationLossScore).take(maxTermsToSelect(sequent))
val result = transformedProofs.map(_.simplifyOnce(withEta = false)).distinct.sortBy(_.informationLossScore).take(maxTermsToSelect(sequent))
// Note: at this point, it is a mistake to do prettyRename, because we are calling this function recursively.
// We will call prettyRename() at the very end of the proof search.
// It is also a mistake to do a `.simplifyOnce(withEta = true)`. The eta-conversion produces incorrect code here.
if (debug) {
println(s"DEBUG: elapsed ${System.currentTimeMillis() - t0} ms, .map(_.simplify()).distinct took ${System.currentTimeMillis() - t1} ms, produced ${result.size} terms out of ${transformedProofs.size} back-transformed terms; after rule ${ruleResult.ruleName} for sequent $sequent")
// println(s"DEBUG: for sequent $sequent, after rule ${ruleResult.ruleName}, transformed ${transformedProofs.length} proof terms:\n ${transformedProofs.mkString(" ;\n ")} ,\nafter simplifying:\n ${result.mkString(" ;\n ")} .")
Expand Down
75 changes: 75 additions & 0 deletions src/main/scala/io/chymyst/ch/data/CategoryTheory.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
package io.chymyst.ch.data

// Declarations of standard type classes, to be used in macros.

trait Semigroup[T] {
def combine(x: T, y: T): T
}

trait Monoid[T] extends Semigroup[T] {
def empty: T
}

object Monoid {
def empty[T](implicit ev: Monoid[T]): T = ev.empty

implicit class MonoidSyntax[T](t: T)(implicit ev: Monoid[T]) {

def combine(y: T): T = ev.combine(t, y)
}

}

trait Functor[F[_]] {
def map[A, B](fa: F[A])(f: A ⇒ B): F[B]
}

trait ContraFunctor[F[_]] {
def map[A, B](fa: F[A])(f: B ⇒ A): F[B]
}

trait Filterable[F[_]] extends Functor[F] {
def deflate[A](fa: F[Option[A]]): F[A]
}

trait ContraFilterable[F[_]] extends ContraFunctor[F] {
def inflate[A](fa: F[A]): F[Option[A]]
}

trait Semimonad[F[_]] extends Functor[F] {
def join[A](ffa: F[F[A]]): F[A]
}

trait Pointed[F[_]] extends Functor[F] {
def point[A]: F[A]
}

trait Zippable[F[_]] extends Functor[F] {
def zip[A, B](fa: F[A], fb: F[B]): F[(A, B)]
}

trait Foldable[F[_]] extends Functor[F] {
def foldMap[A, B: Monoid](fa: F[A])(f: A ⇒ B)
}

trait Traversable[F[_]] extends Functor[F] {
def sequence[Z[_] : Zippable, A](fga: F[Z[A]]): Z[F[A]]
}

trait Monad[F[_]] extends Pointed[F] with Semimonad[F]

trait Applicative[F[_]] extends Pointed[F] with Zippable[F]

trait Cosemimonad[F[_]] extends Functor[F] {
def cojoin[A](fa: F[A]): F[F[A]]
}

trait Copointed[F[_]] extends Functor[F] {
def extract[A](fa: F[A]): A
}

trait Comonad[F[_]] extends Copointed[F] with Cosemimonad[F]

trait Cozippable[F[_]] extends Functor[F] {
def decide[A, B](fab: F[Either[A, B]]): Either[F[A], F[B]]
}
17 changes: 0 additions & 17 deletions src/main/scala/io/chymyst/ch/data/Monoid.scala

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,14 @@ package io.chymyst.ch.data

import io.chymyst.ch._

object LawChecking {
object SymbolicLawChecking {

def checkFlattenAssociativity(fmap: TermExpr, flatten: TermExpr): Boolean = {
// fmap ftn . ftn = ftn . ftn
val lhs = flatten :@@ flatten
val rhs = (fmap :@ flatten) :@@ flatten
// println(s"check associativity laws for flatten = ${flatten.prettyPrint}:\n\tlhs = ${lhs.simplify.prettyRenamePrint}\n\trhs = ${rhs.simplify.prettyRenamePrint}")
lhs equiv rhs
TermExpr.extEqual(lhs, rhs)
}

def checkPureFlattenLaws(fmap: TermExpr, pure: TermExpr, flatten: TermExpr): Boolean = {
Expand All @@ -23,7 +23,7 @@ object LawChecking {
val fpf = (fmap :@ pure) :@@ flatten

// println(s"check identity laws for pure = ${pure.prettyPrint} and flatten = ${flatten.prettyPrint}:\n\tlhs1 = ${pf.simplify.prettyPrint}\n\trhs1 = ${idFA.simplify.prettyPrint}\n\tlhs2 = ${fpf.simplify.prettyPrint}\n\trhs2 = ${idFA.simplify.prettyPrint}")
(pf equiv idFA) && (fpf equiv idFA)
TermExpr.extEqual(pf, idFA) && TermExpr.extEqual(fpf, idFA)
}

}
Loading