Skip to content

Commit

Permalink
Merge pull request #129 from marvin-hansen/main
Browse files Browse the repository at this point in the history
Replaced Cell types with Arc/RwLock to make interior mutability thread safe
  • Loading branch information
marvin-hansen authored Jan 25, 2024
2 parents 08418a7 + c9f5160 commit 45b4763
Show file tree
Hide file tree
Showing 9 changed files with 80 additions and 102 deletions.
11 changes: 6 additions & 5 deletions deep_causality/src/types/reasoning_types/assumption/assumable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,20 +13,21 @@ impl Assumable for Assumption {
}

fn assumption_tested(&self) -> bool {
*self.assumption_tested.borrow()
*self.assumption_tested.read().unwrap()
}

fn assumption_valid(&self) -> bool {
*self.assumption_valid.borrow()
*self.assumption_valid.read().unwrap()
}

fn verify_assumption(&self, data: &[NumericalValue]) -> bool {
let res = (self.assumption_fn)(data);
// int. mutability: https://doc.rust-lang.org/book/ch15-05-interior-mutability.html
*self.assumption_tested.borrow_mut() = true;
let mut guard_tested = self.assumption_tested.write().unwrap();
*guard_tested = true;

if res {
*self.assumption_valid.borrow_mut() = true;
let mut guard_valid = self.assumption_valid.write().unwrap();
*guard_valid = true;
}
res
}
Expand Down
4 changes: 2 additions & 2 deletions deep_causality/src/types/reasoning_types/assumption/debug.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ impl Assumption {
"Assumption: id: {}, description: {}, assumption_fn: fn(&[NumericalValue]) -> bool;, assumption_tested: {},assumption_valid: {}",
self.id,
self.description,
self.assumption_tested.borrow(),
self.assumption_valid.borrow()
self.assumption_tested.read().unwrap().clone(),
self.assumption_valid.read().unwrap().clone()
)
}
}
14 changes: 9 additions & 5 deletions deep_causality/src/types/reasoning_types/assumption/mod.rs
Original file line number Diff line number Diff line change
@@ -1,21 +1,25 @@
// SPDX-License-Identifier: MIT
// Copyright (c) "2023" . The DeepCausality Authors. All Rights Reserved.

use std::cell::RefCell;
use std::sync::{Arc, RwLock};

use crate::prelude::{DescriptionValue, EvalFn, IdentificationValue};

mod assumable;
mod debug;
mod identifiable;

// Interior mutability in Rust, part 2: thread safety
// https://ricardomartins.cc/2016/06/25/interior-mutability-thread-safety
type ArcRWLock<T> = Arc<RwLock<T>>;

#[derive(Clone)]
pub struct Assumption {
id: IdentificationValue,
description: DescriptionValue,
assumption_fn: EvalFn,
assumption_tested: RefCell<bool>,
assumption_valid: RefCell<bool>,
assumption_tested: ArcRWLock<bool>,
assumption_valid: ArcRWLock<bool>,
}

// Constructor
Expand All @@ -29,8 +33,8 @@ impl Assumption {
id,
description,
assumption_fn,
assumption_tested: RefCell::from(false),
assumption_valid: RefCell::from(false),
assumption_tested: Arc::new(RwLock::new(false)),
assumption_valid: Arc::new(RwLock::new(false)),
}
}
}
73 changes: 23 additions & 50 deletions deep_causality/src/types/reasoning_types/causaloid/causable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@ use std::ops::*;

use crate::errors::CausalityError;
use crate::prelude::{
Causable, CausableGraphExplaining, CausableGraphReasoning, CausableReasoning, Causaloid,
Datable, IdentificationValue, NumericalValue, SpaceTemporal, Spatial, Temporable,
Causable, CausableGraph, CausableGraphExplaining, CausableGraphReasoning, CausableReasoning,
Causaloid, Datable, IdentificationValue, NumericalValue, SpaceTemporal, Spatial, Temporable,
};
use crate::types::reasoning_types::causaloid::causal_type::CausalType;

