Skip to content

Commit

Permalink
Fix bug in grid search
Browse files Browse the repository at this point in the history
The parameters that are fitted when evaluating a grid point were not
updated. In particular, after the grid search was done, the parameters
to the best grid point were not returned.

This is now fixed.
  • Loading branch information
gbouvignies committed Aug 29, 2023
1 parent 1f55753 commit 55d1f7d
Show file tree
Hide file tree
Showing 4 changed files with 465 additions and 655 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ repos:

- repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version.
rev: v0.0.277
rev: v0.0.286
hooks:
- id: ruff
args: [--fix, --exit-non-zero-on-fix]
Expand Down
25 changes: 16 additions & 9 deletions chemex/optimize/gridding.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,10 +62,13 @@ def run_group_grid(
shape = tuple(len(values) for values in group_grid.values())
grid_size = np.prod(shape)

basename = group.path if group.path != Path(".") else Path("grid")
basename = group.path if group.path != Path() else Path("grid")
filename = path / "Grid" / f"{basename}.out"
filename.parent.mkdir(parents=True, exist_ok=True)

best_chisqr = np.inf
best_params = group_params

with filename.open("w") as fileout:
fileout.write(print_header(group_grid))

Expand All @@ -75,13 +78,19 @@ def run_group_grid(

for values in track(grid_values, total=float(grid_size), description=" "):
_set_param_values(group_params, grid_ids, values)
group_params = minimize(group.experiments, group_params, fitmethod)
stats = calculate_statistics(group.experiments, group_params)
chisqr_list.append(stats.get("chisqr"))
fileout.write(print_values(values, chisqr_list[-1]))
optimized_params = minimize(group.experiments, group_params, fitmethod)
stats = calculate_statistics(group.experiments, optimized_params)
chisqr: float = stats.get("chisqr", np.inf)
chisqr_list.append(chisqr)
fileout.write(print_values(values, chisqr))
fileout.flush()

if chisqr < best_chisqr:
best_chisqr = chisqr
best_params = optimized_params

chisqr_array = np.array(chisqr_list).reshape(shape)
database.update_from_parameters(best_params)

return GridResult(group_grid, chisqr_array)

Expand Down Expand Up @@ -142,7 +151,7 @@ def combine_grids(
def set_params_from_grid(grids_1d: Iterable[GridResult]):
par_values = {}
for grid_result in grids_1d:
id_, values = list(grid_result.grid.items())[0]
id_, values = next(iter(grid_result.grid.items()))
par_values[id_] = values[grid_result.chisqr.argmin()]
database.set_param_values(par_values)

Expand Down Expand Up @@ -239,7 +248,7 @@ def run_grid(

groups = create_groups(experiments)

grid_results = []
grid_results: list[GridResult] = []
for group in groups:
if message := group.message:
print_group_name(message)
Expand All @@ -250,8 +259,6 @@ def run_grid(
grids_1d = make_grids_nd(grid, grids_combined, 1)
grids_2d = make_grids_nd(grid, grids_combined, 2)
set_params_from_grid(grids_1d)
params = database.build_lmfit_params(experiments.param_ids)
database.update_from_parameters(params)
plot_grid_1d(grids_1d, path / "Grid")
plot_grid_2d(grids_2d, path / "Grid")

Expand Down
2 changes: 1 addition & 1 deletion chemex/parameters/spin_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def name(self) -> str:

def match(self, other: Group) -> bool:
symbol = other.symbol == self.symbol or not self.symbol
number = other.number == self.number or self.number == self.NO_NUMBER
number = self.number in (other.number, self.NO_NUMBER)
suffix = other.suffix == self.suffix or not self.suffix
return number and symbol and suffix

Expand Down
Loading

0 comments on commit 55d1f7d

Please sign in to comment.