From 9523435273ff1c313837e800ddfa3944863e3890 Mon Sep 17 00:00:00 2001 From: Xiaochun Tong Date: Mon, 18 Dec 2023 07:48:19 -0500 Subject: [PATCH] fix shared --- luisa_compute/src/lang.rs | 6 ++- luisa_compute/src/lang/functions.rs | 11 +++--- luisa_compute/src/lang/types/shared.rs | 52 ++++++++++++++++++-------- 3 files changed, 47 insertions(+), 22 deletions(-) diff --git a/luisa_compute/src/lang.rs b/luisa_compute/src/lang.rs index c4181d7..c8249b4 100644 --- a/luisa_compute/src/lang.rs +++ b/luisa_compute/src/lang.rs @@ -309,7 +309,7 @@ pub(crate) struct FnRecorder { pub(crate) block_size: Option<[u32; 3]>, pub(crate) building_kernel: bool, pub(crate) pools: CArc, - pub(crate) arena: Bump, + pub(crate) arena: Rc, pub(crate) dtors: Vec<(*mut u8, fn(*mut u8))>, pub(crate) callable_ret_type: Option>, pub(crate) const_builder: IrBuilder, @@ -427,7 +427,9 @@ impl FnRecorder { device: None, block_size: None, pools: pools.clone(), - arena: Bump::new(), + arena: parent.as_ref() + .map(|p| p.borrow().arena.clone()) + .unwrap_or_else(|| Rc::new(Bump::new())), building_kernel: false, callable_ret_type: None, kernel_id, diff --git a/luisa_compute/src/lang/functions.rs b/luisa_compute/src/lang/functions.rs index 89b40f1..bcf9fe8 100644 --- a/luisa_compute/src/lang/functions.rs +++ b/luisa_compute/src/lang/functions.rs @@ -184,11 +184,12 @@ pub fn set_block_size(size: [u32; 3]) { }); } -pub fn block_size() -> Expr { - with_recorder(|r| { - let s = r.block_size.unwrap_or_else(|| panic!("Block size not set")); - Uint3::new(s[0], s[1], s[2]).expr() - }) +pub fn block_size() -> [u32; 3] { + with_recorder(|r| r.block_size.unwrap_or_else(|| panic!("Block size not set"))) +} +pub fn block_size_expr() -> Expr { + let sz = block_size(); + Uint3::expr(sz[0], sz[1], sz[2]) } pub unsafe fn bitcast(expr: Expr) -> Expr { diff --git a/luisa_compute/src/lang/types/shared.rs b/luisa_compute/src/lang/types/shared.rs index a3a1aab..a8c4dbb 100644 --- a/luisa_compute/src/lang/types/shared.rs +++ b/luisa_compute/src/lang/types/shared.rs @@ -41,21 +41,6 @@ impl Shared { _ => unreachable!(), } } - pub fn write>>(&self, i: I, value: V) { - let i = i.to_u64(); - let value = value.into(); - - if need_runtime_check() { - check_index_lt_usize(i, self.len()); - } - let i = i.node().get(); - let value = value.node().get(); - let self_node = self.node.get(); - __current_scope(|b| { - let gep = b.call(Func::GetElementPtr, &[self_node, i], T::type_()); - b.update(gep, value); - }); - } pub fn load(&self) -> VLArrayExpr { let self_node = self.node.get(); VLArrayExpr::from_node( @@ -70,3 +55,40 @@ impl Shared { }); } } +impl IndexRead for Shared { + type Element = T; + fn read(&self, i: I) -> Expr { + let i = i.to_u64(); + + if need_runtime_check() { + check_index_lt_usize(i, self.len()); + } + let i = i.node().get(); + let self_node = self.node.get(); + Expr::from_node( + __current_scope(|b| { + let gep = b.call(Func::GetElementPtr, &[self_node, i], T::type_()); + b.load(gep) + }) + .into(), + ) + } +} + +impl IndexWrite for Shared { + fn write>(&self, i: I, value: V) { + let i = i.to_u64(); + let value = value.as_expr(); + + if need_runtime_check() { + check_index_lt_usize(i, self.len()); + } + let i = i.node().get(); + let value = value.node().get(); + let self_node = self.node.get(); + __current_scope(|b| { + let gep = b.call(Func::GetElementPtr, &[self_node, i], T::type_()); + b.update(gep, value); + }); + } +}