Expand All @@ -29,14 +29,13 @@ where
+ Clone,
{
fn explain(&self) -> Result<String, CausalityError> {
return if self.active.get() {
return if self.is_active() {
match self.causal_type {
CausalType::Singleton => {
let reason = format!(
"Causaloid: {} {} on last data {} evaluated to {}",
"Causaloid: {} {} evaluated to {}",
self.id,
self.description,
self.last_obs.get(),
self.is_active()
);
Ok(reason)
Expand All @@ -63,7 +62,11 @@ where
}

fn is_active(&self) -> bool {
self.active.get()
match self.causal_type {
CausalType::Singleton => *self.active.read().unwrap(),
CausalType::Collection => self.causal_coll.as_ref().unwrap().number_active() > 0f64,
CausalType::Graph => self.causal_graph.as_ref().unwrap().number_active() > 0f64,
}
}

fn is_singleton(&self) -> bool {
Expand All @@ -79,34 +82,33 @@ where
let contextual_causal_fn = self
.context_causal_fn
.expect("Causaloid::verify_single_cause: context_causal_fn is None");

let context = self
.context
.expect("Causaloid::verify_single_cause: context is None");

let res = match (contextual_causal_fn)(obs.to_owned(), context) {
Ok(res) => {
// store the applied data to provide details in explain()
self.last_obs.set(obs.to_owned());
res
}
Ok(res) => res,
Err(e) => return Err(e),
};

Ok(self.check_active(res))
let mut guard = self.active.write().unwrap();
*guard = res;

Ok(res)
} else {
let causal_fn = self
.causal_fn
.expect("Causaloid::verify_single_cause: causal_fn is None");
let res = match (causal_fn)(obs.to_owned()) {
Ok(res) => {
// store the applied data to provide details in explain()
self.last_obs.set(obs.to_owned());
res
}
Ok(res) => res,
Err(e) => return Err(e),
};

Ok(self.check_active(res))
let mut guard = self.active.write().unwrap();
*guard = res;

Ok(res)
}
}

Expand All @@ -117,7 +119,7 @@ where
) -> Result<bool, CausalityError> {
match self.causal_type {
CausalType::Singleton => Err(CausalityError(
"Causaloid is singleton. Call verify_singleton instead.".into(),
"Causaloid is singleton. Call verify_single_cause instead.".into(),
)),

CausalType::Collection => match &self.causal_coll {
Expand All @@ -130,7 +132,7 @@ where
Err(e) => return Err(e),
};

Ok(self.check_active(res))
Ok(res)
}
},

Expand All @@ -144,38 +146,9 @@ where
Err(e) => return Err(CausalityError(e.to_string())),
};

Ok(self.check_active(res))
Ok(res)
}
},
}
}
}

impl<'l, D, S, T, ST, V> Causaloid<'l, D, S, T, ST, V>
where
D: Datable + Clone,
S: Spatial<V> + Clone,
T: Temporable<V> + Clone,
ST: SpaceTemporal<V> + Clone,
V: Default
+ Copy
+ Clone
+ Hash
+ Eq
+ PartialEq
+ Add<V, Output = V>
+ Sub<V, Output = V>
+ Mul<V, Output = V>
+ Clone,
{
#[inline(always)]
fn check_active(&self, res: bool) -> bool {
if res {
self.active.set(true);
true
} else {
self.active.set(false);
false
}
}
}
13 changes: 5 additions & 8 deletions deep_causality/src/types/reasoning_types/causaloid/getters.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,16 +22,13 @@ where
+ Clone,
{
pub fn active(&self) -> bool {
self.active.get()
self.is_active()
}
pub fn causal_collection(&self) -> Option<CausalVec<'l, D, S, T, ST, V>> {
self.causal_coll.clone()
pub fn causal_collection(&self) -> Option<&CausalVec<'l, D, S, T, ST, V>> {
self.causal_coll
}
pub fn causal_graph(&self) -> Option<CausalGraph<'l, D, S, T, ST, V>> {
self.causal_graph.clone()
}
pub fn last_obs(&self) -> NumericalValue {
self.last_obs.get()
pub fn causal_graph(&self) -> Option<&CausalGraph<'l, D, S, T, ST, V>> {
self.causal_graph
}
pub fn description(&self) -> &'l str {
self.description
Expand Down
Loading

0 comments on commit 45b4763

Please sign in to comment.