diff --git a/sidpy/viz/dataset_viz.py b/sidpy/viz/dataset_viz.py index 0a49e856..2ed604cd 100644 --- a/sidpy/viz/dataset_viz.py +++ b/sidpy/viz/dataset_viz.py @@ -21,7 +21,8 @@ import ipywidgets from IPython.display import display import scipy - +import dill +import base64 # import matplotlib.animation as animation @@ -2058,29 +2059,73 @@ def set_legend(self, set_legend): def get_xy(self): return [self.x, self.y] - class SpectralImageFitVisualizer(SpectralImageVisualizer): - def __init__(self, original_dataset, fit_dataset, figure=None, horizontal=True): + def __init__(self, original_dataset, fit_dataset, xvec = None, figure=None, horizontal=True): ''' Visualizer for spectral image datasets, fit by the Sidpy Fitter This class is called by Sidpy Fitter for visualizing the raw/fit dataset interactively. Inputs: - original_dataset: sidpy.Dataset containing the raw data - - fit_dataset: sidpy.Dataset with the fitted data. This is returned by the - Sidpy Fitter after functional fitting. + - fit_dataset: sidpy.Dataset with the fitted parameters, or the sidpy.Dataset returned by SidpyFitter. + - xvec: Independent dimension vector, default is None (will be acquired from original_dataset if not provided). - figure: (Optional, default None) - handle to existing figure - horiziontal: (Optional, default True) - whether spectrum should be plotted horizontally ''' super().__init__(original_dataset, figure, horizontal) + + self.original_dataset = original_dataset + if xvec is not None: + self.xvec = xvec + else: + self.xvec = None + if fit_dataset.shape != original_dataset.shape: #check if we have an actual fitted dataset or just the parameters + self.fit_parameters = fit_dataset + self.fit_dset = self._return_fit_dataset() + else: + self.fit_parameters = None + self.fit_dset = fit_dataset - self.fit_dset = fit_dataset self.axes[1].clear() self.get_fit_spectrum() self.axes[1].plot(self.energy_scale, self.spectrum, 'bo') self.axes[1].plot(self.energy_scale, self.fit_spectrum, 'r-') + + def _return_fit_dataset(self): + #let's get back the fit function + fit_fn_packed = self.fit_parameters.metadata['fitting_functions'] + key_f = list(self.fit_parameters.metadata['fitting_functions'].keys())[0] + encoded_value = fit_fn_packed[key_f] + serialized_value = base64.b64decode(encoded_value) + self._fit_function = dill.loads(serialized_value) + + #Let's get the independent vector + if self.xvec is None: + ind_dims = [] + for ind, (shape1, shape2) in enumerate(zip(self.fit_parameters.shape, self.original_dataset.shape)): + if shape1!=shape2: + ind_dims.append(ind) + + #We need to get the vector. + if len(ind_dims)>1: + raise NotImplementedError("2 dimensional indepndent vectors are not implemented yet. TODO!") + else: + ind_vec = self.original_dataset._axes[ind_dims[0]].values + else: + ind_vec = self.xvec.copy() + + #create a copy of the original dataset + self.fitted_dataset = self.original_dataset.copy() + self.fitted_dataset = self.fitted_dataset.fold(method = 'spaspec') #TODO: this might not always be the case. + self.fit_parameters_folded = self.fit_parameters[:].reshape((self.fitted_dataset.shape[0],-1)) + + for ind in range(self.fitted_dataset.shape[0]): + self.fitted_dataset[ind,:] = self._fit_function(ind_vec, *self.fit_parameters_folded[ind]) + fitted_dataset = self.fitted_dataset.unfold() + + return fitted_dataset def get_fit_spectrum(self): @@ -2128,5 +2173,4 @@ def _update(self, ev=None): self.axes[1].set_ylabel(self.ylabel) self.fig.canvas.draw_idle() - - + \ No newline at end of file diff --git a/tests/viz/test_dataset_plot.py b/tests/viz/test_dataset_plot.py index 5644a110..681d3137 100644 --- a/tests/viz/test_dataset_plot.py +++ b/tests/viz/test_dataset_plot.py @@ -16,7 +16,48 @@ import numpy as np sys.path.insert(0, "../../sidpy/") import sidpy - +from sidpy.proc.fitter import SidFitter + + +def get_fit_dataset(dset_shape=(5,5,32)): + #Define the function we want each spectrum to + + def one_lin_func(xvec, *coeff): + a1,a2 = coeff + return a1*xvec + a2 + + + #create a dataset + xvec = np.linspace(0,1, dset_shape[-1]) + data_mat = np.zeros(shape=(dset_shape[0]*dset_shape[1], dset_shape[2])) + noise_level = 0.10 + + for xind in range(data_mat.shape[0]): + y_values = one_lin_func(xvec, *[np.random.uniform(0,1), np.random.normal()]) + \ + noise_level*np.random.normal(size=len(xvec)) + data_mat[xind] = y_values + + data_mat = data_mat.reshape(dset_shape) + + #make it a sidpy dataset + data_set = sidpy.Dataset.from_array(data_mat, name='test_dataset') + data_set.data_type = 'spectral_image' + data_set.units = 'nA' + data_set.quantity = 'Current' + + data_set.set_dimension(0, sidpy.Dimension(np.arange(data_set.shape[0]), + name='x', units='um', quantity='Length', + dimension_type='spatial')) + data_set.set_dimension(1, sidpy.Dimension(np.arange(data_set.shape[0]), + 'y', units='um', quantity='Length', + dimension_type='spatial')) + data_set.set_dimension(2, sidpy.Dimension(xvec, + name = 'bias',quantity = 'V', units = 'V', dimension_type='spectral')) + fitter = SidFitter(data_set, one_lin_func,num_workers=4, + threads=2, return_cov=False, return_fit=True, return_std=False, + km_guess=False,num_fit_parms = 2) + output = fitter.do_fit() + return data_set, output[0], output[1] def get_spectrum(dtype=float): x = np.array(np.random.normal(3, 2.5, size=1024), dtype=dtype) @@ -434,11 +475,6 @@ def test_point_selection(self): self.assertTrue(np.allclose(actual, expected, equal_nan=True, rtol=1e-05, atol=1e-08)) - - - - - class Test4DImageStackPlot(unittest.TestCase): def test_plot(self): @@ -481,5 +517,25 @@ def test_plot_complex(self): view = dataset.plot() self.assertEqual(len(view.axes), 3) + +class TestSpectralImageFitVisualizer(unittest.TestCase): + + def test_plot_with_fit_parms(self): + original_dataset, fit_parameters, fitted_dataset = get_fit_dataset() + view = sidpy.viz.dataset_viz.SpectralImageFitVisualizer(original_dataset, fit_parameters) + self.assertEqual(len(view.axes), 2) + + def test_plot_with_fitted_dataset(self): + original_dataset, fit_parameters, fitted_dataset = get_fit_dataset() + view = sidpy.viz.dataset_viz.SpectralImageFitVisualizer(original_dataset, fitted_dataset) + self.assertEqual(len(view.axes), 2) + + def test_plot_with_custom_xvec(self): + original_dataset, fit_parameters, fitted_dataset = get_fit_dataset() + xvec = np.linspace(-1,2,32) + view = sidpy.viz.dataset_viz.SpectralImageFitVisualizer(original_dataset, fit_parameters, xvec = xvec) + self.assertEqual(len(view.axes), 2) + + if __name__ == '__main__': unittest.main()