From bfd62539043b55e57f025dce972518f2fa639de6 Mon Sep 17 00:00:00 2001 From: Sam Estep Date: Tue, 1 Aug 2023 11:17:41 -0400 Subject: [PATCH] Remove branching and nested `Ref`s (#64) --- Cargo.lock | 1 + crates/core/src/id.rs | 40 --- crates/core/src/lib.rs | 132 +++++---- crates/frontend/src/translate.rs | 120 +++----- crates/frontend/tests/interp.rs | 11 +- crates/interp/Cargo.toml | 1 + crates/interp/src/lib.rs | 351 ++++++++++++----------- crates/web/src/lib.rs | 459 ++++++++++++++----------------- packages/core/src/bool.ts | 43 +-- packages/core/src/debug.test.ts | 134 ++++----- packages/core/src/ffi.ts | 20 +- packages/core/src/fn.ts | 21 +- packages/core/src/index.test.ts | 18 +- packages/core/src/index.ts | 2 +- packages/core/src/interp.ts | 2 +- 15 files changed, 563 insertions(+), 792 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index e069812..c6e6f36 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -661,6 +661,7 @@ dependencies = [ name = "rose-interp" version = "0.0.0" dependencies = [ + "enumset", "indexmap 2.0.0", "rose", "serde", diff --git a/crates/core/src/id.rs b/crates/core/src/id.rs index 79f5a57..aace964 100644 --- a/crates/core/src/id.rs +++ b/crates/core/src/id.rs @@ -86,26 +86,6 @@ impl Ty { } } -/// Index of an instantiated function reference in a definition context. -#[cfg_attr(test, derive(TS), ts(export))] -#[cfg_attr( - feature = "serde", - derive(Serialize, Deserialize), - serde(rename = "FuncId") -)] -#[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)] -pub struct Func(usize); - -pub fn func(id: usize) -> Func { - Func(id) -} - -impl Func { - pub fn func(self) -> usize { - self.0 - } -} - /// Index of a local variable in a function definition context. #[cfg_attr(test, derive(TS), ts(export))] #[cfg_attr( @@ -125,23 +105,3 @@ impl Var { self.0 } } - -/// Index of a block in a function definition context. -#[cfg_attr(test, derive(TS), ts(export))] -#[cfg_attr( - feature = "serde", - derive(Serialize, Deserialize), - serde(rename = "BlockId") -)] -#[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)] -pub struct Block(usize); - -pub fn block(id: usize) -> Block { - Block(id) -} - -impl Block { - pub fn block(self) -> usize { - self.0 - } -} diff --git a/crates/core/src/lib.rs b/crates/core/src/lib.rs index c48e593..1162bf6 100644 --- a/crates/core/src/lib.rs +++ b/crates/core/src/lib.rs @@ -11,14 +11,17 @@ use ts_rs::TS; /// A type constraint. #[cfg_attr(test, derive(TS), ts(export))] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] -#[derive(Debug, EnumSetType)] +#[allow(clippy::derived_hash_with_manual_eq)] // `PartialEq` impl comes from enumset; should be fine +#[derive(Debug, EnumSetType, Hash)] pub enum Constraint { + /// Not a `Ref`. + Value, /// Can be the `index` type of an `Array`. Index, - /// Has a zero value and an addition operation. - Vector, - /// Can be the `scope` type of a `Ref`. - Scope, + /// Allows a `Ref` to be read when used as its `scope` type. + Read, + /// Allows a `Ref` to be accumulated into when used as its `scope` type. + Accum, } /// A type. @@ -28,7 +31,6 @@ pub enum Constraint { pub enum Ty { Unit, Bool, - /// Satisfies `Constraint::Vector`. F64, /// A nonnegative integer less than `size`. Satisfies `Constraint::Index`. Fin { @@ -37,77 +39,52 @@ pub enum Ty { Generic { id: id::Generic, }, - /// Satisfies `Constraint::Scope`. Scope { - id: id::Block, + /// Must be either `Read` or `Accum`. + kind: Constraint, + /// The `arg` variable of the `Expr` introducing this scope. + id: id::Var, }, Ref { - /// Must satisfy `Constraint::Scope`. scope: id::Ty, inner: id::Ty, }, - /// Satisfies `Constraint::Vector` if `elem` does. Array { /// Must satisfy `Constraint::Index`. index: id::Ty, elem: id::Ty, }, - /// Satisfies `Constraint::Vector` if all `members` do. Tuple { - members: Vec, + members: Vec, // TODO: change to `Box<[id::Ty]` }, } -/// Reference to a function, with types supplied for its generic parameters. -#[derive(Debug)] -pub struct Func { - pub id: id::Function, - pub generics: Vec, -} - /// A function definition. #[derive(Debug)] pub struct Function { /// Generic type parameters. - pub generics: Vec>, + pub generics: Box<[EnumSet]>, /// Types used in this function definition. - pub types: Vec, - /// Instantiations referenced functions with generic type parameters. - pub funcs: Vec, - /// Parameter type. - pub param: id::Ty, - /// Return type. - pub ret: id::Ty, + pub types: Box<[Ty]>, /// Local variable types. - pub vars: Vec, - /// Blocks of code. - pub blocks: Vec, - /// Main block. - pub main: id::Block, + pub vars: Box<[id::Ty]>, + /// Parameter variables. + pub params: Box<[id::Var]>, + /// Return variable. + pub ret: id::Var, + /// Function body. + pub body: Box<[Instr]>, } /// Wrapper for a `Function` that knows how to resolve its `id::Function`s. pub trait FuncNode { fn def(&self) -> &Function; - /// Only valid with `id::Function`s from `self.def().funcs`. fn get(&self, id: id::Function) -> Option where Self: Sized; } -#[cfg_attr(test, derive(TS), ts(export))] -#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] -#[derive(Debug)] -pub struct Block { - /// Input variable to this block. - pub arg: id::Var, - pub code: Vec, - /// Output variable from this block. - pub ret: id::Var, -} - -#[cfg_attr(test, derive(TS), ts(export))] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] #[derive(Debug)] pub struct Instr { @@ -115,7 +92,6 @@ pub struct Instr { pub expr: Expr, } -#[cfg_attr(test, derive(TS), ts(export))] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] #[derive(Debug)] pub enum Expr { @@ -131,10 +107,10 @@ pub enum Expr { }, Array { - elems: Vec, + elems: Box<[id::Var]>, }, Tuple { - members: Vec, + members: Box<[id::Var]>, }, Index { @@ -154,7 +130,7 @@ pub enum Expr { Field { /// Must actually be a `Ref` of a tuple, not just a tuple. tuple: id::Var, - field: id::Member, + member: id::Member, }, Unary { @@ -166,38 +142,58 @@ pub enum Expr { left: id::Var, right: id::Var, }, + Select { + /// Must be of type `Bool`. + cond: id::Var, + then: id::Var, + els: id::Var, + }, Call { - func: id::Func, - arg: id::Var, - }, - If { - cond: id::Var, - /// `arg` has type `Unit`. - then: id::Block, - /// `arg` has type `Unit`. - els: id::Block, + id: id::Function, + generics: Box<[id::Ty]>, + args: Box<[id::Var]>, }, For { /// Must satisfy `Constraint::Index`. index: id::Ty, - /// `arg` has type `index`. - body: id::Block, + /// has type `index`. + arg: id::Var, + body: Box<[Instr]>, + /// Variable from `body` holding an array element. + ret: id::Var, }, - Accum { - /// Final contents of the `Ref`. + /// Scope for a `Ref` with `Constraint::Read`. Returns `Unit`. + Read { + /// Contents of the `Ref`. var: id::Var, - /// Must satisfy `Constraint::Vector`. - vector: id::Ty, - /// `arg` has type `Ref` with scope `body` and inner type `vector`. - body: id::Block, + /// Has type `Ref` with scope `arg` and inner type same as `var`. + arg: id::Var, + body: Box<[Instr]>, + /// Variable from `body` holding the result of this block; escapes into outer scope. + ret: id::Var, + }, + /// Scope for a `Ref` with `Constraint::Accum`. Returns the final contents of the `Ref`. + Accum { + /// Topology of the `Ref`. + shape: id::Var, + /// Has type `Ref` with scope `arg` and inner type same as `shape`. + arg: id::Var, + body: Box<[Instr]>, + /// Variable from `body` holding the result of this block; escapes into outer scope. + ret: id::Var, }, - /// Accumulate into a `Ref`. Returned type is `Unit`. + /// Read from a `Ref` whose `scope` satisfies `Constraint::Read`. + Ask { + /// The `Ref`, which must be in scope. + var: id::Var, + }, + /// Accumulate into a `Ref` whose `scope` satisfies `Constraint::Accum`. Returns `Unit`. Add { /// The `Ref`, which must be in scope. accum: id::Var, - /// Must be of the `Ref`'s inner type, which must satisfy `Constraint::Vector`. + /// Must be of the `Ref`'s inner type. addend: id::Var, }, } diff --git a/crates/frontend/src/translate.rs b/crates/frontend/src/translate.rs index 5a7f4a4..8d0db5b 100644 --- a/crates/frontend/src/translate.rs +++ b/crates/frontend/src/translate.rs @@ -148,9 +148,7 @@ struct BlockCtx<'input, 'a> { g: HashMap<&'input str, id::Generic>, l: HashMap<&'input str, id::Var>, t: IndexSet, - f: Vec, v: Vec, - b: Vec, c: Vec, } @@ -166,12 +164,6 @@ impl<'input, 'a> BlockCtx<'input, 'a> { id } - fn newfunc(&mut self, f: ir::Func) -> id::Func { - let id = id::func(self.f.len()); - self.f.push(f); - id - } - fn gettype(&self, id: id::Ty) -> &ir::Ty { &self.t[id.ty()] } @@ -225,7 +217,7 @@ impl<'input, 'a> BlockCtx<'input, 'a> { index, elem: self.getlocal(x), }); - Ok(self.instr(ty, ir::Expr::Array { elems: vars })) + Ok(self.instr(ty, ir::Expr::Array { elems: vars.into() })) } None => Err(TypeError::EmptyVec), } @@ -279,17 +271,18 @@ impl<'input, 'a> BlockCtx<'input, 'a> { let vars = args .into_iter() .map(|elem| self.typecheck(elem)) - .collect::, TypeError>>()?; + .collect::, TypeError>>()?; let types: Vec = vars.iter().map(|&v| self.getlocal(v)).collect(); if let Some((i, _, f)) = self.m.funcs.get_full(func.val) { let (generics, ret) = self.unify(f, &types)?; - let func = self.newfunc(ir::Func { - id: id::function(i), - generics, - }); - let ty = self.newtype(ir::Ty::Tuple { members: types }); - let arg = self.instr(ty, ir::Expr::Tuple { members: vars }); - Ok(self.instr(ret, ir::Expr::Call { func, arg })) + Ok(self.instr( + ret, + ir::Expr::Call { + id: id::function(i), + generics: generics.into(), + args: vars, + }, + )) } else { let real = self.newtype(ir::Ty::F64); // TODO: validate argument types for builtin functions @@ -314,38 +307,17 @@ impl<'input, 'a> BlockCtx<'input, 'a> { } ast::Expr::If { cond, then, els } => { let c = self.typecheck(*cond)?; // TODO: ensure this is `Bool` - let code = std::mem::take(&mut self.c); - // the `BlockCtx` type can only think about one under-construction block at a time, - // so when constructing an `If`, we keep swapping them out until we're done - - let unit = self.newtype(ir::Ty::Unit); - - let arg_then = self.newlocal(unit); - let ret_then = self.typecheck(*then)?; - let block_then = id::block(self.b.len()); - let code_then = std::mem::take(&mut self.c); - self.b.push(ir::Block { - arg: arg_then, - code: code_then, - ret: ret_then, - }); - let arg_els = self.newlocal(unit); - let ret_els = self.typecheck(*els)?; - let block_els = id::block(self.b.len()); - let code_els = std::mem::replace(&mut self.c, code); - self.b.push(ir::Block { - arg: arg_els, - code: code_els, - ret: ret_els, - }); + // IR doesn't currently support branching, so just evaluate both branches + let t = self.typecheck(*then)?; + let e = self.typecheck(*els)?; Ok(self.instr( - self.getlocal(ret_then), // TODO: ensure this matches the type of `ret_els` - ir::Expr::If { + self.getlocal(t), // TODO: ensure this matches the type of `e` + ir::Expr::Select { cond: c, - then: block_then, - els: block_els, + then: t, + els: e, }, )) } @@ -360,19 +332,21 @@ impl<'input, 'a> BlockCtx<'input, 'a> { let arg = self.newlocal(i); self.l.insert(index, arg); let elem = self.typecheck(*body)?; - let body = id::block(self.b.len()); - let code_for = std::mem::replace(&mut self.c, code); - self.b.push(ir::Block { - arg, - code: code_for, - ret: elem, - }); + let body = std::mem::replace(&mut self.c, code).into_boxed_slice(); let v = self.newtype(ir::Ty::Array { index: i, elem: self.getlocal(elem), }); - Ok(self.instr(v, ir::Expr::For { index: i, body })) + Ok(self.instr( + v, + ir::Expr::For { + index: i, + arg, + body, + ret: elem, + }, + )) } ast::Expr::Unary { op: _, arg: _ } => todo!(), ast::Expr::Binary { op, left, right } => { @@ -443,47 +417,29 @@ impl<'input> Module<'input> { // TODO: handle return type separately from params w.r.t. generics params.iter().map(|&(_, t)| t).chain([typ]), )?; - let generics = vec![EnumSet::only(ir::Constraint::Index); genericnames.len()]; - let ret = paramtypes.pop().expect("`parse_types` should preserve len"); - let (param_id, _) = typevars.insert_full(ir::Ty::Tuple { - members: paramtypes.clone(), // should be a way to do this without `clone`... - }); - let param = id::ty(param_id); - let arg = id::var(0); + let generics = + vec![EnumSet::only(ir::Constraint::Index); genericnames.len()].into(); + paramtypes.pop().expect("`parse_types` should preserve len"); // pop off return type + let args = (0..params.len()).map(id::var).collect(); let mut ctx = BlockCtx { m: self, g: genericnames, l: HashMap::new(), t: typevars, - f: vec![], - v: vec![param], - b: vec![], + v: paramtypes, c: vec![], }; - for (i, ((bind, _), t)) in params.into_iter().zip(paramtypes).enumerate() { - let expr = ir::Expr::Member { - tuple: arg, - member: id::member(i), - }; - let var = ctx.instr(t, expr); - ctx.bind(bind, var); + for (i, (bind, _)) in params.into_iter().enumerate() { + ctx.bind(bind, id::var(i)); } let retvar = ctx.typecheck(body)?; // TODO: ensure this matches `ret` - let main = id::block(ctx.b.len()); - ctx.b.push(ir::Block { - arg, - code: ctx.c, - ret: retvar, - }); let f = ir::Function { generics, types: ctx.t.into_iter().collect(), - funcs: ctx.f, - param, - ret, - vars: ctx.v, - blocks: ctx.b, - main, + vars: ctx.v.into(), + params: args, + ret: retvar, + body: ctx.c.into(), }; // TODO: check for duplicate function names self.funcs.insert(name, f); diff --git a/crates/frontend/tests/interp.rs b/crates/frontend/tests/interp.rs index 5caab1e..9e788b7 100644 --- a/crates/frontend/tests/interp.rs +++ b/crates/frontend/tests/interp.rs @@ -1,7 +1,6 @@ use indexmap::IndexSet; use rose_frontend::parse; -use rose_interp::{interp, Val}; -use std::rc::Rc; +use rose_interp::{interp, val_f64}; #[test] fn test_add() { @@ -11,10 +10,10 @@ fn test_add() { module.get_func("add").unwrap(), IndexSet::new(), &[], - Val::Tuple(Rc::new(vec![Val::F64(2.), Val::F64(2.)])), + [val_f64(2.), val_f64(2.)].into_iter(), ) .unwrap(); - assert_eq!(answer, Val::F64(4.)); + assert_eq!(answer, val_f64(4.)); } #[test] @@ -25,8 +24,8 @@ fn test_sub() { module.get_func("sub").unwrap(), IndexSet::new(), &[], - Val::Tuple(Rc::new(vec![Val::F64(2.), Val::F64(2.)])), + [val_f64(2.), val_f64(2.)].into_iter(), ) .unwrap(); - assert_eq!(answer, Val::F64(0.)); + assert_eq!(answer, val_f64(0.)); } diff --git a/crates/interp/Cargo.toml b/crates/interp/Cargo.toml index 7e73434..f545430 100644 --- a/crates/interp/Cargo.toml +++ b/crates/interp/Cargo.toml @@ -5,6 +5,7 @@ publish = false edition = "2021" [dependencies] +enumset = "1" indexmap = "2" rose = { path = "../core" } serde = { version = "1", features = ["derive", "rc"], optional = true } diff --git a/crates/interp/src/lib.rs b/crates/interp/src/lib.rs index 03276d0..37f9ec9 100644 --- a/crates/interp/src/lib.rs +++ b/crates/interp/src/lib.rs @@ -14,11 +14,25 @@ use ts_rs::TS; pub enum Val { Unit, Bool(bool), - F64(f64), + F64(Cell), Fin(usize), - Ref(Rc>), - Array(Rc>), // assume all indices are `Fin` - Tuple(Rc>), + Ref(Rc), + Array(Vals), // assume all indices are `Fin` + Tuple(Vals), +} + +pub type Vals = Rc>; // TODO: change to `Rc<[Val]>` https://github.com/rose-lang/rose/issues/63 + +pub fn vals(v: [Val; N]) -> Vals { + Rc::new(v.to_vec()) +} + +pub fn collect_vals(it: impl Iterator) -> Vals { + Rc::new(it.collect()) +} + +pub fn val_f64(x: f64) -> Val { + Val::F64(Cell::new(x)) } impl Val { @@ -31,29 +45,38 @@ impl Val { fn f64(&self) -> f64 { match self { - &Val::F64(x) => x, + Val::F64(x) => x.get(), _ => unreachable!(), } } -} -impl Val { - /// Pull out the immutable inner value represented by this mutable `Ref` type. - fn immut(&self) -> Self { + fn inner(&self) -> &Self { match self { - Self::Ref(x) => Self::F64(x.get()), - Self::Array(x) => Self::Array(Rc::new(x.iter().map(|x| x.immut()).collect())), - Self::Tuple(x) => Self::Tuple(Rc::new(x.iter().map(|x| x.immut()).collect())), - Self::Unit | Self::Bool(..) | Self::F64(..) | Self::Fin { .. } => { - unreachable!() - } + Val::Ref(x) => x.as_ref(), + _ => unreachable!(), + } + } + + /// Return a zero value with this value's topology. + fn zero(&self) -> Self { + match self { + Self::Unit => Self::Unit, + &Self::Bool(x) => Self::Bool(x), + Self::F64(_) => Self::F64(Cell::new(0.)), + &Self::Fin(x) => Self::Fin(x), + Self::Ref(_) => unreachable!(), + Self::Array(x) => Self::Array(collect_vals(x.iter().map(|x| x.zero()))), + Self::Tuple(x) => Self::Tuple(collect_vals(x.iter().map(|x| x.zero()))), } } /// Add `x` to this value, which must represent a mutable `Ref` type. fn add(&self, x: &Self) { match (self, x) { - (Self::Ref(a), Self::F64(b)) => a.set(a.get() + b), + (Self::Unit, Self::Unit) + | (Self::Bool(_), Self::Bool(_)) + | (Self::Fin(_), Self::Fin(_)) => {} + (Self::F64(a), Self::F64(b)) => a.set(a.get() + b.get()), (Self::Array(a), Self::Array(b)) => { for (a, b) in a.iter().zip(b.iter()) { a.add(b); @@ -69,26 +92,6 @@ impl Val { } } -/// Return zero a value of `Ref` type for this type, which must satisfy `Constraint::Vector`. -fn zero(types: &IndexSet, ty: id::Ty) -> Val { - match &types[ty.ty()] { - Ty::F64 => Val::Ref(Rc::new(Cell::new(0.))), - &Ty::Array { index, elem } => match types[index.ty()] { - Ty::Fin { size } => Val::Array(Rc::new((0..size).map(|_| zero(types, elem)).collect())), - _ => unreachable!(), - }, - Ty::Tuple { members } => { - Val::Tuple(Rc::new(members.iter().map(|&x| zero(types, x)).collect())) - } - Ty::Unit - | Ty::Bool - | Ty::Fin { .. } - | Ty::Generic { .. } - | Ty::Scope { .. } - | Ty::Ref { .. } => unreachable!(), - } -} - /// Resolve `ty` via `generics` and `types`, then return its ID in `typemap`, inserting if need be. /// /// This is meant to be used to pull all the types from a callee into a broader context. The @@ -106,7 +109,10 @@ fn resolve(typemap: &mut IndexSet, generics: &[id::Ty], types: &[id::Ty], ty Ty::F64 => Ty::F64, &Ty::Fin { size } => Ty::Fin { size }, - Ty::Scope { id: _ } => Ty::Scope { id: id::block(0) }, // we erase scope info + &Ty::Scope { kind, id: _ } => Ty::Scope { + kind, + id: id::var(usize::MAX), // we erase scope info + }, Ty::Ref { scope, inner } => Ty::Ref { scope: types[scope.ty()], inner: types[inner.ty()], @@ -133,7 +139,7 @@ struct Interpreter<'a, F: FuncNode> { impl<'a, F: FuncNode> Interpreter<'a, F> { fn new(typemap: &'a mut IndexSet, f: &'a F, generics: &'a [id::Ty]) -> Self { let mut types = vec![]; - for ty in &f.def().types { + for ty in f.def().types.iter() { types.push(resolve(typemap, generics, &types, ty)); } Self { @@ -152,15 +158,15 @@ impl<'a, F: FuncNode> Interpreter<'a, F> { match expr { Expr::Unit => Val::Unit, &Expr::Bool { val } => Val::Bool(val), - &Expr::F64 { val } => Val::F64(val), + &Expr::F64 { val } => val_f64(val), &Expr::Fin { val } => Val::Fin(val), - Expr::Array { elems } => Val::Array(Rc::new( - elems.iter().map(|&x| self.get(x).clone()).collect(), - )), - Expr::Tuple { members } => Val::Tuple(Rc::new( - members.iter().map(|&x| self.get(x).clone()).collect(), - )), + Expr::Array { elems } => { + Val::Array(collect_vals(elems.iter().map(|&x| self.get(x).clone()))) + } + Expr::Tuple { members } => { + Val::Tuple(collect_vals(members.iter().map(|&x| self.get(x).clone()))) + } &Expr::Index { array, index } => match (self.get(array), self.get(index)) { (Val::Array(v), &Val::Fin(i)) => v[i].clone(), @@ -171,13 +177,12 @@ impl<'a, F: FuncNode> Interpreter<'a, F> { _ => unreachable!(), }, - // a `Ref` of `F64` becomes `Ref`, while composites just wrap those individual refs - &Expr::Slice { array, index } => match (self.get(array), self.get(index)) { + &Expr::Slice { array, index } => match (self.get(array).inner(), self.get(index)) { (Val::Array(v), &Val::Fin(i)) => v[i].clone(), _ => unreachable!(), }, - &Expr::Field { tuple, field } => match self.get(tuple) { - Val::Tuple(x) => x[field.member()].clone(), + &Expr::Field { tuple, member } => match self.get(tuple).inner() { + Val::Tuple(x) => x[member.member()].clone(), _ => unreachable!(), }, @@ -186,9 +191,9 @@ impl<'a, F: FuncNode> Interpreter<'a, F> { match op { Unop::Not => Val::Bool(!x.bool()), - Unop::Neg => Val::F64(-x.f64()), - Unop::Abs => Val::F64(x.f64().abs()), - Unop::Sqrt => Val::F64(x.f64().sqrt()), + Unop::Neg => val_f64(-x.f64()), + Unop::Abs => val_f64(x.f64().abs()), + Unop::Sqrt => val_f64(x.f64().sqrt()), } } &Expr::Binary { op, left, right } => { @@ -207,70 +212,92 @@ impl<'a, F: FuncNode> Interpreter<'a, F> { Binop::Gt => Val::Bool(x.f64() > y.f64()), Binop::Geq => Val::Bool(x.f64() >= y.f64()), - Binop::Add => Val::F64(x.f64() + y.f64()), - Binop::Sub => Val::F64(x.f64() - y.f64()), - Binop::Mul => Val::F64(x.f64() * y.f64()), - Binop::Div => Val::F64(x.f64() / y.f64()), + Binop::Add => val_f64(x.f64() + y.f64()), + Binop::Sub => val_f64(x.f64() - y.f64()), + Binop::Mul => val_f64(x.f64() * y.f64()), + Binop::Div => val_f64(x.f64() / y.f64()), } } - - &Expr::Call { func, arg } => { - let f = &self.f.def().funcs[func.func()]; - let generics: Vec = - f.generics.iter().map(|id| self.types[id.ty()]).collect(); - call( - self.f.get(f.id).unwrap(), - self.typemap, - &generics, - self.get(arg).clone(), - ) - } - &Expr::If { cond, then, els } => { + &Expr::Select { cond, then, els } => { if self.get(cond).bool() { - self.block(then, Val::Unit).clone() + self.get(then).clone() } else { - self.block(els, Val::Unit).clone() + self.get(els).clone() } } - &Expr::For { index, body } => { + + Expr::Call { id, generics, args } => { + let resolved: Vec = generics.iter().map(|id| self.types[id.ty()]).collect(); + let vals = args.iter().map(|id| self.vars[id.var()].clone().unwrap()); + call(self.f.get(*id).unwrap(), self.typemap, &resolved, vals) + } + Expr::For { + index, + arg, + body, + ret, + } => { let n = match self.typemap[self.types[index.ty()].ty()] { Ty::Fin { size } => size, _ => unreachable!(), }; - let v: Vec = (0..n) - .map(|i| self.block(body, Val::Fin(i)).clone()) - .collect(); - Val::Array(Rc::new(v)) + Val::Array(collect_vals( + (0..n).map(|i| self.block(*arg, body, *ret, Val::Fin(i)).clone()), + )) + } + Expr::Read { + var, + arg, + body, + ret, + } => { + let r = Val::Ref(Rc::new(self.get(*var).clone())); + self.block(*arg, body, *ret, r); + Val::Unit } - &Expr::Accum { var, vector, body } => { - let x = zero(self.typemap, self.types[vector.ty()]); - let y = self.block(body, x.clone()).clone(); - self.vars[var.var()] = Some(x.immut()); - y + Expr::Accum { + shape, + arg, + body, + ret, + } => { + let x = Val::Ref(Rc::new(self.get(*shape).zero())); + self.block(*arg, body, *ret, x.clone()); + x.inner().clone() } + &Expr::Ask { var } => self.get(var).inner().clone(), &Expr::Add { accum, addend } => { - self.get(accum).add(self.get(addend)); + self.get(accum).inner().add(self.get(addend)); Val::Unit } } } - fn block(&mut self, b: id::Block, arg: Val) -> &Val { - let block = &self.f.def().blocks[b.block()]; - self.vars[block.arg.var()] = Some(arg); - for instr in &block.code { + fn block(&mut self, param: id::Var, body: &[rose::Instr], ret: id::Var, arg: Val) -> &Val { + self.vars[param.var()] = Some(arg); + for instr in body.iter() { self.vars[instr.var.var()] = Some(self.expr(&instr.expr)); } - self.vars[block.ret.var()].as_ref().unwrap() + self.vars[ret.var()].as_ref().unwrap() } } /// Assumes `generics` and `arg` are valid. -fn call(f: impl FuncNode, types: &mut IndexSet, generics: &[id::Ty], arg: Val) -> Val { - Interpreter::new(types, &f, generics) - .block(f.def().main, arg) - .clone() +fn call( + f: impl FuncNode, + types: &mut IndexSet, + generics: &[id::Ty], + args: impl Iterator, +) -> Val { + let mut interp = Interpreter::new(types, &f, generics); + for (var, arg) in f.def().params.iter().zip(args) { + interp.vars[var.var()] = Some(arg.clone()); + } + for instr in f.def().body.iter() { + interp.vars[instr.var.var()] = Some(interp.expr(&instr.expr)); + } + interp.vars[f.def().ret.var()].as_ref().unwrap().clone() } #[derive(Debug, thiserror::Error)] @@ -281,16 +308,16 @@ pub fn interp( f: impl FuncNode, mut types: IndexSet, generics: &[id::Ty], - arg: Val, + args: impl Iterator, ) -> Result { // TODO: check that `generics` and `arg` are valid - Ok(call(f, &mut types, generics, arg)) + Ok(call(f, &mut types, generics, args)) } #[cfg(test)] mod tests { use super::*; - use rose::{Block, Func, Function, Instr}; + use rose::{Function, Instr}; #[derive(Clone, Copy, Debug)] struct FuncInSlice<'a> { @@ -314,46 +341,20 @@ mod tests { #[test] fn test_two_plus_two() { let funcs = vec![Function { - generics: vec![], - types: vec![ - Ty::F64, - Ty::Tuple { - members: vec![id::ty(0), id::ty(0)], + generics: vec![].into(), + types: vec![Ty::F64].into(), + vars: vec![id::ty(0), id::ty(0), id::ty(0)].into(), + params: vec![id::var(0), id::var(1)].into(), + ret: id::var(2), + body: vec![Instr { + var: id::var(2), + expr: Expr::Binary { + op: Binop::Add, + left: id::var(0), + right: id::var(1), }, - ], - funcs: vec![], - param: id::ty(1), - ret: id::ty(0), - vars: vec![id::ty(1), id::ty(0), id::ty(0), id::ty(0)], - blocks: vec![Block { - arg: id::var(0), - code: vec![ - Instr { - var: id::var(1), - expr: Expr::Member { - tuple: id::var(0), - member: id::member(0), - }, - }, - Instr { - var: id::var(2), - expr: Expr::Member { - tuple: id::var(0), - member: id::member(1), - }, - }, - Instr { - var: id::var(3), - expr: Expr::Binary { - op: Binop::Add, - left: id::var(1), - right: id::var(2), - }, - }, - ], - ret: id::var(3), - }], - main: id::block(0), + }] + .into(), }]; let answer = interp( FuncInSlice { @@ -362,64 +363,52 @@ mod tests { }, IndexSet::new(), &[], - Val::Tuple(Rc::new(vec![Val::F64(2.), Val::F64(2.)])), + [val_f64(2.), val_f64(2.)].into_iter(), ) .unwrap(); - assert_eq!(answer, Val::F64(4.)); + assert_eq!(answer, val_f64(4.)); } #[test] fn test_nested_call() { let funcs = vec![ Function { - generics: vec![], - types: vec![Ty::Unit, Ty::F64], - funcs: vec![], - param: id::ty(0), - ret: id::ty(1), - vars: vec![id::ty(0), id::ty(1)], - blocks: vec![Block { - arg: id::var(0), - code: vec![Instr { - var: id::var(1), - expr: Expr::F64 { val: 42. }, - }], - ret: id::var(1), - }], - main: id::block(0), + generics: vec![].into(), + types: vec![Ty::F64].into(), + vars: vec![id::ty(0)].into(), + params: vec![].into(), + ret: id::var(0), + body: vec![Instr { + var: id::var(0), + expr: Expr::F64 { val: 42. }, + }] + .into(), }, Function { - generics: vec![], - types: vec![Ty::Unit, Ty::F64], - funcs: vec![Func { - id: id::function(0), - generics: vec![], - }], - param: id::ty(0), - ret: id::ty(1), - vars: vec![id::ty(0), id::ty(1), id::ty(1)], - blocks: vec![Block { - arg: id::var(0), - code: vec![ - Instr { - var: id::var(1), - expr: Expr::Call { - func: id::func(0), - arg: id::var(0), - }, + generics: vec![].into(), + types: vec![Ty::F64].into(), + vars: vec![id::ty(0), id::ty(0)].into(), + params: vec![].into(), + ret: id::var(1), + body: vec![ + Instr { + var: id::var(0), + expr: Expr::Call { + id: id::function(0), + generics: vec![].into(), + args: vec![].into(), }, - Instr { - var: id::var(2), - expr: Expr::Binary { - op: Binop::Mul, - left: id::var(1), - right: id::var(1), - }, + }, + Instr { + var: id::var(1), + expr: Expr::Binary { + op: Binop::Mul, + left: id::var(0), + right: id::var(0), }, - ], - ret: id::var(2), - }], - main: id::block(0), + }, + ] + .into(), }, ]; let answer = interp( @@ -429,9 +418,9 @@ mod tests { }, IndexSet::new(), &[], - Val::Unit, + [].into_iter(), ) .unwrap(); - assert_eq!(answer, Val::F64(1764.)); + assert_eq!(answer, val_f64(1764.)); } } diff --git a/crates/web/src/lib.rs b/crates/web/src/lib.rs index 839ef4f..c31d083 100644 --- a/crates/web/src/lib.rs +++ b/crates/web/src/lib.rs @@ -34,6 +34,7 @@ pub fn layouts() -> Result { to_js_value(&[ ("Expr", layout::()), + ("Function", layout::()), ("Instr", layout::()), ("Ty", layout::()), ]) @@ -63,128 +64,147 @@ impl<'a> rose::FuncNode for &'a Func { pub fn pprint(f: &Func) -> Result { use std::fmt::Write as _; // see https://doc.rust-lang.org/std/macro.write.html - fn print_block( + fn print_instr( mut s: &mut String, def: &rose::Function, spaces: usize, - b: id::Block, + instr: &rose::Instr, ) -> Result<(), JsError> { - for instr in def.blocks[b.block()].code.iter() { - for _ in 0..spaces { - write!(&mut s, " ")?; + for _ in 0..spaces { + write!(&mut s, " ")?; + } + let x = instr.var.var(); + write!(&mut s, "x{}: T{} = ", x, def.vars[x].ty())?; + match &instr.expr { + rose::Expr::Unit => writeln!(&mut s, "unit")?, + rose::Expr::Bool { val } => writeln!(&mut s, "{val}")?, + rose::Expr::F64 { val } => writeln!(&mut s, "{val}")?, + rose::Expr::Fin { val } => writeln!(&mut s, "{val}")?, + rose::Expr::Array { elems } => { + write!(&mut s, "[")?; + print_elems(s, 'x', elems.iter().map(|elem| elem.var()))?; + writeln!(&mut s, "]")?; } - let x = instr.var.var(); - write!(&mut s, "x{}: T{} = ", x, def.vars[x].ty())?; - match &instr.expr { - rose::Expr::Unit => writeln!(&mut s, "unit")?, - rose::Expr::Bool { val } => writeln!(&mut s, "{val}")?, - rose::Expr::F64 { val } => writeln!(&mut s, "{val}")?, - rose::Expr::Fin { val } => writeln!(&mut s, "{val}")?, - rose::Expr::Array { elems } => { - write!(&mut s, "[")?; - print_elems(s, 'x', elems.iter().map(|elem| elem.var()))?; - writeln!(&mut s, "]")?; - } - rose::Expr::Tuple { members } => { - write!(&mut s, "(")?; - print_elems(s, 'x', members.iter().map(|member| member.var()))?; - writeln!(&mut s, ")")?; - } - rose::Expr::Index { array, index } => { - writeln!(&mut s, "x{}[x{}]", array.var(), index.var())? - } - rose::Expr::Member { tuple, member } => { - writeln!(&mut s, "x{}.{}", tuple.var(), member.member())? - } - rose::Expr::Slice { array, index } => { - writeln!(&mut s, "x{}![x{}]", array.var(), index.var())? - } - rose::Expr::Field { tuple, field } => { - writeln!(&mut s, "x{}!.{}", tuple.var(), field.member())? - } - rose::Expr::Unary { op, arg } => match op { - rose::Unop::Not => writeln!(&mut s, "not x{}", arg.var())?, - rose::Unop::Neg => writeln!(&mut s, "-x{}", arg.var())?, - rose::Unop::Abs => writeln!(&mut s, "|x{}|", arg.var())?, - rose::Unop::Sqrt => writeln!(&mut s, "sqrt(x{})", arg.var())?, - }, - rose::Expr::Binary { op, left, right } => match op { - rose::Binop::And => writeln!(&mut s, "x{} and x{}", left.var(), right.var())?, - rose::Binop::Or => writeln!(&mut s, "x{} or x{}", left.var(), right.var())?, - rose::Binop::Iff => writeln!(&mut s, "x{} iff x{}", left.var(), right.var())?, - rose::Binop::Xor => writeln!(&mut s, "x{} xor x{}", left.var(), right.var())?, - rose::Binop::Neq => writeln!(&mut s, "x{} != x{}", left.var(), right.var())?, - rose::Binop::Lt => writeln!(&mut s, "x{} < x{}", left.var(), right.var())?, - rose::Binop::Leq => writeln!(&mut s, "x{} <= x{}", left.var(), right.var())?, - rose::Binop::Eq => writeln!(&mut s, "x{} == x{}", left.var(), right.var())?, - rose::Binop::Gt => writeln!(&mut s, "x{} > x{}", left.var(), right.var())?, - rose::Binop::Geq => writeln!(&mut s, "x{} >= x{}", left.var(), right.var())?, - rose::Binop::Add => writeln!(&mut s, "x{} + x{}", left.var(), right.var())?, - rose::Binop::Sub => writeln!(&mut s, "x{} - x{}", left.var(), right.var())?, - rose::Binop::Mul => writeln!(&mut s, "x{} * x{}", left.var(), right.var())?, - rose::Binop::Div => writeln!(&mut s, "x{} / x{}", left.var(), right.var())?, - }, - rose::Expr::Call { func, arg } => { - writeln!(&mut s, "f{}(x{})", func.func(), arg.var())? + rose::Expr::Tuple { members } => { + write!(&mut s, "(")?; + print_elems(s, 'x', members.iter().map(|member| member.var()))?; + writeln!(&mut s, ")")?; + } + rose::Expr::Index { array, index } => { + writeln!(&mut s, "x{}[x{}]", array.var(), index.var())? + } + rose::Expr::Member { tuple, member } => { + writeln!(&mut s, "x{}.{}", tuple.var(), member.member())? + } + rose::Expr::Slice { array, index } => { + writeln!(&mut s, "x{}![x{}]", array.var(), index.var())? + } + rose::Expr::Field { tuple, member } => { + writeln!(&mut s, "x{}!.{}", tuple.var(), member.member())? + } + rose::Expr::Unary { op, arg } => match op { + rose::Unop::Not => writeln!(&mut s, "not x{}", arg.var())?, + rose::Unop::Neg => writeln!(&mut s, "-x{}", arg.var())?, + rose::Unop::Abs => writeln!(&mut s, "|x{}|", arg.var())?, + rose::Unop::Sqrt => writeln!(&mut s, "sqrt(x{})", arg.var())?, + }, + rose::Expr::Binary { op, left, right } => match op { + rose::Binop::And => writeln!(&mut s, "x{} and x{}", left.var(), right.var())?, + rose::Binop::Or => writeln!(&mut s, "x{} or x{}", left.var(), right.var())?, + rose::Binop::Iff => writeln!(&mut s, "x{} iff x{}", left.var(), right.var())?, + rose::Binop::Xor => writeln!(&mut s, "x{} xor x{}", left.var(), right.var())?, + rose::Binop::Neq => writeln!(&mut s, "x{} != x{}", left.var(), right.var())?, + rose::Binop::Lt => writeln!(&mut s, "x{} < x{}", left.var(), right.var())?, + rose::Binop::Leq => writeln!(&mut s, "x{} <= x{}", left.var(), right.var())?, + rose::Binop::Eq => writeln!(&mut s, "x{} == x{}", left.var(), right.var())?, + rose::Binop::Gt => writeln!(&mut s, "x{} > x{}", left.var(), right.var())?, + rose::Binop::Geq => writeln!(&mut s, "x{} >= x{}", left.var(), right.var())?, + rose::Binop::Add => writeln!(&mut s, "x{} + x{}", left.var(), right.var())?, + rose::Binop::Sub => writeln!(&mut s, "x{} - x{}", left.var(), right.var())?, + rose::Binop::Mul => writeln!(&mut s, "x{} * x{}", left.var(), right.var())?, + rose::Binop::Div => writeln!(&mut s, "x{} / x{}", left.var(), right.var())?, + }, + rose::Expr::Select { cond, then, els } => { + writeln!(&mut s, "x{} ? x{} : x{}", cond.var(), then.var(), els.var())? + } + rose::Expr::Call { id, generics, args } => { + write!(&mut s, "f{}<", id.function())?; + print_elems(s, 'T', generics.iter().map(|generic| generic.ty()))?; + write!(&mut s, ">(")?; + print_elems(s, 'x', args.iter().map(|arg| arg.var()))?; + writeln!(&mut s, ")")?; + } + rose::Expr::For { + index, + arg, + body, + ret, + } => { + writeln!(&mut s, "for x{}: T{} {{", arg.var(), index.ty())?; + print_block(s, def, spaces + 2, body, *ret)?; + for _ in 0..spaces { + write!(&mut s, " ")?; } - rose::Expr::If { cond, then, els } => { - writeln!(&mut s, "if x{} {{", cond.var())?; - for _ in 0..spaces { - write!(&mut s, " ")?; - } - let x = def.blocks[then.block()].arg.var(); - writeln!(&mut s, " x{x}: T{}", def.vars[x].ty())?; - print_block(s, def, spaces + 2, *then)?; - for _ in 0..spaces { - write!(&mut s, " ")?; - } - writeln!(&mut s, "}} else {{")?; - for _ in 0..spaces { - write!(&mut s, " ")?; - } - let y = def.blocks[els.block()].arg.var(); - writeln!(&mut s, " x{y}: T{}", def.vars[y].ty())?; - print_block(s, def, spaces + 2, *els)?; - for _ in 0..spaces { - write!(&mut s, " ")?; - } - writeln!(&mut s, "}}")? + writeln!(&mut s, "}}")? + } + rose::Expr::Read { + var, + arg, + body, + ret, + } => { + writeln!(&mut s, "read x{} {{", var.var())?; + for _ in 0..spaces { + write!(&mut s, " ")?; } - rose::Expr::For { index, body } => { - writeln!( - &mut s, - "for x{}: T{} {{", - def.blocks[body.block()].arg.var(), - index.ty() - )?; - print_block(s, def, spaces + 2, *body)?; - for _ in 0..spaces { - write!(&mut s, " ")?; - } - writeln!(&mut s, "}}")? + let x = arg.var(); + writeln!(&mut s, " x{x}: T{}", def.vars[x].ty())?; + print_block(s, def, spaces + 2, body, *ret)?; + for _ in 0..spaces { + write!(&mut s, " ")?; } - rose::Expr::Accum { var, vector, body } => { - writeln!(&mut s, "accum x{}: T{} {{", var.var(), vector.ty())?; - for _ in 0..spaces { - write!(&mut s, " ")?; - } - let x = def.blocks[body.block()].arg.var(); - writeln!(&mut s, " x{x}: T{}", def.vars[x].ty())?; - print_block(s, def, spaces + 2, *body)?; - for _ in 0..spaces { - write!(&mut s, " ")?; - } - writeln!(&mut s, "}}")? + writeln!(&mut s, "}}")? + } + rose::Expr::Accum { + shape, + arg, + body, + ret, + } => { + writeln!(&mut s, "accum x{} {{", shape.var())?; + for _ in 0..spaces { + write!(&mut s, " ")?; } - rose::Expr::Add { accum, addend } => { - writeln!(&mut s, "x{} += x{}", accum.var(), addend.var())? + let x = arg.var(); + writeln!(&mut s, " x{x}: T{}", def.vars[x].ty())?; + print_block(s, def, spaces + 2, body, *ret)?; + for _ in 0..spaces { + write!(&mut s, " ")?; } + writeln!(&mut s, "}}")? + } + rose::Expr::Ask { var } => writeln!(&mut s, "ask x{}", var.var())?, + rose::Expr::Add { accum, addend } => { + writeln!(&mut s, "x{} += x{}", accum.var(), addend.var())? } } + Ok(()) + } + + fn print_block( + mut s: &mut String, + def: &rose::Function, + spaces: usize, + body: &[rose::Instr], + ret: id::Var, + ) -> Result<(), JsError> { + for instr in body.iter() { + print_instr(s, def, spaces, instr)?; + } for _ in 0..spaces { write!(&mut s, " ")?; } - writeln!(&mut s, "x{}", def.blocks[b.block()].ret.var())?; + writeln!(&mut s, "x{}", ret.var())?; Ok(()) } @@ -227,7 +247,7 @@ pub fn pprint(f: &Func) -> Result { rose::Ty::Unit | rose::Ty::Bool | rose::Ty::F64 => writeln!(&mut s, "{ty:?}")?, rose::Ty::Fin { size } => writeln!(&mut s, "{size}")?, rose::Ty::Generic { id } => writeln!(&mut s, "G{}", id.generic())?, - rose::Ty::Scope { id } => writeln!(&mut s, "B{}", id.block())?, + rose::Ty::Scope { kind, id } => writeln!(&mut s, "B{}: {kind:?}", id.var())?, rose::Ty::Ref { scope, inner } => { writeln!(&mut s, "Ref T{} T{}", scope.ty(), inner.ty())? } @@ -241,23 +261,21 @@ pub fn pprint(f: &Func) -> Result { } } } - for (i, func) in def.funcs.iter().enumerate() { - write!(&mut s, "f{} = F{}<", i, func.id.function())?; - print_elems( - &mut s, - 'T', - func.generics.iter().map(|generic| generic.ty()), - )?; - writeln!(&mut s, ">")?; - } - writeln!( - &mut s, - "x{}: T{} -> T{} {{", - def.blocks[def.main.block()].arg.var(), - def.param.ty(), - def.ret.ty(), - )?; - print_block(&mut s, def, 2, def.main)?; + write!(&mut s, "(")?; + let mut first = true; + for param in def.params.iter() { + if first { + first = false; + } else { + write!(&mut s, ", ")?; + } + write!(&mut s, "x{}: T{}", param.var(), def.vars[param.var()].ty())?; + } + writeln!(&mut s, ") -> T{} {{", def.vars[def.ret.var()].ty())?; + for instr in def.body.iter() { + print_instr(&mut s, def, 2, instr)?; + } + writeln!(&mut s, " x{}", def.ret.var())?; writeln!(&mut s, "}}")?; Ok(s) @@ -269,37 +287,29 @@ pub struct Context { functions: Vec, generics: Vec>, types: IndexSet, - funcs: Vec<(rose::Func, id::Ty)>, - param: id::Ty, - ret: id::Ty, vars: Vec, - blocks: Vec, + params: Vec, } #[wasm_bindgen] -pub fn bake(ctx: Context, main: usize) -> Func { +pub fn bake(ctx: Context, out: usize, main: Block) -> Func { let Context { functions, generics, types, - funcs, - param, - ret, + params, vars, - blocks, } = ctx; Func { rc: Rc::new(( functions, rose::Function { - generics, + generics: generics.into(), types: types.into_iter().collect(), - funcs: funcs.into_iter().map(|(f, _)| f).collect(), - param, - ret, - vars, - blocks, - main: id::block(main), + params: params.into(), + ret: id::var(out), + vars: vars.into(), + body: main.code.into(), }, )), } @@ -326,13 +336,11 @@ impl Default for Block { } } -// just an ephemeral struct return several things which we then unpack on the JS side +// just an ephemeral struct to return several things which we then unpack on the JS side #[wasm_bindgen] pub struct Body { ctx: Option, main: Option, - pub arg: usize, - args: Option>, } #[wasm_bindgen] @@ -350,59 +358,22 @@ impl Body { .take() .ok_or_else(|| JsError::new("block already taken")) } - - #[wasm_bindgen] - pub fn args(&mut self) -> Result, JsError> { - self.args - .take() - .ok_or_else(|| JsError::new("args already taken")) - } } /// The `types` argument is Serde-converted to `indexmap::IndexSet`. #[wasm_bindgen] -pub fn make( - generics: usize, - types: JsValue, - params: Vec, - ret: usize, -) -> Result { - let mut types: IndexSet = serde_wasm_bindgen::from_value(types)?; - - let (param, _) = types.insert_full(rose::Ty::Tuple { - members: params.iter().map(|&i| id::ty(i)).collect(), - }); - let param = id::ty(param); - let mut ctx = Context { +pub fn make(generics: usize, types: JsValue, params: &[usize]) -> Result { + let types: IndexSet = serde_wasm_bindgen::from_value(types)?; + let ctx = Context { functions: vec![], generics: vec![EnumSet::only(rose::Constraint::Index); generics], types, - funcs: vec![], - param, - ret: id::ty(ret), - vars: vec![], - blocks: vec![], + vars: params.iter().map(|&id| id::ty(id)).collect(), + params: (0..params.len()).map(id::var).collect(), }; - - let arg = ctx.var(param); - let mut main = Block { code: vec![] }; - let args = params - .iter() - .enumerate() - .map(|(i, &ty)| { - let expr = rose::Expr::Member { - tuple: arg, - member: id::member(i), - }; - ctx.instr(&mut main, id::ty(ty), expr) - }) - .collect(); - Ok(Body { ctx: Some(ctx), - main: Some(main), - arg: arg.var(), - args: Some(args), + main: Some(Block { code: vec![] }), }) } @@ -424,7 +395,7 @@ fn resolve( ) -> Option { let resolved = match ty { // inner scopes cannot appear in the return type, which is all we care about here - rose::Ty::Scope { id: _ } => return None, + rose::Ty::Scope { kind: _, id: _ } => return None, rose::Ty::Generic { id } => return Some(id::ty(generics[id.generic()])), rose::Ty::Unit => rose::Ty::Unit, @@ -454,45 +425,6 @@ fn resolve( // TODO: catch invalid user-given indices instead of panicking #[wasm_bindgen] impl Context { - #[wasm_bindgen] - pub fn func(&mut self, f: &Func, generics: &[usize]) -> Result { - let mut types = vec![]; - let (_, def) = f.rc.as_ref(); - // push a corresponding type onto our own `types` for each type in the callee - for callee_type in &def.types { - types.push(resolve(&mut self.types, generics, &types, callee_type)); - } - - // push the function reference to the callee - let function_id = id::function(self.functions.len()); - self.functions.push(f.clone()); - - // push data about the callee's interface types - let func_id = self.funcs.len(); - self.funcs.push(( - rose::Func { - generics: generics.iter().map(|&i| id::ty(i)).collect(), - id: function_id, - }, - // the only types we omitted were inner scopes from the callee, which cannot appear in - // its return type - types[def.ret.ty()].unwrap(), - )); - Ok(func_id) - } - - #[wasm_bindgen] - pub fn block(&mut self, b: Block, arg_id: usize, ret_id: usize) -> usize { - let Block { code } = b; - let id = self.blocks.len(); - self.blocks.push(rose::Block { - arg: id::var(arg_id), - code, - ret: id::var(ret_id), - }); - id - } - fn get(&self, var: id::Var) -> id::Ty { self.vars[var.var()] } @@ -554,7 +486,7 @@ impl Context { index, elem: self.get(x), }); - let expr = rose::Expr::Array { elems: xs }; + let expr = rose::Expr::Array { elems: xs.into() }; Ok(self.instr(b, ty, expr)) } @@ -563,7 +495,7 @@ impl Context { let xs: Vec = members.iter().map(|&x| id::var(x)).collect(); let types = xs.iter().map(|&x| self.get(x)).collect(); let ty = self.ty(rose::Ty::Tuple { members: types }); - let expr = rose::Expr::Tuple { members: xs }; + let expr = rose::Expr::Tuple { members: xs.into() }; self.instr(b, ty, expr) } @@ -794,41 +726,67 @@ impl Context { // end of binary #[wasm_bindgen] - pub fn call(&mut self, b: &mut Block, func: usize, arg: usize) -> Result { - let &(_, ty) = self - .funcs - .get(func) - .ok_or_else(|| JsError::new("invalid function ID"))?; + pub fn call( + &mut self, + b: &mut Block, + f: &Func, + generics: &[usize], + args: &[usize], + ) -> Result { + let mut types = vec![]; + let (_, def) = f.rc.as_ref(); + // push a corresponding type onto our own `types` for each type in the callee + for callee_type in def.types.iter() { + types.push(resolve(&mut self.types, generics, &types, callee_type)); + } + + // add the function reference to the callee + let id = id::function(self.functions.len()); + self.functions.push(f.clone()); + + let ty = types[def.vars[def.ret.var()].ty()].unwrap(); let expr = rose::Expr::Call { - func: id::func(func), - arg: id::var(arg), + id, + generics: generics.iter().map(|&i| id::ty(i)).collect(), + args: args.iter().map(|&x| id::var(x)).collect(), }; Ok(self.instr(b, ty, expr)) } - /// `rose::Expr::If` #[wasm_bindgen] - pub fn cond(&mut self, b: &mut Block, cond: usize, then: usize, els: usize) -> usize { - let t = self.get(self.blocks[then].ret); // arbitrary; could have used `els` instead - let expr = rose::Expr::If { + pub fn select(&mut self, b: &mut Block, cond: usize, then: usize, els: usize) -> usize { + let then = id::var(then); + let els = id::var(els); + let t = self.get(then); // arbitrary; could have used `els` instead + let expr = rose::Expr::Select { cond: id::var(cond), - then: id::block(then), - els: id::block(els), + then, + els, }; self.instr(b, t, expr) } // `rose::Expr::For` #[wasm_bindgen] - pub fn arr(&mut self, b: &mut Block, index: usize, body: usize) -> usize { - let rose::Block { arg, ret, .. } = self.blocks[body]; + pub fn arr( + &mut self, + b: &mut Block, + index: usize, + arg: usize, + body: Block, + out: usize, + ) -> usize { + let arg = id::var(arg); + let ret = id::var(out); let ty = self.ty(rose::Ty::Array { index: self.get(arg), elem: self.get(ret), }); let expr = rose::Expr::For { index: id::ty(index), - body: id::block(body), + arg, + body: body.code.into(), + ret, }; self.instr(b, ty, expr) } @@ -836,18 +794,19 @@ impl Context { /// Interpret a function with the given arguments. /// -/// The `types` are Serde-converted to `indexmap::IndexSet`, the `arg` is Serde-converted -/// to `rose_interp::Val`, and the return value is Serde-converted from `rose_interp::Val`. +/// The `types` are Serde-converted to `indexmap::IndexSet`, the `args` are +/// Serde-converted to `Vec`, and the return value is Serde-converted from +/// `rose_interp::Val`. #[wasm_bindgen] pub fn interp( f: &Func, types: JsValue, generics: &[usize], - arg: JsValue, + args: JsValue, ) -> Result { let types: IndexSet = serde_wasm_bindgen::from_value(types)?; - let arg: rose_interp::Val = serde_wasm_bindgen::from_value(arg)?; + let args: Vec = serde_wasm_bindgen::from_value(args)?; let generics: Vec = generics.iter().map(|&i| id::ty(i)).collect(); - let ret = rose_interp::interp(f, types, &generics, arg)?; + let ret = rose_interp::interp(f, types, &generics, args.into_iter())?; Ok(to_js_value(&ret)?) } diff --git a/packages/core/src/bool.ts b/packages/core/src/bool.ts index 015fe7c..956fbf8 100644 --- a/packages/core/src/bool.ts +++ b/packages/core/src/bool.ts @@ -1,5 +1,4 @@ -import { Val, Var, getBlock, getCtx, getVar, setBlock } from "./context.js"; -import * as ffi from "./ffi.js"; +import { Val, Var, getBlock, getCtx, getVar } from "./context.js"; export type Bool = boolean | Var; @@ -33,43 +32,11 @@ export const xor = (p: Bool, q: Bool): Bool => { return { ctx, id: ctx.xor(b, getVar(ctx, b, p), getVar(ctx, b, q)) }; }; -export const cond = ( - cond: Bool, - then: () => T, - els: () => T, -): T | Var => { +export const select = (cond: Bool, then: T, els: T): T | Var => { const ctx = getCtx(); const b = getBlock(); - const p = getVar(ctx, b, cond); - - const at = ctx.varUnit(); // `then` and `els` blocks take in `Unit`-type arg - const bt = new ffi.Block(); - let nt: number; // block ID - try { - setBlock(bt); - const rt = getVar(ctx, bt, then()); - nt = ctx.block(bt, at, rt); - } catch (e) { - bt.free(); - throw e; - } finally { - setBlock(b); - } - - const af = ctx.varUnit(); - const bf = new ffi.Block(); - let nf: number; - try { - setBlock(bf); - const rf = getVar(ctx, bf, els()); - nf = ctx.block(bf, af, rf); - } catch (e) { - bf.free(); - throw e; - } finally { - setBlock(b); - } - - return { ctx, id: ctx.cond(b, p, nt, nf) }; + const t = getVar(ctx, b, then); + const e = getVar(ctx, b, els); + return { ctx, id: ctx.select(b, p, t, e) }; }; diff --git a/packages/core/src/debug.test.ts b/packages/core/src/debug.test.ts index 110d5de..b126fea 100644 --- a/packages/core/src/debug.test.ts +++ b/packages/core/src/debug.test.ts @@ -6,7 +6,6 @@ import { abs, add, and, - cond, div, eq, fn, @@ -20,6 +19,7 @@ import { neq, not, or, + select, sqrt, sub, xor, @@ -28,8 +28,9 @@ import { test("core IR type layouts", () => { // these don't matter too much, but it's good to notice if sizes increase expect(Object.fromEntries(wasm.layouts())).toEqual({ - Expr: { size: 16, align: 8 }, - Instr: { size: 24, align: 8 }, + Expr: { size: 24, align: 8 }, + Function: { size: 44, align: 4 }, + Instr: { size: 32, align: 8 }, Ty: { size: 16, align: 4 }, }); }); @@ -38,17 +39,9 @@ describe("pprint", () => { test("if", () => { const f = fn([Real, Real], Real, (x, y) => { const p = lt(x, y); - const z = cond( - p, - () => { - const a = mul(x, y); - return add(a, x); - }, - () => { - const b = sub(y, x); - return mul(b, y); - }, - ); + const a = mul(x, y); + const b = sub(y, x); + const z = select(p, add(a, x), mul(b, y)); const w = add(z, x); return add(y, w); }); @@ -57,26 +50,16 @@ describe("pprint", () => { ` T0 = Bool T1 = F64 -T2 = (T1, T1) -T3 = Unit -x0: T2 -> T1 { - x1: T1 = x0.0 - x2: T1 = x0.1 - x3: T0 = x1 < x2 - x10: T1 = if x3 { - x4: T3 - x5: T1 = x1 * x2 - x6: T1 = x5 + x1 - x6 - } else { - x7: T3 - x8: T1 = x2 - x1 - x9: T1 = x8 * x2 - x9 - } - x11: T1 = x10 + x1 - x12: T1 = x2 + x11 - x12 +(x0: T1, x1: T1) -> T1 { + x2: T0 = x0 < x1 + x3: T1 = x0 * x1 + x4: T1 = x1 - x0 + x5: T1 = x3 + x0 + x6: T1 = x4 * x1 + x7: T1 = x2 ? x5 : x6 + x8: T1 = x7 + x0 + x9: T1 = x1 + x8 + x9 } `.trimStart(), ); @@ -94,17 +77,11 @@ x0: T2 -> T1 { ` T0 = Bool T1 = F64 -T2 = (T1) -f0 = F0<> -f1 = F1<> -x0: T2 -> T1 { - x1: T1 = x0.0 - x2: T2 = (x1) - x3: T1 = f0(x2) - x4: T2 = (x1) - x5: T1 = f1(x4) - x6: T1 = x3 + x5 - x6 +(x0: T1) -> T1 { + x1: T1 = f0<>(x0) + x2: T1 = f1<>(x0) + x3: T1 = x1 + x2 + x3 } `.trimStart(), ); @@ -122,15 +99,13 @@ x0: T2 -> T1 { ` T0 = Bool T1 = F64 -T2 = (T1) -x0: T2 -> T1 { - x1: T1 = x0.0 - x2: T0 = true - x3: T0 = not x2 - x4: T1 = -x1 - x5: T1 = |x4| - x6: T1 = sqrt(x1) - x6 +(x0: T1) -> T1 { + x1: T0 = true + x2: T0 = not x1 + x3: T1 = -x0 + x4: T1 = |x3| + x5: T1 = sqrt(x0) + x5 } `.trimStart(), ); @@ -157,33 +132,30 @@ x0: T2 -> T1 { ` T0 = Bool T1 = F64 -T2 = (T1, T1) -x0: T2 -> T0 { - x1: T1 = x0.0 - x2: T1 = x0.1 - x3: T1 = x1 + x2 - x4: T1 = x1 - x2 - x5: T1 = x1 * x2 - x6: T1 = x1 / x2 - x7: T0 = true - x8: T0 = false - x9: T0 = x7 and x8 - x10: T0 = true - x11: T0 = false - x12: T0 = x10 or x11 - x13: T0 = true - x14: T0 = false - x15: T0 = x13 iff x14 - x16: T0 = true - x17: T0 = false - x18: T0 = x16 xor x17 - x19: T0 = x1 != x2 - x20: T0 = x1 < x2 - x21: T0 = x1 <= x2 - x22: T0 = x1 == x2 - x23: T0 = x1 > x2 - x24: T0 = x5 >= x6 - x24 +(x0: T1, x1: T1) -> T0 { + x2: T1 = x0 + x1 + x3: T1 = x0 - x1 + x4: T1 = x0 * x1 + x5: T1 = x0 / x1 + x6: T0 = true + x7: T0 = false + x8: T0 = x6 and x7 + x9: T0 = true + x10: T0 = false + x11: T0 = x9 or x10 + x12: T0 = true + x13: T0 = false + x14: T0 = x12 iff x13 + x15: T0 = true + x16: T0 = false + x17: T0 = x15 xor x16 + x18: T0 = x0 != x1 + x19: T0 = x0 < x1 + x20: T0 = x0 <= x1 + x21: T0 = x0 == x1 + x22: T0 = x0 > x1 + x23: T0 = x4 >= x5 + x23 } `.trimStart(), ); diff --git a/packages/core/src/ffi.ts b/packages/core/src/ffi.ts index 5abbbc3..7b80559 100644 --- a/packages/core/src/ffi.ts +++ b/packages/core/src/ffi.ts @@ -24,8 +24,8 @@ export interface Fn { f: wasm.Func; } -export const bake = (ctx: wasm.Context, main: number): Fn => { - const f = wasm.bake(ctx, main); +export const bake = (ctx: wasm.Context, ret: number, main: wasm.Block): Fn => { + const f = wasm.bake(ctx, ret, main); const fn: Fn = { f }; registry.register(fn, () => f.free()); return fn; @@ -39,24 +39,16 @@ export interface Body { */ ctx: wasm.Context; main: wasm.Block; - arg: number; - args: number[]; } export const make = ( generics: number, types: Ty[], params: Uint32Array, - ret: number, ): Body => { - const x = wasm.make(generics, types, params, ret); + const x = wasm.make(generics, types, params); try { - return { - ctx: x.ctx(), - main: x.main(), - arg: x.arg, - args: Array.from(x.args()), - }; + return { ctx: x.ctx(), main: x.main() }; } finally { x.free(); } @@ -66,8 +58,8 @@ export const interp = ( f: Fn, types: Ty[], generics: Uint32Array, - arg: Val, -): Val => wasm.interp(f.f, types, generics, arg); + args: Val[], +): Val => wasm.interp(f.f, types, generics, args); export { Block, Context } from "@rose-lang/wasm"; export type { Ty, Val }; diff --git a/packages/core/src/fn.ts b/packages/core/src/fn.ts index 91f7217..082c1cf 100644 --- a/packages/core/src/fn.ts +++ b/packages/core/src/fn.ts @@ -25,12 +25,9 @@ export interface Fn { const call = (f: Fn, args: Val[]): Val => { const ctx = getCtx(); const b = getBlock(); - const x = ctx.tuple( - b, - new Uint32Array(args.map((arg) => getVar(ctx, b, arg))), - ); - const i = ctx.func(f.f.f, new Uint32Array()); - const y: Var = { ctx, id: ctx.call(b, i, x) }; + const vars = new Uint32Array(args.map((arg) => getVar(ctx, b, arg))); + const generics = new Uint32Array(); // TODO + const y: Var = { ctx, id: ctx.call(b, f.f.f, generics, vars) }; return y; }; @@ -81,19 +78,13 @@ export const fn = ( throw Error("can't define a function while defining another function"); const sig = makeSig(params, ret); let func: ffi.Fn; - const { - ctx, - main, - arg, - args: ids, - } = ffi.make(0, sig.types, sig.params, sig.ret); + const { ctx, main } = ffi.make(0, sig.types, sig.params); try { setCtx(ctx); setBlock(main); - const x = f(...(ids.map((id): Var => ({ ctx, id })) as Args)); + const x = f(...(params.map((_, id): Var => ({ ctx, id })) as Args)); const y = getVar(ctx, main, x); - const b = ctx.block(main, arg, y); - func = ffi.bake(ctx, b); + func = ffi.bake(ctx, y, main); } catch (e) { // `ctx` and `main` point into Wasm memory, so if we didn't finish the // `ffi.bake` call above then we need to be sure to `free` them diff --git a/packages/core/src/index.test.ts b/packages/core/src/index.test.ts index 8a39b2d..8eaa49e 100644 --- a/packages/core/src/index.test.ts +++ b/packages/core/src/index.test.ts @@ -3,12 +3,12 @@ import { Bool, Real, add, - cond, div, fn, interp, lt, mul, + select, sub, } from "./index.js"; @@ -25,26 +25,14 @@ test("basic arithmetic", () => { }); test("branch", () => { - const f = fn([Bool], Real, (x) => - cond( - x, - () => 1, - () => 2, - ), - ); + const f = fn([Bool], Real, (x) => select(x, 1, 2)); const g = interp(f); expect(g(false)).toBe(2); expect(g(true)).toBe(1); }); test("call", () => { - const ifCond = fn([Bool, Real, Real], Real, (p, x, y) => - cond( - p, - () => x, - () => y, - ), - ); + const ifCond = fn([Bool, Real, Real], Real, (p, x, y) => select(p, x, y)); const f = fn([Real], Real, (x) => ifCond(lt(x, 0), 0, x)); const relu = interp(f); expect(relu(-1)).toBe(0); diff --git a/packages/core/src/index.ts b/packages/core/src/index.ts index 67fe019..cc27752 100644 --- a/packages/core/src/index.ts +++ b/packages/core/src/index.ts @@ -8,7 +8,7 @@ export const Bool: Bools = { tag: "Bool" }; export type Real = real.Real; export const Real: Reals = { tag: "Real" }; -export { and, cond, iff, not, or, xor } from "./bool.js"; +export { and, iff, not, or, select, xor } from "./bool.js"; export { fn } from "./fn.js"; export { interp } from "./interp.js"; export { diff --git a/packages/core/src/interp.ts b/packages/core/src/interp.ts index f0685a5..5d4df83 100644 --- a/packages/core/src/interp.ts +++ b/packages/core/src/interp.ts @@ -34,6 +34,6 @@ export const interp = // just return a closure that calls the interpreter (...args: Args) => { // TODO: support generics - const x = ffi.interp(f.f, [], new Uint32Array(), { Tuple: args.map(pack) }); + const x = ffi.interp(f.f, [], new Uint32Array(), args.map(pack)); return unpack(x) as Resolve; };