From c4bf6dfbfaf201a59375c35ec931e318061ae85d Mon Sep 17 00:00:00 2001 From: Sam Estep Date: Mon, 15 Jan 2024 10:50:04 -0500 Subject: [PATCH 1/3] Export `Nat` --- packages/core/src/impl.ts | 2 +- packages/core/src/index.ts | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/packages/core/src/impl.ts b/packages/core/src/impl.ts index 99b02e3..5f70185 100644 --- a/packages/core/src/impl.ts +++ b/packages/core/src/impl.ts @@ -90,7 +90,7 @@ type Zero = typeof zeroSymbol; export type Tan = Zero | Var; /** An abstract natural number, which can be used to index into a vector. */ -type Nat = number | symbol; +export type Nat = number | symbol; /** The portion of an abstract vector that can be directly indexed. */ interface VecIndex { diff --git a/packages/core/src/index.ts b/packages/core/src/index.ts index 589e750..49be4d8 100644 --- a/packages/core/src/index.ts +++ b/packages/core/src/index.ts @@ -3,6 +3,7 @@ export { Bools, Dual, Fn, + Nat, Nats, Null, Nulls, From b2e5d8e8bc255b8f734b3bc49eac46b7ea30ce26 Mon Sep 17 00:00:00 2001 From: Sam Estep Date: Mon, 15 Jan 2024 11:18:13 -0500 Subject: [PATCH 2/3] Implement index comparison ops --- crates/autodiff/src/lib.rs | 11 +++- crates/core/src/lib.rs | 8 +++ crates/interp/src/lib.rs | 7 +++ crates/transpose/src/lib.rs | 17 ++++++- crates/wasm/src/lib.rs | 6 +++ crates/web/src/lib.rs | 80 +++++++++++++++++++++++++++++ crates/web/src/pprint.rs | 6 +++ packages/core/src/impl.ts | 42 ++++++++++++++++ packages/core/src/index.test.ts | 89 +++++++++++++++++++++++++++++++++ packages/core/src/index.ts | 6 +++ 10 files changed, 270 insertions(+), 2 deletions(-) diff --git a/crates/autodiff/src/lib.rs b/crates/autodiff/src/lib.rs index 1d2a9eb..b8d6223 100644 --- a/crates/autodiff/src/lib.rs +++ b/crates/autodiff/src/lib.rs @@ -240,7 +240,16 @@ impl Autodiff<'_> { }, &Expr::Binary { op, left, right } => match op { // boring cases - Binop::And | Binop::Or | Binop::Iff | Binop::Xor => self.code.push(Instr { + Binop::And + | Binop::Or + | Binop::Iff + | Binop::Xor + | Binop::INeq + | Binop::ILt + | Binop::ILeq + | Binop::IEq + | Binop::IGt + | Binop::IGeq => self.code.push(Instr { var, expr: Expr::Binary { op, left, right }, }), diff --git a/crates/core/src/lib.rs b/crates/core/src/lib.rs index 478123e..ca40ca9 100644 --- a/crates/core/src/lib.rs +++ b/crates/core/src/lib.rs @@ -210,6 +210,14 @@ pub enum Binop { Iff, Xor, + // `Fin` -> `Fin` -> `Bool` + INeq, + ILt, + ILeq, + IEq, + IGt, + IGeq, + // `F64` -> `F64` -> `Bool` Neq, Lt, diff --git a/crates/interp/src/lib.rs b/crates/interp/src/lib.rs index 185c629..279ec0e 100644 --- a/crates/interp/src/lib.rs +++ b/crates/interp/src/lib.rs @@ -236,6 +236,13 @@ impl<'a, 'b, O: Opaque, T: Refs<'a, Opaque = O>> Interpreter<'a, 'b, O, T> { Binop::Iff => Val::Bool(x.bool() == y.bool()), Binop::Xor => Val::Bool(x.bool() != y.bool()), + Binop::INeq => Val::Bool(x.fin() != y.fin()), + Binop::ILt => Val::Bool(x.fin() < y.fin()), + Binop::ILeq => Val::Bool(x.fin() <= y.fin()), + Binop::IEq => Val::Bool(x.fin() == y.fin()), + Binop::IGt => Val::Bool(x.fin() > y.fin()), + Binop::IGeq => Val::Bool(x.fin() >= y.fin()), + Binop::Neq => Val::Bool(x.f64() != y.f64()), Binop::Lt => Val::Bool(x.f64() < y.f64()), Binop::Leq => Val::Bool(x.f64() <= y.f64()), diff --git a/crates/transpose/src/lib.rs b/crates/transpose/src/lib.rs index 40c3e35..92d4017 100644 --- a/crates/transpose/src/lib.rs +++ b/crates/transpose/src/lib.rs @@ -630,6 +630,12 @@ impl<'a> Transpose<'a> { | Binop::Or | Binop::Iff | Binop::Xor + | Binop::INeq + | Binop::ILt + | Binop::ILeq + | Binop::IEq + | Binop::IGt + | Binop::IGeq | Binop::Neq | Binop::Lt | Binop::Leq @@ -704,7 +710,16 @@ impl<'a> Transpose<'a> { } _ => { let (a, b) = match op { - Binop::And | Binop::Or | Binop::Iff | Binop::Xor => (left, right), + Binop::And + | Binop::Or + | Binop::Iff + | Binop::Xor + | Binop::INeq + | Binop::ILt + | Binop::ILeq + | Binop::IEq + | Binop::IGt + | Binop::IGeq => (left, right), Binop::Neq | Binop::Lt | Binop::Leq diff --git a/crates/wasm/src/lib.rs b/crates/wasm/src/lib.rs index 4107cba..3794970 100644 --- a/crates/wasm/src/lib.rs +++ b/crates/wasm/src/lib.rs @@ -782,6 +782,12 @@ impl<'a, 'b, O: Eq + Hash, T: Refs<'a, Opaque = O>> Codegen<'a, 'b, O, T> { Binop::Or => self.wasm.instruction(&Instruction::I32Or), Binop::Iff => self.wasm.instruction(&Instruction::I32Eq), Binop::Xor => self.wasm.instruction(&Instruction::I32Xor), + Binop::INeq => self.wasm.instruction(&Instruction::I32Ne), + Binop::ILt => self.wasm.instruction(&Instruction::I32LtU), + Binop::ILeq => self.wasm.instruction(&Instruction::I32LeU), + Binop::IEq => self.wasm.instruction(&Instruction::I32Eq), + Binop::IGt => self.wasm.instruction(&Instruction::I32GtU), + Binop::IGeq => self.wasm.instruction(&Instruction::I32GeU), Binop::Neq => self.wasm.instruction(&Instruction::F64Ne), Binop::Lt => self.wasm.instruction(&Instruction::F64Lt), Binop::Leq => self.wasm.instruction(&Instruction::F64Le), diff --git a/crates/web/src/lib.rs b/crates/web/src/lib.rs index 34623c8..c0951ae 100644 --- a/crates/web/src/lib.rs +++ b/crates/web/src/lib.rs @@ -1318,6 +1318,86 @@ impl Block { self.instr(f, t, expr) } + /// Return the variable ID for a new "integer not equal" instruction on `left` and `right`. + /// + /// Assumes `left` and `right` are defined, in scope, and have the same `Fin` type. + pub fn ineq(&mut self, f: &mut FuncBuilder, left: usize, right: usize) -> usize { + let t = id::ty(f.ty_bool()); + let expr = rose::Expr::Binary { + op: rose::Binop::INeq, + left: id::var(left), + right: id::var(right), + }; + self.instr(f, t, expr) + } + + /// Return the variable ID for a new "integer less than" instruction on `left` and `right`. + /// + /// Assumes `left` and `right` are defined, in scope, and have the same `Fin` type. + pub fn ilt(&mut self, f: &mut FuncBuilder, left: usize, right: usize) -> usize { + let t = id::ty(f.ty_bool()); + let expr = rose::Expr::Binary { + op: rose::Binop::ILt, + left: id::var(left), + right: id::var(right), + }; + self.instr(f, t, expr) + } + + /// Return the variable ID for a new "integer less than or equal" instruction on `left` and + /// `right`. + /// + /// Assumes `left` and `right` are defined, in scope, and have the same `Fin` type. + pub fn ileq(&mut self, f: &mut FuncBuilder, left: usize, right: usize) -> usize { + let t = id::ty(f.ty_bool()); + let expr = rose::Expr::Binary { + op: rose::Binop::ILeq, + left: id::var(left), + right: id::var(right), + }; + self.instr(f, t, expr) + } + + /// Return the variable ID for a new "integer equal" instruction on `left` and `right`. + /// + /// Assumes `left` and `right` are defined, in scope, and have the same `Fin` type. + pub fn ieq(&mut self, f: &mut FuncBuilder, left: usize, right: usize) -> usize { + let t = id::ty(f.ty_bool()); + let expr = rose::Expr::Binary { + op: rose::Binop::IEq, + left: id::var(left), + right: id::var(right), + }; + self.instr(f, t, expr) + } + + /// Return the variable ID for a new "integer greater than" instruction on `left` and `right`. + /// + /// Assumes `left` and `right` are defined, in scope, and have the same `Fin` type. + pub fn igt(&mut self, f: &mut FuncBuilder, left: usize, right: usize) -> usize { + let t = id::ty(f.ty_bool()); + let expr = rose::Expr::Binary { + op: rose::Binop::IGt, + left: id::var(left), + right: id::var(right), + }; + self.instr(f, t, expr) + } + + /// Return the variable ID for a new "integer greater than or equal" instruction on `left` and + /// `right`. + /// + /// Assumes `left` and `right` are defined, in scope, and have the same `Fin` type. + pub fn igeq(&mut self, f: &mut FuncBuilder, left: usize, right: usize) -> usize { + let t = id::ty(f.ty_bool()); + let expr = rose::Expr::Binary { + op: rose::Binop::IGeq, + left: id::var(left), + right: id::var(right), + }; + self.instr(f, t, expr) + } + /// Return the variable ID for a new "not equal" instruction on `left` and `right`. /// /// Assumes `left` and `right` are defined, in scope, and have 64-bit floating point type. diff --git a/crates/web/src/pprint.rs b/crates/web/src/pprint.rs index df38690..d043c68 100644 --- a/crates/web/src/pprint.rs +++ b/crates/web/src/pprint.rs @@ -171,6 +171,12 @@ impl<'a, O: Eq + Hash, T: Refs<'a, Opaque = O>> Function<'a, '_, O, T> { Binop::Or => writeln!(f, "x{} or x{}", left.var(), right.var())?, Binop::Iff => writeln!(f, "x{} iff x{}", left.var(), right.var())?, Binop::Xor => writeln!(f, "x{} xor x{}", left.var(), right.var())?, + Binop::INeq => writeln!(f, "x{} != x{}", left.var(), right.var())?, + Binop::ILt => writeln!(f, "x{} < x{}", left.var(), right.var())?, + Binop::ILeq => writeln!(f, "x{} <= x{}", left.var(), right.var())?, + Binop::IEq => writeln!(f, "x{} == x{}", left.var(), right.var())?, + Binop::IGt => writeln!(f, "x{} > x{}", left.var(), right.var())?, + Binop::IGeq => writeln!(f, "x{} >= x{}", left.var(), right.var())?, Binop::Neq => writeln!(f, "x{} != x{}", left.var(), right.var())?, Binop::Lt => writeln!(f, "x{} < x{}", left.var(), right.var())?, Binop::Leq => writeln!(f, "x{} <= x{}", left.var(), right.var())?, diff --git a/packages/core/src/impl.ts b/packages/core/src/impl.ts index 5f70185..8df72af 100644 --- a/packages/core/src/impl.ts +++ b/packages/core/src/impl.ts @@ -954,6 +954,48 @@ export const xor = (p: Bool, q: Bool): Bool => { return newVar(ctx.block.xor(ctx.func, boolId(ctx, p), boolId(ctx, q))); }; +/** Return an abstract boolean for if `i` is not equal to `j`. */ +export const ineq = (ty: Nats, i: Nat, j: Nat): Bool => { + const ctx = getCtx(); + const t = tyId(ctx, ty); + return newVar(ctx.block.ineq(ctx.func, valId(ctx, t, i), valId(ctx, t, j))); +}; + +/** Return an abstract boolean for if `i` is less than `j`. */ +export const ilt = (ty: Nats, i: Nat, j: Nat): Bool => { + const ctx = getCtx(); + const t = tyId(ctx, ty); + return newVar(ctx.block.ilt(ctx.func, valId(ctx, t, i), valId(ctx, t, j))); +}; + +/** Return an abstract boolean for if `i` is less than or equal to `j`. */ +export const ileq = (ty: Nats, i: Nat, j: Nat): Bool => { + const ctx = getCtx(); + const t = tyId(ctx, ty); + return newVar(ctx.block.ileq(ctx.func, valId(ctx, t, i), valId(ctx, t, j))); +}; + +/** Return an abstract boolean for if `i` is equal to `j`. */ +export const ieq = (ty: Nats, i: Nat, j: Nat): Bool => { + const ctx = getCtx(); + const t = tyId(ctx, ty); + return newVar(ctx.block.ieq(ctx.func, valId(ctx, t, i), valId(ctx, t, j))); +}; + +/** Return an abstract boolean for if `i` is greater than `j`. */ +export const igt = (ty: Nats, i: Nat, j: Nat): Bool => { + const ctx = getCtx(); + const t = tyId(ctx, ty); + return newVar(ctx.block.igt(ctx.func, valId(ctx, t, i), valId(ctx, t, j))); +}; + +/** Return an abstract boolean for if `i` is greater than or equal to `j`. */ +export const igeq = (ty: Nats, i: Nat, j: Nat): Bool => { + const ctx = getCtx(); + const t = tyId(ctx, ty); + return newVar(ctx.block.igeq(ctx.func, valId(ctx, t, i), valId(ctx, t, j))); +}; + /** Return an abstract value selecting between `then` and `els` via `cond`. */ export const select = ( cond: Bool, diff --git a/packages/core/src/index.test.ts b/packages/core/src/index.test.ts index 4528f64..a603b7b 100644 --- a/packages/core/src/index.test.ts +++ b/packages/core/src/index.test.ts @@ -14,7 +14,13 @@ import { floor, fn, gt, + ieq, iff, + igeq, + igt, + ileq, + ilt, + ineq, interp, jvp, mul, @@ -942,4 +948,87 @@ describe("valid", () => { const h = await compile(g); expect(h({ v: [2], i: 0 })).toEqual({ v: [1], i: 0 }); }); + + test("index comparison", async () => { + const f = fn( + [2, 2], + { neq: Bool, lt: Bool, leq: Bool, eq: Bool, gt: Bool, geq: Bool }, + (i, j) => ({ + neq: ineq(2, i, j), + lt: ilt(2, i, j), + leq: ileq(2, i, j), + eq: ieq(2, i, j), + gt: igt(2, i, j), + geq: igeq(2, i, j), + }), + ); + + let g = interp(f); + expect(g(0, 0)).toEqual({ + neq: false, + lt: false, + leq: true, + eq: true, + gt: false, + geq: true, + }); + expect(g(0, 1)).toEqual({ + neq: true, + lt: true, + leq: true, + eq: false, + gt: false, + geq: false, + }); + expect(g(1, 0)).toEqual({ + neq: true, + lt: false, + leq: false, + eq: false, + gt: true, + geq: true, + }); + expect(g(1, 1)).toEqual({ + neq: false, + lt: false, + leq: true, + eq: true, + gt: false, + geq: true, + }); + + g = await compile(f); + expect(g(0, 0)).toEqual({ + neq: false, + lt: false, + leq: true, + eq: true, + gt: false, + geq: true, + }); + expect(g(0, 1)).toEqual({ + neq: true, + lt: true, + leq: true, + eq: false, + gt: false, + geq: false, + }); + expect(g(1, 0)).toEqual({ + neq: true, + lt: false, + leq: false, + eq: false, + gt: true, + geq: true, + }); + expect(g(1, 1)).toEqual({ + neq: false, + lt: false, + leq: true, + eq: true, + gt: false, + geq: true, + }); + }); }); diff --git a/packages/core/src/index.ts b/packages/core/src/index.ts index 49be4d8..2b8dc66 100644 --- a/packages/core/src/index.ts +++ b/packages/core/src/index.ts @@ -26,7 +26,13 @@ export { fn, geq, gt, + ieq, iff, + igeq, + igt, + ileq, + ilt, + ineq, interp, jvp, leq, From e3c8633baf9f863b7fd6782fa4c14808029a13fa Mon Sep 17 00:00:00 2001 From: Sam Estep Date: Mon, 15 Jan 2024 11:32:14 -0500 Subject: [PATCH 3/3] Implement index addition --- crates/autodiff/src/lib.rs | 3 ++- crates/core/src/lib.rs | 3 +++ crates/interp/src/lib.rs | 2 ++ crates/transpose/src/lib.rs | 4 +++- crates/wasm/src/lib.rs | 1 + crates/web/src/lib.rs | 24 ++++++++++++++++++------ crates/web/src/pprint.rs | 1 + packages/core/src/impl.ts | 8 ++++++++ packages/core/src/index.test.ts | 21 +++++++++++++++++++++ packages/core/src/index.ts | 1 + 10 files changed, 60 insertions(+), 8 deletions(-) diff --git a/crates/autodiff/src/lib.rs b/crates/autodiff/src/lib.rs index b8d6223..cfde52c 100644 --- a/crates/autodiff/src/lib.rs +++ b/crates/autodiff/src/lib.rs @@ -249,7 +249,8 @@ impl Autodiff<'_> { | Binop::ILeq | Binop::IEq | Binop::IGt - | Binop::IGeq => self.code.push(Instr { + | Binop::IGeq + | Binop::IAdd => self.code.push(Instr { var, expr: Expr::Binary { op, left, right }, }), diff --git a/crates/core/src/lib.rs b/crates/core/src/lib.rs index ca40ca9..bceba08 100644 --- a/crates/core/src/lib.rs +++ b/crates/core/src/lib.rs @@ -218,6 +218,9 @@ pub enum Binop { IGt, IGeq, + // `Fin` -> `Fin` -> `Fin` + IAdd, + // `F64` -> `F64` -> `Bool` Neq, Lt, diff --git a/crates/interp/src/lib.rs b/crates/interp/src/lib.rs index 279ec0e..7aa6fc8 100644 --- a/crates/interp/src/lib.rs +++ b/crates/interp/src/lib.rs @@ -243,6 +243,8 @@ impl<'a, 'b, O: Opaque, T: Refs<'a, Opaque = O>> Interpreter<'a, 'b, O, T> { Binop::IGt => Val::Bool(x.fin() > y.fin()), Binop::IGeq => Val::Bool(x.fin() >= y.fin()), + Binop::IAdd => Val::Fin(x.fin() + y.fin()), + Binop::Neq => Val::Bool(x.f64() != y.f64()), Binop::Lt => Val::Bool(x.f64() < y.f64()), Binop::Leq => Val::Bool(x.f64() <= y.f64()), diff --git a/crates/transpose/src/lib.rs b/crates/transpose/src/lib.rs index 92d4017..ba1a901 100644 --- a/crates/transpose/src/lib.rs +++ b/crates/transpose/src/lib.rs @@ -636,6 +636,7 @@ impl<'a> Transpose<'a> { | Binop::IEq | Binop::IGt | Binop::IGeq + | Binop::IAdd | Binop::Neq | Binop::Lt | Binop::Leq @@ -719,7 +720,8 @@ impl<'a> Transpose<'a> { | Binop::ILeq | Binop::IEq | Binop::IGt - | Binop::IGeq => (left, right), + | Binop::IGeq + | Binop::IAdd => (left, right), Binop::Neq | Binop::Lt | Binop::Leq diff --git a/crates/wasm/src/lib.rs b/crates/wasm/src/lib.rs index 3794970..5b088f6 100644 --- a/crates/wasm/src/lib.rs +++ b/crates/wasm/src/lib.rs @@ -788,6 +788,7 @@ impl<'a, 'b, O: Eq + Hash, T: Refs<'a, Opaque = O>> Codegen<'a, 'b, O, T> { Binop::IEq => self.wasm.instruction(&Instruction::I32Eq), Binop::IGt => self.wasm.instruction(&Instruction::I32GtU), Binop::IGeq => self.wasm.instruction(&Instruction::I32GeU), + Binop::IAdd => self.wasm.instruction(&Instruction::I32Add), Binop::Neq => self.wasm.instruction(&Instruction::F64Ne), Binop::Lt => self.wasm.instruction(&Instruction::F64Lt), Binop::Leq => self.wasm.instruction(&Instruction::F64Le), diff --git a/crates/web/src/lib.rs b/crates/web/src/lib.rs index c0951ae..aa58bd0 100644 --- a/crates/web/src/lib.rs +++ b/crates/web/src/lib.rs @@ -1318,7 +1318,7 @@ impl Block { self.instr(f, t, expr) } - /// Return the variable ID for a new "integer not equal" instruction on `left` and `right`. + /// Return the variable ID for a new "index not equal" instruction on `left` and `right`. /// /// Assumes `left` and `right` are defined, in scope, and have the same `Fin` type. pub fn ineq(&mut self, f: &mut FuncBuilder, left: usize, right: usize) -> usize { @@ -1331,7 +1331,7 @@ impl Block { self.instr(f, t, expr) } - /// Return the variable ID for a new "integer less than" instruction on `left` and `right`. + /// Return the variable ID for a new "index less than" instruction on `left` and `right`. /// /// Assumes `left` and `right` are defined, in scope, and have the same `Fin` type. pub fn ilt(&mut self, f: &mut FuncBuilder, left: usize, right: usize) -> usize { @@ -1344,7 +1344,7 @@ impl Block { self.instr(f, t, expr) } - /// Return the variable ID for a new "integer less than or equal" instruction on `left` and + /// Return the variable ID for a new "index less than or equal" instruction on `left` and /// `right`. /// /// Assumes `left` and `right` are defined, in scope, and have the same `Fin` type. @@ -1358,7 +1358,7 @@ impl Block { self.instr(f, t, expr) } - /// Return the variable ID for a new "integer equal" instruction on `left` and `right`. + /// Return the variable ID for a new "index equal" instruction on `left` and `right`. /// /// Assumes `left` and `right` are defined, in scope, and have the same `Fin` type. pub fn ieq(&mut self, f: &mut FuncBuilder, left: usize, right: usize) -> usize { @@ -1371,7 +1371,7 @@ impl Block { self.instr(f, t, expr) } - /// Return the variable ID for a new "integer greater than" instruction on `left` and `right`. + /// Return the variable ID for a new "index greater than" instruction on `left` and `right`. /// /// Assumes `left` and `right` are defined, in scope, and have the same `Fin` type. pub fn igt(&mut self, f: &mut FuncBuilder, left: usize, right: usize) -> usize { @@ -1384,7 +1384,7 @@ impl Block { self.instr(f, t, expr) } - /// Return the variable ID for a new "integer greater than or equal" instruction on `left` and + /// Return the variable ID for a new "index greater than or equal" instruction on `left` and /// `right`. /// /// Assumes `left` and `right` are defined, in scope, and have the same `Fin` type. @@ -1398,6 +1398,18 @@ impl Block { self.instr(f, t, expr) } + /// Return the variable ID for a new "index add" instruction on `left` and `right`. + /// + /// Assumes `left` and `right` are defined, in scope, and have the same `Fin` type. + pub fn iadd(&mut self, f: &mut FuncBuilder, t: usize, left: usize, right: usize) -> usize { + let expr = rose::Expr::Binary { + op: rose::Binop::IAdd, + left: id::var(left), + right: id::var(right), + }; + self.instr(f, id::ty(t), expr) + } + /// Return the variable ID for a new "not equal" instruction on `left` and `right`. /// /// Assumes `left` and `right` are defined, in scope, and have 64-bit floating point type. diff --git a/crates/web/src/pprint.rs b/crates/web/src/pprint.rs index d043c68..c616cbc 100644 --- a/crates/web/src/pprint.rs +++ b/crates/web/src/pprint.rs @@ -177,6 +177,7 @@ impl<'a, O: Eq + Hash, T: Refs<'a, Opaque = O>> Function<'a, '_, O, T> { Binop::IEq => writeln!(f, "x{} == x{}", left.var(), right.var())?, Binop::IGt => writeln!(f, "x{} > x{}", left.var(), right.var())?, Binop::IGeq => writeln!(f, "x{} >= x{}", left.var(), right.var())?, + Binop::IAdd => writeln!(f, "x{} + x{}", left.var(), right.var())?, Binop::Neq => writeln!(f, "x{} != x{}", left.var(), right.var())?, Binop::Lt => writeln!(f, "x{} < x{}", left.var(), right.var())?, Binop::Leq => writeln!(f, "x{} <= x{}", left.var(), right.var())?, diff --git a/packages/core/src/impl.ts b/packages/core/src/impl.ts index 8df72af..3c0a135 100644 --- a/packages/core/src/impl.ts +++ b/packages/core/src/impl.ts @@ -996,6 +996,14 @@ export const igeq = (ty: Nats, i: Nat, j: Nat): Bool => { return newVar(ctx.block.igeq(ctx.func, valId(ctx, t, i), valId(ctx, t, j))); }; +/** Return the abstract index `i` plus the abstract index `y`. */ +export const iadd = (ty: Nats, i: Nat, j: Nat): Nat => { + const ctx = getCtx(); + const t = tyId(ctx, ty); + const k = ctx.block.iadd(ctx.func, t, valId(ctx, t, i), valId(ctx, t, j)); + return idVal(ctx, t, k) as Nat; +}; + /** Return an abstract value selecting between `then` and `els` via `cond`. */ export const select = ( cond: Bool, diff --git a/packages/core/src/index.test.ts b/packages/core/src/index.test.ts index a603b7b..77c59d7 100644 --- a/packages/core/src/index.test.ts +++ b/packages/core/src/index.test.ts @@ -14,6 +14,7 @@ import { floor, fn, gt, + iadd, ieq, iff, igeq, @@ -1031,4 +1032,24 @@ describe("valid", () => { geq: true, }); }); + + test("index addition", async () => { + const f = fn([3, 3], 3, (i, j) => iadd(3, i, j)); + + let g = interp(f); + expect(g(0, 0)).toBe(0); + expect(g(0, 1)).toBe(1); + expect(g(0, 2)).toBe(2); + expect(g(1, 0)).toBe(1); + expect(g(1, 1)).toBe(2); + expect(g(2, 0)).toBe(2); + + g = await compile(f); + expect(g(0, 0)).toBe(0); + expect(g(0, 1)).toBe(1); + expect(g(0, 2)).toBe(2); + expect(g(1, 0)).toBe(1); + expect(g(1, 1)).toBe(2); + expect(g(2, 0)).toBe(2); + }); }); diff --git a/packages/core/src/index.ts b/packages/core/src/index.ts index 2b8dc66..e803e2d 100644 --- a/packages/core/src/index.ts +++ b/packages/core/src/index.ts @@ -26,6 +26,7 @@ export { fn, geq, gt, + iadd, ieq, iff, igeq,