Skip to content

Commit

Permalink
fix shared
Browse files Browse the repository at this point in the history
  • Loading branch information
shiinamiyuki committed Dec 18, 2023
1 parent ba7fde6 commit 9523435
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 22 deletions.
6 changes: 4 additions & 2 deletions luisa_compute/src/lang.rs
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,7 @@ pub(crate) struct FnRecorder {
pub(crate) block_size: Option<[u32; 3]>,
pub(crate) building_kernel: bool,
pub(crate) pools: CArc<ModulePools>,
pub(crate) arena: Bump,
pub(crate) arena: Rc<Bump>,
pub(crate) dtors: Vec<(*mut u8, fn(*mut u8))>,
pub(crate) callable_ret_type: Option<CArc<Type>>,
pub(crate) const_builder: IrBuilder,
Expand Down Expand Up @@ -427,7 +427,9 @@ impl FnRecorder {
device: None,
block_size: None,
pools: pools.clone(),
arena: Bump::new(),
arena: parent.as_ref()
.map(|p| p.borrow().arena.clone())
.unwrap_or_else(|| Rc::new(Bump::new())),
building_kernel: false,
callable_ret_type: None,
kernel_id,
Expand Down
11 changes: 6 additions & 5 deletions luisa_compute/src/lang/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -184,11 +184,12 @@ pub fn set_block_size(size: [u32; 3]) {
});
}

pub fn block_size() -> Expr<Uint3> {
with_recorder(|r| {
let s = r.block_size.unwrap_or_else(|| panic!("Block size not set"));
Uint3::new(s[0], s[1], s[2]).expr()
})
pub fn block_size() -> [u32; 3] {
with_recorder(|r| r.block_size.unwrap_or_else(|| panic!("Block size not set")))
}
pub fn block_size_expr() -> Expr<Uint3> {
let sz = block_size();
Uint3::expr(sz[0], sz[1], sz[2])
}

pub unsafe fn bitcast<From: Value, To: Value>(expr: Expr<From>) -> Expr<To> {
Expand Down
52 changes: 37 additions & 15 deletions luisa_compute/src/lang/types/shared.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,21 +41,6 @@ impl<T: Value> Shared<T> {
_ => unreachable!(),
}
}
pub fn write<I: IntoIndex, V: Into<Expr<T>>>(&self, i: I, value: V) {
let i = i.to_u64();
let value = value.into();

if need_runtime_check() {
check_index_lt_usize(i, self.len());
}
let i = i.node().get();
let value = value.node().get();
let self_node = self.node.get();
__current_scope(|b| {
let gep = b.call(Func::GetElementPtr, &[self_node, i], T::type_());
b.update(gep, value);
});
}
pub fn load(&self) -> VLArrayExpr<T> {
let self_node = self.node.get();
VLArrayExpr::from_node(
Expand All @@ -70,3 +55,40 @@ impl<T: Value> Shared<T> {
});
}
}
impl<T: Value> IndexRead for Shared<T> {
type Element = T;
fn read<I: IntoIndex>(&self, i: I) -> Expr<Self::Element> {
let i = i.to_u64();

if need_runtime_check() {
check_index_lt_usize(i, self.len());
}
let i = i.node().get();
let self_node = self.node.get();
Expr::from_node(
__current_scope(|b| {
let gep = b.call(Func::GetElementPtr, &[self_node, i], T::type_());
b.load(gep)
})
.into(),
)
}
}

impl<T: Value> IndexWrite for Shared<T> {
fn write<I: IntoIndex, V: AsExpr<Value = Self::Element>>(&self, i: I, value: V) {
let i = i.to_u64();
let value = value.as_expr();

if need_runtime_check() {
check_index_lt_usize(i, self.len());
}
let i = i.node().get();
let value = value.node().get();
let self_node = self.node.get();
__current_scope(|b| {
let gep = b.call(Func::GetElementPtr, &[self_node, i], T::type_());
b.update(gep, value);
});
}
}

0 comments on commit 9523435

Please sign in to comment.