Skip to content

Commit

Permalink
Fix untyped device
Browse files Browse the repository at this point in the history
  • Loading branch information
elftausend committed Dec 15, 2024
1 parent c8611b6 commit 9dfd7ce
Show file tree
Hide file tree
Showing 4 changed files with 58 additions and 44 deletions.
4 changes: 2 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,8 @@ min-cl = { git = "https://github.com/elftausend/min-cl", optional = true }
[features]
# default = ["cpu", "blas", "static-api", "macro", "cached", "autograd", "stack", "opencl", "fork", "graph", "untyped"]

# default = ["cpu", "cached", "autograd", "static-api", "blas", "macro", "fork", "serde"]
default = ["no-std"]
default = ["cpu", "cached", "autograd", "static-api", "blas", "macro", "fork", "serde", "untyped"]
# default = ["no-std"]
# default = ["opencl"]
# default = ["untyped", "cpu", "lazy", "graph", "autograd", "fork", "serde", "json", "half", "cached", "static-api", "stack", "opencl", "nnapi"]

Expand Down
30 changes: 17 additions & 13 deletions src/devices/untyped/dummy_cuda.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
use crate::{impl_buffer_hook_traits, Base, Buffer, Device, OnDropBuffer, Shape, Unit};
use crate::{impl_buffer_hook_traits, Base, Buffer, Device, Shape, Unit, WrappedData};

pub struct CUDA<Mods = Base> {
pub modules: Mods,
}

impl<Mods: OnDropBuffer> Device for CUDA<Mods> {
type Data<U: Unit, S: Shape> = Mods::Wrap<U, crate::Num<U>>;
impl<Mods: WrappedData> Device for CUDA<Mods> {
type Data<'a, U: Unit, S: Shape> = Mods::Wrap<'a, U, crate::Num<U>>;
type Base<T: Unit, S: Shape> = crate::Num<T>;
type Error = crate::DeviceError;

Expand All @@ -19,26 +19,30 @@ impl<Mods: OnDropBuffer> Device for CUDA<Mods> {
Err(crate::DeviceError::CPUDeviceNotAvailable)
}

fn base_to_data<T: Unit, S: Shape>(&self, base: Self::Base<T, S>) -> Self::Data<T, S> {
fn default_base_to_data<'a, T: Unit, S: Shape>(&'a self, base: Self::Base<T, S>) -> Self::Data<'a, T, S> {
self.modules.wrap_in_base(base)
}

fn default_base_to_data_unbound<'a, T: Unit, S: Shape>(&self, base: Self::Base<T, S>) -> Self::Data<'a, T, S> {
self.modules.wrap_in_base_unbound(base)
}

fn wrap_to_data<T: Unit, S: Shape>(
fn wrap_to_data<'a, T: Unit, S: Shape>(
&self,
wrap: Self::Wrap<T, Self::Base<T, S>>,
) -> Self::Data<T, S> {
wrap: Self::Wrap<'a, T, Self::Base<T, S>>,
) -> Self::Data<'a, T, S> {
wrap
}

fn data_as_wrap<T: Unit, S: Shape>(
data: &Self::Data<T, S>,
) -> &Self::Wrap<T, Self::Base<T, S>> {
fn data_as_wrap<'a, 'b, T: Unit, S: Shape>(
data: &'b Self::Data<'a, T, S>,
) -> &'b Self::Wrap<'a, T, Self::Base<T, S>> {
data
}

fn data_as_wrap_mut<T: Unit, S: Shape>(
data: &mut Self::Data<T, S>,
) -> &mut Self::Wrap<T, Self::Base<T, S>> {
fn data_as_wrap_mut<'a, 'b, T: Unit, S: Shape>(
data: &'b mut Self::Data<'a, T, S>,
) -> &'b mut Self::Wrap<'a, T, Self::Base<T, S>> {
data
}
}
Expand Down
10 changes: 5 additions & 5 deletions src/devices/untyped/ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,17 +42,17 @@ impl<T: 'static + AsType + Default + Clone, S: Shape> Read<T, S> for Untyped {
}
}

