From fabbc566e80d850f838217025b08af3a6a276c7c Mon Sep 17 00:00:00 2001 From: Jianjun Zhu Date: Mon, 1 Jul 2024 16:33:34 +0800 Subject: [PATCH] Add NPU support for wasi-nn WinML backend. This change adds support for NPU (Neural Processing Unit) to the wasi-nn WinML backend. Since NPU support in DirectML is still in developer preview, only a subset of learning models are supported. --- Cargo.lock | 24 +++ crates/wasi-nn/Cargo.toml | 15 +- crates/wasi-nn/src/backend/winml.rs | 223 ++++++++++++++++++++++------ 3 files changed, 215 insertions(+), 47 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 47e695e49d79..019eed90d32b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4143,6 +4143,8 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e48a53791691ab099e5e2ad123536d0fff50652600abaf43bbf952894110d0be" dependencies = [ "windows-core", + "windows-implement", + "windows-interface", "windows-targets 0.52.0", ] @@ -4155,6 +4157,28 @@ dependencies = [ "windows-targets 0.52.0", ] +[[package]] +name = "windows-implement" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "12168c33176773b86799be25e2a2ba07c7aab9968b37541f1094dbd7a60c8946" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.60", +] + +[[package]] +name = "windows-interface" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9d8dc32e0095a7eeccebd0e3f09e9509365ecb3fc6ac4d6f5f14a3f6392942d1" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.60", +] + [[package]] name = "windows-sys" version = "0.48.0" diff --git a/crates/wasi-nn/Cargo.toml b/crates/wasi-nn/Cargo.toml index a5ace788d03a..fa0d5849e263 100644 --- a/crates/wasi-nn/Cargo.toml +++ b/crates/wasi-nn/Cargo.toml @@ -40,7 +40,20 @@ ort = { version = "2.0.0-rc.2", default-features = false, features = [ [target.'cfg(windows)'.dependencies.windows] version = "0.52" -features = ["AI_MachineLearning", "Storage_Streams", "Foundation_Collections"] +features = [ + "AI_MachineLearning", + "Storage_Streams", + "Foundation_Collections", + # For Int64 input support. + "implement", + # Following 6 features are needed for creating a LearningModelDevice from NPU. + "Win32_Foundation", + "Win32_Graphics_Direct3D", + "Win32_Graphics_Direct3D12", + "Win32_Graphics_Dxgi", + "Win32_Graphics_DXCore", + "Win32_System_WinRT_ML", +] optional = true [build-dependencies] diff --git a/crates/wasi-nn/src/backend/winml.rs b/crates/wasi-nn/src/backend/winml.rs index b87510bfc877..52405351f5c0 100644 --- a/crates/wasi-nn/src/backend/winml.rs +++ b/crates/wasi-nn/src/backend/winml.rs @@ -14,14 +14,26 @@ use crate::wit::{ExecutionTarget, GraphEncoding, Tensor, TensorType}; use crate::{ExecutionContext, Graph}; use std::{fs::File, io::Read, mem::size_of, path::Path}; use windows::core::{ComInterface, HSTRING}; -use windows::Foundation::Collections::IVectorView; +use windows::Foundation::Collections::{IVectorView, IIterable}; use windows::Storage::Streams::{ DataWriter, InMemoryRandomAccessStream, RandomAccessStreamReference, }; +use windows::Win32::Graphics::DXCore::{ + DXCoreCreateAdapterFactory, IDXCoreAdapter, IDXCoreAdapterFactory, IDXCoreAdapterList, + DXCORE_ADAPTER_ATTRIBUTE_D3D12_CORE_COMPUTE, DXCORE_ADAPTER_ATTRIBUTE_D3D12_GRAPHICS, +}; +use windows::Win32::Graphics::{ + Direct3D::D3D_FEATURE_LEVEL_1_0_CORE, + Direct3D12::{ + D3D12CreateDevice, ID3D12CommandQueue, ID3D12Device, D3D12_COMMAND_LIST_TYPE_COMPUTE, + D3D12_COMMAND_QUEUE_DESC, D3D12_COMMAND_QUEUE_FLAG_NONE, + }, +}; +use windows::Win32::System::WinRT::ML::ILearningModelDeviceFactoryNative; use windows::AI::MachineLearning::{ ILearningModelFeatureDescriptor, LearningModel, LearningModelBinding, LearningModelDevice, LearningModelDeviceKind, LearningModelEvaluationResult, LearningModelSession, - TensorFeatureDescriptor, TensorFloat, + TensorFeatureDescriptor, TensorFloat, TensorFloat16Bit, TensorInt64Bit, TensorKind, }; #[derive(Default)] @@ -45,12 +57,66 @@ impl BackendInner for WinMLBackend { let model = LearningModel::LoadFromStream(&RandomAccessStreamReference::CreateFromStream( &model_stream, )?)?; - let device_kind = match target { - ExecutionTarget::Cpu => LearningModelDeviceKind::Cpu, - ExecutionTarget::Gpu => LearningModelDeviceKind::DirectX, - ExecutionTarget::Tpu => unimplemented!(), + let device = match target { + ExecutionTarget::Cpu => LearningModelDevice::Create(LearningModelDeviceKind::Cpu), + ExecutionTarget::Gpu => LearningModelDevice::Create(LearningModelDeviceKind::DirectX), + ExecutionTarget::Tpu => unsafe { + // Enumerate adapters with DXCore APIs so MCDM (Microsoft Compute Driver Model) devices can be found. + let dx_adapter_factory: IDXCoreAdapterFactory = DXCoreCreateAdapterFactory()?; + let adapter_list = + dx_adapter_factory.CreateAdapterList::(&[ + DXCORE_ADAPTER_ATTRIBUTE_D3D12_CORE_COMPUTE, + ])?; + let mut selected_device: Option = None; + for i in 0..adapter_list.GetAdapterCount() { + let adapter = adapter_list.GetAdapter::(i)?; + // Select a compute only device. DXCORE_ADAPTER_ATTRIBUTE_D3D12_GENERIC_ML looks more suitable here, but it's defined in DirectX headers. + if adapter.IsAttributeSupported(&DXCORE_ADAPTER_ATTRIBUTE_D3D12_CORE_COMPUTE) + && !adapter.IsAttributeSupported(&DXCORE_ADAPTER_ATTRIBUTE_D3D12_GRAPHICS) + { + selected_device = Some(adapter); + break; + } + } + if selected_device.is_none() { + return Err(BackendError::BackendAccess(anyhow::Error::msg( + "NPU is not available on this device.", + ))); + } + + let mut d3d12_device: Option = None; + D3D12CreateDevice( + &selected_device.unwrap(), + D3D_FEATURE_LEVEL_1_0_CORE, + &mut d3d12_device, + )?; + if d3d12_device.is_none() { + return Err(BackendError::BackendAccess(anyhow::Error::msg( + "Failed to create D3D12 device.", + ))); + } + let d3d12_command_queue_desc: D3D12_COMMAND_QUEUE_DESC = D3D12_COMMAND_QUEUE_DESC { + Type: D3D12_COMMAND_LIST_TYPE_COMPUTE, + Flags: D3D12_COMMAND_QUEUE_FLAG_NONE, + NodeMask: 0, + Priority: 0, + }; + let d3d12_command_queue = d3d12_device + .unwrap() + .CreateCommandQueue::(&d3d12_command_queue_desc)?; + let factory = windows::core::factory::< + LearningModelDevice, + ILearningModelDeviceFactoryNative, + >()?; + factory + .CreateFromD3D12CommandQueue(&d3d12_command_queue)? + .cast::() + }, + }; + let graph = WinMLGraph { + model, + device: device?, }; - let graph = WinMLGraph { model, device_kind }; let box_: Box = Box::new(graph); Ok(box_.into()) @@ -74,7 +140,7 @@ impl BackendFromDir for WinMLBackend { struct WinMLGraph { model: LearningModel, - device_kind: LearningModelDeviceKind, + device: LearningModelDevice, } unsafe impl Send for WinMLGraph {} @@ -82,8 +148,8 @@ unsafe impl Sync for WinMLGraph {} impl BackendGraph for WinMLGraph { fn init_execution_context(&self) -> Result { - let device = LearningModelDevice::Create(self.device_kind.clone())?; - let session = LearningModelSession::CreateFromModelOnDevice(&self.model, &device)?; + let session = + LearningModelSession::CreateFromModelOnDevice(&self.model, &self.device).unwrap(); let box_: Box = Box::new(WinMLExecutionContext::new(session)); Ok(box_.into()) } @@ -136,32 +202,58 @@ impl WinMLExecutionContext { impl BackendExecutionContext for WinMLExecutionContext { fn set_input(&mut self, id: Id, tensor: &Tensor) -> Result<(), BackendError> { + // TODO: Clear previous bindings when needed. + let input_features = self.session.Model()?.InputFeatures()?; let index = self.find(id, &input_features)?; let input = input_features.GetAt(index)?; - // TODO: Support other tensor types. Only FP32 is supported right now. + // TODO: Support other tensor types. Only FP16, FP32 and I64 are + // supported right now. match tensor.ty { - crate::wit::types::TensorType::Fp32 => {} - _ => unimplemented!(), - } + crate::wit::types::TensorType::Fp16 => unsafe { + let data = std::slice::from_raw_parts( + tensor.data.as_ptr() as *const f32, + tensor.data.len() / size_of::(), + ); - // TODO: this is quite unsafe and probably incorrect--will the slice - // still be around by the time the binding is used?! - let data = unsafe { - std::slice::from_raw_parts( - tensor.data.as_ptr() as *const f32, - tensor.data.len() / size_of::(), - ) - }; + // TODO: this is quite unsafe and probably incorrect--will the + // slice still be around by the time the binding is used?! + self.binding.Bind( + &input.Name()?, + &TensorFloat16Bit::CreateFromArray( + &input.cast::()?.Shape()?, + data, + )?, + )?; + }, + crate::wit::types::TensorType::Fp32 => unsafe { + let data = std::slice::from_raw_parts( + tensor.data.as_ptr() as *const f32, + tensor.data.len() / size_of::(), + ); - self.binding.Bind( - &input.Name()?, - &TensorFloat::CreateFromArray( - &input.cast::()?.Shape()?, - data, - )?, - )?; + self.binding.Bind( + &input.Name()?, + &TensorFloat::CreateFromArray( + &input.cast::()?.Shape()?, + data, + )?, + )?; + }, + crate::wit::types::TensorType::I64 => unsafe { + let data = std::slice::from_raw_parts( + tensor.data.as_ptr() as *const i64, + tensor.data.len() / size_of::(), + ); + let dim: Vec = tensor.dimensions.iter().map(|&x| x as i64).collect(); + let shape: IIterable = IIterable::::try_from(dim)?; + let tensor = TensorInt64Bit::CreateFromArray(&shape, data)?; + + self.binding.Bind(&input.Name()?, &tensor)?; + }, + _ => unimplemented!(), + } Ok(()) } @@ -175,23 +267,62 @@ impl BackendExecutionContext for WinMLExecutionContext { if let Some(result) = &self.result { let output_features = self.session.Model()?.OutputFeatures()?; let index = self.find(id, &output_features)?; - let output = output_features.GetAt(index)?; - // TODO: this only handles FP32! - let tensor = result - .Outputs()? - .Lookup(&output.Name()?)? - .cast::()?; - let dimensions = dimensions_as_u32(&tensor.Shape()?)?; - let view = tensor.GetAsVectorView()?; - let mut data = Vec::with_capacity(view.Size()? as usize * size_of::()); - for f in view.into_iter() { - data.extend(f.to_le_bytes()); - } - Ok(Tensor { - ty: TensorType::Fp32, - dimensions, - data, - }) + let output_feature = output_features.GetAt(index)?; + let tensor_kind = match output_feature.Kind()? { + windows::AI::MachineLearning::LearningModelFeatureKind::Tensor => output_feature + .cast::()? + .TensorKind()?, + _ => unimplemented!(), + }; + // TODO: this only handles FP16, FP32 and I64! + let output_inspectable = result.Outputs()?.Lookup(&output_feature.Name()?)?; + let tensor = match tensor_kind { + TensorKind::Float16 => { + let output_tensor = output_inspectable.cast::()?; + let dimensions = dimensions_as_u32(&output_tensor.Shape()?)?; + let view = output_tensor.GetAsVectorView()?; + // TODO: Move to f16 when it's available in stable. + let mut data = Vec::with_capacity(view.Size()? as usize * size_of::()); + for f in view.into_iter() { + data.extend(f.to_le_bytes()); + } + Tensor { + ty: TensorType::Fp16, + dimensions, + data, + } + } + TensorKind::Float => { + let output_tensor = output_inspectable.cast::()?; + let dimensions = dimensions_as_u32(&output_tensor.Shape()?)?; + let view = output_tensor.GetAsVectorView()?; + let mut data = Vec::with_capacity(view.Size()? as usize * size_of::()); + for f in view.into_iter() { + data.extend(f.to_le_bytes()); + } + Tensor { + ty: TensorType::Fp32, + dimensions, + data, + } + } + TensorKind::Int64 => { + let output_tensor = output_inspectable.cast::()?; + let dimensions = dimensions_as_u32(&output_tensor.Shape()?)?; + let view = output_tensor.GetAsVectorView()?; + let mut data = Vec::with_capacity(view.Size()? as usize * size_of::()); + for f in view.into_iter() { + data.extend(f.to_le_bytes()); + } + Tensor { + ty: TensorType::I64, + dimensions, + data, + } + } + _ => unimplemented!(), + }; + Ok(tensor) } else { return Err(BackendError::BackendAccess(anyhow::Error::msg( "Output is not ready.",