Skip to content

Commit

Permalink
Merge branch 'master' into add_new_LM_minimiser
Browse files Browse the repository at this point in the history
  • Loading branch information
RichardWaiteSTFC committed Nov 1, 2024
2 parents b00c3d3 + 73f604a commit dd978e2
Show file tree
Hide file tree
Showing 4 changed files with 80 additions and 30 deletions.
28 changes: 26 additions & 2 deletions +sw_tests/+unit_tests/unittest_ndbase_cost_function_wrapper.m
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@

properties (TestParameter)
bound_param_name = {'lb', 'ub'}
no_lower_bound = {[], [-inf, -inf]};
no_upper_bound = {[], [inf, inf]};
no_lower_bound = {[], [-inf, -inf], [NaN, -inf]};
no_upper_bound = {[], [inf, inf], [inf, NaN]};
errors = {ones(1,3), [], zeros(1,3), 'NoField'}
end

Expand Down Expand Up @@ -79,6 +79,30 @@ function test_init_with_fcost_both_bounds_with_fixed_param(testCase)
testCase.verify_val(cost_func_wrap.pars_fixed, 2.5);
end


function test_init_with_fcost_both_bounds_with_fixed_param_using_ifix(testCase)
% note second param outside bounds
cost_func_wrap = ndbase.cost_function_wrapper(testCase.fcost, testCase.params, 'lb', [1, 2], 'ub', [3, 2.5], 'ifix', [2]);
[pfree, pbound, cost_val] = testCase.get_pars_and_cost_val(cost_func_wrap);
testCase.verify_val(pfree, 0, 'abs_tol', 1e-4); % only first param free
testCase.verify_val(pbound, [2, 2.5], 'abs_tol', 1e-4);
testCase.verify_val(cost_val, testCase.fcost(pbound), 'abs_tol', 1e-4);
testCase.verify_val(cost_func_wrap.ifixed, 2);
testCase.verify_val(cost_func_wrap.ifree, 1);
testCase.verify_val(cost_func_wrap.pars_fixed, 2.5);
end

function test_init_with_fcost_no_bounds_with_fixed_param_using_ifix(testCase)
% note second param outside bounds
cost_func_wrap = ndbase.cost_function_wrapper(testCase.fcost, testCase.params, 'ifix', [2]);
[pfree, pbound, cost_val] = testCase.get_pars_and_cost_val(cost_func_wrap);
testCase.verify_val(pfree, testCase.params(1), 'abs_tol', 1e-4); % only first param free
testCase.verify_val(pbound, testCase.params, 'abs_tol', 1e-4);
testCase.verify_val(cost_func_wrap.ifixed, 2);
testCase.verify_val(cost_func_wrap.ifree, 1);
testCase.verify_val(cost_func_wrap.pars_fixed, testCase.params(2));
end

function test_init_with_data(testCase, errors)
% all errors passed lead to unweighted residuals (either as
% explicitly ones or the default weights if invalid errors)
Expand Down
4 changes: 2 additions & 2 deletions +sw_tests/+unit_tests/unittest_ndbase_optimisers.m
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
classdef unittest_ndbase_optimisers < sw_tests.unit_tests.unittest_super
% Runs through unit test for ndbase optimisers, atm only simplex passes
% these tests
% Runs through unit test for ndbase optimisers using bounded parameter
% transformations.

