From 422dc12262242507cdabb7b1a346b1febc348aaf Mon Sep 17 00:00:00 2001 From: Raven Rothkopf Date: Fri, 28 Jul 2023 09:19:50 -0400 Subject: [PATCH] Add test cases --- crates/autodiff/src/lib.rs | 75 ++++++++++++++++++++++++++++-- crates/frontend/tests/interp.rs | 14 ++++++ crates/frontend/tests/mul.rose | 1 + crates/web/src/lib.rs | 1 + packages/core/src/autodiff.test.ts | 9 ++++ packages/core/src/index.test.ts | 8 ---- 6 files changed, 97 insertions(+), 11 deletions(-) create mode 100644 crates/frontend/tests/mul.rose create mode 100644 packages/core/src/autodiff.test.ts diff --git a/crates/autodiff/src/lib.rs b/crates/autodiff/src/lib.rs index 6bd357d..7b633ef 100644 --- a/crates/autodiff/src/lib.rs +++ b/crates/autodiff/src/lib.rs @@ -1,6 +1,6 @@ use enumset::EnumSet; use indexmap::IndexSet; -use rose::{id, Binop, Block, Constraint, Expr, FuncNode, Function, Instr, Ty, Unop}; +use rose::{id, Binop, Block, Constraint, Expr, Func, FuncNode, Function, Instr, Ty, Unop}; use std::collections::HashMap; pub struct Derivative { @@ -424,7 +424,6 @@ impl Forward<'_> { if let Some(&new_id) = self.block_mapping.get(&old_id) { return new_id; } - // otherwise, process the block let old = &self.old_blocks[old_id.block()]; @@ -469,6 +468,7 @@ pub fn forward(f: Derivative) -> Function { blocks: vec![], block_mapping: HashMap::new(), }; + let unitvar = g.unitvar(); for ty in &f.types { let primal = match ty { Ty::Unit => Ty::Unit, @@ -477,7 +477,7 @@ pub fn forward(f: Derivative) -> Function { &Ty::Fin { size } => Ty::Fin { size }, &Ty::Generic { id } => Ty::Generic { id }, Ty::Scope { id } => { - let processed_block = g.block(*id, g.unitvar(), vec![]); + let processed_block = g.block(*id, unitvar, vec![]); Ty::Scope { id: processed_block, } @@ -536,6 +536,7 @@ pub fn forward(f: Derivative) -> Function { let main = g.block(f.main, arg, code); let b = &g.blocks[main.block()]; + Function { generics: g.generics, types: g.types.into_iter().collect(), @@ -555,3 +556,71 @@ pub fn unzip(f: Derivative) -> (Function, Linear) { pub fn transpose(f: Linear) -> Linear { f // TODO } + +#[cfg(test)] +mod tests { + use super::*; + + #[derive(Clone, Debug)] + struct TestFuncNode { + f: Function, + } + + impl rose::FuncNode for TestFuncNode { + fn def(&self) -> &rose::Function { + &self.f + } + + fn get(&self, _id: id::Function) -> Option { + None + } + } + + fn func1() -> Function { + Function { + generics: vec![], + types: vec![Ty::Unit, Ty::F64], + funcs: vec![], + param: id::ty(0), + ret: id::ty(1), + vars: vec![id::ty(0), id::ty(1)], + blocks: vec![Block { + arg: id::var(0), + code: vec![Instr { + var: id::var(1), + expr: Expr::F64 { val: 42. }, + }], + ret: id::var(1), + }], + main: id::block(0), + } + } + + #[test] + fn test_block_mapping() { + // get funcs + let og_func = TestFuncNode { f: func1() }; + let cloned_func = og_func.f.clone(); + let derivative = derivative(og_func); + let new_func = forward(derivative); + + // extract blocks + let old_blocks = cloned_func.blocks; + let new_blocks = new_func.blocks; + + // check block mapping for each block index + for i in 0..old_blocks.len() { + let old_block = &old_blocks[i]; + let new_block = &new_blocks[i]; + // check if the blocks have the same indices + assert_eq!(old_block.arg.var(), new_block.arg.var()); + assert_eq!(old_block.ret.var(), new_block.ret.var()); + } + + // check if the new block index is strictly greater than the number of blocks in the original function + for new_block in new_blocks.iter().skip(old_blocks.len()) { + assert!(new_block.arg.var() > old_blocks.len()); + assert!(new_block.ret.var() > old_blocks.len()); + } + } +} diff --git a/crates/frontend/tests/interp.rs b/crates/frontend/tests/interp.rs index 5caab1e..5f47605 100644 --- a/crates/frontend/tests/interp.rs +++ b/crates/frontend/tests/interp.rs @@ -30,3 +30,17 @@ fn test_sub() { .unwrap(); assert_eq!(answer, Val::F64(0.)); } + +#[test] +fn test_mul() { + let src = include_str!("mul.rose"); + let module = parse(src).unwrap(); + let answer = interp( + module.get_func("mul").unwrap(), + IndexSet::new(), + &[], + Val::Tuple(Rc::new(vec![Val::F64(2.), Val::F64(2.)])), + ) + .unwrap(); + assert_eq!(answer, Val::F64(0.)); +} diff --git a/crates/frontend/tests/mul.rose b/crates/frontend/tests/mul.rose new file mode 100644 index 0000000..2cb7992 --- /dev/null +++ b/crates/frontend/tests/mul.rose @@ -0,0 +1 @@ +def mul(x: R, y: R): R = x * y \ No newline at end of file diff --git a/crates/web/src/lib.rs b/crates/web/src/lib.rs index 7b75cb7..44c21bb 100644 --- a/crates/web/src/lib.rs +++ b/crates/web/src/lib.rs @@ -107,6 +107,7 @@ pub fn pprint(f: &Func) -> Result { rose::Unop::Neg => writeln!(&mut s, "-x{}", arg.var())?, rose::Unop::Abs => writeln!(&mut s, "|x{}|", arg.var())?, rose::Unop::Sqrt => writeln!(&mut s, "sqrt(x{})", arg.var())?, + rose::Unop::Sign => writeln!(&mut s, "sign(x{})", arg.var())?, }, rose::Expr::Binary { op, left, right } => match op { rose::Binop::And => writeln!(&mut s, "x{} and x{}", left.var(), right.var())?, diff --git a/packages/core/src/autodiff.test.ts b/packages/core/src/autodiff.test.ts new file mode 100644 index 0000000..8d893e7 --- /dev/null +++ b/packages/core/src/autodiff.test.ts @@ -0,0 +1,9 @@ +import { expect, test } from "vitest"; +import { Real, derivative, fn, interp, mul } from "./index.js"; + +test("derivative", () => { + const f = fn([Real], Real, (x) => mul(x, x)); + const g = derivative(f); + const h = interp(g as any); + expect(h(3, 1)).toBe(6); +}); diff --git a/packages/core/src/index.test.ts b/packages/core/src/index.test.ts index 70afe2a..8a39b2d 100644 --- a/packages/core/src/index.test.ts +++ b/packages/core/src/index.test.ts @@ -4,7 +4,6 @@ import { Real, add, cond, - derivative, div, fn, interp, @@ -52,10 +51,3 @@ test("call", () => { expect(relu(0)).toBe(0); expect(relu(1)).toBe(1); }); - -test("derivative", () => { - const f = fn([Real], Real, (x) => mul(x, x)); - const g = derivative(f); - const h = interp(g as any); - expect(h(3, 1)).toBe(6); -});