impl<T, S> ApplyFunction<T, S> for Untyped
impl<'a, T, S> ApplyFunction<'a, T, S> for Untyped
where
T: CDatatype + Default + Copy + AsType,
S: Shape,
{
fn apply_fn<F>(
&self,
&'a self,
// buf: &D::Data<T, S>,
buf: &crate::Buffer<T, Self, S>,
f: impl Fn(crate::Resolve<T>) -> F + Copy + 'static,
) -> crate::Buffer<T, Self, S>
) -> crate::Buffer<'a, T, Self, S>
where
F: crate::TwoWay<T> + 'static,
{
Expand Down Expand Up @@ -216,8 +216,8 @@ mod tests {
out
}

impl<T: Unit, S: Shape> AddEw<T, Self, S> for Untyped {
fn add(&self, lhs: &Buffer<T, Self, S>, rhs: &Buffer<T, Self, S>) -> Buffer<T, Self, S> {
impl<'a, T: Unit, S: Shape> AddEw<'a, T, Self, S> for Untyped {
fn add(&'a self, lhs: &Buffer<T, Self, S>, rhs: &Buffer<T, Self, S>) -> Buffer<'a, T, Self, S> {
untyped_binary_op!(self, lhs, rhs, alloc_and_add_slice, alloc_and_add_cu)
}
}
Expand Down
58 changes: 34 additions & 24 deletions src/devices/untyped/untyped_device.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
use crate::{
Alloc, Base, Buffer, Device, HasId, HasModules, IsShapeIndep, OnDropBuffer, OnNewBuffer,
PtrType, Retriever, Shape, Unit, WrappedData, CPU,
Alloc, Base, Buffer, Device, HasId, HasModules, IsBasePtr, IsShapeIndep, OnNewBuffer, PtrType, Retriever, Shape, Unit, WrappedData, CPU
};

use super::{
Expand All @@ -25,33 +24,37 @@ pub struct Untyped {

impl Device for Untyped {
type Base<T: Unit, S: crate::Shape> = UntypedData;
type Data<T: Unit, S: crate::Shape> = UntypedData;
type Data<'a, T: Unit, S: crate::Shape> = UntypedData;
type Error = crate::Error;

#[inline]
fn base_to_data<T: Unit, S: crate::Shape>(&self, base: Self::Base<T, S>) -> Self::Data<T, S> {
fn default_base_to_data<'a, T: Unit, S: crate::Shape>(&'a self, base: Self::Base<T, S>) -> Self::Data<'a, T, S> {
base
}

fn default_base_to_data_unbound<'a, T: Unit, S: Shape>(&self, base: Self::Base<T, S>) -> Self::Data<'a, T, S> {
base
}

#[inline]
fn wrap_to_data<T: Unit, S: crate::Shape>(
fn wrap_to_data<'a, T: Unit, S: crate::Shape>(
&self,
wrap: Self::Wrap<T, Self::Base<T, S>>,
) -> Self::Data<T, S> {
wrap: Self::Wrap<'a, T, Self::Base<T, S>>,
) -> Self::Data<'a, T, S> {
wrap
}

#[inline]
fn data_as_wrap<T: Unit, S: crate::Shape>(
data: &Self::Data<T, S>,
) -> &Self::Wrap<T, Self::Base<T, S>> {
fn data_as_wrap<'a, 'b, T: Unit, S: crate::Shape>(
data: &'b Self::Data<'a, T, S>,
) -> &'b Self::Wrap<'a, T, Self::Base<T, S>> {
data
}

#[inline]
fn data_as_wrap_mut<T: Unit, S: crate::Shape>(
data: &mut Self::Data<T, S>,
) -> &mut Self::Wrap<T, Self::Base<T, S>> {
fn data_as_wrap_mut<'a, 'b, T: Unit, S: crate::Shape>(
data: &'b mut Self::Data<'a, T, S>,
) -> &'b mut Self::Wrap<'a, T, Self::Base<T, S>> {
data
}

Expand Down Expand Up @@ -79,31 +82,38 @@ impl HasModules for Untyped {
}
}

impl OnDropBuffer for Untyped {}
impl<'dev, T: Unit, D: Device, S: Shape> OnNewBuffer<'dev, T, D, S> for Untyped {}

impl WrappedData for Untyped {
type Wrap<'a, T: Unit, Base: HasId + PtrType> = Base;
type Wrap<'a, T: Unit, Base: IsBasePtr> = Base;

#[inline]
fn wrap_in_base<T: Unit, Base: crate::HasId + crate::PtrType>(
fn wrap_in_base<'a, T: Unit, Base: IsBasePtr>(
&'a self,
base: Base,
) -> Self::Wrap<'a, T, Base> {
base
}

#[inline]
fn wrap_in_base_unbound<'a, T: Unit, Base: IsBasePtr>(
&self,
base: Base,
) -> Self::Wrap<T, Base> {
) -> Self::Wrap<'a, T, Base> {
base
}

#[inline]
fn wrapped_as_base<T: Unit, Base: crate::HasId + crate::PtrType>(
wrap: &Self::Wrap<T, Base>,
) -> &Base {
fn wrapped_as_base<'a, 'b, T: Unit, Base: IsBasePtr>(
wrap: &'b Self::Wrap<'a, T, Base>,
) -> &'b Base {
wrap
}

#[inline]
fn wrapped_as_base_mut<T: Unit, Base: crate::HasId + crate::PtrType>(
wrap: &mut Self::Wrap<T, Base>,
) -> &mut Base {
fn wrapped_as_base_mut<'a, 'b, T: Unit, Base: IsBasePtr>(
wrap: &'b mut Self::Wrap<'a, T, Base>,
) -> &'b mut Base {
wrap
}
}
Expand Down Expand Up @@ -152,7 +162,7 @@ impl<T: AsType> Alloc<T> for Untyped {
}
}

impl<T: AsType, S: Shape> Retriever<T, S> for Untyped {
impl<'a, T: AsType, S: Shape> Retriever<'a, T, S> for Untyped {
#[inline]
fn retrieve<const NUM_PARENTS: usize>(
&self,
Expand Down

0 comments on commit 9dfd7ce

Please sign in to comment.