diff --git a/packages/core/src/impl.ts b/packages/core/src/impl.ts index 99b02e3..7be8d11 100644 --- a/packages/core/src/impl.ts +++ b/packages/core/src/impl.ts @@ -532,7 +532,7 @@ export const fn = ( }; /** Construct an opaque function whose implementation runs `f`. */ -export const opaque = ( +export const opaque = ( params: P, ret: R, f: (...args: JsArgs>) => ToJs>, @@ -604,6 +604,8 @@ type ToJs = [T] extends [Null] ? number : [T] extends [Nat] ? number + : [T] extends [Vec] + ? ToJs[] : { [K in keyof T]: ToJs }; /** Map from an abstract value type array to a concrete argument type array. */ diff --git a/packages/core/src/index.test.ts b/packages/core/src/index.test.ts index 4528f64..961e9e6 100644 --- a/packages/core/src/index.test.ts +++ b/packages/core/src/index.test.ts @@ -942,4 +942,46 @@ describe("valid", () => { const h = await compile(g); expect(h({ v: [2], i: 0 })).toEqual({ v: [1], i: 0 }); }); + + test("sort", async () => { + const n = 3; + + const sortWithIndices = opaque( + [Vec(n, Real)], + Vec(n, struct({ x: Real, i: n })), + (v) => v.map((x, i) => ({ x, i })).sort((a, b) => a.x - b.x), + ); + sortWithIndices.jvp = fn([Vec(n, Dual)], Vec(n, { x: Dual, i: n }), (v) => { + const w = sortWithIndices(vec(n, Real, (i) => v[i].re)); + return vec(n, { x: Dual, i: n }, (j) => { + const { i } = w[j]; + return { x: v[i], i }; + }); + }); + + const sort = fn([Vec(n, Real)], Vec(n, Real), (v) => { + const w = sortWithIndices(v); + return vec(n, Real, (i) => w[i].x); + }); + + const sortGrad = fn([Vec(n, Real), Vec(n, Real)], Vec(n, Real), (v, g) => + vjp(sort)(v).grad(g), + ); + + const interpreted = interp(sortGrad); + expect(interpreted([42, 121, 342], [-1, -2, -3])).toEqual([-1, -2, -3]); + expect(interpreted([42, 342, 121], [-1, -2, -3])).toEqual([-1, -3, -2]); + expect(interpreted([121, 42, 342], [-1, -2, -3])).toEqual([-2, -1, -3]); + expect(interpreted([121, 342, 42], [-1, -2, -3])).toEqual([-2, -3, -1]); + expect(interpreted([342, 42, 121], [-1, -2, -3])).toEqual([-3, -1, -2]); + expect(interpreted([342, 121, 42], [-1, -2, -3])).toEqual([-3, -2, -1]); + + const compiled = await compile(sortGrad); + expect(compiled([42, 121, 342], [-1, -2, -3])).toEqual([-1, -2, -3]); + expect(compiled([42, 342, 121], [-1, -2, -3])).toEqual([-1, -3, -2]); + expect(compiled([121, 42, 342], [-1, -2, -3])).toEqual([-2, -1, -3]); + expect(compiled([121, 342, 42], [-1, -2, -3])).toEqual([-2, -3, -1]); + expect(compiled([342, 42, 121], [-1, -2, -3])).toEqual([-3, -1, -2]); + expect(compiled([342, 121, 42], [-1, -2, -3])).toEqual([-3, -2, -1]); + }); });