Skip to content

Commit

Permalink
add rank-1 update of Cholesky decomposition
Browse files Browse the repository at this point in the history
  • Loading branch information
AndersonYin committed Mar 19, 2024
1 parent ac9ba2b commit 8071fb1
Show file tree
Hide file tree
Showing 4 changed files with 103 additions and 0 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
/target
Cargo.lock
.idea
69 changes: 69 additions & 0 deletions src/cholesky_update.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
use ndarray::{Array1, ArrayBase, DataMut, Ix2, NdFloat};


pub trait CholeskyUpdate<F> {
fn cholesky_update_inplace(&mut self, update_vector: &Array1<F>);
}

impl<V,F> CholeskyUpdate<F> for ArrayBase<V,Ix2>
where
F: NdFloat,
V: DataMut<Elem=F>,
{
fn cholesky_update_inplace(&mut self, update_vector: &Array1<F>) {
let n = self.shape()[0];
if self.shape()[0] != update_vector.len() {
panic!("update_vector should be same size as self");
}
let mut w=update_vector.to_owned();
let mut b=F::from(1.0).unwrap();
for j in 0..n{
let ljj=self[(j,j)];
let ljj2=ljj*ljj;
let wj=w[j];
let wj2=wj*wj;
let nljj=(ljj2+wj2/b).sqrt();
let gamma=ljj2*b+wj2;
for k in j+1..n{
let lkj=self[(k,j)];
let wk=w[k]-wj*lkj/ljj;
self[(k,j)]=nljj*(lkj/ljj+wj*wk/gamma);
w[k]=wk;
}
b=b+wj2/ljj2;
self[(j,j)]=nljj;
}
}
}



#[cfg(test)]
mod test{
use approx::assert_abs_diff_eq;
use super::*;
use ndarray::{array, Array};
use crate::cholesky::Cholesky;

#[test]
fn test_cholesky_update(){
let mut arr=array![[1.0, 0.0, 2.0, 3.0, 4.0],
[-2.0, 3.0, 10.0,5.0, 6.0],
[-1.0,-2.0,-7.0, 8.0, 9.0],
[11.0, 12.0, 3.0, 14.0, 5.0],
[8.0, 2.0, 13.0, 4.0, 5.0]];
arr=arr.t().dot(&arr);
let mut l_tri = arr.cholesky().unwrap();

let x = Array::from(vec![1.0, 2.0, 3.0,0.0, 1.0]);
let vt=x.clone().into_shape((1,x.shape()[0])).unwrap();
let v=x.clone().into_shape((x.shape()[0],1)).unwrap();

l_tri.cholesky_update_inplace(&x);

let restore=l_tri.dot(&l_tri.t());
let expected=arr+v.dot(&vt);

assert_abs_diff_eq!(restore, expected, epsilon=1e-7);
}
}
1 change: 1 addition & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ pub mod reflection;
pub mod svd;
pub mod triangular;
pub mod tridiagonal;
pub mod cholesky_update;

use ndarray::{ArrayBase, Ix2, RawData, ShapeError};
use thiserror::Error;
Expand Down
32 changes: 32 additions & 0 deletions tests/cholesky_update.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
use approx::assert_abs_diff_eq;
use ndarray::prelude::*;
use proptest::prelude::*;
use linfa_linalg::{cholesky::*, cholesky_update::*};
mod common;

prop_compose! {
fn gram_arr()
(arr in common::square_arr()) -> (Array2<f64>,Array1<f64>){
let dim = arr.nrows();
let mut mul = arr.t().dot(&arr);
for i in 0..dim {
mul[(i, i)] += 1.0;
}

(mul,arr.slice(s![0,..]).to_owned())
}
}

fn run_cholesky_update_test(orig: (Array2<f64>, Array1<f64>)) {

Check warning on line 20 in tests/cholesky_update.rs

View workflow job for this annotation

GitHub Actions / testing-stable-windows-2019

function `run_cholesky_update_test` is never used
let (arr, x) = orig;
let mut l_tri = arr.cholesky().unwrap();
l_tri.cholesky_update_inplace(&x);

let vt=x.clone().into_shape((1,x.shape()[0])).unwrap();
let v=x.clone().into_shape((x.shape()[0],1)).unwrap();

let restore = l_tri.dot(&l_tri.t());
let expected = arr + v.dot(&vt);
assert_abs_diff_eq!(restore, expected, epsilon = 1e-7);
}

0 comments on commit 8071fb1

Please sign in to comment.