Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fitter viz update #204

Merged
merged 5 commits into from
Apr 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 52 additions & 8 deletions sidpy/viz/dataset_viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@
import ipywidgets
from IPython.display import display
import scipy

import dill
import base64

# import matplotlib.animation as animation

Expand Down Expand Up @@ -2058,29 +2059,73 @@ def set_legend(self, set_legend):
def get_xy(self):
return [self.x, self.y]


class SpectralImageFitVisualizer(SpectralImageVisualizer):

def __init__(self, original_dataset, fit_dataset, figure=None, horizontal=True):
def __init__(self, original_dataset, fit_dataset, xvec = None, figure=None, horizontal=True):
'''
Visualizer for spectral image datasets, fit by the Sidpy Fitter
This class is called by Sidpy Fitter for visualizing the raw/fit dataset interactively.

Inputs:
- original_dataset: sidpy.Dataset containing the raw data
- fit_dataset: sidpy.Dataset with the fitted data. This is returned by the
Sidpy Fitter after functional fitting.
- fit_dataset: sidpy.Dataset with the fitted parameters, or the sidpy.Dataset returned by SidpyFitter.
- xvec: Independent dimension vector, default is None (will be acquired from original_dataset if not provided).
- figure: (Optional, default None) - handle to existing figure
- horiziontal: (Optional, default True) - whether spectrum should be plotted horizontally
'''

super().__init__(original_dataset, figure, horizontal)

self.original_dataset = original_dataset
if xvec is not None:
self.xvec = xvec
else:
self.xvec = None
if fit_dataset.shape != original_dataset.shape: #check if we have an actual fitted dataset or just the parameters
self.fit_parameters = fit_dataset
self.fit_dset = self._return_fit_dataset()
else:
self.fit_parameters = None
self.fit_dset = fit_dataset

self.fit_dset = fit_dataset
self.axes[1].clear()
self.get_fit_spectrum()
self.axes[1].plot(self.energy_scale, self.spectrum, 'bo')
self.axes[1].plot(self.energy_scale, self.fit_spectrum, 'r-')

def _return_fit_dataset(self):
#let's get back the fit function
fit_fn_packed = self.fit_parameters.metadata['fitting_functions']
key_f = list(self.fit_parameters.metadata['fitting_functions'].keys())[0]
encoded_value = fit_fn_packed[key_f]
serialized_value = base64.b64decode(encoded_value)
self._fit_function = dill.loads(serialized_value)

#Let's get the independent vector
if self.xvec is None:
ind_dims = []
for ind, (shape1, shape2) in enumerate(zip(self.fit_parameters.shape, self.original_dataset.shape)):
if shape1!=shape2:
ind_dims.append(ind)

#We need to get the vector.
if len(ind_dims)>1:
raise NotImplementedError("2 dimensional indepndent vectors are not implemented yet. TODO!")
else:
ind_vec = self.original_dataset._axes[ind_dims[0]].values
else:
ind_vec = self.xvec.copy()

#create a copy of the original dataset
self.fitted_dataset = self.original_dataset.copy()
self.fitted_dataset = self.fitted_dataset.fold(method = 'spaspec') #TODO: this might not always be the case.
self.fit_parameters_folded = self.fit_parameters[:].reshape((self.fitted_dataset.shape[0],-1))

for ind in range(self.fitted_dataset.shape[0]):
self.fitted_dataset[ind,:] = self._fit_function(ind_vec, *self.fit_parameters_folded[ind])
fitted_dataset = self.fitted_dataset.unfold()

return fitted_dataset

def get_fit_spectrum(self):

Expand Down Expand Up @@ -2128,5 +2173,4 @@ def _update(self, ev=None):
self.axes[1].set_ylabel(self.ylabel)

self.fig.canvas.draw_idle()



68 changes: 62 additions & 6 deletions tests/viz/test_dataset_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,48 @@
import numpy as np
sys.path.insert(0, "../../sidpy/")
import sidpy

from sidpy.proc.fitter import SidFitter


def get_fit_dataset(dset_shape=(5,5,32)):
#Define the function we want each spectrum to

def one_lin_func(xvec, *coeff):
a1,a2 = coeff
return a1*xvec + a2


#create a dataset
xvec = np.linspace(0,1, dset_shape[-1])
data_mat = np.zeros(shape=(dset_shape[0]*dset_shape[1], dset_shape[2]))
noise_level = 0.10

for xind in range(data_mat.shape[0]):
y_values = one_lin_func(xvec, *[np.random.uniform(0,1), np.random.normal()]) + \
noise_level*np.random.normal(size=len(xvec))
data_mat[xind] = y_values

data_mat = data_mat.reshape(dset_shape)

#make it a sidpy dataset
data_set = sidpy.Dataset.from_array(data_mat, name='test_dataset')
data_set.data_type = 'spectral_image'
data_set.units = 'nA'
data_set.quantity = 'Current'

data_set.set_dimension(0, sidpy.Dimension(np.arange(data_set.shape[0]),
name='x', units='um', quantity='Length',
dimension_type='spatial'))
data_set.set_dimension(1, sidpy.Dimension(np.arange(data_set.shape[0]),
'y', units='um', quantity='Length',
dimension_type='spatial'))
data_set.set_dimension(2, sidpy.Dimension(xvec,
name = 'bias',quantity = 'V', units = 'V', dimension_type='spectral'))
fitter = SidFitter(data_set, one_lin_func,num_workers=4,
threads=2, return_cov=False, return_fit=True, return_std=False,
km_guess=False,num_fit_parms = 2)
output = fitter.do_fit()
return data_set, output[0], output[1]

def get_spectrum(dtype=float):
x = np.array(np.random.normal(3, 2.5, size=1024), dtype=dtype)
Expand Down Expand Up @@ -434,11 +475,6 @@ def test_point_selection(self):
self.assertTrue(np.allclose(actual, expected, equal_nan=True, rtol=1e-05, atol=1e-08))







class Test4DImageStackPlot(unittest.TestCase):

def test_plot(self):
Expand Down Expand Up @@ -481,5 +517,25 @@ def test_plot_complex(self):
view = dataset.plot()
self.assertEqual(len(view.axes), 3)


class TestSpectralImageFitVisualizer(unittest.TestCase):

def test_plot_with_fit_parms(self):
original_dataset, fit_parameters, fitted_dataset = get_fit_dataset()
view = sidpy.viz.dataset_viz.SpectralImageFitVisualizer(original_dataset, fit_parameters)
self.assertEqual(len(view.axes), 2)

def test_plot_with_fitted_dataset(self):
original_dataset, fit_parameters, fitted_dataset = get_fit_dataset()
view = sidpy.viz.dataset_viz.SpectralImageFitVisualizer(original_dataset, fitted_dataset)
self.assertEqual(len(view.axes), 2)

def test_plot_with_custom_xvec(self):
original_dataset, fit_parameters, fitted_dataset = get_fit_dataset()
xvec = np.linspace(-1,2,32)
view = sidpy.viz.dataset_viz.SpectralImageFitVisualizer(original_dataset, fit_parameters, xvec = xvec)
self.assertEqual(len(view.axes), 2)


if __name__ == '__main__':
unittest.main()
Loading