Skip to content

Commit

Permalink
Add test cases
Browse files Browse the repository at this point in the history
  • Loading branch information
ravenrothkopf committed Jul 28, 2023
1 parent 10a1284 commit 422dc12
Show file tree
Hide file tree
Showing 6 changed files with 97 additions and 11 deletions.
75 changes: 72 additions & 3 deletions crates/autodiff/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use enumset::EnumSet;
use indexmap::IndexSet;
use rose::{id, Binop, Block, Constraint, Expr, FuncNode, Function, Instr, Ty, Unop};
use rose::{id, Binop, Block, Constraint, Expr, Func, FuncNode, Function, Instr, Ty, Unop};
use std::collections::HashMap;

pub struct Derivative {
Expand Down Expand Up @@ -424,7 +424,6 @@ impl Forward<'_> {
if let Some(&new_id) = self.block_mapping.get(&old_id) {
return new_id;
}

// otherwise, process the block
let old = &self.old_blocks[old_id.block()];

Expand Down Expand Up @@ -469,6 +468,7 @@ pub fn forward(f: Derivative) -> Function {
blocks: vec![],
block_mapping: HashMap::new(),
};
let unitvar = g.unitvar();
for ty in &f.types {
let primal = match ty {
Ty::Unit => Ty::Unit,
Expand All @@ -477,7 +477,7 @@ pub fn forward(f: Derivative) -> Function {
&Ty::Fin { size } => Ty::Fin { size },
&Ty::Generic { id } => Ty::Generic { id },
Ty::Scope { id } => {
let processed_block = g.block(*id, g.unitvar(), vec![]);
let processed_block = g.block(*id, unitvar, vec![]);
Ty::Scope {
id: processed_block,
}
Expand Down Expand Up @@ -536,6 +536,7 @@ pub fn forward(f: Derivative) -> Function {
let main = g.block(f.main, arg, code);

let b = &g.blocks[main.block()];

Function {
generics: g.generics,
types: g.types.into_iter().collect(),
Expand All @@ -555,3 +556,71 @@ pub fn unzip(f: Derivative) -> (Function, Linear) {
pub fn transpose(f: Linear) -> Linear {
f // TODO
}

#[cfg(test)]
mod tests {
use super::*;

#[derive(Clone, Debug)]
struct TestFuncNode {
f: Function,
}

impl rose::FuncNode for TestFuncNode {
fn def(&self) -> &rose::Function {
&self.f
}

fn get(&self, _id: id::Function) -> Option<Self> {
None
}
}

fn func1() -> Function {
Function {
generics: vec![],
types: vec![Ty::Unit, Ty::F64],
funcs: vec![],
param: id::ty(0),
ret: id::ty(1),
vars: vec![id::ty(0), id::ty(1)],
blocks: vec![Block {
arg: id::var(0),
code: vec![Instr {
var: id::var(1),
expr: Expr::F64 { val: 42. },
}],
ret: id::var(1),
}],
main: id::block(0),
}
}

#[test]
fn test_block_mapping() {
// get funcs
let og_func = TestFuncNode { f: func1() };
let cloned_func = og_func.f.clone();
let derivative = derivative(og_func);
let new_func = forward(derivative);

// extract blocks
let old_blocks = cloned_func.blocks;
let new_blocks = new_func.blocks;

// check block mapping for each block index
for i in 0..old_blocks.len() {
let old_block = &old_blocks[i];
let new_block = &new_blocks[i];
// check if the blocks have the same indices
assert_eq!(old_block.arg.var(), new_block.arg.var());
assert_eq!(old_block.ret.var(), new_block.ret.var());
}

// check if the new block index is strictly greater than the number of blocks in the original function
for new_block in new_blocks.iter().skip(old_blocks.len()) {
assert!(new_block.arg.var() > old_blocks.len());
assert!(new_block.ret.var() > old_blocks.len());
}
}
}
14 changes: 14 additions & 0 deletions crates/frontend/tests/interp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,17 @@ fn test_sub() {
.unwrap();
assert_eq!(answer, Val::F64(0.));
}

#[test]
fn test_mul() {
let src = include_str!("mul.rose");
let module = parse(src).unwrap();
let answer = interp(
module.get_func("mul").unwrap(),
IndexSet::new(),
&[],
Val::Tuple(Rc::new(vec![Val::F64(2.), Val::F64(2.)])),
)
.unwrap();
assert_eq!(answer, Val::F64(0.));
}
1 change: 1 addition & 0 deletions crates/frontend/tests/mul.rose
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
def mul(x: R, y: R): R = x * y
1 change: 1 addition & 0 deletions crates/web/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ pub fn pprint(f: &Func) -> Result<String, JsError> {
rose::Unop::Neg => writeln!(&mut s, "-x{}", arg.var())?,
rose::Unop::Abs => writeln!(&mut s, "|x{}|", arg.var())?,
rose::Unop::Sqrt => writeln!(&mut s, "sqrt(x{})", arg.var())?,
rose::Unop::Sign => writeln!(&mut s, "sign(x{})", arg.var())?,
},
rose::Expr::Binary { op, left, right } => match op {
rose::Binop::And => writeln!(&mut s, "x{} and x{}", left.var(), right.var())?,
Expand Down
9 changes: 9 additions & 0 deletions packages/core/src/autodiff.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
import { expect, test } from "vitest";
import { Real, derivative, fn, interp, mul } from "./index.js";

test("derivative", () => {
const f = fn([Real], Real, (x) => mul(x, x));
const g = derivative(f);
const h = interp(g as any);
expect(h(3, 1)).toBe(6);
});
8 changes: 0 additions & 8 deletions packages/core/src/index.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import {
Real,
add,
cond,
derivative,
div,
fn,
interp,
Expand Down Expand Up @@ -52,10 +51,3 @@ test("call", () => {
expect(relu(0)).toBe(0);
expect(relu(1)).toBe(1);
});

test("derivative", () => {
const f = fn([Real], Real, (x) => mul(x, x));
const g = derivative(f);
const h = interp(g as any);
expect(h(3, 1)).toBe(6);
});

0 comments on commit 422dc12

Please sign in to comment.