Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Be able to generate egg from python example #226

Merged
merged 6 commits into from
Oct 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,8 @@ jobs:
- uses: CodSpeedHQ/action@v3
with:
token: ${{ secrets.CODSPEED_TOKEN }}
run: uv run pytest -vvv -n auto
# allow updating snapshots due to indeterministic benchmarks
run: uv run pytest -vvv -n auto --snapshot-update

docs:
runs-on: ubuntu-latest
Expand Down
2 changes: 2 additions & 0 deletions docs/changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ _This project uses semantic versioning_
## UNRELEASED

- Upgrade dependencies including [egglog](https://github.com/egraphs-good/egglog/compare/saulshanabrook:egg-smol:a555b2f5e82c684442775cc1a5da94b71930113c...b0db06832264c9b22694bd3de2bdacd55bbe9e32)
- Fix bug with non glob star import
- Fix bug extracting functions

## 8.0.0 (2024-10-17)

Expand Down
2 changes: 1 addition & 1 deletion python/egglog/egraph_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -493,7 +493,7 @@ def from_expr(self, tp: JustTypeRef, term: bindings._Term) -> TypedExprDecl:
if term.name == "py-object":
call = bindings.termdag_term_to_expr(self.termdag, term)
expr_decl = PyObjectDecl(self.state.egraph.eval_py_object(call))
if term.name == "unstable-fn":
elif term.name == "unstable-fn":
# Get function name
fn_term, *arg_terms = term.args
fn_value = self.resolve_term(fn_term, JustTypeRef("String"))
Expand Down
9 changes: 4 additions & 5 deletions python/egglog/exp/array_api_jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,17 @@ def jit(fn: X) -> X:
# 1. Create variables for each of the two args in the functions
sig = inspect.signature(fn)
arg1, arg2 = sig.parameters.keys()

with EGraph() as egraph:
egraph = EGraph()
with egraph:
res = fn(NDArray.var(arg1), NDArray.var(arg2))
egraph.register(res)
egraph.run(array_api_numba_schedule)
res_optimized = egraph.extract(res)
egraph.display(split_primitive_outputs=True, n_inline_leaves=3)
# egraph.display(split_primitive_outputs=True, n_inline_leaves=3)

egraph = EGraph()
fn_program = ndarray_function_two(res_optimized, NDArray.var(arg1), NDArray.var(arg2))
egraph.register(fn_program)
egraph.run(array_api_program_gen_schedule)
fn = cast(X, egraph.eval(fn_program.py_object))
fn = cast(X, egraph.eval(egraph.extract(fn_program.py_object)))
fn.expr = res_optimized # type: ignore[attr-defined]
return fn
24 changes: 11 additions & 13 deletions python/egglog/exp/array_api_program_gen.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
# mypy: disable-error-code="empty-body"
from __future__ import annotations

import numpy as np

from egglog import *

from .array_api import *
Expand All @@ -13,9 +11,12 @@
# Depends on `np` as a global variable.
##

array_api_program_gen_ruleset = ruleset()
array_api_program_gen_ruleset = ruleset(name="array_api_program_gen_ruleset")
array_api_program_gen_eval_ruleset = ruleset(name="array_api_program_gen_eval_ruleset")

array_api_program_gen_schedule = array_api_program_gen_ruleset.saturate() + program_gen_ruleset.saturate()
array_api_program_gen_schedule = (
array_api_program_gen_ruleset | program_gen_ruleset | array_api_program_gen_eval_ruleset | eval_program_rulseset
).saturate()


@function
Expand Down Expand Up @@ -98,17 +99,14 @@ def _tuple_int_program(i: Int, ti: TupleInt, k: i64, idx_fn: Callable[[Int], Int
def ndarray_program(x: NDArray) -> Program: ...


@function
def ndarray_function_two(res: NDArray, l: NDArray, r: NDArray) -> Program: ...
@function(ruleset=array_api_program_gen_ruleset)
def ndarray_function_two_program(res: NDArray, l: NDArray, r: NDArray) -> Program:
return ndarray_program(res).function_two(ndarray_program(l), ndarray_program(r))


@array_api_program_gen_ruleset.register
def _ndarray_function_two(f: Program, res: NDArray, l: NDArray, r: NDArray, o: PyObject):
# When we have function, set the program and trigger it to be compiled
yield rule(eq(f).to(ndarray_function_two(res, l, r))).then(
union(f).with_(ndarray_program(res).function_two(ndarray_program(l), ndarray_program(r))),
f.eval_py_object({"np": np}),
)
@function(ruleset=array_api_program_gen_eval_ruleset)
def ndarray_function_two(res: NDArray, l: NDArray, r: NDArray) -> EvalProgram:
return EvalProgram(ndarray_function_two_program(res, l, r), {"np": np})


@function
Expand Down
40 changes: 23 additions & 17 deletions python/egglog/exp/program_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,18 @@ def parent(self) -> Program:
Only keeps the original parent, not any additional ones, so that each set of statements is only added once.
"""

@method(default=Unit())
def eval_py_object(self, globals: object) -> Unit:
@property
def is_identifer(self) -> Bool:
"""
Returns whether the expression is an identifier. Used so that we don't re-assign any identifiers.
"""


converter(String, Program, Program)


class EvalProgram(Expr):
def __init__(self, program: Program, globals: object) -> None:
"""
Evaluates the program and saves as the py_object
"""
Expand All @@ -98,38 +108,34 @@ def py_object(self) -> PyObject:
"""

@property
def is_identifer(self) -> Bool:
def statements(self) -> String:
"""
Returns whether the expression is an identifier. Used so that we don't re-assign any identifiers.
Returns the statements of the program, if it's been compiled
"""


converter(String, Program, Program)

program_gen_ruleset = ruleset()


@program_gen_ruleset.register
def _py_object(p: Program, expr: String, statements: String, g: PyObject):
@ruleset
def eval_program_rulseset(ep: EvalProgram, p: Program, expr: String, statements: String, g: PyObject):
# When we evaluate a program, we first want to compile to a string
yield rule(p.eval_py_object(g)).then(p.compile())
yield rule(EvalProgram(p, g)).then(p.compile())
# Then we want to evaluate the statements/expr
yield rule(
p.eval_py_object(g),
eq(ep).to(EvalProgram(p, g)),
eq(p.statements).to(statements),
eq(p.expr).to(expr),
).then(
set_(p.py_object).to(
set_(ep.py_object).to(
py_eval(
"l['___res']",
PyObject.dict(PyObject.from_string("l"), py_exec(join(statements, "\n", "___res = ", expr), g)),
)
)
),
set_(ep.statements).to(statements),
)


@program_gen_ruleset.register
def _compile(
@ruleset
def program_gen_ruleset(
s: String,
s1: String,
s2: String,
Expand Down
93 changes: 66 additions & 27 deletions python/tests/test_array_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis

from egglog.exp.array_api import *
from egglog.exp.array_api_jit import jit
from egglog.exp.array_api_numba import array_api_numba_schedule
from egglog.exp.array_api_program_gen import *

Expand Down Expand Up @@ -103,51 +104,69 @@ def _load_py_snapshot(fn: Callable, var: str | None = None) -> Any:
return globals[var]


def load_source(expr, egraph: EGraph):
with egraph:
fn_program = egraph.let("fn_program", ndarray_function_two(expr, NDArray.var("X"), NDArray.var("y")))
egraph.run(array_api_program_gen_schedule)
return egraph.eval(egraph.extract(fn_program.statements))
def load_source(fn_program: EvalProgram, egraph: EGraph):
egraph.register(fn_program)
egraph.run(array_api_program_gen_schedule)
# dp the needed pieces in here for benchmarking
return egraph.eval(egraph.extract(fn_program.py_object))


def trace_lda(egraph: EGraph):
X_arr = NDArray.var("X")
assume_dtype(X_arr, X_np.dtype)
assume_shape(X_arr, X_np.shape)
assume_isfinite(X_arr)
def lda(X, y):
assume_dtype(X, X_np.dtype)
assume_shape(X, X_np.shape)
assume_isfinite(X)

y_arr = NDArray.var("y")
assume_dtype(y_arr, y_np.dtype)
assume_shape(y_arr, y_np.shape)
assume_value_one_of(y_arr, tuple(map(int, np.unique(y_np)))) # type: ignore[arg-type]
assume_dtype(y, y_np.dtype)
assume_shape(y, y_np.shape)
assume_value_one_of(y, tuple(map(int, np.unique(y_np)))) # type: ignore[arg-type]
return run_lda(X, y)

with egraph:
return run_lda(X_arr, y_arr)

def simplify_lda(egraph: EGraph, expr: NDArray) -> NDArray:
egraph.register(expr)
egraph.run(array_api_numba_schedule)
return egraph.extract(expr)


@pytest.mark.benchmark(min_rounds=3)
class TestLDA:
"""
Incrementally benchmark each part of the LDA to see how long it takes to run.
"""

def test_trace(self, snapshot_py, benchmark):
X_r2 = benchmark(trace_lda, EGraph())
X = NDArray.var("X")
y = NDArray.var("y")
with EGraph():
X_r2 = benchmark(lda, X, y)
assert str(X_r2) == snapshot_py

def test_optimize(self, snapshot_py, benchmark):
egraph = EGraph()
expr = trace_lda(egraph)
simplified = benchmark(egraph.simplify, expr, array_api_numba_schedule)
X = NDArray.var("X")
y = NDArray.var("y")
with egraph:
expr = lda(X, y)
simplified = benchmark(simplify_lda, egraph, expr)
assert str(simplified) == snapshot_py

@pytest.mark.xfail(reason="Original source is not working")
def test_source(self, snapshot_py, benchmark):
egraph = EGraph()
expr = trace_lda(egraph)
assert benchmark(load_source, expr, egraph) == snapshot_py
# @pytest.mark.xfail(reason="Original source is not working")
# def test_source(self, snapshot_py, benchmark):
# egraph = EGraph()
# expr = trace_lda(egraph)
# assert benchmark(load_source, expr, egraph) == snapshot_py

def test_source_optimized(self, snapshot_py, benchmark):
egraph = EGraph()
expr = trace_lda(egraph)
optimized_expr = egraph.simplify(expr, array_api_numba_schedule)
assert benchmark(load_source, optimized_expr, egraph) == snapshot_py
X = NDArray.var("X")
y = NDArray.var("y")
with egraph:
expr = lda(X, y)
optimized_expr = simplify_lda(egraph, expr)
fn_program = ndarray_function_two(optimized_expr, NDArray.var("X"), NDArray.var("y"))
py_object = benchmark(load_source, fn_program, egraph)
assert np.allclose(py_object(X_np, y_np), res_np)
assert egraph.eval(fn_program.statements) == snapshot_py

@pytest.mark.parametrize(
"fn",
Expand All @@ -156,9 +175,29 @@ def test_source_optimized(self, snapshot_py, benchmark):
pytest.param(run_lda, id="array_api"),
pytest.param(_load_py_snapshot(test_source_optimized, "__fn"), id="array_api-optimized"),
pytest.param(numba.njit(_load_py_snapshot(test_source_optimized, "__fn")), id="array_api-optimized-numba"),
pytest.param(jit(lda), id="array_api-jit"),
],
)
def test_execution(self, fn, benchmark):
# warmup once for numba
assert np.allclose(res_np, fn(X_np, y_np))
benchmark(fn, X_np, y_np)


# if calling as script, print out egglog source for test
# similar to jit, but don't include pyobject parts so it works in vanilla egglog
if __name__ == "__main__":
print("Generating egglog source for test")
egraph = EGraph(save_egglog_string=True)
X_ = NDArray.var("X")
y_ = NDArray.var("y")
with egraph:
expr = lda(X_, y_)
optimized_expr = egraph.simplify(expr, array_api_numba_schedule)
fn_program = ndarray_function_two_program(optimized_expr, X_, y_)
egraph.register(fn_program.compile())
egraph.run(array_api_program_gen_ruleset.saturate() + program_gen_ruleset.saturate())
egraph.extract(fn_program.statements)
name = "python.egg"
print("Saving to", name)
Path(name).write_text(egraph.as_egglog_string)
9 changes: 5 additions & 4 deletions python/tests/test_program_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def test_to_string(snapshot_py) -> None:
egraph = EGraph()
egraph.register(fn)
egraph.register(fn.compile())
egraph.run(to_program_ruleset * 100 + program_gen_ruleset * 200)
egraph.run((to_program_ruleset | program_gen_ruleset).saturate())
# egraph.display(n_inline_leaves=1)
assert egraph.eval(fn.expr) == "my_fn"
assert egraph.eval(fn.statements) == snapshot_py
Expand All @@ -67,8 +67,9 @@ def test_py_object():
z = Math.var("z")
fn = (x + y + z).program.function_two(x.program, y.program)
egraph = EGraph()
egraph.register(fn.eval_py_object({"z": 10}))
egraph.run(to_program_ruleset * 100 + program_gen_ruleset * 100)
res = egraph.eval(fn.py_object)
evalled = EvalProgram(fn, {"z": 10})
egraph.register(evalled)
egraph.run((to_program_ruleset | eval_program_rulseset | program_gen_ruleset).saturate())
res = egraph.eval(evalled.py_object)
assert res(1, 2) == 13 # type: ignore[operator]
assert inspect.getsource(res) # type: ignore[arg-type]