From c6e8aa6127dfecd57cc455ef8f337c5f44f0527b Mon Sep 17 00:00:00 2001 From: Tianqi Chen Date: Sun, 10 Mar 2024 21:35:20 -0400 Subject: [PATCH] [WEB] Initial support for asyncify (#16694) This PR enables asyncify support for web runtime. Asyncify is a feature to allow C++ to call async function in javascript. The emcc compiler will unwind and store the stack, returning control to JS runtime. The JS runtime needs to be able to await the promise and then call rewind to get to the original suspended point. This feature can be potentially useful when we would like to call WebGPU sync in C++ runtime. As on web platform everything have to be non-blocking. Because asyncify can increase the wasm size by 2x, we don't enable it by default in emcc.py and still would need to pass in options. We will confirm potential benefit tradeoffs before turning it on by default. Another catch is that as of now asyncify is not compatible with wasm exception, so we temporary turn wasm-exception it off for now. This is an item that is being worked on by emscripten so we might be able to turn it back on later. The testcases are added. reference: https://emscripten.org/docs/porting/asyncify.html --- python/tvm/contrib/emcc.py | 9 +- src/runtime/c_runtime_api.cc | 1 - web/Makefile | 5 +- web/apps/node/example.js | 2 +- web/emcc/decorate_as_wasi.py | 1 + web/emcc/wasm_runtime.cc | 5 + web/emcc/webgpu_runtime.cc | 6 +- web/src/artifact_cache.ts | 46 ++++-- web/src/asyncify.ts | 227 +++++++++++++++++++++++++++++ web/src/runtime.ts | 77 +++++++++- web/src/support.ts | 12 ++ web/tests/node/test_packed_func.js | 30 ++++ 12 files changed, 395 insertions(+), 26 deletions(-) create mode 100644 web/src/asyncify.ts diff --git a/python/tvm/contrib/emcc.py b/python/tvm/contrib/emcc.py index fac2043215865..07ff29205e10c 100644 --- a/python/tvm/contrib/emcc.py +++ b/python/tvm/contrib/emcc.py @@ -42,7 +42,14 @@ def create_tvmjs_wasm(output, objects, options=None, cc="emcc"): cmd += ["-O3"] cmd += ["-std=c++17"] cmd += ["--no-entry"] - cmd += ["-fwasm-exceptions"] + # NOTE: asynctify conflicts with wasm-exception + # so we temp disable exception handling for now + # + # We also expect user to explicitly pass in + # -s ASYNCIFY=1 as it can increase wasm size by 2xq + # + # cmd += ["-s", "ASYNCIFY=1"] + # cmd += ["-fwasm-exceptions"] cmd += ["-s", "WASM_BIGINT=1"] cmd += ["-s", "ERROR_ON_UNDEFINED_SYMBOLS=0"] cmd += ["-s", "STANDALONE_WASM=1"] diff --git a/src/runtime/c_runtime_api.cc b/src/runtime/c_runtime_api.cc index 799ef116ce8cd..ea22b89dd7719 100644 --- a/src/runtime/c_runtime_api.cc +++ b/src/runtime/c_runtime_api.cc @@ -569,7 +569,6 @@ int TVMByteArrayFree(TVMByteArray* arr) { int TVMFuncCall(TVMFunctionHandle func, TVMValue* args, int* arg_type_codes, int num_args, TVMValue* ret_val, int* ret_type_code) { API_BEGIN(); - TVMRetValue rv; (static_cast(func)) ->CallPacked(TVMArgs(args, arg_type_codes, num_args), &rv); diff --git a/web/Makefile b/web/Makefile index bd5e6cbf2bd94..317438842b23e 100644 --- a/web/Makefile +++ b/web/Makefile @@ -27,10 +27,11 @@ all: dist/wasm/tvmjs_runtime.wasm dist/wasm/tvmjs_runtime.wasi.js src/tvmjs_runt EMCC = emcc -EMCC_CFLAGS = $(INCLUDE_FLAGS) -O3 -std=c++17 -Wno-ignored-attributes -fwasm-exceptions +EMCC_CFLAGS = $(INCLUDE_FLAGS) -O3 -std=c++17 -Wno-ignored-attributes EMCC_LDFLAGS = --no-entry -s WASM_BIGINT=1 -s ALLOW_MEMORY_GROWTH=1 -s STANDALONE_WASM=1\ - -s ERROR_ON_UNDEFINED_SYMBOLS=0 --pre-js emcc/preload.js + -s ERROR_ON_UNDEFINED_SYMBOLS=0 --pre-js emcc/preload.js\ + -s ASYNCIFY=1 dist/wasm/%.bc: emcc/%.cc @mkdir -p $(@D) diff --git a/web/apps/node/example.js b/web/apps/node/example.js index d17ec072fa21f..580bbf57ab803 100644 --- a/web/apps/node/example.js +++ b/web/apps/node/example.js @@ -21,7 +21,7 @@ */ const path = require("path"); const fs = require("fs"); -const tvmjs = require("../../lib"); +const tvmjs = require("../../dist/tvmjs.bundle"); const wasmPath = tvmjs.wasmPath(); const wasmSource = fs.readFileSync(path.join(wasmPath, "tvmjs_runtime.wasm")); diff --git a/web/emcc/decorate_as_wasi.py b/web/emcc/decorate_as_wasi.py index bce0dbb80e9f4..6d6b0a7b82dca 100644 --- a/web/emcc/decorate_as_wasi.py +++ b/web/emcc/decorate_as_wasi.py @@ -20,6 +20,7 @@ template_head = """ function EmccWASI() { +var asyncifyStubs = {}; """ template_tail = """ diff --git a/web/emcc/wasm_runtime.cc b/web/emcc/wasm_runtime.cc index be9704eaef992..8543361340e70 100644 --- a/web/emcc/wasm_runtime.cc +++ b/web/emcc/wasm_runtime.cc @@ -100,6 +100,11 @@ TVM_REGISTER_GLOBAL("testing.echo").set_body([](TVMArgs args, TVMRetValue* ret) *ret = args[0]; }); +TVM_REGISTER_GLOBAL("testing.call").set_body([](TVMArgs args, TVMRetValue* ret) { + (args[0].operator PackedFunc()) + .CallPacked(TVMArgs(args.values + 1, args.type_codes + 1, args.num_args - 1), ret); +}); + TVM_REGISTER_GLOBAL("testing.ret_string").set_body([](TVMArgs args, TVMRetValue* ret) { *ret = args[0].operator String(); }); diff --git a/web/emcc/webgpu_runtime.cc b/web/emcc/webgpu_runtime.cc index ce2a7cadb68eb..1d7dbe0787b27 100644 --- a/web/emcc/webgpu_runtime.cc +++ b/web/emcc/webgpu_runtime.cc @@ -112,7 +112,11 @@ class WebGPUDeviceAPI : public DeviceAPI { LOG(FATAL) << "Not implemented"; } - void StreamSync(Device dev, TVMStreamHandle stream) final { LOG(FATAL) << "Not implemented"; } + void StreamSync(Device dev, TVMStreamHandle stream) final { + static const PackedFunc* func = runtime::Registry::Get("__asyncify.WebGPUWaitForTasks"); + ICHECK(func != nullptr) << "Stream sync inside c++ only supported in asyncify mode"; + (*func)(); + } void SetStream(Device dev, TVMStreamHandle stream) final { LOG(FATAL) << "Not implemented"; } diff --git a/web/src/artifact_cache.ts b/web/src/artifact_cache.ts index 394cda83bc437..ffb5011324f55 100644 --- a/web/src/artifact_cache.ts +++ b/web/src/artifact_cache.ts @@ -1,19 +1,37 @@ /* - Common Interface for the artifact cache -*/ + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/** + * Common Interface for the artifact cache + */ export interface ArtifactCacheTemplate { - /** - * fetch key url from cache - */ - fetchWithCache(url: string); + /** + * fetch key url from cache + */ + fetchWithCache(url: string); - /** - * check if cache has all keys in Cache - */ - hasAllKeys(keys: string[]); + /** + * check if cache has all keys in Cache + */ + hasAllKeys(keys: string[]); - /** - * Delete url in cache if url exists - */ - deleteInCache(url: string); + /** + * Delete url in cache if url exists + */ + deleteInCache(url: string); } diff --git a/web/src/asyncify.ts b/web/src/asyncify.ts new file mode 100644 index 0000000000000..703dbbf80a105 --- /dev/null +++ b/web/src/asyncify.ts @@ -0,0 +1,227 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +// Helper tools to enable asynctify handling +// Thie following code is used to support wrapping of +// functins that can have async await calls in the backend runtime +// reference +// - https://kripken.github.io/blog/wasm/2019/07/16/asyncify.html +// - https://github.com/GoogleChromeLabs/asyncify +import { assert, isPromise } from "./support"; + +/** + * enums to check the current state of asynctify + */ +const enum AsyncifyStateKind { + None = 0, + Unwinding = 1, + Rewinding = 2 +} + +/** The start location of asynctify stack data */ +const ASYNCIFY_DATA_ADDR = 16; +/** The data start of stack rewind/unwind */ +const ASYNCIFY_DATA_START = ASYNCIFY_DATA_ADDR + 8; +/** The data end of stack rewind/unwind */ +const ASYNCIFY_DATA_END = 1024; + +/** Hold asynctify handler instance that runtime can use */ +export class AsyncifyHandler { + /** exports from wasm */ + private exports: Record; + /** current state kind */ + private state: AsyncifyStateKind = AsyncifyStateKind.None; + /** The stored value before unwind */ + private storedPromiseBeforeUnwind : Promise = null; + // NOTE: asynctify do not work with exceptions + // this implementation here is mainly for possible future compact + /** The stored value that is resolved */ + private storedValueBeforeRewind: any = null; + /** The stored exception */ + private storedExceptionBeforeRewind: any = null; + + constructor(exports: Record, memory: WebAssembly.Memory) { + this.exports = exports; + this.initMemory(memory); + } + + // NOTE: wrapImport and wrapExport are closely related to each other + // We mark the logical jump pt in comments to increase the readability + /** + * Whether the wasm enables asynctify + * @returns Whether the wasm enables asynctify + */ + enabled(): boolean { + return this.exports.asyncify_stop_rewind !== undefined; + } + + /** + * Get the current asynctify state + * + * @returns The current asynctify state + */ + getState(): AsyncifyStateKind { + return this.state; + } + + /** + * Wrap a function that can be used as import of the wasm asynctify layer + * + * @param func The input import function + * @returns The wrapped function that can be registered to the system + */ + wrapImport(func: (...args: Array) => any): (...args: Array) => any { + return (...args: any) => { + // this is being called second time + // where we are rewinding the stack + if (this.getState() == AsyncifyStateKind.Rewinding) { + // JUMP-PT-REWIND: rewind will jump to this pt + // while rewinding the stack + this.stopRewind(); + // the value has been resolved + if (this.storedValueBeforeRewind !== null) { + assert(this.storedExceptionBeforeRewind === null); + const result = this.storedValueBeforeRewind; + this.storedValueBeforeRewind = null; + return result; + } else { + assert(this.storedValueBeforeRewind === null); + const error = this.storedExceptionBeforeRewind; + this.storedExceptionBeforeRewind = null; + throw error; + } + } + // this function is being called for the first time + assert(this.getState() == AsyncifyStateKind.None); + + // call the function + const value = func(...args); + // if the value is promise + // we need to unwind the stack + // so the caller will be able to evaluate the promise + if (isPromise(value)) { + // The next code step is JUMP-PT-UNWIND in wrapExport + // The value will be passed to that pt through storedPromiseBeforeUnwind + // getState() == Unwinding and we will enter the while loop in wrapExport + this.startUnwind(); + assert(this.storedPromiseBeforeUnwind == null); + this.storedPromiseBeforeUnwind = value; + return undefined; + } else { + // The next code step is JUMP-PT-UNWIND in wrapExport + // normal value, we don't have to do anything + // getState() == None and we will exit while loop there + return value; + } + }; + } + + /** + * Warp an exported asynctify function so it can return promise + * + * @param func The input function + * @returns The wrapped async function + */ + wrapExport(func: (...args: Array) => any): (...args: Array) => Promise { + return async (...args: Array) => { + assert(this.getState() == AsyncifyStateKind.None); + + // call the original function + let result = func(...args); + + // JUMP-PT-UNWIND + // after calling the function + // the caller may hit a unwinding point depending on + // the if (isPromise(value)) condition in wrapImport + while (this.getState() == AsyncifyStateKind.Unwinding) { + this.stopUnwind(); + // try to resolve the promise that the internal requested + // we then store it into the temp value in storedValueBeforeRewind + // which then get passed onto the function(see wrapImport) + // that can return the value + const storedPromiseBeforeUnwind = this.storedPromiseBeforeUnwind; + this.storedPromiseBeforeUnwind = null; + assert(this.storedExceptionBeforeRewind === null); + assert(this.storedValueBeforeRewind == null); + + try { + this.storedValueBeforeRewind = await storedPromiseBeforeUnwind; + } catch (error) { + // the store exception + this.storedExceptionBeforeRewind = error; + } + assert(!isPromise(this.storedValueBeforeRewind)); + // because we called asynctify_stop_unwind,the state is now none + assert(this.getState() == AsyncifyStateKind.None); + + // re-enter the function, jump to JUMP-PT-REWIND in wrapImport + // the value will be passed to that point via storedValueBeforeRewind + // + // NOTE: we guarantee that if exception is throw the asynctify state + // will already be at None, this is because we will goto JUMP-PT-REWIND + // which will call aynctify_stop_rewind + this.startRewind(); + result = func(...args); + } + return result; + }; + } + + private startRewind() : void { + if (this.exports.asyncify_start_rewind === undefined) { + throw Error("Asynctify is not enabled, please compile with -s ASYNCIFY=1 in emcc"); + } + this.exports.asyncify_start_rewind(ASYNCIFY_DATA_ADDR); + this.state = AsyncifyStateKind.Rewinding; + } + + private stopRewind() : void { + if (this.exports.asyncify_stop_rewind === undefined) { + throw Error("Asynctify is not enabled, please compile with -s ASYNCIFY=1 in emcc"); + } + this.exports.asyncify_stop_rewind(); + this.state = AsyncifyStateKind.None; + } + + private startUnwind() : void { + if (this.exports.asyncify_start_unwind === undefined) { + throw Error("Asynctify is not enabled, please compile with -s ASYNCIFY=1 in emcc"); + } + this.exports.asyncify_start_unwind(ASYNCIFY_DATA_ADDR); + this.state = AsyncifyStateKind.Unwinding; + } + + private stopUnwind() : void { + if (this.exports.asyncify_stop_unwind === undefined) { + throw Error("Asynctify is not enabled, please compile with -s ASYNCIFY=1 in emcc"); + } + this.exports.asyncify_stop_unwind(); + this.state = AsyncifyStateKind.None; + } + /** + * Initialize the wasm memory to setup necessary meta-data + * for asynctify handling + * @param memory The memory ti + */ + private initMemory(memory: WebAssembly.Memory): void { + // Set the meta-data at address ASYNCTIFY_DATA_ADDR + new Int32Array(memory.buffer, ASYNCIFY_DATA_ADDR, 2).set( + [ASYNCIFY_DATA_START, ASYNCIFY_DATA_END] + ); + } +} diff --git a/web/src/runtime.ts b/web/src/runtime.ts index 6ef2255263248..8df48c43a5f9c 100644 --- a/web/src/runtime.ts +++ b/web/src/runtime.ts @@ -25,6 +25,7 @@ import { Disposable } from "./types"; import { Memory, CachedCallStack } from "./memory"; import { assert, StringToUint8Array } from "./support"; import { Environment } from "./environment"; +import { AsyncifyHandler } from "./asyncify"; import { FunctionInfo, WebGPUContext } from "./webgpu"; import { ArtifactCacheTemplate } from "./artifact_cache"; @@ -32,11 +33,18 @@ import * as compact from "./compact"; import * as ctypes from "./ctypes"; /** - * Type for PackedFunc inthe TVMRuntime. + * Type for PackedFunc in the TVMRuntime. */ export type PackedFunc = ((...args: any) => any) & Disposable & { _tvmPackedCell: PackedFuncCell }; +/** + * Type for AyncPackedFunc in TVMRuntime + * possibly may contain stack unwinding through Asynctify + */ +export type AsyncPackedFunc = ((...args: any) => Promise) & + Disposable & { _tvmPackedCell: PackedFuncCell }; + /** * @internal * FFI Library wrapper, maintains most runtime states. @@ -79,7 +87,6 @@ class FFILibrary implements Disposable { if (code != 0) { const msgPtr = (this.exports .TVMGetLastError as ctypes.FTVMGetLastError)(); - console.log("Here"); throw new Error("TVMError: " + this.memory.loadCString(msgPtr)); } } @@ -1057,6 +1064,7 @@ export class Instance implements Disposable { private env: Environment; private objFactory: Map; private ctx: RuntimeContext; + private asyncifyHandler: AsyncifyHandler; private initProgressCallback: Array = []; /** @@ -1099,6 +1107,7 @@ export class Instance implements Disposable { this.lib = new FFILibrary(wasmInstance, env.imports); this.memory = this.lib.memory; this.exports = this.lib.exports; + this.asyncifyHandler = new AsyncifyHandler(this.exports, this.memory.memory); this.objFactory = new Map(); this.ctx = new RuntimeContext( (name: string) => { @@ -1140,6 +1149,14 @@ export class Instance implements Disposable { return results; } + /** + * Check whether we enabled asyncify mode + * @returns The asynctify mode toggle + */ + asyncifyEnabled(): boolean { + return this.asyncifyHandler.enabled(); + } + dispose(): void { // order matters // ctx release goes back into lib. @@ -1922,13 +1939,55 @@ export class Instance implements Disposable { } this.objFactory.set(typeIndex, func); } + + /** + * Wrap a function obtained from tvm runtime as AsyncPackedFunc + * through the asyncify mechanism + * + * You only need to call it if the function may contain callback into async + * JS function via asynctify. A common one can be GPU synchronize. + * + * It is always safe to wrap any function as Asynctify, however you do need + * to make sure you use await when calling the funciton. + * + * @param func The PackedFunc. + * @returns The wrapped AsyncPackedFunc + */ + wrapAsyncifyPackedFunc(func: PackedFunc): AsyncPackedFunc { + const asyncFunc = this.asyncifyHandler.wrapExport(func) as AsyncPackedFunc; + asyncFunc.dispose = func.dispose; + asyncFunc._tvmPackedCell = func._tvmPackedCell; + return asyncFunc; + } + + /** + * Register async function as asynctify callable in global environment. + * + * @param name The name of the function. + * @param func function to be registered. + * @param override Whether overwrite function in existing registry. + * + * @note This function is handled via asynctify mechanism + * The wasm needs to be compiled with Asynctify + */ + registerAsyncifyFunc( + name: string, + func: (...args: Array) => Promise, + override = false + ): void { + const asyncWrapped = this.asyncifyHandler.wrapImport(func); + this.registerFunc(name, asyncWrapped, override); + } + /** * Register an asyncfunction to be global function in the server. + * * @param name The name of the function. * @param func function to be registered. * @param override Whether overwrite function in existing registry. * - * @note The async function will only be used for serving remote calls in the rpc. + * @note The async function will only be used for serving remote calls in the rpc + * These functions contains explicit continuation */ registerAsyncServerFunc( name: string, @@ -2036,6 +2095,11 @@ export class Instance implements Disposable { this.registerAsyncServerFunc("wasm.WebGPUWaitForTasks", async () => { await webGPUContext.sync(); }); + if (this.asyncifyHandler.enabled()) { + this.registerAsyncifyFunc("__asyncify.WebGPUWaitForTasks", async () => { + await webGPUContext.sync(); + }); + } this.lib.webGPUContext = webGPUContext; } @@ -2281,7 +2345,6 @@ export class Instance implements Disposable { // normal return path // recycle all js object value in function unless we want to retain them. this.ctx.endScope(); - if (rv !== undefined && rv !== null) { const stack = lib.getOrAllocCallStack(); const valueOffset = stack.allocRawBytes(SizeOf.TVMValue); @@ -2320,8 +2383,10 @@ export class Instance implements Disposable { const rvaluePtr = stack.ptrFromOffset(rvalueOffset); const rcodePtr = stack.ptrFromOffset(rcodeOffset); - // commit to wasm memory, till rvalueOffset (the return value don't need to be committed) - stack.commitToWasmMemory(rvalueOffset); + // pre-store the rcode to be null, in case caller unwind + // and not have chance to reset this rcode. + stack.storeI32(rcodeOffset, ArgTypeCode.Null); + stack.commitToWasmMemory(); this.lib.checkCall( (this.exports.TVMFuncCall as ctypes.FTVMFuncCall)( diff --git a/web/src/support.ts b/web/src/support.ts index 18748c2c85ba0..b03fa363cdce7 100644 --- a/web/src/support.ts +++ b/web/src/support.ts @@ -17,6 +17,18 @@ * under the License. */ + +/** + * Check if value is a promise type + * + * @param value The input value + * @returns Whether value is promise + */ +export function isPromise(value: any): boolean { + return value !== undefined && ( + typeof value == "object" || typeof value == "function" + ) && typeof value.then == "function"; +} /** * Convert string to Uint8array. * @param str The string. diff --git a/web/tests/node/test_packed_func.js b/web/tests/node/test_packed_func.js index f5c0ac6c2fad5..e1d070f0e473d 100644 --- a/web/tests/node/test_packed_func.js +++ b/web/tests/node/test_packed_func.js @@ -22,6 +22,9 @@ const fs = require("fs"); const assert = require("assert"); const tvmjs = require("../../dist/tvmjs.bundle") +// for now skip exception testing +// as it may not be compatible with asyncify +const exceptionEnabled = false; const wasmPath = tvmjs.wasmPath(); const wasmSource = fs.readFileSync(path.join(wasmPath, "tvmjs_runtime.wasm")); @@ -127,6 +130,8 @@ test("RegisterGlobal", () => { }); test("ExceptionPassing", () => { + if (!exceptionEnabled) return; + tvm.beginScope(); tvm.registerFunc("throw_error", function (msg) { throw Error(msg); @@ -141,6 +146,31 @@ test("ExceptionPassing", () => { tvm.endScope(); }); + +test("AsyncifyFunc", async () => { + if (!tvm.asyncifyEnabled()) { + console.log("Skip asyncify tests as it is not enabled.."); + return; + } + tvm.beginScope(); + tvm.registerAsyncifyFunc("async_sleep_echo", async function (x) { + await new Promise(resolve => setTimeout(resolve, 10)); + return x; + }); + let fecho = tvm.wrapAsyncifyPackedFunc( + tvm.getGlobalFunc("async_sleep_echo") + ); + let fcall = tvm.wrapAsyncifyPackedFunc( + tvm.getGlobalFunc("testing.call") + ); + assert((await fecho(1)) == 1); + assert((await fecho(2)) == 2); + assert((await fcall(fecho, 2) == 2)); + tvm.endScope(); + assert(fecho._tvmPackedCell.getHandle(false) == 0); + assert(fcall._tvmPackedCell.getHandle(false) == 0); +}); + test("NDArrayCbArg", () => { tvm.beginScope(); let use_count = tvm.getGlobalFunc("testing.object_use_count");