Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support index arithmetic #124

Merged
merged 3 commits into from
Jan 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion crates/autodiff/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,17 @@ impl Autodiff<'_> {
},
&Expr::Binary { op, left, right } => match op {
// boring cases
Binop::And | Binop::Or | Binop::Iff | Binop::Xor => self.code.push(Instr {
Binop::And
| Binop::Or
| Binop::Iff
| Binop::Xor
| Binop::INeq
| Binop::ILt
| Binop::ILeq
| Binop::IEq
| Binop::IGt
| Binop::IGeq
| Binop::IAdd => self.code.push(Instr {
var,
expr: Expr::Binary { op, left, right },
}),
Expand Down
11 changes: 11 additions & 0 deletions crates/core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,17 @@ pub enum Binop {
Iff,
Xor,

// `Fin` -> `Fin` -> `Bool`
INeq,
ILt,
ILeq,
IEq,
IGt,
IGeq,

// `Fin` -> `Fin` -> `Fin`
IAdd,

// `F64` -> `F64` -> `Bool`
Neq,
Lt,
Expand Down
9 changes: 9 additions & 0 deletions crates/interp/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,15 @@ impl<'a, 'b, O: Opaque, T: Refs<'a, Opaque = O>> Interpreter<'a, 'b, O, T> {
Binop::Iff => Val::Bool(x.bool() == y.bool()),
Binop::Xor => Val::Bool(x.bool() != y.bool()),

Binop::INeq => Val::Bool(x.fin() != y.fin()),
Binop::ILt => Val::Bool(x.fin() < y.fin()),
Binop::ILeq => Val::Bool(x.fin() <= y.fin()),
Binop::IEq => Val::Bool(x.fin() == y.fin()),
Binop::IGt => Val::Bool(x.fin() > y.fin()),
Binop::IGeq => Val::Bool(x.fin() >= y.fin()),

Binop::IAdd => Val::Fin(x.fin() + y.fin()),

Binop::Neq => Val::Bool(x.f64() != y.f64()),
Binop::Lt => Val::Bool(x.f64() < y.f64()),
Binop::Leq => Val::Bool(x.f64() <= y.f64()),
Expand Down
19 changes: 18 additions & 1 deletion crates/transpose/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -630,6 +630,13 @@ impl<'a> Transpose<'a> {
| Binop::Or
| Binop::Iff
| Binop::Xor
| Binop::INeq
| Binop::ILt
| Binop::ILeq
| Binop::IEq
| Binop::IGt
| Binop::IGeq
| Binop::IAdd
| Binop::Neq
| Binop::Lt
| Binop::Leq
Expand Down Expand Up @@ -704,7 +711,17 @@ impl<'a> Transpose<'a> {
}
_ => {
let (a, b) = match op {
Binop::And | Binop::Or | Binop::Iff | Binop::Xor => (left, right),
Binop::And
| Binop::Or
| Binop::Iff
| Binop::Xor
| Binop::INeq
| Binop::ILt
| Binop::ILeq
| Binop::IEq
| Binop::IGt
| Binop::IGeq
| Binop::IAdd => (left, right),
Binop::Neq
| Binop::Lt
| Binop::Leq
Expand Down
7 changes: 7 additions & 0 deletions crates/wasm/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -782,6 +782,13 @@ impl<'a, 'b, O: Eq + Hash, T: Refs<'a, Opaque = O>> Codegen<'a, 'b, O, T> {
Binop::Or => self.wasm.instruction(&Instruction::I32Or),
Binop::Iff => self.wasm.instruction(&Instruction::I32Eq),
Binop::Xor => self.wasm.instruction(&Instruction::I32Xor),
Binop::INeq => self.wasm.instruction(&Instruction::I32Ne),
Binop::ILt => self.wasm.instruction(&Instruction::I32LtU),
Binop::ILeq => self.wasm.instruction(&Instruction::I32LeU),
Binop::IEq => self.wasm.instruction(&Instruction::I32Eq),
Binop::IGt => self.wasm.instruction(&Instruction::I32GtU),
Binop::IGeq => self.wasm.instruction(&Instruction::I32GeU),
Binop::IAdd => self.wasm.instruction(&Instruction::I32Add),
Binop::Neq => self.wasm.instruction(&Instruction::F64Ne),
Binop::Lt => self.wasm.instruction(&Instruction::F64Lt),
Binop::Leq => self.wasm.instruction(&Instruction::F64Le),
Expand Down
92 changes: 92 additions & 0 deletions crates/web/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1318,6 +1318,98 @@ impl Block {
self.instr(f, t, expr)
}

/// Return the variable ID for a new "index not equal" instruction on `left` and `right`.
///
/// Assumes `left` and `right` are defined, in scope, and have the same `Fin` type.
pub fn ineq(&mut self, f: &mut FuncBuilder, left: usize, right: usize) -> usize {
let t = id::ty(f.ty_bool());
let expr = rose::Expr::Binary {
op: rose::Binop::INeq,
left: id::var(left),
right: id::var(right),
};
self.instr(f, t, expr)
}

/// Return the variable ID for a new "index less than" instruction on `left` and `right`.
///
/// Assumes `left` and `right` are defined, in scope, and have the same `Fin` type.
pub fn ilt(&mut self, f: &mut FuncBuilder, left: usize, right: usize) -> usize {
let t = id::ty(f.ty_bool());
let expr = rose::Expr::Binary {
op: rose::Binop::ILt,
left: id::var(left),
right: id::var(right),
};
self.instr(f, t, expr)
}

/// Return the variable ID for a new "index less than or equal" instruction on `left` and
/// `right`.
///
/// Assumes `left` and `right` are defined, in scope, and have the same `Fin` type.
pub fn ileq(&mut self, f: &mut FuncBuilder, left: usize, right: usize) -> usize {
let t = id::ty(f.ty_bool());
let expr = rose::Expr::Binary {
op: rose::Binop::ILeq,
left: id::var(left),
right: id::var(right),
};
self.instr(f, t, expr)
}

/// Return the variable ID for a new "index equal" instruction on `left` and `right`.
///
/// Assumes `left` and `right` are defined, in scope, and have the same `Fin` type.
pub fn ieq(&mut self, f: &mut FuncBuilder, left: usize, right: usize) -> usize {
let t = id::ty(f.ty_bool());
let expr = rose::Expr::Binary {
op: rose::Binop::IEq,
left: id::var(left),
right: id::var(right),
};
self.instr(f, t, expr)
}

/// Return the variable ID for a new "index greater than" instruction on `left` and `right`.
///
/// Assumes `left` and `right` are defined, in scope, and have the same `Fin` type.
pub fn igt(&mut self, f: &mut FuncBuilder, left: usize, right: usize) -> usize {
let t = id::ty(f.ty_bool());
let expr = rose::Expr::Binary {
op: rose::Binop::IGt,
left: id::var(left),
right: id::var(right),
};
self.instr(f, t, expr)
}

/// Return the variable ID for a new "index greater than or equal" instruction on `left` and
/// `right`.
///
/// Assumes `left` and `right` are defined, in scope, and have the same `Fin` type.
pub fn igeq(&mut self, f: &mut FuncBuilder, left: usize, right: usize) -> usize {
let t = id::ty(f.ty_bool());
let expr = rose::Expr::Binary {
op: rose::Binop::IGeq,
left: id::var(left),
right: id::var(right),
};
self.instr(f, t, expr)
}

/// Return the variable ID for a new "index add" instruction on `left` and `right`.
///
/// Assumes `left` and `right` are defined, in scope, and have the same `Fin` type.
pub fn iadd(&mut self, f: &mut FuncBuilder, t: usize, left: usize, right: usize) -> usize {
let expr = rose::Expr::Binary {
op: rose::Binop::IAdd,
left: id::var(left),
right: id::var(right),
};
self.instr(f, id::ty(t), expr)
}

/// Return the variable ID for a new "not equal" instruction on `left` and `right`.
///
/// Assumes `left` and `right` are defined, in scope, and have 64-bit floating point type.
Expand Down
7 changes: 7 additions & 0 deletions crates/web/src/pprint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,13 @@ impl<'a, O: Eq + Hash, T: Refs<'a, Opaque = O>> Function<'a, '_, O, T> {
Binop::Or => writeln!(f, "x{} or x{}", left.var(), right.var())?,
Binop::Iff => writeln!(f, "x{} iff x{}", left.var(), right.var())?,
Binop::Xor => writeln!(f, "x{} xor x{}", left.var(), right.var())?,
Binop::INeq => writeln!(f, "x{} != x{}", left.var(), right.var())?,
Binop::ILt => writeln!(f, "x{} < x{}", left.var(), right.var())?,
Binop::ILeq => writeln!(f, "x{} <= x{}", left.var(), right.var())?,
Binop::IEq => writeln!(f, "x{} == x{}", left.var(), right.var())?,
Binop::IGt => writeln!(f, "x{} > x{}", left.var(), right.var())?,
Binop::IGeq => writeln!(f, "x{} >= x{}", left.var(), right.var())?,
Binop::IAdd => writeln!(f, "x{} + x{}", left.var(), right.var())?,
Binop::Neq => writeln!(f, "x{} != x{}", left.var(), right.var())?,
Binop::Lt => writeln!(f, "x{} < x{}", left.var(), right.var())?,
Binop::Leq => writeln!(f, "x{} <= x{}", left.var(), right.var())?,
Expand Down
52 changes: 51 additions & 1 deletion packages/core/src/impl.ts
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ type Zero = typeof zeroSymbol;
export type Tan = Zero | Var;

/** An abstract natural number, which can be used to index into a vector. */
type Nat = number | symbol;
export type Nat = number | symbol;

/** The portion of an abstract vector that can be directly indexed. */
interface VecIndex<T> {
Expand Down Expand Up @@ -954,6 +954,56 @@ export const xor = (p: Bool, q: Bool): Bool => {
return newVar(ctx.block.xor(ctx.func, boolId(ctx, p), boolId(ctx, q)));
};

/** Return an abstract boolean for if `i` is not equal to `j`. */
export const ineq = (ty: Nats, i: Nat, j: Nat): Bool => {
const ctx = getCtx();
const t = tyId(ctx, ty);
return newVar(ctx.block.ineq(ctx.func, valId(ctx, t, i), valId(ctx, t, j)));
};

/** Return an abstract boolean for if `i` is less than `j`. */
export const ilt = (ty: Nats, i: Nat, j: Nat): Bool => {
const ctx = getCtx();
const t = tyId(ctx, ty);
return newVar(ctx.block.ilt(ctx.func, valId(ctx, t, i), valId(ctx, t, j)));
};

/** Return an abstract boolean for if `i` is less than or equal to `j`. */
export const ileq = (ty: Nats, i: Nat, j: Nat): Bool => {
const ctx = getCtx();
const t = tyId(ctx, ty);
return newVar(ctx.block.ileq(ctx.func, valId(ctx, t, i), valId(ctx, t, j)));
};

/** Return an abstract boolean for if `i` is equal to `j`. */
export const ieq = (ty: Nats, i: Nat, j: Nat): Bool => {
const ctx = getCtx();
const t = tyId(ctx, ty);
return newVar(ctx.block.ieq(ctx.func, valId(ctx, t, i), valId(ctx, t, j)));
};

/** Return an abstract boolean for if `i` is greater than `j`. */
export const igt = (ty: Nats, i: Nat, j: Nat): Bool => {
const ctx = getCtx();
const t = tyId(ctx, ty);
return newVar(ctx.block.igt(ctx.func, valId(ctx, t, i), valId(ctx, t, j)));
};

/** Return an abstract boolean for if `i` is greater than or equal to `j`. */
export const igeq = (ty: Nats, i: Nat, j: Nat): Bool => {
const ctx = getCtx();
const t = tyId(ctx, ty);
return newVar(ctx.block.igeq(ctx.func, valId(ctx, t, i), valId(ctx, t, j)));
};

/** Return the abstract index `i` plus the abstract index `y`. */
export const iadd = (ty: Nats, i: Nat, j: Nat): Nat => {
const ctx = getCtx();
const t = tyId(ctx, ty);
const k = ctx.block.iadd(ctx.func, t, valId(ctx, t, i), valId(ctx, t, j));
return idVal(ctx, t, k) as Nat;
};

/** Return an abstract value selecting between `then` and `els` via `cond`. */
export const select = <const T>(
cond: Bool,
Expand Down
Loading