-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add rank-1 update of Cholesky decomposition
- Loading branch information
1 parent
ac9ba2b
commit 8071fb1
Showing
4 changed files
with
103 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,3 @@ | ||
/target | ||
Cargo.lock | ||
.idea |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
use ndarray::{Array1, ArrayBase, DataMut, Ix2, NdFloat}; | ||
|
||
|
||
pub trait CholeskyUpdate<F> { | ||
fn cholesky_update_inplace(&mut self, update_vector: &Array1<F>); | ||
} | ||
|
||
impl<V,F> CholeskyUpdate<F> for ArrayBase<V,Ix2> | ||
where | ||
F: NdFloat, | ||
V: DataMut<Elem=F>, | ||
{ | ||
fn cholesky_update_inplace(&mut self, update_vector: &Array1<F>) { | ||
let n = self.shape()[0]; | ||
if self.shape()[0] != update_vector.len() { | ||
panic!("update_vector should be same size as self"); | ||
} | ||
let mut w=update_vector.to_owned(); | ||
let mut b=F::from(1.0).unwrap(); | ||
for j in 0..n{ | ||
let ljj=self[(j,j)]; | ||
let ljj2=ljj*ljj; | ||
let wj=w[j]; | ||
let wj2=wj*wj; | ||
let nljj=(ljj2+wj2/b).sqrt(); | ||
let gamma=ljj2*b+wj2; | ||
for k in j+1..n{ | ||
let lkj=self[(k,j)]; | ||
let wk=w[k]-wj*lkj/ljj; | ||
self[(k,j)]=nljj*(lkj/ljj+wj*wk/gamma); | ||
w[k]=wk; | ||
} | ||
b=b+wj2/ljj2; | ||
self[(j,j)]=nljj; | ||
} | ||
} | ||
} | ||
|
||
|
||
|
||
#[cfg(test)] | ||
mod test{ | ||
use approx::assert_abs_diff_eq; | ||
use super::*; | ||
use ndarray::{array, Array}; | ||
use crate::cholesky::Cholesky; | ||
|
||
#[test] | ||
fn test_cholesky_update(){ | ||
let mut arr=array![[1.0, 0.0, 2.0, 3.0, 4.0], | ||
[-2.0, 3.0, 10.0,5.0, 6.0], | ||
[-1.0,-2.0,-7.0, 8.0, 9.0], | ||
[11.0, 12.0, 3.0, 14.0, 5.0], | ||
[8.0, 2.0, 13.0, 4.0, 5.0]]; | ||
arr=arr.t().dot(&arr); | ||
let mut l_tri = arr.cholesky().unwrap(); | ||
|
||
let x = Array::from(vec![1.0, 2.0, 3.0,0.0, 1.0]); | ||
let vt=x.clone().into_shape((1,x.shape()[0])).unwrap(); | ||
let v=x.clone().into_shape((x.shape()[0],1)).unwrap(); | ||
|
||
l_tri.cholesky_update_inplace(&x); | ||
|
||
let restore=l_tri.dot(&l_tri.t()); | ||
let expected=arr+v.dot(&vt); | ||
|
||
assert_abs_diff_eq!(restore, expected, epsilon=1e-7); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
use approx::assert_abs_diff_eq; | ||
use ndarray::prelude::*; | ||
use proptest::prelude::*; | ||
use linfa_linalg::{cholesky::*, cholesky_update::*}; | ||
mod common; | ||
|
||
prop_compose! { | ||
fn gram_arr() | ||
(arr in common::square_arr()) -> (Array2<f64>,Array1<f64>){ | ||
let dim = arr.nrows(); | ||
let mut mul = arr.t().dot(&arr); | ||
for i in 0..dim { | ||
mul[(i, i)] += 1.0; | ||
} | ||
|
||
(mul,arr.slice(s![0,..]).to_owned()) | ||
} | ||
} | ||
|
||
fn run_cholesky_update_test(orig: (Array2<f64>, Array1<f64>)) { | ||
let (arr, x) = orig; | ||
let mut l_tri = arr.cholesky().unwrap(); | ||
l_tri.cholesky_update_inplace(&x); | ||
|
||
let vt=x.clone().into_shape((1,x.shape()[0])).unwrap(); | ||
let v=x.clone().into_shape((x.shape()[0],1)).unwrap(); | ||
|
||
let restore = l_tri.dot(&l_tri.t()); | ||
let expected = arr + v.dot(&vt); | ||
assert_abs_diff_eq!(restore, expected, epsilon = 1e-7); | ||
} | ||
|