properties
rosenbrock = @(x) (1-x(1)).^2 + 100*(x(2) - x(1).^2).^2;
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/build_pyspinw.yml
Original file line number Diff line number Diff line change
Expand Up @@ -161,14 +161,14 @@ jobs:
run: |
pip install scipy
cd ${{ github.workspace }}/python
pip install build/*whl
pip install wheelhouse/*whl
cd tests
python -m unittest
- name: Create wheel artifact
uses: actions/upload-artifact@v4
with:
name: pySpinW Wheel
path: ${{ github.workspace }}/python/build/*.whl
path: ${{ github.workspace }}/python/wheelhouse/*.whl
- name: Upload release wheels
if: ${{ github.event_name == 'release' }}
run: |
Expand Down
74 changes: 50 additions & 24 deletions swfiles/+ndbase/cost_function_wrapper.m
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,10 @@
% Optionally the parameters can be bound in which case the class will
% perform a transformation to convert the constrained optimization problem
% into an un-constrained problem, using the formulation devised
% (and documented) for MINUIT (and also used in lmfit).
% (and documented) for MINUIT [1] and also used in lmfit [2].
%
% [1] https://root.cern/root/htmldoc/guides/minuit2/Minuit2.pdf#section.2.3
% [2] https://lmfit.github.io/lmfit-py/bounds.html
%
% ### Input Arguments
%
Expand Down Expand Up @@ -37,15 +40,18 @@
%
% `lb`
% : Optional vector of doubles corresponding to the lower bound of the
% parameters. Empty vector [] or vector of -inf interpreted as no lower
% bound.
% parameters. Empty vector [] or vector of non-finite elements
% (e.g. -inf and NaN) are interpreted as no lower bound.
%
% `ub`
% : Optional vector of doubles corresponding to the upper bound of the
% parameters. Empty vector [] or vector of inf interpreted as no upper
% bound.
% parameters. Empty vector [] or vector of non-finite elements
% (e.g. inf and NaN) are interpreted as no upper bound.
%
% ### Examples
% `ifix`
% : Optional vector of ints corresponding of indices of parameters to fix
% (overides bounds if provided)

properties (SetObservable)
% data
cost_func
Expand All @@ -57,6 +63,10 @@
pars_fixed
end

properties (Constant)
fix_tol = 1e-10
end

methods
function obj = cost_function_wrapper(fhandle, params, options)
arguments
Expand All @@ -65,6 +75,7 @@
options.lb double = []
options.ub double = []
options.data struct = struct()
options.ifix = []
end
if ischar(fhandle)
fhandle = str2func(fhandle); % convert to fuction handle
Expand Down Expand Up @@ -101,41 +112,48 @@
"Upper bounds have to be larger than the lower bounds.");
end
% init bound parameters
obj.init_bound_parameter_transforms(params, lb, ub)
obj.init_bound_parameter_transforms(params, lb, ub, options.ifix);
end

function init_bound_parameter_transforms(obj, pars, lb, ub)
function init_bound_parameter_transforms(obj, pars, lb, ub, ifix)
% Note free parameters to be used externally in the
% optimisation (note in lmfit [2] pfree is called p_internal).
% Bound parameters are the original parameters
% pass into the constructor (that may or may not be bound or
% fixed).
obj.free_to_bound_funcs = cell(size(pars));
obj.bound_to_free_funcs = cell(size(pars));
obj.ifixed = [];
ipars = 1:numel(pars); % used later
for ipar = ipars
has_lb = ~isempty(lb) && lb(ipar) > -inf;
has_ub = ~isempty(ub) && ub(ipar) < inf;
has_lb = ~isempty(lb) && isfinite(lb(ipar));
has_ub = ~isempty(ub) && isfinite(ub(ipar));
is_fixed = any(uint8(ifix) == ipar);
if has_lb && has_ub
% both bounds specified and parameter not fixed
obj.free_to_bound_funcs{ipar} = @(p) obj.free_to_bound_has_lb_and_ub(p, lb(ipar), ub(ipar));
obj.bound_to_free_funcs{ipar} = @(p) obj.bound_to_free_has_lb_and_ub(p, lb(ipar), ub(ipar));
% check if fixed
if abs(ub(ipar) - lb(ipar)) < max(abs(ub(ipar)), 1)*1e-10
obj.ifixed = [obj.ifixed, ipar];
if pars(ipar) < lb(ipar)
pars(ipar) = lb(ipar);
elseif pars(ipar) > ub(ipar)
pars(ipar) = ub(ipar);
end
obj.pars_fixed = [obj.pars_fixed, pars(ipar)];
end
bounds_equal = abs(ub(ipar) - lb(ipar)) < max(abs(ub(ipar)), 1)*obj.fix_tol;
is_fixed = is_fixed || bounds_equal;
elseif has_lb
obj.free_to_bound_funcs{ipar} = @(p) obj.free_to_bound_has_lb(p, lb(ipar));
obj.bound_to_free_funcs{ipar} = @(p) obj.bound_to_free_has_lb(p, lb(ipar));
elseif has_ub
obj.free_to_bound_funcs{ipar} = @(p) obj.free_to_bound_has_ub(p, ub(ipar));
obj.bound_to_free_funcs{ipar} = @(p) obj.bound_to_free_has_ub(p, ub(ipar));
else
obj.free_to_bound_funcs{ipar} = @(p) p;
obj.bound_to_free_funcs{ipar} = @(p) p;
end
% check fixed parameters
if is_fixed
obj.ifixed = [obj.ifixed, ipar];
if has_lb && pars(ipar) < lb(ipar)
pars(ipar) = lb(ipar);
elseif has_ub && pars(ipar) > ub(ipar)
pars(ipar) = ub(ipar);
end
obj.pars_fixed = [obj.pars_fixed, pars(ipar)];
end

end
% get index of free parameters
obj.ifree = ipars(~ismember(1:numel(pars), obj.ifixed));
Expand All @@ -146,7 +164,11 @@ function init_bound_parameter_transforms(obj, pars, lb, ub)
pars_bound = zeros(size(obj.free_to_bound_funcs));
for ipar_free = 1:numel(pars)
ipar_bound = obj.ifree(ipar_free);
pars_bound(ipar_bound) = obj.free_to_bound_funcs{ipar_bound}(pars(ipar_free));
if isempty(obj.free_to_bound_funcs{ipar_bound})
pars_bound(ipar_bound) = pars(ipar_free); % no bounds
else
pars_bound(ipar_bound) = obj.free_to_bound_funcs{ipar_bound}(pars(ipar_free));
end
end
% add in fixed parameter values
pars_bound(obj.ifixed) = obj.pars_fixed;
Expand All @@ -155,7 +177,11 @@ function init_bound_parameter_transforms(obj, pars, lb, ub)
function pars = get_free_parameters(obj, pars_bound)
pars = zeros(size(pars_bound)); % to preserve par vector shape
for ipar = obj.ifree
pars(ipar) = obj.bound_to_free_funcs{ipar}(pars_bound(ipar));
if isempty(obj.bound_to_free_funcs{ipar})
pars(ipar) = pars_bound(ipar);
else
pars(ipar) = obj.bound_to_free_funcs{ipar}(pars_bound(ipar));
end
end
pars = pars(obj.ifree);
end
Expand Down

0 comments on commit dd978e2

Please sign in to comment.